前言
眾所周知,Dataset和Dataloder是pytorch中進行數據載入的部件。必須將數據載入后,再進行深度學習模型的訓練。在pytorch的一些案例教學中,常使用torchvision.datasets
自帶的MNIST、CIFAR-10數據集,一般流程為:
1
2
3
4
|
# 下載并存放數據集 train_dataset = torchvision.datasets.CIFAR10(root = "數據集存放位置" ,download = True ) # load數據 train_loader = torch.utils.data.DataLoader(dataset = train_dataset) |
但是,在我們自己的模型訓練中,需要使用非官方自制的數據集。這時應該怎么辦呢?
我們可以通過改寫torch.utils.data.Dataset
中的__getitem__
和__len__
來載入我們自己的數據集。
__getitem__
獲取數據集中的數據,__len__
獲取整個數據集的長度(即個數)。
改寫
采用pytorch官網案例中提供的一個臉部landmark數據集。數據集中含有存放landmark的csv文件,但是我們在這篇文章中不使用(其實也可以隨便下載一些圖片作數據集來實驗)。
1
2
3
4
5
6
7
8
9
|
import os import torch from skimage import io, transform import numpy as np import matplotlib.pyplot as plt from torch.utils.data import Dataset, DataLoader from torchvision import transforms, utils plt.ion() # interactive mode |
torch.utils.data.Dataset
是一個抽象類,我們自己的數據集需要繼承Dataset
,然后改寫上述兩個函數:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
|
class ImageLoader(Dataset): def __init__( self , file_path, transform = None ): super (ImageLoader, self ).__init__() self .file_path = file_path self .transform = transform # 對輸入圖像進行預處理,這里并沒有做,預設為None self .image_names = os.listdir( self .file_path) # 文件名的列表 def __getitem__( self ,idx): image = self .image_names[idx] image = io.imread(os.path.join( self .file_path,image)) # if self.transform: # image= self.transform(image) return image def __len__( self ): return len ( self .image_names) # 設置自己存放的數據集位置,并plot展示 imageloader = ImageLoader(file_path = "D:\\Projects\\datasets\\faces\\" ) # imageloader.__len__() # 輸出數據集長度(個數),應為71 # print(imageloader.__getitem__(0)) # 以數據形式展示 plt.imshow(imageloader.__getitem__( 0 )) # 以圖像形式展示 plt.show() |
得到的圖片輸出:
得到的數據輸出,:
1
2
3
4
5
6
7
8
9
10
11
|
array([[[ 66 , 59 , 53 ], [ 66 , 59 , 53 ], [ 66 , 59 , 53 ], ..., [ 59 , 54 , 48 ], [ 59 , 54 , 48 ], [ 59 , 54 , 48 ]], ..., [ 153 , 141 , 129 ], [ 158 , 146 , 134 ], [ 158 , 146 , 134 ]]], dtype = uint8) |
上面看到dytpe=uint8
,實際進行訓練的時候,常常需要更改成float
的數據類型。可以使用:
1
2
3
|
# 直接改成pytorch中的tensor下的float格式 # 也可以用numpy的改成普通的float格式 to_float = torch.from_numpy(imageloader.__getitem__( 0 )). float () |
改寫完成后,直接使用train_loader =torch.utils.data.DataLoader(dataset=imageloader)
載入到Dataloader
中,就可以使用了。
下面的代碼可以試著運行一下,產生的是一模一樣的圖片結果。
1
2
3
4
|
train_loader = torch.utils.data.DataLoader(dataset = imageloader) train_loader.dataset[ 0 ] plt.imshow(train_loader.dataset[ 0 ]) plt.show() |
到此這篇關于PyTorch實現重寫/改寫Dataset并載入Dataloader的文章就介紹到這了,更多相關PyTorch重寫/改寫Dataset 內容請搜索服務器之家以前的文章或繼續瀏覽下面的相關文章希望大家以后多多支持服務器之家!
原文鏈接:https://blog.csdn.net/qq_38372240/article/details/107322677