keras: how to save the training history attribute of the history object

让人想犯罪 __ 提交于 2019-12-18 10:35:06

问题


In Keras, we can return the output of model.fit to a history as follows:

 history = model.fit(X_train, y_train, 
                     batch_size=batch_size, 
                     nb_epoch=nb_epoch,
                     validation_data=(X_test, y_test))

Now, how to save the history attribute of the history object to a file for further uses (e.g. draw plots of acc or loss against epochs)?


回答1:


What I use is the following:

    with open('/trainHistoryDict', 'wb') as file_pi:
        pickle.dump(history.history, file_pi)

In this way I save the history as a dictionary in case I want to plot the loss or accuracy later on.




回答2:


The model history can be saved into a file as follows

import json
hist = model.fit(X_train, y_train, epochs=5, batch_size=batch_size,validation_split=0.1)
with open('file.json', 'w') as f:
    json.dump(hist.history, f)



回答3:


A history objects has a history field is a dictionary which helds different training metrics spanned across every training epoch. So e.g. history.history['loss'][99] will return a loss of your model in a 100th epoch of training. In order to save that you could pickle this dictionary or simple save different lists from this dictionary to appropriate file.




回答4:


An other way to do this:

As history.history is a dict, you can convert it as well to a pandas DataFrame object, which can then be saved to suit your needs.

Step by step:

import pandas as pd

# assuming you stored your model.fit results in a 'history' variable:
history = model.fit(x_train, y_train, epochs=10)

# convert the history.history dict to a pandas DataFrame:     
hist_df = pd.DataFrame(history.history) 

# save to json:  
hist_json_file = 'history.json' 
with open(hist_json_file, mode='w') as f:
    hist_df.to_json(f)

# or save to csv: 
hist_csv_file = 'history.csv'
with open(hist_csv_file, mode='w') as f:
    hist_df.to_csv(f)



回答5:


I came across the problem that the values inside of the list in keras are not json seriazable. Therefore I wrote this two handy functions for my use cause.

import json,codecs
import numpy as np
def saveHist(path,history):

    new_hist = {}
    for key in list(history.history.keys()):
        if type(history.history[key]) == np.ndarray:
            new_hist[key] = history.history[key].tolist()
        elif type(history.history[key]) == list:
           if  type(history.history[key][0]) == np.float64:
               new_hist[key] = list(map(float, history.history[key]))

    print(new_hist)
    with codecs.open(path, 'w', encoding='utf-8') as file:
        json.dump(new_hist, file, separators=(',', ':'), sort_keys=True, indent=4) 

def loadHist(path):
    with codecs.open(path, 'r', encoding='utf-8') as file:
        n = json.loads(file.read())
    return n

where saveHist just needs to get the path to where the json file should be saved, and the history object returned from the keras fit or fit_generator method.




回答6:


I'm sure there are many ways to do this, but I fiddled around and came up with a version of my own.

First, a custom callback enables grabbing and updating the history at the end of every epoch. In there I also have a callback to save the model. Both of these are handy because if you crash, or shutdown, you can pick up training at the last completed epoch.

class LossHistory(Callback):

    # https://stackoverflow.com/a/53653154/852795
    def on_epoch_end(self, epoch, logs = None):
        new_history = {}
        for k, v in logs.items(): # compile new history from logs
            new_history[k] = [v] # convert values into lists
        current_history = loadHist(history_filename) # load history from current training
        current_history = appendHist(current_history, new_history) # append the logs
        saveHist(history_filename, current_history) # save history from current training

model_checkpoint = ModelCheckpoint(model_filename, verbose = 0, period = 1)
history_checkpoint = LossHistory()
callbacks_list = [model_checkpoint, history_checkpoint]

Second, here are some 'helper' functions to do exactly the things that they say they do. These are all called from the LossHistory() callback.

# https://stackoverflow.com/a/54092401/852795
import json, codecs

def saveHist(path, history):
    with codecs.open(path, 'w', encoding='utf-8') as f:
        json.dump(history, f, separators=(',', ':'), sort_keys=True, indent=4) 

def loadHist(path):
    n = {} # set history to empty
    if os.path.exists(path): # reload history if it exists
        with codecs.open(path, 'r', encoding='utf-8') as f:
            n = json.loads(f.read())
    return n

def appendHist(h1, h2):
    if h1 == {}:
        return h2
    else:
        dest = {}
        for key, value in h1.items():
            dest[key] = value + h2[key]
        return dest

After that, all you need is to set history_filename to something like data/model-history.json, as well as set model_filesname to something like data/model.h5. One final tweak to make sure not to mess up your history at the end of training, assuming you stop and start, as well as stick in the callbacks, is to do this:

new_history = model.fit(X_train, y_train, 
                     batch_size = batch_size, 
                     nb_epoch = nb_epoch,
                     validation_data=(X_test, y_test),
                     callbacks=callbacks_list)

history = appendHist(history, new_history.history)

Whenever you want, history = loadHist(history_filename) gets your history back.

The funkiness comes from the json and the lists but I wasn't able to get it to work without converting it by iterating. Anyway, I know that this works because I've been cranking on it for days now. The pickled.dump answer at https://stackoverflow.com/a/44674337/852795 might be better, but I don't know what that is. If I missed anything here or you can't get it to work, let me know.



来源:https://stackoverflow.com/questions/41061457/keras-how-to-save-the-training-history-attribute-of-the-history-object

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!