這篇文章主要講解了“怎么使用pytorch框架”,文中的講解內(nèi)容簡(jiǎn)單清晰,易于學(xué)習(xí)與理解,下面請(qǐng)大家跟著小編的思路慢慢深入,一起來(lái)研究和學(xué)習(xí)“怎么使用pytorch框架”吧!
我們一直強(qiáng)調(diào)成都網(wǎng)站設(shè)計(jì)、成都網(wǎng)站建設(shè)對(duì)于企業(yè)的重要性,如果您也覺得重要,那么就需要我們慎重對(duì)待,選擇一個(gè)安全靠譜的網(wǎng)站建設(shè)公司,企業(yè)網(wǎng)站我們建議是要么不做,要么就做好,讓網(wǎng)站能真正成為企業(yè)發(fā)展過(guò)程中的有力推手。專業(yè)網(wǎng)站制作公司不一定是大公司,創(chuàng)新互聯(lián)作為專業(yè)的網(wǎng)絡(luò)公司選擇我們就是放心。中文新聞情感分類 Bert-Pytorch-transformers
使用pytorch框架以及transformers包,以及Bert的中文預(yù)訓(xùn)練模型
文件目錄
data
Train_DataSet.csv
Train_DataSet_Label.csv
main.py
NewsData.py
#main.py
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import BertConfig
from transformers import BertPreTrainedModel
import torch
import torch.nn as nn
from transformers import BertModel
import time
import argparse
from NewsData import NewsData
import os
def get_train_args():
parser=argparse.ArgumentParser()
parser.add_argument('--batch_size',type=int,default=10,help = '每批數(shù)據(jù)的數(shù)量')
parser.add_argument('--nepoch',type=int,default=3,help = '訓(xùn)練的輪次')
parser.add_argument('--lr',type=float,default=0.001,help = '學(xué)習(xí)率')
parser.add_argument('--gpu',type=bool,default=True,help = '是否使用gpu')
parser.add_argument('--num_workers',type=int,default=2,help='dataloader使用的線程數(shù)量')
parser.add_argument('--num_labels',type=int,default=3,help='分類類數(shù)')
parser.add_argument('--data_path',type=str,default='./data',help='數(shù)據(jù)路徑')
opt=parser.parse_args()
print(opt)
return opt
def get_model(opt):
#類方法.from_pretrained()獲取預(yù)訓(xùn)練模型,num_labels是分類的類數(shù)
model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=opt.num_labels)
return model
def get_data(opt):
#NewsData繼承于pytorch的Dataset類
trainset = NewsData(opt.data_path,is_train = 1)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers)
testset = NewsData(opt.data_path,is_train = 0)
testloader=torch.utils.data.DataLoader(testset,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)
return trainloader,testloader
def train(epoch,model,trainloader,testloader,optimizer,opt):
print('\ntrain-Epoch: %d' % (epoch+1))
model.train()
start_time = time.time()
print_step = int(len(trainloader)/10)
for batch_idx,(sue,label,posi) in enumerate(trainloader):
if opt.gpu:
sue = sue.cuda()
posi = posi.cuda()
label = label.unsqueeze(1).cuda()
optimizer.zero_grad()
#輸入?yún)?shù)為詞列表、位置列表、標(biāo)簽
outputs = model(sue, position_ids=posi,labels = label)
loss, logits = outputs[0],outputs[1]
loss.backward()
optimizer.step()
if batch_idx % print_step == 0:
print("Epoch:%d [%d|%d] loss:%f" %(epoch+1,batch_idx,len(trainloader),loss.mean()))
print("time:%.3f" % (time.time() - start_time))
def test(epoch,model,trainloader,testloader,opt):
print('\ntest-Epoch: %d' % (epoch+1))
model.eval()
total=0
correct=0
with torch.no_grad():
for batch_idx,(sue,label,posi) in enumerate(testloader):
if opt.gpu:
sue = sue.cuda()
posi = posi.cuda()
labels = label.unsqueeze(1).cuda()
label = label.cuda()
else:
labels = label.unsqueeze(1)
outputs = model(sue, labels=labels)
loss, logits = outputs[:2]
_,predicted=torch.max(logits.data,1)
total+=sue.size(0)
correct+=predicted.data.eq(label.data).cpu().sum()
s = ("Acc:%.3f" %((1.0*correct.numpy())/total))
print(s)
if __name__=='__main__':
opt = get_train_args()
model = get_model(opt)
trainloader,testloader = get_data(opt)
if opt.gpu:
model.cuda()
optimizer=torch.optim.SGD(model.parameters(),lr=opt.lr,momentum=0.9)
if not os.path.exists('./model.pth'):
for epoch in range(opt.nepoch):
train(epoch,model,trainloader,testloader,optimizer,opt)
test(epoch,model,trainloader,testloader,opt)
torch.save(model.state_dict(),'./model.pth')
else:鄭州治療婦科哪個(gè)醫(yī)院好 http://www.120kdfk.com/
model.load_state_dict(torch.load('model.pth'))
print('模型存在,直接test')
test(0,model,trainloader,testloader,opt)
#NewsData.py
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import BertConfig
from transformers import BertPreTrainedModel
import torch
import torch.nn as nn
from transformers import BertModel
import time
import argparse
class NewsData(torch.utils.data.Dataset):
def __init__(self,root,is_train = 1):
self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
self.data_num = 7346
self.x_list = []
self.y_list = []
self.posi = []
with open(root + '/Train_DataSet.csv',encoding='UTF-8') as f:
for i in range(self.data_num+1):
line = f.readline()[:-1] + '這是一個(gè)中性的數(shù)據(jù)'
data_one_str = line.split(',')[len(line.split(','))-2]
data_two_str = line.split(',')[len(line.split(','))-1]
if len(data_one_str) < 6:
z = len(data_one_str)
data_one_str = data_one_str + ',' + data_two_str[0:min(200,len(data_two_str))]
else:
data_one_str = data_one_str
if i==0:
continue
word_l = self.tokenizer.encode(data_one_str, add_special_tokens=False)
if len(word_l)<100:
while(len(word_l)!=100):
word_l.append(0)
else:
word_l = word_l[0:100]
word_l.append(102)
l = word_l
word_l = [101]
word_l.extend(l)
self.x_list.append(torch.tensor(word_l))
self.posi.append(torch.tensor([i for i in range(102)]))
with open(root + '/Train_DataSet_Label.csv',encoding='UTF-8') as f:
for i in range(self.data_num+1):
#print(i)
label_one = f.readline()[-2]
if i==0:
continue
label_one = int(label_one)
self.y_list.append(torch.tensor(label_one))
#訓(xùn)練集或者是測(cè)試集
if is_train == 1:
self.x_list = self.x_list[0:6000]
self.y_list = self.y_list[0:6000]
self.posi = self.posi[0:6000]
else:
self.x_list = self.x_list[6000:]
self.y_list = self.y_list[6000:]
self.posi = self.posi[6000:]
self.len = len(self.x_list)
def __getitem__(self, index):
return self.x_list[index], self.y_list[index],self.posi[index]
def __len__(self):
return self.len
感謝各位的閱讀,以上就是“怎么使用pytorch框架”的內(nèi)容了,經(jīng)過(guò)本文的學(xué)習(xí)后,相信大家對(duì)怎么使用pytorch框架這一問(wèn)題有了更深刻的體會(huì),具體使用情況還需要大家實(shí)踐驗(yàn)證。這里是創(chuàng)新互聯(lián),小編將為大家推送更多相關(guān)知識(shí)點(diǎn)的文章,歡迎關(guān)注!