DataLoader完整的參數表如下:
1
2
3
4
5
6
7
8
9
10
11
12
|
class torch.utils.data.DataLoader( dataset, batch_size = 1 , shuffle = False , sampler = None , batch_sampler = None , num_workers = 0 , collate_fn = <function default_collate>, pin_memory = False , drop_last = False , timeout = 0 , worker_init_fn = None ) |
DataLoader在數據集上提供單進程或多進程的迭代器
幾個關鍵的參數意思:
- shuffle:設置為True的時候,每個世代都會打亂數據集
- collate_fn:如何取樣本的,我們可以定義自己的函數來準確地實現想要的功能
- drop_last:告訴如何處理數據集長度除于batch_size余下的數據。True就拋棄,否則保留
一個測試的例子
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
|
import torch import torch.utils.data as Data import numpy as np test = np.array([ 0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 , 9 , 10 , 11 ]) inputing = torch.tensor(np.array([test[i:i + 3 ] for i in range ( 10 )])) target = torch.tensor(np.array([test[i:i + 1 ] for i in range ( 10 )])) torch_dataset = Data.TensorDataset(inputing,target) batch = 3 loader = Data.DataLoader( dataset = torch_dataset, batch_size = batch, # 批大小 # 若dataset中的樣本數不能被batch_size整除的話,最后剩余多少就使用多少 collate_fn = lambda x:( torch.cat( [x[i][j].unsqueeze( 0 ) for i in range ( len (x))], 0 ).unsqueeze( 0 ) for j in range ( len (x[ 0 ])) ) ) for (i,j) in loader: print (i) print (j) |
輸出結果:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
tensor([[[ 0 , 1 , 2 ], [ 1 , 2 , 3 ], [ 2 , 3 , 4 ]]], dtype = torch.int32) tensor([[[ 0 ], [ 1 ], [ 2 ]]], dtype = torch.int32) tensor([[[ 3 , 4 , 5 ], [ 4 , 5 , 6 ], [ 5 , 6 , 7 ]]], dtype = torch.int32) tensor([[[ 3 ], [ 4 ], [ 5 ]]], dtype = torch.int32) tensor([[[ 6 , 7 , 8 ], [ 7 , 8 , 9 ], [ 8 , 9 , 10 ]]], dtype = torch.int32) tensor([[[ 6 ], [ 7 ], [ 8 ]]], dtype = torch.int32) tensor([[[ 9 , 10 , 11 ]]], dtype = torch.int32) tensor([[[ 9 ]]], dtype = torch.int32) |
如果不要collate_fn的值,輸出變成
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
|
tensor([[ 0 , 1 , 2 ], [ 1 , 2 , 3 ], [ 2 , 3 , 4 ]], dtype = torch.int32) tensor([[ 0 ], [ 1 ], [ 2 ]], dtype = torch.int32) tensor([[ 3 , 4 , 5 ], [ 4 , 5 , 6 ], [ 5 , 6 , 7 ]], dtype = torch.int32) tensor([[ 3 ], [ 4 ], [ 5 ]], dtype = torch.int32) tensor([[ 6 , 7 , 8 ], [ 7 , 8 , 9 ], [ 8 , 9 , 10 ]], dtype = torch.int32) tensor([[ 6 ], [ 7 ], [ 8 ]], dtype = torch.int32) tensor([[ 9 , 10 , 11 ]], dtype = torch.int32) tensor([[ 9 ]], dtype = torch.int32) |
所以collate_fn就是使結果多一維。
看看collate_fn的值是什么意思。我們把它改為如下
1
|
collate_fn = lambda x:x |
并輸出
1
2
|
for i in loader: print (i) |
得到結果
1
2
3
4
|
[(tensor([ 0 , 1 , 2 ], dtype = torch.int32), tensor([ 0 ], dtype = torch.int32)), (tensor([ 1 , 2 , 3 ], dtype = torch.int32), tensor([ 1 ], dtype = torch.int32)), (tensor([ 2 , 3 , 4 ], dtype = torch.int32), tensor([ 2 ], dtype = torch.int32))] [(tensor([ 3 , 4 , 5 ], dtype = torch.int32), tensor([ 3 ], dtype = torch.int32)), (tensor([ 4 , 5 , 6 ], dtype = torch.int32), tensor([ 4 ], dtype = torch.int32)), (tensor([ 5 , 6 , 7 ], dtype = torch.int32), tensor([ 5 ], dtype = torch.int32))] [(tensor([ 6 , 7 , 8 ], dtype = torch.int32), tensor([ 6 ], dtype = torch.int32)), (tensor([ 7 , 8 , 9 ], dtype = torch.int32), tensor([ 7 ], dtype = torch.int32)), (tensor([ 8 , 9 , 10 ], dtype = torch.int32), tensor([ 8 ], dtype = torch.int32))] [(tensor([ 9 , 10 , 11 ], dtype = torch.int32), tensor([ 9 ], dtype = torch.int32))] |
每個i都是一個列表,每個列表包含batch_size個元組,每個元組包含TensorDataset的單獨數據。所以要將重新組合成每個batch包含1*3*3的input和1*3*1的target,就要重新解包并打包。 看看我們的collate_fn:
1
2
3
4
5
|
collate_fn = lambda x:( torch.cat( [x[i][j].unsqueeze( 0 ) for i in range ( len (x))], 0 ).unsqueeze( 0 ) for j in range ( len (x[ 0 ])) ) |
j取的是兩個變量:input和target。i取的是batch_size。然后通過unsqueeze(0)方法在前面加一維。torch.cat(,0)將其打包起來。然后再通過unsqueeze(0)方法在前面加一維。 完成。
以上這篇Pytorch技巧:DataLoader的collate_fn參數使用詳解就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支持服務器之家。
原文鏈接:https://blog.csdn.net/weixin_42028364/article/details/81675021