真实的国产乱ⅩXXX66竹夫人,五月香六月婷婷激情综合,亚洲日本VA一区二区三区,亚洲精品一区二区三区麻豆

成都創(chuàng)新互聯(lián)網(wǎng)站制作重慶分公司

怎么使用pytorch框架-創(chuàng)新互聯(lián)

這篇文章主要講解了“怎么使用pytorch框架”,文中的講解內(nèi)容簡(jiǎn)單清晰,易于學(xué)習(xí)與理解,下面請(qǐng)大家跟著小編的思路慢慢深入,一起來(lái)研究和學(xué)習(xí)“怎么使用pytorch框架”吧!

我們一直強(qiáng)調(diào)成都網(wǎng)站設(shè)計(jì)、成都網(wǎng)站建設(shè)對(duì)于企業(yè)的重要性,如果您也覺得重要,那么就需要我們慎重對(duì)待,選擇一個(gè)安全靠譜的網(wǎng)站建設(shè)公司,企業(yè)網(wǎng)站我們建議是要么不做,要么就做好,讓網(wǎng)站能真正成為企業(yè)發(fā)展過(guò)程中的有力推手。專業(yè)網(wǎng)站制作公司不一定是大公司,創(chuàng)新互聯(lián)作為專業(yè)的網(wǎng)絡(luò)公司選擇我們就是放心。

  中文新聞情感分類 Bert-Pytorch-transformers

  使用pytorch框架以及transformers包,以及Bert的中文預(yù)訓(xùn)練模型

  文件目錄

  data

  Train_DataSet.csv

  Train_DataSet_Label.csv

  main.py

  NewsData.py

  #main.py

  from transformers import BertTokenizer

  from transformers import BertForSequenceClassification

  from transformers import BertConfig

  from transformers import BertPreTrainedModel

  import torch

  import torch.nn as nn

  from transformers import BertModel

  import time

  import argparse

  from NewsData import NewsData

  import os

  def get_train_args():

  parser=argparse.ArgumentParser()

  parser.add_argument('--batch_size',type=int,default=10,help = '每批數(shù)據(jù)的數(shù)量')

  parser.add_argument('--nepoch',type=int,default=3,help = '訓(xùn)練的輪次')

  parser.add_argument('--lr',type=float,default=0.001,help = '學(xué)習(xí)率')

  parser.add_argument('--gpu',type=bool,default=True,help = '是否使用gpu')

  parser.add_argument('--num_workers',type=int,default=2,help='dataloader使用的線程數(shù)量')

  parser.add_argument('--num_labels',type=int,default=3,help='分類類數(shù)')

  parser.add_argument('--data_path',type=str,default='./data',help='數(shù)據(jù)路徑')

  opt=parser.parse_args()

  print(opt)

  return opt

  def get_model(opt):

  #類方法.from_pretrained()獲取預(yù)訓(xùn)練模型,num_labels是分類的類數(shù)

  model = BertForSequenceClassification.from_pretrained('bert-base-chinese',num_labels=opt.num_labels)

  return model

  def get_data(opt):

  #NewsData繼承于pytorch的Dataset類

  trainset = NewsData(opt.data_path,is_train = 1)

  trainloader=torch.utils.data.DataLoader(trainset,batch_size=opt.batch_size,shuffle=True,num_workers=opt.num_workers)

  testset = NewsData(opt.data_path,is_train = 0)

  testloader=torch.utils.data.DataLoader(testset,batch_size=opt.batch_size,shuffle=False,num_workers=opt.num_workers)

  return trainloader,testloader

  def train(epoch,model,trainloader,testloader,optimizer,opt):

  print('\ntrain-Epoch: %d' % (epoch+1))

  model.train()

  start_time = time.time()

  print_step = int(len(trainloader)/10)

  for batch_idx,(sue,label,posi) in enumerate(trainloader):

  if opt.gpu:

  sue = sue.cuda()

  posi = posi.cuda()

  label = label.unsqueeze(1).cuda()

  optimizer.zero_grad()

  #輸入?yún)?shù)為詞列表、位置列表、標(biāo)簽

  outputs = model(sue, position_ids=posi,labels = label)

  loss, logits = outputs[0],outputs[1]

  loss.backward()

  optimizer.step()

  if batch_idx % print_step == 0:

  print("Epoch:%d [%d|%d] loss:%f" %(epoch+1,batch_idx,len(trainloader),loss.mean()))

  print("time:%.3f" % (time.time() - start_time))

  def test(epoch,model,trainloader,testloader,opt):

  print('\ntest-Epoch: %d' % (epoch+1))

  model.eval()

  total=0

  correct=0

  with torch.no_grad():

  for batch_idx,(sue,label,posi) in enumerate(testloader):

  if opt.gpu:

  sue = sue.cuda()

  posi = posi.cuda()

  labels = label.unsqueeze(1).cuda()

  label = label.cuda()

  else:

  labels = label.unsqueeze(1)

  outputs = model(sue, labels=labels)

  loss, logits = outputs[:2]

  _,predicted=torch.max(logits.data,1)

  total+=sue.size(0)

  correct+=predicted.data.eq(label.data).cpu().sum()

  s = ("Acc:%.3f" %((1.0*correct.numpy())/total))

  print(s)

  if __name__=='__main__':

  opt = get_train_args()

  model = get_model(opt)

  trainloader,testloader = get_data(opt)

  if opt.gpu:

  model.cuda()

  optimizer=torch.optim.SGD(model.parameters(),lr=opt.lr,momentum=0.9)

  if not os.path.exists('./model.pth'):

  for epoch in range(opt.nepoch):

  train(epoch,model,trainloader,testloader,optimizer,opt)

  test(epoch,model,trainloader,testloader,opt)

  torch.save(model.state_dict(),'./model.pth')

  else:鄭州治療婦科哪個(gè)醫(yī)院好 http://www.120kdfk.com/

  model.load_state_dict(torch.load('model.pth'))

  print('模型存在,直接test')

  test(0,model,trainloader,testloader,opt)

  #NewsData.py

  from transformers import BertTokenizer

  from transformers import BertForSequenceClassification

  from transformers import BertConfig

  from transformers import BertPreTrainedModel

  import torch

  import torch.nn as nn

  from transformers import BertModel

  import time

  import argparse

  class NewsData(torch.utils.data.Dataset):

  def __init__(self,root,is_train = 1):

  self.tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

  self.data_num = 7346

  self.x_list = []

  self.y_list = []

  self.posi = []

  with open(root + '/Train_DataSet.csv',encoding='UTF-8') as f:

  for i in range(self.data_num+1):

  line = f.readline()[:-1] + '這是一個(gè)中性的數(shù)據(jù)'

  data_one_str = line.split(',')[len(line.split(','))-2]

  data_two_str = line.split(',')[len(line.split(','))-1]

  if len(data_one_str) < 6:

  z = len(data_one_str)

  data_one_str = data_one_str + ',' + data_two_str[0:min(200,len(data_two_str))]

  else:

  data_one_str = data_one_str

  if i==0:

  continue

  word_l = self.tokenizer.encode(data_one_str, add_special_tokens=False)

  if len(word_l)<100:

  while(len(word_l)!=100):

  word_l.append(0)

  else:

  word_l = word_l[0:100]

  word_l.append(102)

  l = word_l

  word_l = [101]

  word_l.extend(l)

  self.x_list.append(torch.tensor(word_l))

  self.posi.append(torch.tensor([i for i in range(102)]))

  with open(root + '/Train_DataSet_Label.csv',encoding='UTF-8') as f:

  for i in range(self.data_num+1):

  #print(i)

  label_one = f.readline()[-2]

  if i==0:

  continue

  label_one = int(label_one)

  self.y_list.append(torch.tensor(label_one))

  #訓(xùn)練集或者是測(cè)試集

  if is_train == 1:

  self.x_list = self.x_list[0:6000]

  self.y_list = self.y_list[0:6000]

  self.posi = self.posi[0:6000]

  else:

  self.x_list = self.x_list[6000:]

  self.y_list = self.y_list[6000:]

  self.posi = self.posi[6000:]

  self.len = len(self.x_list)

  def __getitem__(self, index):

  return self.x_list[index], self.y_list[index],self.posi[index]

  def __len__(self):

  return self.len

感謝各位的閱讀,以上就是“怎么使用pytorch框架”的內(nèi)容了,經(jīng)過(guò)本文的學(xué)習(xí)后,相信大家對(duì)怎么使用pytorch框架這一問(wèn)題有了更深刻的體會(huì),具體使用情況還需要大家實(shí)踐驗(yàn)證。這里是創(chuàng)新互聯(lián),小編將為大家推送更多相關(guān)知識(shí)點(diǎn)的文章,歡迎關(guān)注!


新聞名稱:怎么使用pytorch框架-創(chuàng)新互聯(lián)
鏈接分享:http://weahome.cn/article/djisdp.html

其他資訊

在線咨詢

微信咨詢

電話咨詢

028-86922220(工作日)

18980820575(7×24)

提交需求

返回頂部