這篇文章給大家分享的是有關(guān)Pytorch如何實現(xiàn)數(shù)據(jù)加載與數(shù)據(jù)預(yù)處理的內(nèi)容。小編覺得挺實用的,因此分享給大家做個參考,一起跟隨小編過來看看吧。
成都創(chuàng)新互聯(lián)公司堅信:善待客戶,將會成為終身客戶。我們能堅持多年,是因為我們一直可值得信賴。我們從不忽悠初訪客戶,我們用心做好本職工作,不忘初心,方得始終。十多年網(wǎng)站建設(shè)經(jīng)驗成都創(chuàng)新互聯(lián)公司是成都老牌網(wǎng)站營銷服務(wù)商,為您提供成都網(wǎng)站建設(shè)、成都做網(wǎng)站、網(wǎng)站設(shè)計、html5、網(wǎng)站制作、成都品牌網(wǎng)站建設(shè)、小程序制作服務(wù),給眾多知名企業(yè)提供過好品質(zhì)的建站服務(wù)。數(shù)據(jù)加載分為加載torchvision.datasets中的數(shù)據(jù)集以及加載自己使用的數(shù)據(jù)集兩種情況。
torchvision.datasets中的數(shù)據(jù)集
torchvision.datasets中自帶MNIST,Imagenet-12,CIFAR等數(shù)據(jù)集,所有的數(shù)據(jù)集都是torch.utils.data.Dataset的子類,都包含 _ _ len _ (獲取數(shù)據(jù)集長度)和 _ getItem _ _ (獲取數(shù)據(jù)集中每一項)兩個子方法。
Dataset源碼如上,可以看到其中包含了兩個沒有實現(xiàn)的子方法,之后所有的Dataet類都繼承該類,并根據(jù)數(shù)據(jù)情況定制這兩個子方法的具體實現(xiàn)。
因此當我們需要加載自己的數(shù)據(jù)集的時候也可以借鑒這種方法,只需要繼承torch.utils.data.Dataset類并重寫 init ,len,以及getitem這三個方法即可。這樣組著的類可以直接作為參數(shù)傳入到torch.util.data.DataLoader中去。
以CIFAR10為例 源碼:
class torchvision.datasets.CIFAR10(root, train=True, transform=None, target_transform=None, download=False)
root (string) – Root directory of dataset where directory cifar-10-batches-py exists or will be saved to if download is set to True. train (bool, optional) – If True, creates dataset from training set, otherwise creates from test set. transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop target_transform (callable, optional) – A function/transform that takes in the target and transforms it. download (bool, optional) – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
加載自己的數(shù)據(jù)集
對于torchvision.datasets中有兩個不同的類,分別為DatasetFolder和ImageFolder,ImageFolder是繼承自DatasetFolder。
下面我們通過源碼來看一看folder文件中DatasetFolder和ImageFolder分別做了些什么
import torch.utils.data as data from PIL import Image import os import os.path def has_file_allowed_extension(filename, extensions): //檢查輸入是否是規(guī)定的擴展名 """Checks if a file is an allowed extension. Args: filename (string): path to a file Returns: bool: True if the filename ends with a known image extension """ filename_lower = filename.lower() return any(filename_lower.endswith(ext) for ext in extensions) def find_classes(dir): classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] //獲取root目錄下所有的文件夾名稱 classes.sort() class_to_idx = {classes[i]: i for i in range(len(classes))} //生成類別名稱與類別id的對應(yīng)Dictionary return classes, class_to_idx def make_dataset(dir, class_to_idx, extensions): images = [] dir = os.path.expanduser(dir)// 將~和~user轉(zhuǎn)化為用戶目錄,對參數(shù)中出現(xiàn)~進行處理 for target in sorted(os.listdir(dir)): d = os.path.join(dir, target) if not os.path.isdir(d): continue for root, _, fnames in sorted(os.walk(d)): //os.work包含三個部分,root代表該目錄路徑 _代表該路徑下的文件夾名稱集合,fnames代表該路徑下的文件名稱集合 for fname in sorted(fnames): if has_file_allowed_extension(fname, extensions): path = os.path.join(root, fname) item = (path, class_to_idx[target]) images.append(item) //生成(訓(xùn)練樣本圖像目錄,訓(xùn)練樣本所屬類別)的元組 return images //返回上述元組的列表 class DatasetFolder(data.Dataset): """A generic data loader where the samples are arranged in this way: :: root/class_x/xxx.ext root/class_x/xxy.ext root/class_x/xxz.ext root/class_y/123.ext root/class_y/nsdf3.ext root/class_y/asd932_.ext Args: root (string): Root directory path. loader (callable): A function to load a sample given its path. extensions (list[string]): A list of allowed extensions. transform (callable, optional): A function/transform that takes in a sample and returns a transformed version. E.g, ``transforms.RandomCrop`` for images. target_transform (callable, optional): A function/transform that takes in the target and transforms it. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). samples (list): List of (sample path, class_index) tuples """ def __init__(self, root, loader, extensions, transform=None, target_transform=None): classes, class_to_idx = find_classes(root) samples = make_dataset(root, class_to_idx, extensions) if len(samples) == 0: raise(RuntimeError("Found 0 files in subfolders of: " + root + "\n" "Supported extensions are: " + ",".join(extensions))) self.root = root self.loader = loader self.extensions = extensions self.classes = classes self.class_to_idx = class_to_idx self.samples = samples self.transform = transform self.target_transform = target_transform def __getitem__(self, index): """ 根據(jù)index獲取sample 返回值為(sample,target)元組,同時如果該類輸入?yún)?shù)中有transform和target_transform,torchvision.transforms類型的參數(shù)時,將獲取的元組分別執(zhí)行transform和target_transform中的數(shù)據(jù)轉(zhuǎn)換方法。 Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """ path, target = self.samples[index] sample = self.loader(path) if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: target = self.target_transform(target) return sample, target def __len__(self): return len(self.samples) def __repr__(self): //定義輸出對象格式 其中和__str__的區(qū)別是__repr__無論是print輸出還是直接輸出對象自身 都是以定義的格式進行輸出,而__str__ 只有在print輸出的時候會是以定義的格式進行輸出 fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) fmt_str += ' Root Location: {}\n'.format(self.root) tmp = ' Transforms (if any): ' fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) tmp = ' Target Transforms (if any): ' fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) return fmt_str IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] def pil_loader(path): # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) with open(path, 'rb') as f: img = Image.open(f) return img.convert('RGB') def accimage_loader(path): import accimage try: return accimage.Image(path) except IOError: # Potentially a decoding problem, fall back to PIL.Image return pil_loader(path) def default_loader(path): from torchvision import get_image_backend if get_image_backend() == 'accimage': return accimage_loader(path) else: return pil_loader(path) class ImageFolder(DatasetFolder): """A generic data loader where the images are arranged in this way: :: root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png Args: root (string): Root directory path. transform (callable, optional): A function/transform that takes in an PIL image and returns a transformed version. E.g, ``transforms.RandomCrop`` target_transform (callable, optional): A function/transform that takes in the target and transforms it. loader (callable, optional): A function to load an image given its path. Attributes: classes (list): List of the class names. class_to_idx (dict): Dict with items (class_name, class_index). imgs (list): List of (image path, class_index) tuples """ def __init__(self, root, transform=None, target_transform=None, loader=default_loader): super(ImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, transform=transform, target_transform=target_transform) self.imgs = self.samples
如果自己所要加載的數(shù)據(jù)組織形式如下
root/dog/xxx.png root/dog/xxy.png root/dog/xxz.png root/cat/123.png root/cat/nsdf3.png root/cat/asd932_.png
即不同類別的訓(xùn)練數(shù)據(jù)分別存儲在不同的文件夾中,這些文件夾都在root(即形如 D:/animals 或者 /usr/animals )路徑下
class torchvision.datasets.ImageFolder(root, transform=None, target_transform=None, loader=)
參數(shù)如下:
root (string) – Root directory path. transform (callable, optional) – A function/transform that takes in an PIL image and returns a transformed version. E.g, transforms.RandomCrop target_transform (callable, optional) – A function/transform that takes in the target and transforms it. loader – A function to load an image given its path. 就是上述源碼中 __getitem__(index) Parameters: index (int) – Index Returns: (sample, target) where target is class_index of the target class. Return type: tuple
可以通過torchvision.datasets.ImageFolder進行加載
img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower', transform=transforms.Compose([ transforms.Scale(256), transforms.CenterCrop(224), transforms.ToTensor()]) ) print(len(img_data)) data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True) print(len(data_loader))
對于所有的訓(xùn)練樣本都在一個文件夾中 同時有一個對應(yīng)的txt文件每一行分別是對應(yīng)圖像的路徑以及其所屬的類別,可以參照上述class寫出對應(yīng)的加載類
def default_loader(path): return Image.open(path).convert('RGB') class MyDataset(Dataset): def __init__(self, txt, transform=None, target_transform=None, loader=default_loader): fh = open(txt, 'r') imgs = [] for line in fh: line = line.strip('\n') line = line.rstrip() words = line.split() imgs.append((words[0],int(words[1]))) self.imgs = imgs self.transform = transform self.target_transform = target_transform self.loader = loader def __getitem__(self, index): fn, label = self.imgs[index] img = self.loader(fn) if self.transform is not None: img = self.transform(img) return img,label def __len__(self): return len(self.imgs) train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor()) data_loader = DataLoader(train_data, batch_size=100,shuffle=True) print(len(data_loader))
DataLoader解析
位于torch.util.data.DataLoader中 源代碼
該接口的主要目的是將pytorch中已有的數(shù)據(jù)接口如torchvision.datasets.ImageFolder,或者自定義的數(shù)據(jù)讀取接口轉(zhuǎn)化按照
batch_size的大小封裝為Tensor,即相當于在內(nèi)置數(shù)據(jù)接口或者自定義數(shù)據(jù)接口的基礎(chǔ)上增加一維,大小為batch_size的大小,
得到的數(shù)據(jù)在之后可以通過封裝為Variable,作為模型的輸出
_ _ init _ _中所需的參數(shù)如下
1. dataset torch.utils.data.Dataset類的子類,可以是torchvision.datasets.ImageFolder等內(nèi)置類,也可是繼承了torch.utils.data.Dataset的自定義類 2. batch_size 每一個batch中包含的樣本個數(shù),默認是1 3. shuffle 一般在訓(xùn)練集中采用,默認是false,設(shè)置為true則每一個epoch都會將訓(xùn)練樣本打亂 4. sampler 訓(xùn)練樣本選取策略,和shuffle是互斥的 如果 shuffle為true,該參數(shù)一定要為None 5. batch_sampler BatchSampler 一次產(chǎn)生一個 batch 的 indices,和sampler以及shuffle互斥,一般使用默認的即可 上述Sampler的源代碼地址如下[源代碼](https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py) 6. num_workers 用于數(shù)據(jù)加載的線程數(shù)量 默認為0 即只有主線程用來加載數(shù)據(jù) 7. collate_fn 用來聚合數(shù)據(jù)生成mini_batch
使用的時候一般為如下使用方法:
train_data=torch.utils.data.DataLoader(...) for i, (input, target) in enumerate(train_data): ...
循環(huán)取DataLoader中的數(shù)據(jù)會觸發(fā)類中_ _ iter __方法,查看源代碼可知 其中調(diào)用的方法為 return _DataLoaderIter(self),因此需要查看 DataLoaderIter 這一內(nèi)部類
class DataLoaderIter(object): "Iterates once over the DataLoader's dataset, as specified by the sampler" def __init__(self, loader): self.dataset = loader.dataset self.collate_fn = loader.collate_fn self.batch_sampler = loader.batch_sampler self.num_workers = loader.num_workers self.pin_memory = loader.pin_memory and torch.cuda.is_available() self.timeout = loader.timeout self.done_event = threading.Event() self.sample_iter = iter(self.batch_sampler) if self.num_workers > 0: self.worker_init_fn = loader.worker_init_fn self.index_queue = multiprocessing.SimpleQueue() self.worker_result_queue = multiprocessing.SimpleQueue() self.batches_outstanding = 0 self.worker_pids_set = False self.shutdown = False self.send_idx = 0 self.rcvd_idx = 0 self.reorder_dict = {} base_seed = torch.LongTensor(1).random_()[0] self.workers = [ multiprocessing.Process( target=_worker_loop, args=(self.dataset, self.index_queue, self.worker_result_queue, self.collate_fn, base_seed + i, self.worker_init_fn, i)) for i in range(self.num_workers)] if self.pin_memory or self.timeout > 0: self.data_queue = queue.Queue() self.worker_manager_thread = threading.Thread( target=_worker_manager_loop, args=(self.worker_result_queue, self.data_queue, self.done_event, self.pin_memory, torch.cuda.current_device())) self.worker_manager_thread.daemon = True self.worker_manager_thread.start() else: self.data_queue = self.worker_result_queue for w in self.workers: w.daemon = True # ensure that the worker exits on process exit w.start() _update_worker_pids(id(self), tuple(w.pid for w in self.workers)) _set_SIGCHLD_handler() self.worker_pids_set = True # prime the prefetch loop for _ in range(2 * self.num_workers): self._put_indices()
感謝各位的閱讀!關(guān)于“Pytorch如何實現(xiàn)數(shù)據(jù)加載與數(shù)據(jù)預(yù)處理”這篇文章就分享到這里了,希望以上內(nèi)容可以對大家有一定的幫助,讓大家可以學(xué)到更多知識,如果覺得文章不錯,可以把它分享出去讓更多的人看到吧!
另外有需要云服務(wù)器可以了解下創(chuàng)新互聯(lián)scvps.cn,海內(nèi)外云服務(wù)器15元起步,三天無理由+7*72小時售后在線,公司持有idc許可證,提供“云服務(wù)器、裸金屬服務(wù)器、高防服務(wù)器、香港服務(wù)器、美國服務(wù)器、虛擬主機、免備案服務(wù)器”等云主機租用服務(wù)以及企業(yè)上云的綜合解決方案,具有“安全穩(wěn)定、簡單易用、服務(wù)可用性高、性價比高”等特點與優(yōu)勢,專為企業(yè)上云打造定制,能夠滿足用戶豐富、多元化的應(yīng)用場景需求。