1. torch.utils.data.Dataset
datasets這是一個pytorch定義的dataset的源碼集合。下面是一個自定義Datasets的基本框架,初始化放在__init__()中,其中__getitem__()和__len__()兩個方法是必須重寫的。
__getitem__()返回訓練數據,如圖片和label,而__len__()返回數據長度。
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
|
class CustomDataset(data.Dataset): #需要繼承data.Dataset def __init__( self ): # TODO # 1. Initialize file path or list of file names. pass def __getitem__( self , index): # TODO # 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open). # 2. Preprocess the data (e.g. torchvision.Transform). # 3. Return a data pair (e.g. image and label). #這里需要注意的是,第一步:read one data,是一個data pass def __len__( self ): # You should change 0 to the total size of your dataset. return 0 |
2. torch.utils.data.DataLoader
DataLoader(object)
可用參數:
dataset(Dataset)
傳入的數據集
batch_size(int, optional)
每個batch有多少個樣本
shuffle(bool, optional)
在每個epoch開始的時候,對數據進行重新排序
sampler(Sampler, optional)
自定義從數據集中取樣本的策略,如果指定這個參數,那么shuffle必須為False
batch_sampler(Sampler, optional)
與sampler類似,但是一次只返回一個batch的indices(索引),需要注意的是,一旦指定了這個參數,那么batch_size,shuffle,sampler,drop_last就不能再制定了(互斥——Mutually exclusive)
num_workers (int, optional)
這個參數決定了有幾個進程來處理data loading。0意味著所有的數據都會被load進主進程。(默認為0)
collate_fn (callable, optional)
將一個list的sample組成一個mini-batch的函數
pin_memory (bool, optional)
如果設置為True,那么data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定內存(CUDA pinned memory)中.
drop_last (bool, optional)
如果設置為True:這個是對最后的未完成的batch來說的,比如你的batch_size設置為64,而一個epoch只有100個樣本,那么訓練的時候后面的36個就被扔掉了。 如果為False(默認),那么會繼續正常執行,只是最后的batch_size會小一點。
timeout(numeric, optional)
如果是正數,表明等待從worker進程中收集一個batch等待的時間,若超出設定的時間還沒有收集到,那就不收集這個內容了。這個numeric應總是大于等于0。默認為0
worker_init_fn (callable, optional)
每個worker初始化函數 If not None, this will be called on eachworker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
3. 使用Dataset, DataLoader產生自定義訓練數據
假設TXT文件保存了數據的圖片和label,格式如下:第一列是圖片的名字,第二列是label
1
2
3
4
5
6
7
8
9
10
|
0.jpg 0 1.jpg 1 2.jpg 2 3.jpg 3 4.jpg 4 5.jpg 5 6.jpg 6 7.jpg 7 8.jpg 8 9.jpg 9 |
也可以是多標簽的數據,如:
1
2
3
4
5
6
7
8
9
10
|
0.jpg 0 10 1.jpg 1 11 2.jpg 2 12 3.jpg 3 13 4.jpg 4 14 5.jpg 5 15 6.jpg 6 16 7.jpg 7 17 8.jpg 8 18 9.jpg 9 19 |
圖庫十張原始圖片放在./dataset/images目錄下,然后我們就可以自定義一個Dataset解析這些數據并讀取圖片,再使用DataLoader類產生batch的訓練數據
3.1 自定義Dataset
首先先自定義一個TorchDataset類,用于讀取圖片數據,產生標簽:
注意初始化函數:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
|
import torch from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader import numpy as np from utils import image_processing import os class TorchDataset(Dataset): def __init__( self , filename, image_dir, resize_height = 256 , resize_width = 256 , repeat = 1 ): ''' :param filename: 數據文件TXT:格式:imge_name.jpg label1_id labe2_id :param image_dir: 圖片路徑:image_dir+imge_name.jpg構成圖片的完整路徑 :param resize_height 為None時,不進行縮放 :param resize_width 為None時,不進行縮放, PS:當參數resize_height或resize_width其中一個為None時,可實現等比例縮放 :param repeat: 所有樣本數據重復次數,默認循環一次,當repeat為None時,表示無限循環<sys.maxsize ''' self .image_label_list = self .read_file(filename) self .image_dir = image_dir self . len = len ( self .image_label_list) self .repeat = repeat self .resize_height = resize_height self .resize_width = resize_width # 相關預處理的初始化 '''class torchvision.transforms.ToTensor''' # 把shape=(H,W,C)的像素值范圍為[0, 255]的PIL.Image或者numpy.ndarray數據 # 轉換成shape=(C,H,W)的像素數據,并且被歸一化到[0.0, 1.0]的torch.FloatTensor類型。 self .toTensor = transforms.ToTensor() '''class torchvision.transforms.Normalize(mean, std) 此轉換類作用于torch. * Tensor,給定均值(R, G, B) 和標準差(R, G, B), 用公式channel = (channel - mean) / std進行規范化。 ''' # self.normalize=transforms.Normalize() def __getitem__( self , i): index = i % self . len # print("i={},index={}".format(i, index)) image_name, label = self .image_label_list[index] image_path = os.path.join( self .image_dir, image_name) img = self .load_data(image_path, self .resize_height, self .resize_width, normalization = False ) img = self .data_preproccess(img) label = np.array(label) return img, label def __len__( self ): if self .repeat = = None : data_len = 10000000 else : data_len = len ( self .image_label_list) * self .repeat return data_len def read_file( self , filename): image_label_list = [] with open (filename, 'r' ) as f: lines = f.readlines() for line in lines: # rstrip:用來去除結尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格) content = line.rstrip().split( ' ' ) name = content[ 0 ] labels = [] for value in content[ 1 :]: labels.append( int (value)) image_label_list.append((name, labels)) return image_label_list def load_data( self , path, resize_height, resize_width, normalization): ''' 加載數據 :param path: :param resize_height: :param resize_width: :param normalization: 是否歸一化 :return: ''' image = image_processing.read_image(path, resize_height, resize_width, normalization) return image def data_preproccess( self , data): ''' 數據預處理 :param data: :return: ''' data = self .toTensor(data) return data |
3.2 DataLoader產生批訓練數據
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
if __name__ = = '__main__' : train_filename = "../dataset/train.txt" # test_filename="../dataset/test.txt" image_dir = '../dataset/images' epoch_num = 2 #總樣本循環次數 batch_size = 7 #訓練時的一組數據的大小 train_data_nums = 10 max_iterate = int ((train_data_nums + batch_size - 1 ) / batch_size * epoch_num) #總迭代次數 train_data = TorchDataset(filename = train_filename, image_dir = image_dir,repeat = 1 ) # test_data = TorchDataset(filename=test_filename, image_dir=image_dir,repeat=1) train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle = False ) # test_loader = DataLoader(dataset=test_data, batch_size=batch_size,shuffle=False) # [1]使用epoch方法迭代,TorchDataset的參數repeat=1 for epoch in range (epoch_num): for batch_image, batch_label in train_loader: image = batch_image[ 0 ,:] image = image.numpy() #image=np.array(image) image = image.transpose( 1 , 2 , 0 ) # 通道由[c,h,w]->[h,w,c] image_processing.cv_show_image( "image" ,image) print ( "batch_image.shape:{},batch_label:{}" . format (batch_image.shape,batch_label)) # batch_x, batch_y = Variable(batch_x), Variable(batch_y) |
上面的迭代代碼是通過兩個for實現,其中參數epoch_num表示總樣本循環次數,比如epoch_num=2,那就是所有樣本循環迭代2次。
但這會出現一個問題,當樣本總數train_data_nums與batch_size不能整取時,最后一個batch會少于規定batch_size的大小,比如這里樣本總數train_data_nums=10,batch_size=7,第一次迭代會產生7個樣本,第二次迭代會因為樣本不足,只能產生3個樣本。
我們希望,每次迭代都會產生相同大小的batch數據,因此可以如下迭代:注意本人在構造TorchDataset類時,就已經考慮循環迭代的方法,因此,你現在只需修改repeat為None時,就表示無限循環了,調用方法如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
|
''' 下面兩種方式,TorchDataset設置repeat=None可以實現無限循環,退出循環由max_iterate設定 ''' train_data = TorchDataset(filename = train_filename, image_dir = image_dir,repeat = None ) train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle = False ) # [2]第2種迭代方法 for step, (batch_image, batch_label) in enumerate (train_loader): image = batch_image[ 0 ,:] image = image.numpy() #image=np.array(image) image = image.transpose( 1 , 2 , 0 ) # 通道由[c,h,w]->[h,w,c] image_processing.cv_show_image( "image" ,image) print ( "step:{},batch_image.shape:{},batch_label:{}" . format (step,batch_image.shape,batch_label)) # batch_x, batch_y = Variable(batch_x), Variable(batch_y) if step> = max_iterate: break # [3]第3種迭代方法 # for step in range(max_iterate): # batch_image, batch_label=train_loader.__iter__().__next__() # image=batch_image[0,:] # image=image.numpy()#image=np.array(image) # image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c] # image_processing.cv_show_image("image",image) # print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label)) # # batch_x, batch_y = Variable(batch_x), Variable(batch_y) |
3.3 附件:image_processing.py
上面代碼,用到image_processing,這是本人封裝好的圖像處理包,包含讀取圖片,畫圖等基本方法:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
|
# -*-coding: utf-8 -*- """ @Project: IntelligentManufacture @File : image_processing.py @Author : panjq @E-mail : pan_jinquan@163.com @Date : 2019-02-14 15:34:50 """ import os import glob import cv2 import numpy as np import matplotlib.pyplot as plt def show_image(title, image): ''' 調用matplotlib顯示RGB圖片 :param title: 圖像標題 :param image: 圖像的數據 :return: ''' # plt.figure("show_image") # print(image.dtype) plt.imshow(image) plt.axis( 'on' ) # 關掉坐標軸為 off plt.title(title) # 圖像題目 plt.show() def cv_show_image(title, image): ''' 調用OpenCV顯示RGB圖片 :param title: 圖像標題 :param image: 輸入RGB圖像 :return: ''' channels = image.shape[ - 1 ] if channels = = 3 : image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # 將BGR轉為RGB cv2.imshow(title,image) cv2.waitKey( 0 ) def read_image(filename, resize_height = None , resize_width = None , normalization = False ): ''' 讀取圖片數據,默認返回的是uint8,[0,255] :param filename: :param resize_height: :param resize_width: :param normalization:是否歸一化到[0.,1.0] :return: 返回的RGB圖片數據 ''' bgr_image = cv2.imread(filename) # bgr_image = cv2.imread(filename,cv2.IMREAD_IGNORE_ORIENTATION|cv2.IMREAD_COLOR) if bgr_image is None : print ( "Warning:不存在:{}" , filename) return None if len (bgr_image.shape) = = 2 : # 若是灰度圖則轉為三通道 print ( "Warning:gray image" , filename) bgr_image = cv2.cvtColor(bgr_image, cv2.COLOR_GRAY2BGR) rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 將BGR轉為RGB # show_image(filename,rgb_image) # rgb_image=Image.open(filename) rgb_image = resize_image(rgb_image,resize_height,resize_width) rgb_image = np.asanyarray(rgb_image) if normalization: # 不能寫成:rgb_image=rgb_image/255 rgb_image = rgb_image / 255.0 # show_image("src resize image",image) return rgb_image def fast_read_image_roi(filename, orig_rect, ImreadModes = cv2.IMREAD_COLOR, normalization = False ): ''' 快速讀取圖片的方法 :param filename: 圖片路徑 :param orig_rect:原始圖片的感興趣區域rect :param ImreadModes: IMREAD_UNCHANGED IMREAD_GRAYSCALE IMREAD_COLOR IMREAD_ANYDEPTH IMREAD_ANYCOLOR IMREAD_LOAD_GDAL IMREAD_REDUCED_GRAYSCALE_2 IMREAD_REDUCED_COLOR_2 IMREAD_REDUCED_GRAYSCALE_4 IMREAD_REDUCED_COLOR_4 IMREAD_REDUCED_GRAYSCALE_8 IMREAD_REDUCED_COLOR_8 IMREAD_IGNORE_ORIENTATION :param normalization: 是否歸一化 :return: 返回感興趣區域ROI ''' # 當采用IMREAD_REDUCED模式時,對應rect也需要縮放 scale = 1 if ImreadModes = = cv2.IMREAD_REDUCED_COLOR_2 or ImreadModes = = cv2.IMREAD_REDUCED_COLOR_2: scale = 1 / 2 elif ImreadModes = = cv2.IMREAD_REDUCED_GRAYSCALE_4 or ImreadModes = = cv2.IMREAD_REDUCED_COLOR_4: scale = 1 / 4 elif ImreadModes = = cv2.IMREAD_REDUCED_GRAYSCALE_8 or ImreadModes = = cv2.IMREAD_REDUCED_COLOR_8: scale = 1 / 8 rect = np.array(orig_rect) * scale rect = rect.astype( int ).tolist() bgr_image = cv2.imread(filename,flags = ImreadModes) if bgr_image is None : print ( "Warning:不存在:{}" , filename) return None if len (bgr_image.shape) = = 3 : # rgb_image = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2RGB) # 將BGR轉為RGB else : rgb_image = bgr_image #若是灰度圖 rgb_image = np.asanyarray(rgb_image) if normalization: # 不能寫成:rgb_image=rgb_image/255 rgb_image = rgb_image / 255.0 roi_image = get_rect_image(rgb_image , rect) # show_image_rect("src resize image",rgb_image,rect) # cv_show_image("reROI",roi_image) return roi_image def resize_image(image,resize_height, resize_width): ''' :param image: :param resize_height: :param resize_width: :return: ''' image_shape = np.shape(image) height = image_shape[ 0 ] width = image_shape[ 1 ] if (resize_height is None ) and (resize_width is None ): #錯誤寫法:resize_height and resize_width is None return image if resize_height is None : resize_height = int (height * resize_width / width) elif resize_width is None : resize_width = int (width * resize_height / height) image = cv2.resize(image, dsize = (resize_width, resize_height)) return image def scale_image(image,scale): ''' :param image: :param scale: (scale_w,scale_h) :return: ''' image = cv2.resize(image,dsize = None , fx = scale[ 0 ],fy = scale[ 1 ]) return image def get_rect_image(image,rect): ''' :param image: :param rect: [x,y,w,h] :return: ''' x, y, w, h = rect cut_img = image[y:(y + h),x:(x + w)] return cut_img def scale_rect(orig_rect,orig_shape,dest_shape): ''' 對圖像進行縮放時,對應的rectangle也要進行縮放 :param orig_rect: 原始圖像的rect=[x,y,w,h] :param orig_shape: 原始圖像的維度shape=[h,w] :param dest_shape: 縮放后圖像的維度shape=[h,w] :return: 經過縮放后的rectangle ''' new_x = int (orig_rect[ 0 ] * dest_shape[ 1 ] / orig_shape[ 1 ]) new_y = int (orig_rect[ 1 ] * dest_shape[ 0 ] / orig_shape[ 0 ]) new_w = int (orig_rect[ 2 ] * dest_shape[ 1 ] / orig_shape[ 1 ]) new_h = int (orig_rect[ 3 ] * dest_shape[ 0 ] / orig_shape[ 0 ]) dest_rect = [new_x,new_y,new_w,new_h] return dest_rect def show_image_rect(win_name,image,rect): ''' :param win_name: :param image: :param rect: :return: ''' x, y, w, h = rect point1 = (x,y) point2 = (x + w,y + h) cv2.rectangle(image, point1, point2, ( 0 , 0 , 255 ), thickness = 2 ) cv_show_image(win_name, image) def rgb_to_gray(image): image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) return image def save_image(image_path, rgb_image,toUINT8 = True ): if toUINT8: rgb_image = np.asanyarray(rgb_image * 255 , dtype = np.uint8) if len (rgb_image.shape) = = 2 : # 若是灰度圖則轉為三通道 bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_GRAY2BGR) else : bgr_image = cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR) cv2.imwrite(image_path, bgr_image) def combime_save_image(orig_image, dest_image, out_dir,name,prefix): ''' 命名標準:out_dir/name_prefix.jpg :param orig_image: :param dest_image: :param image_path: :param out_dir: :param prefix: :return: ''' dest_path = os.path.join(out_dir, name + "_" + prefix + ".jpg" ) save_image(dest_path, dest_image) dest_image = np.hstack((orig_image, dest_image)) save_image(os.path.join(out_dir, "{}_src_{}.jpg" . format (name,prefix)), dest_image) |
3.4 完整的代碼
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
|
# -*-coding: utf-8 -*- """ @Project: pytorch-learning-tutorials @File : dataset.py @Author : panjq @E-mail : pan_jinquan@163.com @Date : 2019-03-07 18:45:06 """ import torch from torch.autograd import Variable from torchvision import transforms from torch.utils.data import Dataset, DataLoader import numpy as np from utils import image_processing import os class TorchDataset(Dataset): def __init__( self , filename, image_dir, resize_height = 256 , resize_width = 256 , repeat = 1 ): ''' :param filename: 數據文件TXT:格式:imge_name.jpg label1_id labe2_id :param image_dir: 圖片路徑:image_dir+imge_name.jpg構成圖片的完整路徑 :param resize_height 為None時,不進行縮放 :param resize_width 為None時,不進行縮放, PS:當參數resize_height或resize_width其中一個為None時,可實現等比例縮放 :param repeat: 所有樣本數據重復次數,默認循環一次,當repeat為None時,表示無限循環<sys.maxsize ''' self .image_label_list = self .read_file(filename) self .image_dir = image_dir self . len = len ( self .image_label_list) self .repeat = repeat self .resize_height = resize_height self .resize_width = resize_width # 相關預處理的初始化 '''class torchvision.transforms.ToTensor''' # 把shape=(H,W,C)的像素值范圍為[0, 255]的PIL.Image或者numpy.ndarray數據 # 轉換成shape=(C,H,W)的像素數據,并且被歸一化到[0.0, 1.0]的torch.FloatTensor類型。 self .toTensor = transforms.ToTensor() '''class torchvision.transforms.Normalize(mean, std) 此轉換類作用于torch. * Tensor,給定均值(R, G, B) 和標準差(R, G, B), 用公式channel = (channel - mean) / std進行規范化。 ''' # self.normalize=transforms.Normalize() def __getitem__( self , i): index = i % self . len # print("i={},index={}".format(i, index)) image_name, label = self .image_label_list[index] image_path = os.path.join( self .image_dir, image_name) img = self .load_data(image_path, self .resize_height, self .resize_width, normalization = False ) img = self .data_preproccess(img) label = np.array(label) return img, label def __len__( self ): if self .repeat = = None : data_len = 10000000 else : data_len = len ( self .image_label_list) * self .repeat return data_len def read_file( self , filename): image_label_list = [] with open (filename, 'r' ) as f: lines = f.readlines() for line in lines: # rstrip:用來去除結尾字符、空白符(包括\n、\r、\t、' ',即:換行、回車、制表符、空格) content = line.rstrip().split( ' ' ) name = content[ 0 ] labels = [] for value in content[ 1 :]: labels.append( int (value)) image_label_list.append((name, labels)) return image_label_list def load_data( self , path, resize_height, resize_width, normalization): ''' 加載數據 :param path: :param resize_height: :param resize_width: :param normalization: 是否歸一化 :return: ''' image = image_processing.read_image(path, resize_height, resize_width, normalization) return image def data_preproccess( self , data): ''' 數據預處理 :param data: :return: ''' data = self .toTensor(data) return data if __name__ = = '__main__' : train_filename = "../dataset/train.txt" # test_filename="../dataset/test.txt" image_dir = '../dataset/images' epoch_num = 2 #總樣本循環次數 batch_size = 7 #訓練時的一組數據的大小 train_data_nums = 10 max_iterate = int ((train_data_nums + batch_size - 1 ) / batch_size * epoch_num) #總迭代次數 train_data = TorchDataset(filename = train_filename, image_dir = image_dir,repeat = 1 ) # test_data = TorchDataset(filename=test_filename, image_dir=image_dir,repeat=1) train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle = False ) # test_loader = DataLoader(dataset=test_data, batch_size=batch_size,shuffle=False) # [1]使用epoch方法迭代,TorchDataset的參數repeat=1 for epoch in range (epoch_num): for batch_image, batch_label in train_loader: image = batch_image[ 0 ,:] image = image.numpy() #image=np.array(image) image = image.transpose( 1 , 2 , 0 ) # 通道由[c,h,w]->[h,w,c] image_processing.cv_show_image( "image" ,image) print ( "batch_image.shape:{},batch_label:{}" . format (batch_image.shape,batch_label)) # batch_x, batch_y = Variable(batch_x), Variable(batch_y) ''' 下面兩種方式,TorchDataset設置repeat=None可以實現無限循環,退出循環由max_iterate設定 ''' train_data = TorchDataset(filename = train_filename, image_dir = image_dir,repeat = None ) train_loader = DataLoader(dataset = train_data, batch_size = batch_size, shuffle = False ) # [2]第2種迭代方法 for step, (batch_image, batch_label) in enumerate (train_loader): image = batch_image[ 0 ,:] image = image.numpy() #image=np.array(image) image = image.transpose( 1 , 2 , 0 ) # 通道由[c,h,w]->[h,w,c] image_processing.cv_show_image( "image" ,image) print ( "step:{},batch_image.shape:{},batch_label:{}" . format (step,batch_image.shape,batch_label)) # batch_x, batch_y = Variable(batch_x), Variable(batch_y) if step> = max_iterate: break # [3]第3種迭代方法 # for step in range(max_iterate): # batch_image, batch_label=train_loader.__iter__().__next__() # image=batch_image[0,:] # image=image.numpy()#image=np.array(image) # image = image.transpose(1, 2, 0) # 通道由[c,h,w]->[h,w,c] # image_processing.cv_show_image("image",image) # print("batch_image.shape:{},batch_label:{}".format(batch_image.shape,batch_label)) # # batch_x, batch_y = Variable(batch_x), Variable(batch_y) |
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持服務器之家。如有錯誤或未考慮完全的地方,望不吝賜教。
原文鏈接:https://blog.csdn.net/guyuealian/article/details/88343924