pytorch 官網(wǎng)給出的例子中都是使用了已經(jīng)定義好的特殊數(shù)據(jù)集接口來(lái)加載數(shù)據(jù),而且其使用的數(shù)據(jù)都是官方給出的數(shù)據(jù)。如果我們有自己收集的數(shù)據(jù)集,如何用來(lái)訓(xùn)練網(wǎng)絡(luò)呢?此時(shí)需要我們自己定義好數(shù)據(jù)處理接口。幸運(yùn)的是pytroch給出了一個(gè)數(shù)據(jù)集接口類(torch.utils.data.Dataset),可以方便我們繼承并實(shí)現(xiàn)自己的數(shù)據(jù)集接口。
創(chuàng)新互聯(lián)公司成立以來(lái)不斷整合自身及行業(yè)資源、不斷突破觀念以使企業(yè)策略得到完善和成熟,建立了一套“以技術(shù)為基點(diǎn),以客戶需求中心、市場(chǎng)為導(dǎo)向”的快速反應(yīng)體系。對(duì)公司的主營(yíng)項(xiàng)目,如中高端企業(yè)網(wǎng)站企劃 / 設(shè)計(jì)、行業(yè) / 企業(yè)門戶設(shè)計(jì)推廣、行業(yè)門戶平臺(tái)運(yùn)營(yíng)、App定制開發(fā)、手機(jī)網(wǎng)站制作設(shè)計(jì)、微信網(wǎng)站制作、軟件開發(fā)、資陽(yáng)主機(jī)托管等實(shí)行標(biāo)準(zhǔn)化操作,讓客戶可以直觀的預(yù)知到從創(chuàng)新互聯(lián)公司可以獲得的服務(wù)效果。torch.utils.data
torch的這個(gè)文件包含了一些關(guān)于數(shù)據(jù)集處理的類。
class torch.utils.data.Dataset: 一個(gè)抽象類, 所有其他類的數(shù)據(jù)集類都應(yīng)該是它的子類。而且其子類必須重載兩個(gè)重要的函數(shù):len(提供數(shù)據(jù)集的大?。etitem(支持整數(shù)索引)。
class torch.utils.data.TensorDataset: 封裝成tensor的數(shù)據(jù)集,每一個(gè)樣本都通過索引張量來(lái)獲得。
class torch.utils.data.ConcatDataset: 連接不同的數(shù)據(jù)集以構(gòu)成更大的新數(shù)據(jù)集。
class torch.utils.data.Subset(dataset, indices): 獲取指定一個(gè)索引序列對(duì)應(yīng)的子數(shù)據(jù)集。
class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=
torch.utils.data.random_split(dataset, lengths): 按照給定的長(zhǎng)度將數(shù)據(jù)集劃分成沒有重疊的新數(shù)據(jù)集組合。
class torch.utils.data.Sampler(data_source):所有采樣的器的基類。每個(gè)采樣器子類都需要提供 __iter__ 方法以方便迭代器進(jìn)行索引 和一個(gè) len方法 以方便返回迭代器的長(zhǎng)度。
class torch.utils.data.SequentialSampler(data_source):順序采樣樣本,始終按照同一個(gè)順序。
class torch.utils.data.RandomSampler(data_source):無(wú)放回地隨機(jī)采樣樣本元素。
class torch.utils.data.SubsetRandomSampler(indices):無(wú)放回地按照給定的索引列表采樣樣本元素。
class torch.utils.data.WeightedRandomSampler(weights, num_samples, replacement=True): 按照給定的概率來(lái)采樣樣本。
class torch.utils.data.BatchSampler(sampler, batch_size, drop_last): 在一個(gè)batch中封裝一個(gè)其他的采樣器。
class torch.utils.data.distributed.DistributedSampler(dataset, num_replicas=None, rank=None):采樣器可以約束數(shù)據(jù)加載進(jìn)數(shù)據(jù)集的子集。
自定義數(shù)據(jù)集
自己定義的數(shù)據(jù)集需要繼承抽象類class torch.utils.data.Dataset,并且需要重載兩個(gè)重要的函數(shù):__len__ 和__getitem__。
整個(gè)代碼僅供參考。在__init__中是初始化了該類的一些基本參數(shù);__getitem__中是真正讀取數(shù)據(jù)的地方,迭代器通過索引來(lái)讀取數(shù)據(jù)集中數(shù)據(jù),因此只需要這一個(gè)方法中加入讀取數(shù)據(jù)的相關(guān)功能即可;__len__給出了整個(gè)數(shù)據(jù)集的尺寸大小,迭代器的索引范圍是根據(jù)這個(gè)函數(shù)得來(lái)的。
import torch class myDataset(torch.nn.data.Dataset): def __init__(self, dataSource) self.dataSource = dataSource def __getitem__(self, index): element = self.dataSource[index] return element def __len__(self): return len(self.dataSource) train_data = myDataset(dataSource)