本篇內(nèi)容介紹了“PyTorch frozen怎么使用”的有關(guān)知識,在實際案例的操作過程中,不少人都會遇到這樣的困境,接下來就讓小編帶領大家學習一下如何處理這些情況吧!希望大家仔細閱讀,能夠?qū)W有所成!
創(chuàng)新互聯(lián)建站專注于企業(yè)營銷型網(wǎng)站建設、網(wǎng)站重做改版、眉縣網(wǎng)站定制設計、自適應品牌網(wǎng)站建設、H5響應式網(wǎng)站、購物商城網(wǎng)站建設、集團公司官網(wǎng)建設、成都外貿(mào)網(wǎng)站制作、高端網(wǎng)站制作、響應式網(wǎng)頁設計等建站業(yè)務,價格優(yōu)惠性價比高,為眉縣等各大城市提供網(wǎng)站開發(fā)制作服務。
# ============================ step 2/5 模型 ============================ # 1/3 構(gòu)建模型 resnet18_ft = models.resnet18() # 2/3 加載參數(shù) # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 3/3 替換fc層 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device)
# ============================ step 2/5 模型 ============================ # 1/3 構(gòu)建模型 resnet18_ft = models.resnet18() # 2/3 加載參數(shù) # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 法1 : 凍結(jié)卷積層 flag_m1 = 0 # flag_m1 = 1 if flag_m1: for param in resnet18_ft.parameters(): param.requires_grad = False print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...])) # 3/3 替換fc層 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device)
# -*- coding: utf-8 -*- """ # @brief : 模型finetune方法 """ import os import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torchvision.transforms as transforms import torch.optim as optim from matplotlib import pyplot as plt from tools.my_dataset import AntsDataset from tools.common_tools2 import set_seed import torchvision.models as models import torchvision BASEDIR = os.path.dirname(os.path.abspath(__file__)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("use device :{}".format(device)) set_seed(1) # 設置隨機種子 label_name = {"ants": 0, "bees": 1} # 參數(shù)設置 MAX_EPOCH = 25 BATCH_SIZE = 16 LR = 0.001 log_interval = 10 val_interval = 1 classes = 2 start_epoch = -1 lr_decay_step = 7 # ============================ step 1/5 數(shù)據(jù) ============================ data_dir = os.path.join(BASEDIR, "..", "..", "data/hymenoptera_data") train_dir = os.path.join(data_dir, "train") valid_dir = os.path.join(data_dir, "val") norm_mean = [0.485, 0.456, 0.406] norm_std = [0.229, 0.224, 0.225] train_transform = transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) valid_transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(norm_mean, norm_std), ]) # 構(gòu)建MyDataset實例 train_data = AntsDataset(data_dir=train_dir, transform=train_transform) valid_data = AntsDataset(data_dir=valid_dir, transform=valid_transform) # 構(gòu)建DataLoder train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE) # ============================ step 2/5 模型 ============================ # 1/3 構(gòu)建模型 resnet18_ft = models.resnet18() # 2/3 加載參數(shù) # flag = 0 flag = 1 if flag: path_pretrained_model = os.path.join(BASEDIR, "..", "..", "data/resnet18-5c106cde.pth") state_dict_load = torch.load(path_pretrained_model) resnet18_ft.load_state_dict(state_dict_load) # 法1 : 凍結(jié)卷積層 flag_m1 = 0 # flag_m1 = 1 if flag_m1: for param in resnet18_ft.parameters(): param.requires_grad = False print("conv1.weights[0, 0, ...]:\n {}".format(resnet18_ft.conv1.weight[0, 0, ...])) # 3/3 替換fc層 num_ftrs = resnet18_ft.fc.in_features resnet18_ft.fc = nn.Linear(num_ftrs, classes) resnet18_ft.to(device) # ============================ step 3/5 損失函數(shù) ============================ criterion = nn.CrossEntropyLoss() # 選擇損失函數(shù) # ============================ step 4/5 優(yōu)化器 ============================ # 法2 : conv 小學習率 # flag = 0 flag = 1 if flag: fc_params_id = list(map(id, resnet18_ft.fc.parameters())) # 返回的是parameters的 內(nèi)存地址 base_params = filter(lambda p: id(p) not in fc_params_id, resnet18_ft.parameters()) optimizer = optim.SGD([ {'params': base_params, 'lr': LR * 0}, # 0 {'params': resnet18_ft.fc.parameters(), 'lr': LR}], momentum=0.9) else: optimizer = optim.SGD(resnet18_ft.parameters(), lr=LR, momentum=0.9) # 選擇優(yōu)化器 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=lr_decay_step, gamma=0.1) # 設置學習率下降策略 # ============================ step 5/5 訓練 ============================ train_curve = list() valid_curve = list() for epoch in range(start_epoch + 1, MAX_EPOCH): loss_mean = 0. correct = 0. total = 0. resnet18_ft.train() for i, data in enumerate(train_loader): # forward inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = resnet18_ft(inputs) # backward optimizer.zero_grad() loss = criterion(outputs, labels) loss.backward() # update weights optimizer.step() # 統(tǒng)計分類情況 _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).squeeze().cpu().sum().numpy() # 打印訓練信息 loss_mean += loss.item() train_curve.append(loss.item()) if (i + 1) % log_interval == 0: loss_mean = loss_mean / log_interval print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, i + 1, len(train_loader), loss_mean, correct / total)) loss_mean = 0. # if flag_m1: print("epoch:{} conv1.weights[0, 0, ...] :\n {}".format(epoch, resnet18_ft.conv1.weight[0, 0, ...])) scheduler.step() # 更新學習率 # validate the model if (epoch + 1) % val_interval == 0: correct_val = 0. total_val = 0. loss_val = 0. resnet18_ft.eval() with torch.no_grad(): for j, data in enumerate(valid_loader): inputs, labels = data inputs, labels = inputs.to(device), labels.to(device) outputs = resnet18_ft(inputs) loss = criterion(outputs, labels) _, predicted = torch.max(outputs.data, 1) total_val += labels.size(0) correct_val += (predicted == labels).squeeze().cpu().sum().numpy() loss_val += loss.item() loss_val_mean = loss_val / len(valid_loader) valid_curve.append(loss_val_mean) print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format( epoch, MAX_EPOCH, j + 1, len(valid_loader), loss_val_mean, correct_val / total_val)) resnet18_ft.train() train_x = range(len(train_curve)) train_y = train_curve train_iters = len(train_loader) valid_x = np.arange(1, len(valid_curve) + 1) * train_iters * val_interval # 由于valid中記錄的是epochloss,需要對記錄點進行轉(zhuǎn)換到iterations valid_y = valid_curve plt.plot(train_x, train_y, label='Train') plt.plot(valid_x, valid_y, label='Valid') plt.legend(loc='upper right') plt.ylabel('loss value') plt.xlabel('Iteration') plt.show()
“PyTorch frozen怎么使用”的內(nèi)容就介紹到這里了,感謝大家的閱讀。如果想了解更多行業(yè)相關(guān)的知識可以關(guān)注創(chuàng)新互聯(lián)網(wǎng)站,小編將為大家輸出更多高質(zhì)量的實用文章!