今天就跟大家聊聊有關index_select()函數(shù)怎么在pytorch中使用,可能很多人都不太了解,為了讓大家更加了解,小編給大家總結了以下內容,希望大家根據(jù)這篇文章可以有所收獲。
成都創(chuàng)新互聯(lián)公司長期為上千家客戶提供的網(wǎng)站建設服務,團隊從業(yè)經驗10年,關注不同地域、不同群體,并針對不同對象提供差異化的產品和服務;打造開放共贏平臺,與合作伙伴共同營造健康的互聯(lián)網(wǎng)生態(tài)環(huán)境。為景寧畬族自治企業(yè)提供專業(yè)的網(wǎng)站建設、成都網(wǎng)站建設,景寧畬族自治網(wǎng)站改版等技術服務。擁有10年豐富建站經驗和眾多成功案例,為您定制開發(fā)。pytorch中index_select()的用法
index_select(input, dim, index)
功能:在指定的維度dim上選取數(shù)據(jù),不如選取某些行,列
參數(shù)介紹
第一個參數(shù)input是要索引查找的對象
第二個參數(shù)dim是要查找的維度,因為通常情況下我們使用的都是二維張量,所以可以簡單的記憶: 0代表行,1代表列
第三個參數(shù)index是你要索引的序列,它是一個tensor對象
剛開始學習pytorch,遇到了index_select(),一開始不太明白幾個參數(shù)的意思,后來查了一下資料,算是明白了一點。
a = torch.linspace(1, 12, steps=12).view(3, 4) print(a) b = torch.index_select(a, 0, torch.tensor([0, 2])) print(b) print(a.index_select(0, torch.tensor([0, 2]))) c = torch.index_select(a, 1, torch.tensor([1, 3])) print(c)
先定義了一個tensor,這里用到了linspace和view方法。
第一個參數(shù)是索引的對象,第二個參數(shù)0表示按行索引,1表示按列進行索引,第三個參數(shù)是一個tensor,就是索引的序號,比如b里面tensor[0, 2]表示第0行和第2行,c里面tensor[1, 3]表示第1列和第3列。
輸出結果如下:
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.]])
tensor([[ 1., 2., 3., 4.],
[ 9., 10., 11., 12.]])
tensor([[ 1., 2., 3., 4.],
[ 9., 10., 11., 12.]])
tensor([[ 2., 4.],
[ 6., 8.],
[10., 12.]])
示例2
import torch x = torch.Tensor([[[1, 2, 3], [4, 5, 6]], [[9, 8, 7], [6, 5, 4]]]) print(x) print(x.size()) index = torch.LongTensor([0, 0, 1]) print(torch.index_select(x, 0, index)) print(torch.index_select(x, 0, index).size()) print(torch.index_select(x, 1, index)) print(torch.index_select(x, 1, index).size()) print(torch.index_select(x, 2, index)) print(torch.index_select(x, 2, index).size())
input的張量形狀為2×2×3,index為[0, 0, 1]的向量
分別從0、1、2三個維度來使用index_select()函數(shù),并輸出結果和形狀,維度大于2就會報錯因為input較大只有三個維度
輸出:
tensor([[[1., 2., 3.],
[4., 5., 6.]],
[[9., 8., 7.],
[6., 5., 4.]]])
torch.Size([2, 2, 3])
tensor([[[1., 2., 3.],
[4., 5., 6.]],
[[1., 2., 3.],
[4., 5., 6.]],
[[9., 8., 7.],
[6., 5., 4.]]])
torch.Size([3, 2, 3])
tensor([[[1., 2., 3.],
[1., 2., 3.],
[4., 5., 6.]],
[[9., 8., 7.],
[9., 8., 7.],
[6., 5., 4.]]])
torch.Size([2, 3, 3])
tensor([[[1., 1., 2.],
[4., 4., 5.]],
[[9., 9., 8.],
[6., 6., 5.]]])
torch.Size([2, 2, 3])
對結果進行分析:
index是大小為3的向量,輸入的張量形狀為2×2×3
dim = 0時,輸出的張量形狀為3×2×3
dim = 1時,輸出的張量形狀為2×3×3
dim = 2時,輸出的張量形狀為2×2×3
注意輸出張量維度的變化與index大小的關系,結合輸出的張量與原始張量來分析index_select()函數(shù)的作用
看完上述內容,你們對index_select()函數(shù)怎么在pytorch中使用有進一步的了解嗎?如果還想了解更多知識或者相關內容,請關注創(chuàng)新互聯(lián)行業(yè)資訊頻道,感謝大家的支持。