真实的国产乱ⅩXXX66竹夫人,五月香六月婷婷激情综合,亚洲日本VA一区二区三区,亚洲精品一区二区三区麻豆

成都創(chuàng)新互聯(lián)網(wǎng)站制作重慶分公司

PyTorch如何實現(xiàn)手寫數(shù)字識別功能-創(chuàng)新互聯(lián)

這篇文章主要介紹了PyTorch如何實現(xiàn)手寫數(shù)字識別功能,具有一定借鑒價值,感興趣的朋友可以參考下,希望大家閱讀完這篇文章之后大有收獲,下面讓小編帶著大家一起了解一下。

成都創(chuàng)新互聯(lián)公司長期為千余家客戶提供的網(wǎng)站建設(shè)服務(wù),團隊從業(yè)經(jīng)驗10年,關(guān)注不同地域、不同群體,并針對不同對象提供差異化的產(chǎn)品和服務(wù);打造開放共贏平臺,與合作伙伴共同營造健康的互聯(lián)網(wǎng)生態(tài)環(huán)境。為欒城企業(yè)提供專業(yè)的成都網(wǎng)站制作、成都網(wǎng)站建設(shè),欒城網(wǎng)站改版等技術(shù)服務(wù)。擁有十年豐富建站經(jīng)驗和眾多成功案例,為您定制開發(fā)。

MNIST 手寫數(shù)字識別是一個比較簡單的入門項目,相當(dāng)于深度學(xué)習(xí)中的 Hello World,可以讓我們快速了解構(gòu)建神經(jīng)網(wǎng)絡(luò)的大致過程。雖然網(wǎng)上的案例比較多,但還是要自己實現(xiàn)一遍。代碼采用 PyTorch 1.0 編寫并運行。

導(dǎo)入相關(guān)庫

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import torchvision
from torch.autograd import Variable
from torch.utils.data import DataLoader
import cv2

torchvision 用于下載并導(dǎo)入數(shù)據(jù)集

cv2 用于展示數(shù)據(jù)的圖像

獲取訓(xùn)練集和測試集

# 下載訓(xùn)練集
train_dataset = datasets.MNIST(root='./num/',
                train=True,
                transform=transforms.ToTensor(),
                download=True)
# 下載測試集
test_dataset = datasets.MNIST(root='./num/',
               train=False,
               transform=transforms.ToTensor(),
               download=True)

root 用于指定數(shù)據(jù)集在下載之后的存放路徑

transform 用于指定導(dǎo)入數(shù)據(jù)集需要對數(shù)據(jù)進行那種變化操作

train是指定在數(shù)據(jù)集下載完成后需要載入的那部分?jǐn)?shù)據(jù),設(shè)置為 True 則說明載入的是該數(shù)據(jù)集的訓(xùn)練集部分,設(shè)置為 False 則說明載入的是該數(shù)據(jù)集的測試集部分

download 為 True 表示數(shù)據(jù)集需要程序自動幫你下載

這樣設(shè)置并運行后,就會在指定路徑中下載 MNIST 數(shù)據(jù)集,之后就可以使用了。

數(shù)據(jù)裝載和預(yù)覽

# dataset 參數(shù)用于指定我們載入的數(shù)據(jù)集名稱
# batch_size參數(shù)設(shè)置了每個包中的圖片數(shù)據(jù)個數(shù)
# 在裝載的過程會將數(shù)據(jù)隨機打亂順序并進打包

# 裝載訓(xùn)練集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                      batch_size=batch_size,
                      shuffle=True)
# 裝載測試集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
                     batch_size=batch_size,
                     shuffle=True)

在裝載完成后,可以選取其中一個批次的數(shù)據(jù)進行預(yù)覽:

images, labels = next(iter(data_loader_train))
img = torchvision.utils.make_grid(images)

img = img.numpy().transpose(1, 2, 0)
std = [0.5, 0.5, 0.5]
mean = [0.5, 0.5, 0.5]
img = img * std + mean
print(labels)
cv2.imshow('win', img)
key_pressed = cv2.waitKey(0)

在以上代碼中使用了 iter 和 next 來獲取取一個批次的圖片數(shù)據(jù)和其對應(yīng)的圖片標(biāo)簽,然后使用 torchvision.utils 中的 make_grid 類方法將一個批次的圖片構(gòu)造成網(wǎng)格模式。

預(yù)覽圖片如下:

PyTorch如何實現(xiàn)手寫數(shù)字識別功能

并且打印出了圖片相對應(yīng)的數(shù)字:

PyTorch如何實現(xiàn)手寫數(shù)字識別功能

搭建神經(jīng)網(wǎng)絡(luò)

# 卷積層使用 torch.nn.Conv2d
# 激活層使用 torch.nn.ReLU
# 池化層使用 torch.nn.MaxPool2d
# 全連接層使用 torch.nn.Linear

class LeNet(nn.Module):
  def __init__(self):
    super(LeNet, self).__init__()
    self.conv1 = nn.Sequential(nn.Conv2d(1, 6, 3, 1, 2), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.conv2 = nn.Sequential(nn.Conv2d(6, 16, 5), nn.ReLU(),
                  nn.MaxPool2d(2, 2))

    self.fc1 = nn.Sequential(nn.Linear(16 * 5 * 5, 120),
                 nn.BatchNorm1d(120), nn.ReLU())

    self.fc2 = nn.Sequential(
      nn.Linear(120, 84),
      nn.BatchNorm1d(84),
      nn.ReLU(),
      nn.Linear(84, 10))
    	# 最后的結(jié)果一定要變?yōu)?nbsp;10,因為數(shù)字的選項是 0 ~ 9

  def forward(self, x):
    x = self.conv1(x)
    x = self.conv2(x)
    x = x.view(x.size()[0], -1)
    x = self.fc1(x)
    x = self.fc2(x)
    x = self.fc3(x)
    return x

前向傳播內(nèi)容:

首先經(jīng)過 self.conv1() 和 self.conv1() 進行卷積處理

然后進行 x = x.view(x.size()[0], -1),對參數(shù)實現(xiàn)扁平化(便于后面全連接層輸入)

最后通過 self.fc1() 和 self.fc2() 定義的全連接層進行最后的分類

訓(xùn)練模型

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
LR = 0.001

net = LeNet().to(device)
# 損失函數(shù)使用交叉熵
criterion = nn.CrossEntropyLoss()
# 優(yōu)化函數(shù)使用 Adam 自適應(yīng)優(yōu)化算法
optimizer = optim.Adam(
  net.parameters(),
  lr=LR,
)

epoch = 1
if __name__ == '__main__':
  for epoch in range(epoch):
    sum_loss = 0.0
    for i, data in enumerate(train_loader):
      inputs, labels = data
      inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
      optimizer.zero_grad() #將梯度歸零
      outputs = net(inputs) #將數(shù)據(jù)傳入網(wǎng)絡(luò)進行前向運算
      loss = criterion(outputs, labels) #得到損失函數(shù)
      loss.backward() #反向傳播
      optimizer.step() #通過梯度做一步參數(shù)更新

      # print(loss)
      sum_loss += loss.item()
      if i % 100 == 99:
        print('[%d,%d] loss:%.03f' %
           (epoch + 1, i + 1, sum_loss / 100))
        sum_loss = 0.0

測試模型

net.eval() #將模型變換為測試模式
  correct = 0
  total = 0
  for data_test in test_loader:
    images, labels = data_test
    images, labels = Variable(images).cuda(), Variable(labels).cuda()
    output_test = net(images)
    _, predicted = torch.max(output_test, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
  print("correct1: ", correct)
  print("Test acc: {0}".format(correct.item() /
                 len(test_dataset)))

訓(xùn)練及測試的情況:

PyTorch如何實現(xiàn)手寫數(shù)字識別功能

98% 以上的成功率,效果還不錯。

感謝你能夠認(rèn)真閱讀完這篇文章,希望小編分享的“PyTorch如何實現(xiàn)手寫數(shù)字識別功能”這篇文章對大家有幫助,同時也希望大家多多支持創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計公司,關(guān)注創(chuàng)新互聯(lián)成都網(wǎng)站設(shè)計公司行業(yè)資訊頻道,更多相關(guān)知識等著你來學(xué)習(xí)!

另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無理由+7*72小時售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、網(wǎng)站設(shè)計器、香港服務(wù)器、美國服務(wù)器、虛擬主機、免備案服務(wù)器”等云主機租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡單易用、服務(wù)可用性高、性價比高”等特點與優(yōu)勢,專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場景需求。


分享文章:PyTorch如何實現(xiàn)手寫數(shù)字識別功能-創(chuàng)新互聯(lián)
轉(zhuǎn)載注明:http://weahome.cn/article/dccigh.html

其他資訊

在線咨詢

微信咨詢

電話咨詢

028-86922220(工作日)

18980820575(7×24)

提交需求

返回頂部