在使用TensorFlow訓練神經網絡時,首先面臨的問題是:網絡的輸入
創(chuàng)新互聯(lián)建站是一家專注于網站制作、成都做網站與策劃設計,東坡網站建設哪家好?創(chuàng)新互聯(lián)建站做網站,專注于網站建設10年,網設計領域的專業(yè)建站公司;建站業(yè)務涵蓋:東坡等地區(qū)。東坡做網站價格咨詢:18982081108此篇文章,教大家將自己的數據集制作成TFRecord格式,feed進網絡,除了TFRecord格式,TensorFlow也支持其他格
式的數據,此處就不再介紹了。建議大家使用TFRecord格式,在后面可以通過api進行多線程的讀取文件隊列。
1. 原本的數據集
此時,我有兩類圖片,分別是xiansu100,xiansu60,每一類中有10張圖片。
2.制作成TFRecord格式
tfrecord會根據你選擇輸入文件的類,自動給每一類打上同樣的標簽。如在本例中,只有0,1 兩類,想知道文件夾名與label關系的,可以自己保存起來。
#生成整數型的屬性 def _int64_feature(value): return tf.train.Feature(int64_list = tf.train.Int64List(value = [value])) #生成字符串類型的屬性 def _bytes_feature(value): return tf.train.Feature(bytes_list = tf.train.BytesList(value = [value])) #制作TFRecord格式 def createTFRecord(filename,mapfile): class_map = {} data_dir = '/home/wc/DataSet/traffic/testTFRecord/' classes = {'xiansu60','xiansu100'} #輸出TFRecord文件的地址 writer = tf.python_io.TFRecordWriter(filename) for index,name in enumerate(classes): class_path=data_dir+name+'/' class_map[index] = name for img_name in os.listdir(class_path): img_path = class_path + img_name #每個圖片的地址 img = Image.open(img_path) img= img.resize((224,224)) img_raw = img.tobytes() #將圖片轉化成二進制格式 example = tf.train.Example(features = tf.train.Features(feature = { 'label':_int64_feature(index), 'image_raw': _bytes_feature(img_raw) })) writer.write(example.SerializeToString()) writer.close() txtfile = open(mapfile,'w+') for key in class_map.keys(): txtfile.writelines(str(key)+":"+class_map[key]+"\n") txtfile.close()