這篇文章給大家介紹如何利用PyTorch中的Moco-V2減少計(jì)算約束,內(nèi)容非常詳細(xì),感興趣的小伙伴們可以參考借鑒,希望對(duì)大家能有所幫助。
站在用戶的角度思考問題,與客戶深入溝通,找到商城網(wǎng)站設(shè)計(jì)與商城網(wǎng)站推廣的解決方案,憑借多年的經(jīng)驗(yàn),讓設(shè)計(jì)與互聯(lián)網(wǎng)技術(shù)結(jié)合,創(chuàng)造個(gè)性化、用戶體驗(yàn)好的作品,建站類型包括:成都網(wǎng)站制作、做網(wǎng)站、企業(yè)官網(wǎng)、英文網(wǎng)站、手機(jī)端網(wǎng)站、網(wǎng)站推廣、域名注冊(cè)、網(wǎng)站空間、企業(yè)郵箱。業(yè)務(wù)覆蓋商城地區(qū)。
SimCLR論文(http://cse.iitkgp.ac.in/~arastogi/papers/simclr.pdf)解釋了這個(gè)框架如何從更大的模型和更大的批處理中獲益,并且如果有足夠的計(jì)算能力,可以產(chǎn)生與監(jiān)督模型類似的結(jié)果。
但是這些需求使得框架的計(jì)算量相當(dāng)大。如果我們可以擁有這個(gè)框架的簡(jiǎn)單性和強(qiáng)大功能,并且有更少的計(jì)算需求,這樣每個(gè)人都可以訪問它,這不是很好嗎?Moco-v2前來救援。
這次我們將在Pytorch中在更大的數(shù)據(jù)集上實(shí)現(xiàn)Moco-v2,并在Google Colab上訓(xùn)練我們的模型。這次我們將使用Imagenette和Imagewoof數(shù)據(jù)集
來自Imagenette數(shù)據(jù)集的一些圖像
這些數(shù)據(jù)集的快速摘要(更多信息在這里:https://github.com/fastai/imagenette):
Imagenette由Imagenet的10個(gè)容易分類的類組成,總共有9479個(gè)訓(xùn)練圖像和3935個(gè)驗(yàn)證集圖像。
Imagewoof是一個(gè)由Imagenet提供的10個(gè)難分類組成的數(shù)據(jù)集,因?yàn)樗械念惗际枪返钠贩N。總共有9035個(gè)訓(xùn)練圖像,3939個(gè)驗(yàn)證集圖像。
對(duì)比學(xué)習(xí)在自我監(jiān)督學(xué)習(xí)中的作用是基于這樣一個(gè)理念:我們希望同一類別中不同的圖像觀具有相似的表征。但是,由于我們不知道哪些圖像屬于同一類別,通常所做的是將同一圖像的不同外觀的表示拉近。我們把這些不同的外觀稱為正對(duì)(positive pairs)。
另外,我們希望不同類別的圖像有不同的外觀,使它們的表征彼此遠(yuǎn)離。不同圖像的不同外觀的呈現(xiàn)與類別無關(guān),會(huì)被彼此推開。我們把這些不同的外觀稱為負(fù)對(duì)(negative pairs)。
在這種情況下,一個(gè)圖像的前景是什么?前景可以被認(rèn)為是以一種經(jīng)過修改的方式看待圖像的某些部分,它本質(zhì)上是圖像的一種變換。
根據(jù)手頭的任務(wù),有些轉(zhuǎn)換可以比其他轉(zhuǎn)換工作得更好。SimCLR表明,應(yīng)用隨機(jī)裁剪和顏色抖動(dòng)可以很好地完成各種任務(wù),包括圖像分類。這本質(zhì)上來自于網(wǎng)格搜索,從旋轉(zhuǎn)、裁剪、剪切、噪聲、模糊、Sobel濾波等選項(xiàng)中選擇一對(duì)變換。
從外觀到表示空間的映射是通過神經(jīng)網(wǎng)絡(luò)完成的,通常,resnet用于此目的。下面是從圖像到表示的管道
在同一幅圖像中,由于隨機(jī)裁剪,我們可以得到多個(gè)表示。這樣,我們就可以產(chǎn)生正對(duì)。
但是如何生成負(fù)對(duì)呢?負(fù)對(duì)是來自不同圖像的表示。SimCLR論文在同一批中創(chuàng)建了這些。如果一個(gè)批包含N個(gè)圖像,那么對(duì)于每個(gè)圖像,我們將得到2個(gè)表示,這總共占2*N個(gè)表示。對(duì)于一個(gè)特定的表示x,有一個(gè)表示與x形成正對(duì)(與x來自同一個(gè)圖像的表示),其余所有表示(正好是2*N–2)與x形成負(fù)對(duì)。
如果我們手頭有大量的負(fù)樣本,這些表示就會(huì)得到改善。但是,在SimCLR中,只有當(dāng)批量較大時(shí),才能實(shí)現(xiàn)大量的負(fù)樣本,這導(dǎo)致了對(duì)計(jì)算能力的更高要求。MoCo-v2提供了生成負(fù)樣本的另一種方法。讓我們?cè)敿?xì)了解一下。
我們可以用一種稍微不同的方式來看待對(duì)比學(xué)習(xí)方法,即將查詢與鍵進(jìn)行匹配。我們現(xiàn)在有兩個(gè)編碼器,一個(gè)用于查詢,另一個(gè)用于鍵。此外,為了得到大量的負(fù)樣本,我們需要一個(gè)大的鍵編碼字典。
此上下文中的正對(duì)表示查詢與鍵匹配。如果查詢和鍵都來自同一個(gè)圖像,則它們匹配。編碼的查詢應(yīng)該與其匹配的鍵相似,而與其他查詢不同。
對(duì)于負(fù)對(duì),我們維護(hù)一個(gè)大字典,其中包含以前批處理的編碼鍵。它們作為查詢的負(fù)樣本。我們以隊(duì)列的形式維護(hù)字典。新的batch被入隊(duì),較早的batch被出列。通過更改此隊(duì)列的大小,可以更改負(fù)采樣數(shù)。
隨著鍵編碼器的更改,在稍后時(shí)間點(diǎn)排隊(duì)的鍵可能與較早排隊(duì)的鍵不一致。為了使用對(duì)比學(xué)習(xí)方法,與查詢進(jìn)行比較的所有鍵必須來自相同或相似的編碼器,這樣比較才會(huì)有意義且一致。
另一個(gè)挑戰(zhàn)是,使用反向傳播學(xué)習(xí)編碼器參數(shù)是不可行的,因?yàn)檫@將需要計(jì)算隊(duì)列中所有樣本的梯度(這將導(dǎo)致大的計(jì)算圖)。
為了解決這兩個(gè)問題,MoCo將鍵編碼器實(shí)現(xiàn)為基于動(dòng)量的查詢編碼器的移動(dòng)平均值[1]。這意味著它以這種方式更新關(guān)鍵編碼器參數(shù):
其中m非常接近于1(例如,典型值為0.999),這確保我們?cè)诓煌臅r(shí)間從相似的編碼器獲得編碼鍵。
我們希望查詢接近其所有正樣本,遠(yuǎn)離所有負(fù)樣本。InfoNC函數(shù)E會(huì)捕獲它。它代表信息噪聲對(duì)比估計(jì)。對(duì)于查詢q和鍵k,InfoNCE損失函數(shù)是:
我們可以重寫為:
當(dāng)q和k的相似性增大,q與負(fù)樣本的相似性減小時(shí),損失值減小
以下是損失函數(shù)的代碼:
τ = 0.05 def loss_function(q, k, queue): # N是批量大小 N = q.shape[0] # C是表示的維數(shù) C = q.shape[1] # bmm代表批處理矩陣乘法 # 如果mat1是b×n×m張量,那么mat2是b×m×p張量, # 然后輸出一個(gè)b×n×p張量。 pos = torch.exp(torch.div(torch.bmm(q.view(N,1,C), k.view(N,C,1)).view(N, 1),τ)) # 在查詢和隊(duì)列張量之間執(zhí)行矩陣乘法 neg = torch.sum(torch.exp(torch.div(torch.mm(q.view(N,C), torch.t(queue)),τ)), dim=1) # 求和 denominator = neg + pos return torch.mean(-torch.log(torch.div(pos,denominator)))
讓我們?cè)倏纯催@個(gè)損失函數(shù),并將它與分類交叉熵?fù)p失函數(shù)進(jìn)行比較。
這里pred?是數(shù)據(jù)點(diǎn)在第i類中的概率值預(yù)測(cè),true?是該點(diǎn)屬于第i類的實(shí)際概率值(可以是模糊的,但大多數(shù)情況下是一個(gè)one-hot)。
如果你不熟悉這個(gè)話題,你可以看這個(gè)視頻來更好地理解交叉熵。另外,請(qǐng)注意,我們經(jīng)常通過softmax這樣的函數(shù)將分?jǐn)?shù)轉(zhuǎn)換為概率值:https://www.youtube.com/watch?v=ErfnhcEV1O8
我們可以把信息損失函數(shù)看作交叉熵?fù)p失。數(shù)據(jù)樣本“q”的正確類是第r類,底層分類器基于softmax,它試圖在K+1類之間進(jìn)行分類。
Info-NCE還與編碼表示之間的相互信息有關(guān);關(guān)于這一點(diǎn)的更多細(xì)節(jié)見[4]。
現(xiàn)在,讓我們把所有的東西放在一起,看看整個(gè)Moco-v2算法是什么樣子的。
我們必須得到查詢和鍵編碼器。最初,鍵編碼器具有與查詢編碼器相同的參數(shù)。它們是彼此的復(fù)制品。隨著訓(xùn)練的進(jìn)行,鍵編碼器將成為查詢編碼器的移動(dòng)平均值(在這一點(diǎn)上進(jìn)展緩慢)。
由于計(jì)算能力的限制,我們使用Resnet-18體系結(jié)構(gòu)來實(shí)現(xiàn)。在通常的resnet架構(gòu)之上,我們添加了一些密集的層,以使表示的維數(shù)降到25。這些層中的某些層稍后將充當(dāng)投影。
# 定義我們的深度學(xué)習(xí)架構(gòu) resnetq = resnet18(pretrained=False) classifier = nn.Sequential(OrderedDict([ ('fc1', nn.Linear(resnetq.fc.in_features, 100)), ('added_relu1', nn.ReLU(inplace=True)), ('fc2', nn.Linear(100, 50)), ('added_relu2', nn.ReLU(inplace=True)), ('fc3', nn.Linear(50, 25)) ])) resnetq.fc = classifier resnetk = copy.deepcopy(resnetq) # 將resnet架構(gòu)遷移到設(shè)備 resnetq.to(device) resnetk.to(device)
現(xiàn)在,我們已經(jīng)有了編碼器,并且假設(shè)我們已經(jīng)設(shè)置了其他重要的數(shù)據(jù)結(jié)構(gòu),現(xiàn)在是時(shí)候開始訓(xùn)練循環(huán)并理解管道了。
這一步是從訓(xùn)練批中獲取編碼查詢和鍵。我們用L2范數(shù)對(duì)表示進(jìn)行規(guī)范化。
只是一個(gè)約定警告,所有后續(xù)步驟中的代碼都將位于批處理和epoch循環(huán)中。我們還將張量“k”從它的梯度中分離出來,因?yàn)槲覀儾恍枰?jì)算圖中的鍵編碼器部分,因?yàn)閯?dòng)量更新方程會(huì)更新鍵編碼器。
# 梯度零化 optimizer.zero_grad() # 檢索xq和xk這兩個(gè)圖像batch xq = sample_batched['image1'] xk = sample_batched['image2'] # 把它們移到設(shè)備上 xq = xq.to(device) xk = xk.to(device) # 獲取他們的輸出 q = resnetq(xq) k = resnetk(xk) k = k.detach() # 將輸出規(guī)范化,使它們成為單位向量 q = torch.div(q,torch.norm(q,dim=1).reshape(-1,1)) k = torch.div(k,torch.norm(k,dim=1).reshape(-1,1))
現(xiàn)在,我們將查詢、鍵和隊(duì)列傳遞給前面定義的loss函數(shù),并將值存儲(chǔ)在一個(gè)列表中。然后,像往常一樣,對(duì)損失值調(diào)用backward函數(shù)并運(yùn)行優(yōu)化器。
# 獲得損失值 loss = loss_function(q, k, queue) # 把這個(gè)損失值放到epoch損失列表中 epoch_losses_train.append(loss.cpu().data.item()) # 反向傳播 loss.backward() # 運(yùn)行優(yōu)化器 optimizer.step()
我們將最新的batch加入我們的隊(duì)列。如果我們的隊(duì)列大小大于我們定義的最大隊(duì)列大小(K),那么我們就從其中取出最老的batch。可以使用torch.cat進(jìn)行隊(duì)列操作。
# 更新隊(duì)列 queue = torch.cat((queue, k), 0) # 如果隊(duì)列大于最大隊(duì)列大小(k),則出列 # batch大小是256,可以用變量替換 if queue.shape[0] > K: queue = queue[256:,:]
現(xiàn)在我們進(jìn)入訓(xùn)練循環(huán)的最后一步,即更新鍵編碼器。我們使用下面的for循環(huán)來實(shí)現(xiàn)這一點(diǎn)。
# 更新resnet for θ_k, θ_q in zip(resnetk.parameters(), resnetq.parameters()): θ_k.data.copy_(momentum*θ_k.data + θ_q.data*(1.0 - momentum))
訓(xùn)練resnet-18模型的Imagenette和Imagewoof數(shù)據(jù)集的GPU時(shí)間接近18小時(shí)。為此,我們使用了googlecolab的GPU(16GB)。我們使用的batch大小為256,tau值為0.05,學(xué)習(xí)率為0.001,最終降低到1e-5,權(quán)重衰減為1e-6。我們的隊(duì)列大小為8192,鍵編碼器的動(dòng)量值為0.999。
前3層(將relu視為一層)定義了投影頭,我們將其移除用于圖像分類的下游任務(wù)。在剩下的網(wǎng)絡(luò)上,我們訓(xùn)練了一個(gè)線性分類器。
我們得到了64.2%的正確率,而使用10%的標(biāo)記訓(xùn)練數(shù)據(jù),使用MoCo-v2。相比之下,使用最先進(jìn)的監(jiān)督學(xué)習(xí)方法,其準(zhǔn)確率接近95%。
對(duì)于Imagewoof,我們對(duì)10%的標(biāo)記數(shù)據(jù)得到了38.6%的準(zhǔn)確率。在這個(gè)數(shù)據(jù)集上進(jìn)行對(duì)比學(xué)習(xí)的效果低于我們的預(yù)期。我們懷疑這是因?yàn)槭紫?,?shù)據(jù)集非常困難,因?yàn)樗蓄惗际枪奉悺?/p>
其次,我們認(rèn)為顏色是這些類的一個(gè)重要的區(qū)別特征。應(yīng)用顏色抖動(dòng)可能會(huì)導(dǎo)致來自不同類的多個(gè)圖像彼此混合表示。相比之下,監(jiān)督方法的準(zhǔn)確率接近90%。
能夠彌合自監(jiān)督模型和監(jiān)督模型之間差距的設(shè)計(jì)變更:
使用更大更寬的模型。
通過使用更大的批量和字典大小。
使用更多的數(shù)據(jù),如果可以的話。同時(shí)引入所有未標(biāo)記的數(shù)據(jù)。
在大量數(shù)據(jù)上訓(xùn)練大型模型,然后提取它們。
關(guān)于如何利用PyTorch中的Moco-V2減少計(jì)算約束就分享到這里了,希望以上內(nèi)容可以對(duì)大家有一定的幫助,可以學(xué)到更多知識(shí)。如果覺得文章不錯(cuò),可以把它分享出去讓更多的人看到。