真实的国产乱ⅩXXX66竹夫人,五月香六月婷婷激情综合,亚洲日本VA一区二区三区,亚洲精品一区二区三区麻豆

成都創(chuàng)新互聯(lián)網(wǎng)站制作重慶分公司

pytorch的hook函數(shù)怎么使用

本篇內(nèi)容主要講解“pytorch的hook函數(shù)怎么使用”,感興趣的朋友不妨來看看。本文介紹的方法操作簡單快捷,實用性強。下面就讓小編來帶大家學習“pytorch的hook函數(shù)怎么使用”吧!

在隴西等地區(qū),都構(gòu)建了全面的區(qū)域性戰(zhàn)略布局,加強發(fā)展的系統(tǒng)性、市場前瞻性、產(chǎn)品創(chuàng)新能力,以專注、極致的服務理念,為客戶提供網(wǎng)站設(shè)計、成都做網(wǎng)站 網(wǎng)站設(shè)計制作按需求定制制作,公司網(wǎng)站建設(shè),企業(yè)網(wǎng)站建設(shè),成都品牌網(wǎng)站建設(shè),成都全網(wǎng)營銷,外貿(mào)營銷網(wǎng)站建設(shè),隴西網(wǎng)站建設(shè)費用合理。

課程代碼

"""
@brief      : pytorch的hook函數(shù)
"""

import torch
import torch.nn as nn
from tools.common_tools2 import set_seed

set_seed(1)

# ----------------------------------- 1 tensor hook 1
flag = 0
# flag = 1
if flag:
    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()


    def grad_hook(grad):
        a_grad.append(grad)


    handle = a.register_hook(grad_hook)

    y.backward()

    # 查看梯度
    print("gradient:", w.grad, x.grad, a.grad, b.grad, y.grad)
    print("a_grad[0]:", a_grad[0])

    handle.remove()

# ----------------------------------- 2 tensor hook 2
flag = 0
# flag = 1
if flag:
    w = torch.tensor([1.], requires_grad=True)
    x = torch.tensor([2.], requires_grad=True)
    a = torch.add(w, x)
    b = torch.add(w, 1)
    y = torch.mul(a, b)

    a_grad = list()


    def grad_hook(grad):
        grad *= 2
        return grad * 3


    handle = w.register_hook(grad_hook)

    y.backward()

    print("w.grad:", w.grad)
    handle.remove()

# --------------------------- 3 Module.register_forward_hook and pre hook
# flag = 0
flag = 1
if flag:
    class Net(nn.Module):
        def __init__(self):
            super(Net, self).__init__()
            self.conv1 = nn.Conv2d(1, 2, 3)
            self.pool1 = nn.MaxPool2d(2, 2)

        def forward(self, x):
            x = self.conv1(x)
            x = self.pool1(x)
            return x


    def forward_hook(module, data_input, data_output):
        fmap_block.append(data_output)
        input_block.append(data_input)


    def forward_pre_hook(module, data_input):
        print("forward_pre_hook input:{}".format(data_input))


    def backward_hook(module, grad_input, grad_output):
        print("backward hook input:{}".format(grad_input))
        print("backward hook output:{}".format(grad_output))


    # 初始化網(wǎng)絡(luò)
    net = Net()
    net.conv1.weight[0].detach().fill_(1)
    net.conv1.weight[1].detach().fill_(2)
    net.conv1.bias.data.detach().zero_()

    # 注冊hook
    fmap_block = list()
    input_block = list()
    net.conv1.register_forward_hook(forward_hook)
    net.conv1.register_forward_pre_hook(forward_pre_hook)
    net.conv1.register_backward_hook(backward_hook)

    # inference
    fake_img = torch.ones((1, 1, 4, 4))  # batch size * channel * H * W
    output = net(fake_img)  # 前向傳播

    loss_fnc = nn.L1Loss()
    target = torch.randn_like(output)
    loss = loss_fnc(target, output)
    loss.backward()

    # 觀察
    print("output shape: {}\noutput value: {}\n".format(output.shape, output))
    print("feature maps shape: {}\noutput value: {}\n".format(fmap_block[0].shape, fmap_block[0]))
    print("input shape: {}\ninput value: {}".format(input_block[0][0].shape, input_block[0]))

作業(yè)

1.    采用torch.nn.Module.register_forward_hook機制實現(xiàn)AlexNet第一個卷積層輸出特征圖的可視化,并將/torchvision/models/alexnet.py中第28行改為:nn.ReLU(inplace=False),觀察

inplace=True與inplace=False的差異

1. hook畫特征圖

# -*- coding:utf-8 -*-
"""
@brief      : 采用hook函數(shù)可視化特征圖
"""

import torch.nn as nn
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
from tools.common_tools2 import set_seed
import torchvision.models as models

set_seed(1)  # 設(shè)置隨機種子

# ----------------------------------- feature map visualization
# flag = 0
flag = 1
if flag:
    writer = SummaryWriter(comment='test_your_comment', filename_suffix="_test_your_filename_suffix")

    # 數(shù)據(jù)
    path_img = "./lena.png"  # your path to image
    normMean = [0.49139968, 0.48215827, 0.44653124]
    normStd = [0.24703233, 0.24348505, 0.26158768]

    norm_transform = transforms.Normalize(normMean, normStd)
    img_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        norm_transform
    ])

    img_pil = Image.open(path_img).convert('RGB')
    if img_transforms is not None:
        img_tensor = img_transforms(img_pil)
    img_tensor.unsqueeze_(0)  # chw --> bchw

    # 模型
    alexnet = models.alexnet(pretrained=True)

    # 注冊hook
    fmap_dict = dict()
    for name, sub_module in alexnet.named_modules():

        if isinstance(sub_module, nn.Conv2d):
            key_name = str(sub_module.weight.shape)
            fmap_dict.setdefault(key_name, list())

            n1, n2 = name.split(".")


            def hook_func(m, i, o):
                key_name = str(m.weight.shape)
                fmap_dict[key_name].append(o)


            alexnet._modules[n1]._modules[n2].register_forward_hook(hook_func)

    # forward
    output = alexnet(img_tensor)

    # add image
    for layer_name, fmap_list in fmap_dict.items():
        fmap = fmap_list[0]
        fmap.transpose_(0, 1)

        nrow = int(np.sqrt(fmap.shape[0]))
        fmap_grid = vutils.make_grid(fmap, normalize=True, scale_each=True, nrow=nrow)
        writer.add_image('feature map in {}'.format(layer_name), fmap_grid, global_step=322)

到此,相信大家對“pytorch的hook函數(shù)怎么使用”有了更深的了解,不妨來實際操作一番吧!這里是創(chuàng)新互聯(lián)網(wǎng)站,更多相關(guān)內(nèi)容可以進入相關(guān)頻道進行查詢,關(guān)注我們,繼續(xù)學習!


網(wǎng)站名稱:pytorch的hook函數(shù)怎么使用
文章轉(zhuǎn)載:http://weahome.cn/article/gijdsh.html

其他資訊

在線咨詢

微信咨詢

電話咨詢

028-86922220(工作日)

18980820575(7×24)

提交需求

返回頂部