這篇文章將為大家詳細(xì)講解有關(guān)如何使用TensorFlow 載入模型,文章內(nèi)容質(zhì)量較高,因此小編分享給大家做個(gè)參考,希望大家閱讀完這篇文章后對相關(guān)知識有一定的了解。
成都創(chuàng)新互聯(lián)公司是一家專業(yè)提供五華企業(yè)網(wǎng)站建設(shè),專注與網(wǎng)站建設(shè)、網(wǎng)站設(shè)計(jì)、H5頁面制作、小程序制作等業(yè)務(wù)。10年已為五華眾多企業(yè)、政府機(jī)構(gòu)等服務(wù)。創(chuàng)新互聯(lián)專業(yè)網(wǎng)站設(shè)計(jì)公司優(yōu)惠進(jìn)行中。一、TensorFlow常規(guī)模型加載方法
保存模型
tf.train.Saver()類,.save(sess, ckpt文件目錄)方法
參數(shù)名稱 | 功能說明 | 默認(rèn)值 |
var_list | Saver中存儲變量集合 | 全局變量集合 |
reshape | 加載時(shí)是否恢復(fù)變量形狀 | True |
sharded | 是否將變量輪循放在所有設(shè)備上 | True |
max_to_keep | 保留最近檢查點(diǎn)個(gè)數(shù) | 5 |
restore_sequentially | 是否按順序恢復(fù)變量,模型較大時(shí)順序恢復(fù)內(nèi)存消耗小 | True |
var_list是字典形式{變量名字符串: 變量符號},相對應(yīng)的restore也根據(jù)同樣形式的字典將ckpt中的字符串對應(yīng)的變量加載給程序中的符號。
如果Saver給定了字典作為加載方式,則按照字典來,如:saver = tf.train.Saver({"v/ExponentialMovingAverage":v}),否則每個(gè)變量尋找自己的name屬性在ckpt中的對應(yīng)值進(jìn)行加載。
加載模型
當(dāng)我們基于checkpoint文件(ckpt)加載參數(shù)時(shí),實(shí)際上我們使用Saver.restore取代了initializer的初始化
checkpoint文件會記錄保存信息,通過它可以定位最新保存的模型:
ckpt = tf.train.get_checkpoint_state('./model/') print(ckpt.model_checkpoint_path)
.meta文件保存了當(dāng)前圖結(jié)構(gòu)
.index文件保存了當(dāng)前參數(shù)名
.data文件保存了當(dāng)前參數(shù)值
tf.train.import_meta_graph函數(shù)給出model.ckpt-n.meta的路徑后會加載圖結(jié)構(gòu),并返回saver對象
ckpt = tf.train.get_checkpoint_state('./model/')
tf.train.Saver函數(shù)會返回加載默認(rèn)圖的saver對象,saver對象初始化時(shí)可以指定變量映射方式,根據(jù)名字映射變量(『TensorFlow』滑動平均)
saver = tf.train.Saver({"v/ExponentialMovingAverage":v})
saver.restore函數(shù)給出model.ckpt-n的路徑后會自動尋找參數(shù)名-值文件進(jìn)行加載
saver.restore(sess,'./model/model.ckpt-0') saver.restore(sess,ckpt.model_checkpoint_path)
1.不加載圖結(jié)構(gòu),只加載參數(shù)
由于實(shí)際上我們參數(shù)保存的都是Variable變量的值,所以其他的參數(shù)值(例如batch_size)等,我們在restore時(shí)可能希望修改,但是圖結(jié)構(gòu)在train時(shí)一般就已經(jīng)確定了,所以我們可以使用tf.Graph().as_default()新建一個(gè)默認(rèn)圖(建議使用上下文環(huán)境),利用這個(gè)新圖修改和變量無關(guān)的參值大小,從而達(dá)到目的。
''' 使用原網(wǎng)絡(luò)保存的模型加載到自己重新定義的圖上 可以使用python變量名加載模型,也可以使用節(jié)點(diǎn)名 ''' import AlexNet as Net import AlexNet_train as train import random import tensorflow as tf IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg' with tf.Graph().as_default() as g: x = tf.placeholder(tf.float32, [1, train.INPUT_SIZE[0], train.INPUT_SIZE[1], 3]) y = Net.inference_1(x, N_CLASS=5, train=False) with tf.Session() as sess: # 程序前面得有 Variable 供 save or restore 才不報(bào)錯(cuò) # 否則會提示沒有可保存的變量 saver = tf.train.Saver() ckpt = tf.train.get_checkpoint_state('./model/') img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read() img = sess.run(tf.expand_dims(tf.image.resize_images( tf.image.decode_jpeg(img_raw),[224,224],method=random.randint(0,3)),0)) if ckpt and ckpt.model_checkpoint_path: print(ckpt.model_checkpoint_path) saver.restore(sess,'./model/model.ckpt-0') global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] res = sess.run(y, feed_dict={x: img}) print(global_step,sess.run(tf.argmax(res,1)))
2.加載圖結(jié)構(gòu)和參數(shù)
''' 直接使用使用保存好的圖 無需加載python定義的結(jié)構(gòu),直接使用節(jié)點(diǎn)名稱加載模型 由于節(jié)點(diǎn)形狀已經(jīng)定下來了,所以有不便之處,placeholder定義batch后單張傳會報(bào)錯(cuò) 現(xiàn)階段不推薦使用,以后如果理解深入了可能會找到使用方法 ''' import AlexNet_train as train import random import tensorflow as tf IMAGE_PATH = './flower_photos/daisy/5673728_71b8cb57eb.jpg' ckpt = tf.train.get_checkpoint_state('./model/') # 通過檢查點(diǎn)文件鎖定最新的模型 saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') # 載入圖結(jié)構(gòu),保存在.meta文件中 with tf.Session() as sess: saver.restore(sess,ckpt.model_checkpoint_path) # 載入?yún)?shù),參數(shù)保存在兩個(gè)文件中,不過restore會自己尋找 img_raw = tf.gfile.FastGFile(IMAGE_PATH, 'rb').read() img = sess.run(tf.image.resize_images( tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3))) imgs = [] for i in range(128): imgs.append(img) print(sess.run(tf.get_default_graph().get_tensor_by_name('fc3:0'),feed_dict={'Placeholder:0': imgs})) ''' img = sess.run(tf.expand_dims(tf.image.resize_images( tf.image.decode_jpeg(img_raw), train.INPUT_SIZE, method=random.randint(0, 3)), 0)) print(img) imgs = [] for i in range(128): imgs.append(img) print(sess.run(tf.get_default_graph().get_tensor_by_name('conv1:0'), feed_dict={'Placeholder:0':img}))
注意,在所有兩種方式中都可以通過調(diào)用節(jié)點(diǎn)名稱使用節(jié)點(diǎn)輸出張量,節(jié)點(diǎn).name屬性返回節(jié)點(diǎn)名稱。
3.簡化版本
# 連同圖結(jié)構(gòu)一同加載 ckpt = tf.train.get_checkpoint_state('./model/') saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path +'.meta') with tf.Session() as sess: saver.restore(sess,ckpt.model_checkpoint_path) # 只加載數(shù)據(jù),不加載圖結(jié)構(gòu),可以在新圖中改變batch_size等的值 # 不過需要注意,Saver對象實(shí)例化之前需要定義好新的圖結(jié)構(gòu),否則會報(bào)錯(cuò) saver = tf.train.Saver() with tf.Session() as sess: ckpt = tf.train.get_checkpoint_state('./model/') saver.restore(sess,ckpt.model_checkpoint_path)
二、TensorFlow二進(jìn)制模型加載方法
這種加載方法一般是對應(yīng)網(wǎng)上各大公司已經(jīng)訓(xùn)練好的網(wǎng)絡(luò)模型進(jìn)行修改的工作
# 新建空白圖 self.graph = tf.Graph() # 空白圖列為默認(rèn)圖 with self.graph.as_default(): # 二進(jìn)制讀取模型文件 with tf.gfile.FastGFile(os.path.join(model_dir,model_name),'rb') as f: # 新建GraphDef文件,用于臨時(shí)載入模型中的圖 graph_def = tf.GraphDef() # GraphDef加載模型中的圖 graph_def.ParseFromString(f.read()) # 在空白圖中加載GraphDef中的圖 tf.import_graph_def(graph_def,name='') # 在圖中獲取張量需要使用graph.get_tensor_by_name加張量名 # 這里的張量可以直接用于session的run方法求值了 # 補(bǔ)充一個(gè)基礎(chǔ)知識,形如'conv1'是節(jié)點(diǎn)名稱,而'conv1:0'是張量名稱,表示節(jié)點(diǎn)的第一個(gè)輸出張量 self.input_tensor = self.graph.get_tensor_by_name(self.input_tensor_name) self.layer_tensors = [self.graph.get_tensor_by_name(name + ':0') for name in self.layer_operation_names]
關(guān)于如何使用TensorFlow 載入模型就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,可以學(xué)到更多知識。如果覺得文章不錯(cuò),可以把它分享出去讓更多的人看到。