Is it logical to loop on model.fit in Keras?

放肆的年华 提交于 2020-05-13 14:49:13

问题


Is it logical to do as below in Keras in order not to run out of memory?

for path in ['xaa', 'xab', 'xac', 'xad']:
    x_train, y_train = prepare_data(path)
    model.fit(x_train, y_train, batch_size=50, epochs=20, shuffle=True)

model.save('model')

回答1:


It is, but prefer model.train_on_batch if each iteration is generating a single batch. This eliminates some overhead that comes with fit.

You can also try to create a generator and use model.fit_generator():

def dataGenerator(pathes, batch_size):

    while True: #generators for keras must be infinite
        for path in pathes:
            x_train, y_train = prepare_data(path)

            totalSamps = x_train.shape[0]
            batches = totalSamps // batch_size

            if totalSamps % batch_size > 0:
                batches+=1

            for batch in range(batches):
                section = slice(batch*batch_size,(batch+1)*batch_size)
                yield (x_train[section], y_train[section])

Create and use:

gen = dataGenerator(['xaa', 'xab', 'xac', 'xad'], 50)
model.fit_generator(gen,
                    steps_per_epoch = expectedTotalNumberOfYieldsForOneEpoch
                    epochs = epochs)



回答2:


I would suggest having a look at this thread on Github.

You could indeed consider using model.fit(), but it would make the training more stable to do it in such a way:

for epoch in range(20):
    for path in ['xaa', 'xab', 'xac', 'xad']:
        x_train, y_train = prepare_data(path)
        model.fit(x_train, y_train, batch_size=50, epochs=epoch+1, initial_epoch=epoch, shuffle=True)

This way you are iterating over all your data once per epoch, and not iterating 20 epochs over part of your data before switching.

As discussed in the thread, another solution would be to develop your own data generator and use it with model.fit_generator().



来源:https://stackoverflow.com/questions/50448743/is-it-logical-to-loop-on-model-fit-in-keras

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!