真实的国产乱ⅩXXX66竹夫人,五月香六月婷婷激情综合,亚洲日本VA一区二区三区,亚洲精品一区二区三区麻豆

成都創(chuàng)新互聯(lián)網(wǎng)站制作重慶分公司

pytorch下使用LSTM神經(jīng)網(wǎng)絡(luò)寫詩實例-創(chuàng)新互聯(lián)

在pytorch下,以數(shù)萬首唐詩為素材,訓(xùn)練雙層LSTM神經(jīng)網(wǎng)絡(luò),使其能夠以唐詩的方式寫詩。

創(chuàng)新互聯(lián)公司服務(wù)項目包括安定網(wǎng)站建設(shè)、安定網(wǎng)站制作、安定網(wǎng)頁制作以及安定網(wǎng)絡(luò)營銷策劃等。多年來,我們專注于互聯(lián)網(wǎng)行業(yè),利用自身積累的技術(shù)優(yōu)勢、行業(yè)經(jīng)驗、深度合作伙伴關(guān)系等,向廣大中小型企業(yè)、政府機(jī)構(gòu)等提供互聯(lián)網(wǎng)行業(yè)的解決方案,安定網(wǎng)站推廣取得了明顯的社會效益與經(jīng)濟(jì)效益。目前,我們服務(wù)的客戶以成都為中心已經(jīng)輻射到安定省份的部分城市,未來相信會繼續(xù)擴(kuò)大服務(wù)區(qū)域并繼續(xù)獲得客戶的支持與信任!

代碼結(jié)構(gòu)分為四部分,分別為

1.model.py,定義了雙層LSTM模型

2.data.py,定義了從網(wǎng)上得到的唐詩數(shù)據(jù)的處理方法

3.utlis.py 定義了損失可視化的函數(shù)

4.main.py定義了模型參數(shù),以及訓(xùn)練、唐詩生成函數(shù)。

參考:電子工業(yè)出版社的《深度學(xué)習(xí)框架PyTorch:入門與實踐》第九章

main代碼及注釋如下

import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
 
class Config(object):
	data_path = 'data/'
	pickle_path = 'tang.npz'
	author = None
	constrain = None
	category = 'poet.tang' #or poet.song
	lr = 1e-3
	weight_decay = 1e-4
	use_gpu = True
	epoch = 20
	batch_size = 128
	maxlen = 125
	plot_every = 20
	#use_env = True #是否使用visodm
	env = 'poety' 
	#visdom env
	max_gen_len = 200
	debug_file = '/tmp/debugp'
	model_path = None
	prefix_words = '細(xì)雨魚兒出,微風(fēng)燕子斜。' 
	#不是詩歌組成部分,是意境
	start_words = '閑云潭影日悠悠' 
	#詩歌開始
	acrostic = False 
	#是否藏頭
	model_prefix = 'checkpoints/tang' 
	#模型保存路徑
opt = Config()
 
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
	'''
	給定幾個詞,根據(jù)這幾個詞接著生成一首完整的詩歌
	'''
	results = list(start_words)
	start_word_len = len(start_words)
	# 手動設(shè)置第一個詞為
	# 這個地方有問題,最后需要再看一下
	input = Variable(t.Tensor([word2ix['']]).view(1,1).long())
	if opt.use_gpu:input=input.cuda()
	hidden = None
	
	if prefix_words:
		for word in prefix_words:
			output,hidden = model(input,hidden)
			# 下邊這句話是為了把input變成1*1?
			input = Variable(input.data.new([word2ix[word]])).view(1,1)
	for i in range(opt.max_gen_len):
		output,hidden = model(input,hidden)
		
		if i':
			del results[-1] #-1的意思是倒數(shù)第一個
			break
	return results
 
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
 '''
 生成藏頭詩
 start_words : u'深度學(xué)習(xí)'
 生成:
 深木通中岳,青苔半日脂。
 度山分地險,逆浪到南巴。
 學(xué)道兵猶毒,當(dāng)時燕不移。
 習(xí)根通古岸,開鏡出清羸。
 '''
 results = []
 start_word_len = len(start_words)
 input = Variable(t.Tensor([word2ix['']]).view(1,1).long())
 if opt.use_gpu:input=input.cuda()
 hidden = None
 
 index=0 # 用來指示已經(jīng)生成了多少句藏頭詩
 # 上一個詞
 pre_word=''
 
 if prefix_words:
  for word in prefix_words:
   output,hidden = model(input,hidden)
   input = Variable(input.data.new([word2ix[word]])).view(1,1)
 
 for i in range(opt.max_gen_len):
  output,hidden = model(input,hidden)
  top_index = output.data[0].topk(1)[1][0]
  w = ix2word[top_index]
 
  if (pre_word in {u'。',u'!',''} ):
   # 如果遇到句號,藏頭的詞送進(jìn)去生成
 
   if index==start_word_len:
    # 如果生成的詩歌已經(jīng)包含全部藏頭的詞,則結(jié)束
    break
   else: 
    # 把藏頭的詞作為輸入送入模型
    w = start_words[index]
    index+=1
    input = Variable(input.data.new([word2ix[w]])).view(1,1) 
  else:
   # 否則的話,把上一次預(yù)測是詞作為下一個詞輸入
   input = Variable(input.data.new([word2ix[w]])).view(1,1)
  results.append(w)
  pre_word = w
 return results
 
def train(**kwargs):
	
	for k,v in kwargs.items():
		setattr(opt,k,v) #設(shè)置apt里屬性的值
	vis = Visualizer(env=opt.env)
	
	#獲取數(shù)據(jù)
	data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函數(shù)
	data = t.from_numpy(data)
	#這個地方出錯了,是大寫的L
	dataloader = t.utils.data.DataLoader(data, 
					batch_size = opt.batch_size,
					shuffle = True,
					num_workers = 1) #在python里,這樣寫程序可以嗎?
 #模型定義
	model = PoetryModel(len(word2ix), 128, 256)
	optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
	criterion = nn.CrossEntropyLoss()
 
	if opt.model_path:
		model.load_state_dict(t.load(opt.model_path))
	if opt.use_gpu:
		model.cuda()
		criterion.cuda()
		
	#The tnt.AverageValueMeter measures and returns the average value 
	#and the standard deviation of any collection of numbers that are 
	#added to it. It is useful, for instance, to measure the average 
	#loss over a collection of examples.
 
 #The add() function expects as input a Lua number value, which 
 #is the value that needs to be added to the list of values to 
 #average. It also takes as input an optional parameter n that 
 #assigns a weight to value in the average, in order to facilitate 
 #computing weighted averages (default = 1).
 
 #The tnt.AverageValueMeter has no parameters to be set at initialization time. 
	loss_meter = meter.AverageValueMeter()
	
	for epoch in range(opt.epoch):
		loss_meter.reset()
		for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
			#tqdm是python中的進(jìn)度條
			#訓(xùn)練
			data_ = data_.long().transpose(1,0).contiguous()
			#上邊一句話,把data_變成long類型,把1維和0維轉(zhuǎn)置,把內(nèi)存調(diào)成連續(xù)的
			if opt.use_gpu: data_ = data_.cuda()
			optimizer.zero_grad()
			input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
			#上邊一句,將輸入的詩句錯開一個字,形成訓(xùn)練和目標(biāo)
			output,_ = model(input_)
			loss = criterion(output, target.view(-1))
			loss.backward()
			optimizer.step()
			
			loss_meter.add(loss.data[0]) #為什么是data[0]?
			
			#可視化用到的是utlis.py里的函數(shù)
			if (1+ii)%opt.plot_every ==0:
				
				if os.path.exists(opt.debug_file):
					ipdb.set_trace()
				vis.plot('loss',loss_meter.value()[0])
				
				# 下面是對目前模型情況的測試,詩歌原文
				poetrys = [[ix2word[_word] for _word in data_[:,_iii]] 
									for _iii in range(data_.size(1))][:16]
				#上面句子嵌套了兩個循環(huán),主要是將詩歌索引的前十六個字變成原文
				vis.text('
'.join([''.join(poetry) for poetry in poetrys]),win = u'origin_poem') gen_poetries = [] #分別以以下幾個字作為詩歌的第一個字,生成8首詩 for word in list(u'春江花月夜涼如水'): gen_poetry = ''.join(generate(model,word,ix2word,word2ix)) gen_poetries.append(gen_poetry) vis.text('
'.join([''.join(poetry) for poetry in gen_poetries]), win = u'gen_poem') t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch)) def gen(**kwargs): ''' 提供命令行接口,用以生成相應(yīng)的詩 ''' for k,v in kwargs.items(): setattr(opt,k,v) data, word2ix, ix2word = get_data(opt) model = PoetryModel(len(word2ix), 128, 256) map_location = lambda s,l:s # 上邊句子里的map_location是在load里用的,用以加載到指定的CPU或GPU, # 上邊句子的意思是將模型加載到默認(rèn)的GPU上 state_dict = t.load(opt.model_path, map_location = map_location) model.load_state_dict(state_dict) if opt.use_gpu: model.cuda() if sys.version_info.major == 3: if opt.start_words.insprintable(): start_words = opt.start_words prefix_words = opt.prefix_words if opt.prefix_words else None else: start_words = opt.start_words.encode('ascii',\ 'surrogateescape').decode('utf8') prefix_words = opt.prefix_words.encode('ascii',\ 'surrogateescape').decode('utf8') if opt.prefix_words else None start_words = start_words.replace(',',u',')\ .replace('.',u'。')\ .replace('?',u'?') gen_poetry = gen_acrostic if opt.acrostic else generate result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words) print(''.join(result)) if __name__ == '__main__': import fire fire.Fire()

另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無理由+7*72小時售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國服務(wù)器、虛擬主機(jī)、免備案服務(wù)器”等云主機(jī)租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡單易用、服務(wù)可用性高、性價比高”等特點與優(yōu)勢,專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場景需求。


文章標(biāo)題:pytorch下使用LSTM神經(jīng)網(wǎng)絡(luò)寫詩實例-創(chuàng)新互聯(lián)
文章轉(zhuǎn)載:http://weahome.cn/article/dcssdj.html

其他資訊

在線咨詢

微信咨詢

電話咨詢

028-86922220(工作日)

18980820575(7×24)

提交需求

返回頂部