這篇文章主要為大家展示了“如何使用tensorflow識別自己手寫數(shù)字”,內(nèi)容簡而易懂,條理清晰,希望能夠幫助大家解決疑惑,下面讓小編帶領(lǐng)大家一起研究并學(xué)習(xí)一下“如何使用tensorflow識別自己手寫數(shù)字”這篇文章吧。
創(chuàng)新互聯(lián)建站是一家專業(yè)提供廊坊企業(yè)網(wǎng)站建設(shè),專注與成都網(wǎng)站建設(shè)、成都網(wǎng)站制作、html5、小程序制作等業(yè)務(wù)。10年已為廊坊眾多企業(yè)、政府機(jī)構(gòu)等服務(wù)。創(chuàng)新互聯(lián)專業(yè)網(wǎng)站建設(shè)公司優(yōu)惠進(jìn)行中。tensorflow作為google開源的項目,現(xiàn)在趕超了caffe,好像成為最受歡迎的深度學(xué)習(xí)框架。確實在編寫的時候更能感受到代碼的真實存在,這點和caffe不同,caffe通過編寫配置文件進(jìn)行網(wǎng)絡(luò)的生成。環(huán)境tensorflow是0.10的版本,注意其他版本有的語句會有錯誤,這是tensorflow版本之間的兼容問題。
還需要安裝PIL:pip install Pillow
圖片的格式:
– 圖像標(biāo)準(zhǔn)化,可安裝在20×20像素的框內(nèi),同時保留其長寬比。
– 圖片都集中在一個28×28的圖像中。
– 像素以列為主進(jìn)行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。
創(chuàng)建一個.png的文件,背景是白色的,手寫的字體是黑色的,
下面是數(shù)據(jù)測試的代碼,一個兩層的卷積神經(jīng)網(wǎng),然后用save進(jìn)行模型的保存。
# coding: UTF-8 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import input_data ''''' 得到數(shù)據(jù) ''' mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) training = mnist.train.images trainlable = mnist.train.labels testing = mnist.test.images testlabel = mnist.test.labels print ("MNIST loaded") # 獲取交互式的方式 sess = tf.InteractiveSession() # 初始化變量 x = tf.placeholder("float", shape=[None, 784]) y_ = tf.placeholder("float", shape=[None, 10]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) ''''' 生成權(quán)重函數(shù),其中shape是數(shù)據(jù)的形狀 ''' def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) ''''' 生成偏執(zhí)項 其中shape是數(shù)據(jù)形狀 ''' def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') W_conv1 = weight_variable([5, 5, 1, 32]) b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1, 28, 28, 1]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_2x2(h_conv1) W_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_2x2(h_conv2) W_fc1 = weight_variable([7 * 7 * 64, 1024]) b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder("float") h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv)) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy) correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) # 保存網(wǎng)絡(luò)訓(xùn)練的參數(shù) saver = tf.train.Saver() sess.run(tf.initialize_all_variables()) for i in range(8000): batch = mnist.train.next_batch(50) if i%100 == 0: train_accuracy = accuracy.eval(feed_dict={ x:batch[0], y_: batch[1], keep_prob: 1.0}) print "step %d, training accuracy %g"%(i, train_accuracy) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) save_path = saver.save(sess, "model_mnist.ckpt") print("Model saved in life:", save_path) print "test accuracy %g"%accuracy.eval(feed_dict={ x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
其中input_data.py如下代碼,是進(jìn)行mnist數(shù)據(jù)集的下載的:代碼是由mnist數(shù)據(jù)集提供的官方下載的版本。
# Copyright 2015 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Functions for downloading and reading MNIST data.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import gzip import os import tensorflow.python.platform import numpy from six.moves import urllib from six.moves import xrange # pylint: disable=redefined-builtin import tensorflow as tf SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' def maybe_download(filename, work_directory): """Download the data from Yann's website, unless it's already here.""" if not os.path.exists(work_directory): os.mkdir(work_directory) filepath = os.path.join(work_directory, filename) if not os.path.exists(filepath): filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath) statinfo = os.stat(filepath) print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') return filepath def _read32(bytestream): dt = numpy.dtype(numpy.uint32).newbyteorder('>') return numpy.frombuffer(bytestream.read(4), dtype=dt)[0] def extract_images(filename): """Extract the images into a 4D uint8 numpy array [index, y, x, depth].""" print('Extracting', filename) with gzip.open(filename) as bytestream: magic = _read32(bytestream) if magic != 2051: raise ValueError( 'Invalid magic number %d in MNIST image file: %s' % (magic, filename)) num_images = _read32(bytestream) rows = _read32(bytestream) cols = _read32(bytestream) buf = bytestream.read(rows * cols * num_images) data = numpy.frombuffer(buf, dtype=numpy.uint8) data = data.reshape(num_images, rows, cols, 1) return data def dense_to_one_hot(labels_dense, num_classes=10): """Convert class labels from scalars to one-hot vectors.""" num_labels = labels_dense.shape[0] index_offset = numpy.arange(num_labels) * num_classes labels_one_hot = numpy.zeros((num_labels, num_classes)) labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1 return labels_one_hot def extract_labels(filename, one_hot=False): """Extract the labels into a 1D uint8 numpy array [index].""" print('Extracting', filename) with gzip.open(filename) as bytestream: magic = _read32(bytestream) if magic != 2049: raise ValueError( 'Invalid magic number %d in MNIST label file: %s' % (magic, filename)) num_items = _read32(bytestream) buf = bytestream.read(num_items) labels = numpy.frombuffer(buf, dtype=numpy.uint8) if one_hot: return dense_to_one_hot(labels) return labels class DataSet(object): def __init__(self, images, labels, fake_data=False, one_hot=False, dtype=tf.float32): """Construct a DataSet. one_hot arg is used only if fake_data is true. `dtype` can be either `uint8` to leave the input as `[0, 255]`, or `float32` to rescale into `[0, 1]`. """ dtype = tf.as_dtype(dtype).base_dtype if dtype not in (tf.uint8, tf.float32): raise TypeError('Invalid image dtype %r, expected uint8 or float32' % dtype) if fake_data: self._num_examples = 10000 self.one_hot = one_hot else: assert images.shape[0] == labels.shape[0], ( 'images.shape: %s labels.shape: %s' % (images.shape, labels.shape)) self._num_examples = images.shape[0] # Convert shape from [num examples, rows, columns, depth] # to [num examples, rows*columns] (assuming depth == 1) assert images.shape[3] == 1 images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]) if dtype == tf.float32: # Convert from [0, 255] -> [0.0, 1.0]. images = images.astype(numpy.float32) images = numpy.multiply(images, 1.0 / 255.0) self._images = images self._labels = labels self._epochs_completed = 0 self._index_in_epoch = 0 @property def images(self): return self._images @property def labels(self): return self._labels @property def num_examples(self): return self._num_examples @property def epochs_completed(self): return self._epochs_completed def next_batch(self, batch_size, fake_data=False): """Return the next `batch_size` examples from this data set.""" if fake_data: fake_image = [1] * 784 if self.one_hot: fake_label = [1] + [0] * 9 else: fake_label = 0 return [fake_image for _ in xrange(batch_size)], [ fake_label for _ in xrange(batch_size)] start = self._index_in_epoch self._index_in_epoch += batch_size if self._index_in_epoch > self._num_examples: # Finished epoch self._epochs_completed += 1 # Shuffle the data perm = numpy.arange(self._num_examples) numpy.random.shuffle(perm) self._images = self._images[perm] self._labels = self._labels[perm] # Start next epoch start = 0 self._index_in_epoch = batch_size assert batch_size <= self._num_examples end = self._index_in_epoch return self._images[start:end], self._labels[start:end] def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32): class DataSets(object): pass data_sets = DataSets() if fake_data: def fake(): return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype) data_sets.train = fake() data_sets.validation = fake() data_sets.test = fake() return data_sets TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' TEST_LABELS = 't10k-labels-idx1-ubyte.gz' VALIDATION_SIZE = 5000 local_file = maybe_download(TRAIN_IMAGES, train_dir) train_images = extract_images(local_file) local_file = maybe_download(TRAIN_LABELS, train_dir) train_labels = extract_labels(local_file, one_hot=one_hot) local_file = maybe_download(TEST_IMAGES, train_dir) test_images = extract_images(local_file) local_file = maybe_download(TEST_LABELS, train_dir) test_labels = extract_labels(local_file, one_hot=one_hot) validation_images = train_images[:VALIDATION_SIZE] validation_labels = train_labels[:VALIDATION_SIZE] train_images = train_images[VALIDATION_SIZE:] train_labels = train_labels[VALIDATION_SIZE:] data_sets.train = DataSet(train_images, train_labels, dtype=dtype) data_sets.validation = DataSet(validation_images, validation_labels, dtype=dtype) data_sets.test = DataSet(test_images, test_labels, dtype=dtype) return data_sets
然后進(jìn)行代碼的測試:
# import modules import sys import tensorflow as tf from PIL import Image, ImageFilter def predictint(imvalue): """ This function returns the predicted integer. The imput is the pixel values from the imageprepare() function. """ # Define the model (same as when creating the model file) x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) def weight_variable(shape): initial = tf.truncated_normal(shape, stddev=0.1) return tf.Variable(initial) def bias_variable(shape): initial = tf.constant(0.1, shape=shape) return tf.Variable(initial) def conv2d(x, W): return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME') def max_pool_2x2(x): return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME') W_conv1 = weight_variable([5, 5, 1, 32]) b_conv1 = bias_variable([32]) x_image = tf.reshape(x, [-1, 28, 28, 1]) h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1) h_pool1 = max_pool_2x2(h_conv1) W_conv2 = weight_variable([5, 5, 32, 64]) b_conv2 = bias_variable([64]) h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2) h_pool2 = max_pool_2x2(h_conv2) W_fc1 = weight_variable([7 * 7 * 64, 1024]) b_fc1 = bias_variable([1024]) h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64]) h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1) keep_prob = tf.placeholder(tf.float32) h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob) W_fc2 = weight_variable([1024, 10]) b_fc2 = bias_variable([10]) y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2) init_op = tf.initialize_all_variables() saver = tf.train.Saver() """ Load the model_mnist.ckpt file file is stored in the same directory as this python script is started Use the model to predict the integer. Integer is returend as list. Based on the documentatoin at https://www.tensorflow.org/versions/master/how_tos/variables/index.html """ with tf.Session() as sess: sess.run(init_op) saver.restore(sess, "model_mnist.ckpt") # print ("Model restored.") prediction = tf.argmax(y_conv, 1) return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess) def imageprepare(argv): """ This function returns the pixel values. The imput is a png file location. """ im = Image.open(argv).convert('L') width = float(im.size[0]) height = float(im.size[1]) newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels if width > height: # check which dimension is bigger # Width is bigger. Width becomes 20 pixels. nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width if (nheight == 0): # rare case but minimum is 1 pixel nheigth = 1 # resize and sharpen img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition newImage.paste(img, (4, wtop)) # paste resized image on white canvas else: # Height is bigger. Heigth becomes 20 pixels. nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height if (nwidth == 0): # rare case but minimum is 1 pixel nwidth = 1 # resize and sharpen img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN) wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition newImage.paste(img, (wleft, 4)) # paste resized image on white canvas # newImage.save("sample.png") tv = list(newImage.getdata()) # get pixel values # normalize pixels to 0 and 1. 0 is pure white, 1 is pure black. tva = [(255 - x) * 1.0 / 255.0 for x in tv] return tva # print(tva) def main(argv): """ Main function. """ imvalue = imageprepare(argv) predint = predictint(imvalue) print (predint[0]) # first value in list if __name__ == "__main__": main('2.png')
其中我用于測試的代碼如下:
可以將圖片另存到路徑下面,然后進(jìn)行測試。
(1)載入我的手寫數(shù)字的圖像。
(2)將圖像轉(zhuǎn)換為黑白(模式“L”)
(3)確定原始圖像的尺寸是大的
(4)調(diào)整圖像的大小,使得大尺寸(醚的高度及寬度)為20像素,并且以相同的比例最小化尺寸刻度。
(5)銳化圖像。這會極大地強(qiáng)化結(jié)果。
(6)把圖像粘貼在28×28像素的白色畫布上。在大的尺寸上從頂部或側(cè)面居中圖像4個像素。大尺寸始終是20個像素和4 + 20 + 4 = 28,最小尺寸被定位在28和縮放的圖像的新的大小之間差的一半。
(7)獲取新的圖像(畫布+居中的圖像)的像素值。
(8)歸一化像素值到0和1之間的一個值(這也在TensorFlow MNIST教程中完成)。其中0是白色的,1是純黑色。從步驟7得到的像素值是與之相反的,其中255是白色的,0黑色,所以數(shù)值必須反轉(zhuǎn)。下述公式包括反轉(zhuǎn)和規(guī)格化(255-X)* 1.0 / 255.0
以上是“如何使用tensorflow識別自己手寫數(shù)字”這篇文章的所有內(nèi)容,感謝各位的閱讀!相信大家都有了一定的了解,希望分享的內(nèi)容對大家有所幫助,如果還想學(xué)習(xí)更多知識,歡迎關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道!