fit_generator 是 keras 提供的用來進(jìn)行批次訓(xùn)練的函數(shù),使用方法如下:
1
2
3
4
|
model.fit_generator(generator, steps_per_epoch = None , epochs = 1 , verbose = 1 , callbacks = None , validation_data = None , validation_steps = None , class_weight = None , max_queue_size = 10 , workers = 1 , use_multiprocessing = False , shuffle = True , initial_epoch = 0 ) |
參數(shù)說明:
generator: 一個(gè)生成器,或者一個(gè) Sequence (keras.utils.Sequence) 對(duì)象的實(shí)例, 以在使用多進(jìn)程時(shí)避免數(shù)據(jù)的重復(fù)。 生成器的輸出應(yīng)該為以下之一:
一個(gè)(inputs, targets) 元組
一個(gè) (inputs, targets, sample_weights) 元組。
這個(gè)元組(生成器的單個(gè)輸出)組成了單個(gè)的 batch。 因此,這個(gè)元組中的所有數(shù)組長度必須相同(與這一個(gè) batch 的大小相等)。 不同的 batch 可能大小不同。 例如,一個(gè) epoch 的最后一個(gè) batch 往往比其他 batch 要小, 如果數(shù)據(jù)集的尺寸不能被 batch size 整除。 生成器將無限地在數(shù)據(jù)集上循環(huán)。當(dāng)運(yùn)行到第steps_per_epoch 時(shí),記一個(gè) epoch 結(jié)束。
steps_per_epoch: 在聲明一個(gè) epoch 完成并開始下一個(gè) epoch 之前從 generator產(chǎn)生的總步數(shù)(批次樣本)。 它通常應(yīng)該等于你的數(shù)據(jù)集的樣本數(shù)量除以批量大小。 對(duì)于Sequence,它是可選的:如果未指定,將使用len(generator)作為步數(shù)。
epochs: 整數(shù)。訓(xùn)練模型的迭代總輪數(shù)。一個(gè) epoch 是對(duì)所提供的整個(gè)數(shù)據(jù)的一輪迭代,如 steps_per_epoch 所定義。注意,與 initial_epoch 一起使用,epoch 應(yīng)被理解為「最后一輪」。模型沒有經(jīng)歷由 epochs 給出的多次迭代的訓(xùn)練,而僅僅是直到達(dá)到索引 epoch 的輪次。
verbose: 0, 1 或 2。日志顯示模式。 0 = 安靜模式, 1 = 進(jìn)度條, 2 = 每輪一行。
callbacks: keras.callbacks.Callback 實(shí)例的列表。在訓(xùn)練時(shí)調(diào)用的一系列回調(diào)函數(shù)。
validation_data: 它可以是以下之一:
驗(yàn)證數(shù)據(jù)的生成器或Sequence實(shí)例
一個(gè)(inputs, targets) 元組
一個(gè)(inputs, targets, sample_weights) 元組。
在每個(gè) epoch 結(jié)束時(shí)評(píng)估損失和任何模型指標(biāo)。該模型不會(huì)對(duì)此數(shù)據(jù)進(jìn)行訓(xùn)練。
validation_steps: 僅當(dāng) validation_data 是一個(gè)生成器時(shí)才可用。 在停止前 generator 生成的總步數(shù)(樣本批數(shù))。 對(duì)于 Sequence,它是可選的:如果未指定,將使用 len(generator) 作為步數(shù)。
class_weight: 可選的將類索引(整數(shù))映射到權(quán)重(浮點(diǎn))值的字典,用于加權(quán)損失函數(shù)(僅在訓(xùn)練期間)。 這可以用來告訴模型「更多地關(guān)注」來自代表性不足的類的樣本。
max_queue_size: 整數(shù)。生成器隊(duì)列的最大尺寸。 如未指定,max_queue_size 將默認(rèn)為 10。
workers: 整數(shù)。使用的最大進(jìn)程數(shù)量,如果使用基于進(jìn)程的多線程。 如未指定,workers 將默認(rèn)為 1。如果為 0,將在主線程上執(zhí)行生成器。
use_multiprocessing: 布爾值。如果 True,則使用基于進(jìn)程的多線程。 如未指定, use_multiprocessing 將默認(rèn)為 False。 請(qǐng)注意,由于此實(shí)現(xiàn)依賴于多進(jìn)程,所以不應(yīng)將不可傳遞的參數(shù)傳遞給生成器,因?yàn)樗鼈儾荒鼙惠p易地傳遞給子進(jìn)程。
shuffle: 是否在每輪迭代之前打亂 batch 的順序。 只能與 Sequence (keras.utils.Sequence) 實(shí)例同用。
initial_epoch: 開始訓(xùn)練的輪次(有助于恢復(fù)之前的訓(xùn)練)。
補(bǔ)充知識(shí):Keras中fit_generator 的多個(gè)分支輸入時(shí),需注意generator的格式 以及 輸入序列的順序
需要注意迭代器 yeild返回不能是[x1,x2],y 這樣,而是要完整的字典格式的:
yield ({'input_1': x1, 'input_2': x2}, {'output': y})
這也不算坑 追進(jìn)去 fit_generator也能看到示例
1
2
3
4
5
6
7
8
9
10
11
12
13
14
|
def generate_batch(x_train,y_train,batch_size,x_train2,randomFlag = True ): ylen = len (y_train) loopcount = ylen / / batch_size i = - 1 while True : if randomFlag: i = random.randint( 0 ,loopcount - 1 ) else : i = i + 1 i = i % loopcount yield ({ 'lstmInput' : x_train[i * batch_size:(i + 1 ) * batch_size], 'bgInput' : x_train2[i * batch_size:(i + 1 ) * batch_size]}, { 'prediction' : y_train[i * batch_size:(i + 1 ) * batch_size]}) |
ps: 因?yàn)橐莟uple yield后的括號(hào)不能省
需注意的坑1是,validation data中如果用【】組成數(shù)組進(jìn)行輸入,是要按順序的,按編譯model前的設(shè)置model = Model(inputs=[simInput,lstmInput,bgInput], outputs=predictions),中數(shù)組的順序來編譯
需注意的坑2是,多輸入input時(shí),以后都用 inputs1=Input(batch_shape=(batchSize,TPeriod,dimIn,),name='input1LSTM')指定batchSize,不然跟stateful lstm結(jié)合時(shí),會(huì)提示不匹配。
1
2
3
4
5
|
history = model.fit_generator(generate_batch(trainX,trainY,batchSize,trainX2), steps_per_epoch = len (trainX) / / batchSize, validation_data = ([testX,testX2],testY), epochs = epochs, callbacks = [tensorboard,checkpoint],initial_epoch = 0 ,verbose = 1 ) # Fit the LSTM network/擬合LSTM網(wǎng)絡(luò) |
以上這篇keras和tensorflow使用fit_generator 批次訓(xùn)練操作就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持服務(wù)器之家。
原文鏈接:https://blog.csdn.net/zhangpeterx/article/details/90900118