PyTorch中還單獨(dú)提供了一個sampler模塊,用來對數(shù)據(jù)進(jìn)行采樣。常用的有隨機(jī)采樣器:RandomSampler,當(dāng)dataloader的shuffle參數(shù)為True時,系統(tǒng)會自動調(diào)用這個采樣器,實(shí)現(xiàn)打亂數(shù)據(jù)。默認(rèn)的是采用SequentialSampler,它會按順序一個一個進(jìn)行采樣。這里介紹另外一個很有用的采樣方法: WeightedRandomSampler,它會根據(jù)每個樣本的權(quán)重選取數(shù)據(jù),在樣本比例不均衡的問題中,可用它來進(jìn)行重采樣。
成都創(chuàng)新互聯(lián)公司是一家專注于網(wǎng)站設(shè)計(jì)、網(wǎng)站制作與策劃設(shè)計(jì),廣漢網(wǎng)站建設(shè)哪家好?成都創(chuàng)新互聯(lián)公司做網(wǎng)站,專注于網(wǎng)站建設(shè)十年,網(wǎng)設(shè)計(jì)領(lǐng)域的專業(yè)建站公司;建站業(yè)務(wù)涵蓋:廣漢等地區(qū)。廣漢做網(wǎng)站價格咨詢:028-86922220構(gòu)建WeightedRandomSampler時需提供兩個參數(shù):每個樣本的權(quán)重weights、共選取的樣本總數(shù)num_samples,以及一個可選參數(shù)replacement。權(quán)重越大的樣本被選中的概率越大,待選取的樣本數(shù)目一般小于全部的樣本數(shù)目。replacement用于指定是否可以重復(fù)選取某一個樣本,默認(rèn)為True,即允許在一個epoch中重復(fù)采樣某一個數(shù)據(jù)。如果設(shè)為False,則當(dāng)某一類的樣本被全部選取完,但其樣本數(shù)目仍未達(dá)到num_samples時,sampler將不會再從該類中選擇數(shù)據(jù),此時可能導(dǎo)致weights參數(shù)失效。
下面舉例說明。
from dataSet import * dataset = DogCat('data/dogcat/', transform=transform) from torch.utils.data import DataLoader # 狗的圖片被取出的概率是貓的概率的兩倍 # 兩類圖片被取出的概率與weights的絕對大小無關(guān),只和比值有關(guān) weights = [2 if label == 1 else 1 for data, label in dataset] print(weights) from torch.utils.data.sampler import WeightedRandomSampler sampler = WeightedRandomSampler(weights,\ num_samples=9,\ replacement=True) dataloader = DataLoader(dataset, batch_size=3, sampler=sampler) for datas, labels in dataloader: print(labels.tolist())