NLP transformers - 文本分类

在这里插入图片描述

Text classification

文章目录

  • Text classification
    • 加载 IMDb 数据集
    • Preprocess 预处理
    • Evaluate
    • Train
    • Inference


本文翻译自:Text classification
https://huggingface.co/docs/transformers/tasks/sequence_classification
notebook : https://colab.research.google.com/github/huggingface/notebooks/blob/main/transformers_doc/en/pytorch/sequence_classification.ipynb


文本分类是一种常见的 NLP 任务,它为文本分配标签或类别。一些大公司在生产中运行文本分类,以实现广泛的实际应用。最流行的文本分类形式之一是 情感分析,它为文本序列分配 🙂 积极、🙁 消极或 😐 中性等标签。

本指南将向您展示:

  1. 在IMDb数据集上微调DistilBERT,以确定电影评论是正面还是负面。
  2. 使用您的微调模型进行推理。

本教程中演示的任务由以下模型架构支持:

ALBERT, BART, BERT, BigBird, BigBird-Pegasus, BioGpt, BLOOM, CamemBERT, CANINE, CodeLlama, ConvBERT, CTRL, Data2VecText, DeBERTa, DeBERTa-v2, DistilBERT, ELECTRA, ERNIE, ErnieM, ESM, Falcon, FlauBERT, FNet, Funnel Transformer, Gemma, GPT-Sw3, OpenAI GPT-2, GPTBigCode, GPT Neo, GPT NeoX, GPT-J, I-BERT, Jamba, LayoutLM, LayoutLMv2, LayoutLMv3, LED, LiLT, LLaMA, Longformer, LUKE, MarkupLM, mBART, MEGA, Megatron-BERT, Mistral, Mixtral, MobileBERT, MPNet, MPT, MRA, MT5, MVP, Nezha, Nyströmformer, OpenLlama, OpenAI GPT, OPT, Perceiver, Persimmon, Phi, PLBart, QDQBert, Qwen2, Qwen2MoE, Reformer, RemBERT, RoBERTa, RoBERTa-PreLayerNorm, RoCBert, RoFormer, SqueezeBERT, StableLm, Starcoder2, T5, TAPAS, Transformer-XL, UMT5, XLM, XLM-RoBERTa, XLM-RoBERTa-XL, XLNet, X-MOD, YOSO


在开始之前,请确保已安装所有必需的库:

pip install transformers datasets evaluate accelerate

我们鼓励您登录 Hugging Face 帐户,以便您可以上传模型并与社区分享。出现提示时,输入您的令牌进行登录:

from huggingface_hub import notebook_login

notebook_login()

加载 IMDb 数据集

首先从 🤗 数据集库加载 IMDb 数据集:

from  datasets import load_dataset

imdb = load_dataset("imdb")

然后看一个数据样例:

IMDB[ “测试” ][ 0 ]
{
    "label" : 0 ,
     "text" : "我喜欢科幻小说,并且愿意忍受很多。... 一切又来了。” ,
}

该数据集中有两个字段:

  • text: 影评文字。
  • label: 0:表示负面评论或1正面评论的值。

Preprocess 预处理

下一步是加载 DistilBERT 分词器来预处理该text字段:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from _pretrained( "distilbert/distilbert-base-uncased" )

创建一个预处理函数来对text序列进行标记和截断,使其长度不超过 DistilBERT 的最大输入长度:

def  preprocess_function ( Examples ):
    return tokenizer(examples[ "text" ], truncation= True )

要将预处理函数应用于整个数据集,请使用 🤗 数据集 map 函数。
您可以map通过设置 batched=True 一次处理数据集的多个元素来加快速度:

tokenized_imdb = imdb.map(preprocess_function, batched=True)

现在使用 DataCollatorWithPadding 创建一批示例。在整理过程中 动态地将句子填充 到批次中的最长长度,比将整个数据集填充到最大长度更有效。

from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

Evaluate

在训练期间包含指标通常有助于评估模型的性能。您可以使用 🤗 Evaluate库快速加载评估方法。对于此任务,加载准确性指标(请参阅 🤗 评估快速浏览以了解有关如何加载和计算指标的更多信息):

import evaluate

accuracy = evaluate.load("accuracy")

然后创建一个传递预测和标签的函数来compute计算准确性:

import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels) 

您的compute_metrics函数现在已准备就绪,您将在设置训练时返回该函数。


Train

在开始训练模型之前,请使用id2labellabel2id ,创建预期 id 到其标签的映射:

id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

如果您不熟悉使用 Trainer 微调模型,
请查看基本教程:<(https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer>

您现在就可以开始训练您的模型了!使用 AutoModelForSequenceClassification 加载 DistilBERT以及预期标签的数量和标签映射:

from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert/distilbert-base-uncased", num_labels=2, id2label=id2label, label2id=label2id
)

此时,只剩下三步:

  1. 在TrainingArguments中定义训练超参数。
    唯一必需的参数是output_dir指定保存模型的位置。您可以通过设置将此模型推送到 Hub push_to_hub=True(您需要登录 Hugging Face 才能上传模型)。
    在每个 epoch 结束时,Trainer 将评估准确性并保存训练检查点。
  2. 将训练参数以及模型、数据集、分词器、数据整理器和compute_metrics函数传递给Trainer 。
  3. 调用 train() 来微调您的模型。
training_args = TrainingArguments(
    output_dir="my_awesome_model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_imdb["train"],
    eval_dataset=tokenized_imdb["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

当您传递 token 给Trainer时, 它默认应用动态填充tokenizer。在这种情况下,您不需要显式指定数据整理器。

训练完成后,使用 push_to_hub()方法将您的模型共享到 Hub,以便每个人都可以使用您的模型:

trainer.push_to_hub()

有关如何微调文本分类模型的更深入示例,请查看相应的 PyTorch 笔记本 或 TensorFlow 笔记本。


Inference

太好了,现在您已经微调了模型,您可以使用它进行推理!

获取一些您想要进行推理的文本:

text = “这是一部杰作。并不完全忠实于原著,但从头到尾都令人着迷。可能是三本书中我最喜欢的。”

尝试微调模型进行推理的最简单方法是在 pipeline() 中使用它。使用您的模型实例化pipeline情感分析,并将文本传递给它:

from transformers import pipeline

classifier = pipeline("sentiment-analysis", model="stevhliu/my_awesome_model")
classifier(text)

如果您愿意,您还可以手动复制 pipeline 的结果:


对文本进行分词并返回 PyTorch 张量:

from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("stevhliu/my_awesome_model")
inputs = tokenizer(text, return_tensors="pt")

将您的输入传递给模型并返回logits

from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("stevhliu/my_awesome_model")

with torch.no_grad():
    logits = model(**inputs).logits

获取概率最高的类,并使用模型的id2label映射将其转换为文本标签:

predicted_class_id = logits.argmax().item()
model.config.id2label[predicted_class_id]
# -> 'POSITIVE'

2024-04-28(日)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/581639.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FPGA高端项目:FPGA帧差算法多目标图像识别+目标跟踪,提供11套工程源码和技术支持

目录 1、前言免责声明 2、相关方案推荐FPGA帧差算法单个目标图像识别目标跟踪 3、详细设计方案设计原理框图运动目标检测原理OV5640摄像头配置与采集OV7725摄像头配置与采集RGB视频流转AXI4-StreamVDMA图像缓存多目标帧差算法图像识别目标跟踪模块视频输出Xilinx系列FPGA工程源…

STM32之HAL开发——ADC入门介绍

ADC简介 模数转换&#xff0c;即Analog-to-Digital Converter&#xff0c;常称ADC&#xff0c;是指将连续变量的模拟信号转换为离散的数字信号的器件&#xff0c;比如将模温度感器产生的电信号转为控制芯片能处理的数字信号0101&#xff0c;这样ADC就建立了模拟世界的传感器和…

机器学习每周挑战——百思买数据

最近由于比赛&#xff0c;断更了好久&#xff0c;从五一开始不会再断更了。这个每周挑战我分析的较为简单&#xff0c;有兴趣的可以将数据集下载下来试着分析一下&#xff0c;又不会的我们可以讨论一下。 这是数据集&#xff1a; import pandas as pd import numpy as np impo…

leetcode_38.外观数列

38. 外观数列 题目描述&#xff1a;给定一个正整数 n &#xff0c;输出外观数列的第 n 项。 「外观数列」是一个整数序列&#xff0c;从数字 1 开始&#xff0c;序列中的每一项都是对前一项的描述。 你可以将其视作是由递归公式定义的数字字符串序列&#xff1a; countAndSay(1…

bugku-ok

打开文件发现有很多ok的字符 转在线地址解码

基于3D机器视觉的注塑缺陷检测解决方案

注塑检测是对注塑生产过程中的产品缺陷进行识别和检测的过程。这些缺陷可能包括色差、料流痕、黑点&#xff08;包括杂质&#xff09;等&#xff0c;它们可能是由多种因素引起&#xff0c;如原料未搅拌均匀、烘料时间过长、工业温度局部偏高、模具等问题造成的。不仅影响产品的…

Stable Diffusion教程:文生图

最近几天AI绘画没有什么大动作&#xff0c;正好有时间总结下Stable Diffusion的一些基础知识&#xff0c;今天就给大家再唠叨一下文生图这个功能&#xff0c;会详细说明其中的各个参数。 文生图是Stable Diffusion的核心功能&#xff0c;它的核心能力就是根据提示词生成相应的…

【喜报】科大睿智为武汉博睿英特科技高质量通过CMMI3级评估咨询工作

武汉博睿英特科技有限公司是信息通信技术产品、建筑智慧工程服务提供商。其拥有专注于航空、政府、教育、金融等多行业领域的资深团队&#xff0c;及时掌握最新信息通信应用技术&#xff0c;深刻理解行业业务流程&#xff0c;擅于整合市场优质资源&#xff0c;积极保持与高校产…

redis ZRANGE 使用最详细文档

环境&#xff1a; redis_version:7.2.2 本文参考 redis 官方文档1 语法 ZRANGE key start stop [BYSCORE | BYLEX] [REV] [LIMIT offset count] [WITHSCORES]参数含义key是有序集合的键名start stop在不同语境下&#xff0c;可用值不一样BYSCORE | BYLEX按照分数查询 | 相…

【汇编】#6 80x86指令系统其二(串处理与控制转移与子函数)

文章目录 一、串处理指令1. 与 REP 协作的 MOVS / STOS / LODS的指令1.1 重复前缀指令REP1.2 字符串传送指令&#xff08;Move String Instruction&#xff09;1.2 存串指令&#xff08;Store String Instruction&#xff09;1.3 取字符串指令&#xff08;Load String Instruct…

[华为OD]给定一个 N*M 矩阵,请先找出 M 个该矩阵中每列元素的最大值 100

题目&#xff1a; 给定一个 N*M 矩阵&#xff0c;请先找出 M 个该矩阵中每列元素的最大值&#xff0c;然后输出这 M 个值中的 最小值 补充说明&#xff1a; N 和 M 的取值范围均为&#xff1a;[0, 100] 示例 1 输入&#xff1a; [[1,2],[3,4]] 输出&#xff1a; 3 说…

【UE5】数字人基础

这里主要记录一下自己在实现数字人得过程中涉及导XSens惯性动捕&#xff0c;视频动捕&#xff0c;LiveLinkFace表捕&#xff0c;GRoom物理头发等。 一、导入骨骼网格体 骨骼网格体即模型要在模型雕刻阶段就要雕刻好表捕所需的表情体(blendshape)&#xff0c;后面表捕的效果直…

机器学习:基于Sklearn框架,使用逻辑回归对由心脏病引发的死亡进行预测分析

前言 系列专栏&#xff1a;机器学习&#xff1a;高级应用与实践【项目实战100】【2024】✨︎ 在本专栏中不仅包含一些适合初学者的最新机器学习项目&#xff0c;每个项目都处理一组不同的问题&#xff0c;包括监督和无监督学习、分类、回归和聚类&#xff0c;而且涉及创建深度学…

数据分析-----方法论

什么是数据分析方法 数据分析方法&#xff1a;将零散的想法和经验整理成有条理的、系统的思路&#xff0c;从而快速地解决问题。 案例&#xff1a; 用户活跃度下降 想法&#xff1a; APP出现问题&#xff1f;去年也下降了吗&#xff1f;是所有的人群都在下降吗&#xff1f…

vscode中新建vue项目

vscode中新建vue项目 进入项目文件夹&#xff0c;打开终端 输入命令vue create 项目名 如vue create test 选择y 选择vue3 进入项目&#xff0c;运行vue项目 输入命令cd test和npm run serve

Spark RDD

Spark RDD操作 Spark执行流程 在上一讲中&#xff0c;我们知道了什么是Spark&#xff0c;什么是RDD、Spark的核心构成组件&#xff0c;以及Spark案例程序。在这一讲中&#xff0c;我们将继续需要Spark作业的执行过程&#xff0c;以及编程模型RDD的各种花式操作&#xff0c;首…

蓝桥杯ctf2024 部分wp

数据分析 1. packet 密码破解 1. cc 逆向分析 1. 欢乐时光 XXTEA #include<stdio.h> #include<stdint.h> #define DELTA 0x9e3779b9 #define MX (((z>>5^y<<2)(y>>3^z<<4))^((sum^y)(key[(p&3)^e]^z))) void btea(unsigned int* v…

【Python 对接QQ的接口】简单用接口查询【等级/昵称/头像/Q龄/当天在线时长/下一个等级升级需多少天】

文章日期&#xff1a;2024.04.28 使用工具&#xff1a;Python 类型&#xff1a;QQ接口 文章全程已做去敏处理&#xff01;&#xff01;&#xff01; 【需要做的可联系我】 AES解密处理&#xff08;直接解密即可&#xff09;&#xff08;crypto-js.js 标准算法&#xff09;&…

纯血鸿蒙APP实战开发——监听HiLog日志实现测试用例验证

介绍 日常中在进行测试用例验证时&#xff0c;会出现部分场景无法通过判断UI的变化来确认用例是否正常运行&#xff0c;我们可以通过监听日志的方式来巧妙的实现这种场景。本示例通过监听hilog日志的回调&#xff0c;判断指定日志是否打印&#xff0c;来确定测试用例的执行结果…

Linux 第十三章

&#x1f436;博主主页&#xff1a;ᰔᩚ. 一怀明月ꦿ ❤️‍&#x1f525;专栏系列&#xff1a;线性代数&#xff0c;C初学者入门训练&#xff0c;题解C&#xff0c;C的使用文章&#xff0c;「初学」C&#xff0c;linux &#x1f525;座右铭&#xff1a;“不要等到什么都没有了…