keras: how to save the training history attribute of the history object

前端 未结 8 2027
小鲜肉
小鲜肉 2020-12-12 23:43

In Keras, we can return the output of model.fit to a history as follows:

 history = model.fit(X_train, y_train, 
                     batch_size         


        
相关标签:
8条回答
  • 2020-12-13 00:12

    I came across the problem that the values inside of the list in keras are not json seriazable. Therefore I wrote this two handy functions for my use cause.

    import json,codecs
    import numpy as np
    def saveHist(path,history):
        
        new_hist = {}
        for key in list(history.history.keys()):
            new_hist[key]=history.history[key]
            if type(history.history[key]) == np.ndarray:
                new_hist[key] = history.history[key].tolist()
            elif type(history.history[key]) == list:
               if  type(history.history[key][0]) == np.float64:
                   new_hist[key] = list(map(float, history.history[key]))
                
        print(new_hist)
        with codecs.open(path, 'w', encoding='utf-8') as file:
            json.dump(new_hist, file, separators=(',', ':'), sort_keys=True, indent=4) 
    
    def loadHist(path):
        with codecs.open(path, 'r', encoding='utf-8') as file:
            n = json.loads(file.read())
        return n
    

    where saveHist just needs to get the path to where the json file should be saved, and the history object returned from the keras fit or fit_generator method.

    0 讨论(0)
  • 2020-12-13 00:12

    I'm sure there are many ways to do this, but I fiddled around and came up with a version of my own.

    First, a custom callback enables grabbing and updating the history at the end of every epoch. In there I also have a callback to save the model. Both of these are handy because if you crash, or shutdown, you can pick up training at the last completed epoch.

    class LossHistory(Callback):
    
        # https://stackoverflow.com/a/53653154/852795
        def on_epoch_end(self, epoch, logs = None):
            new_history = {}
            for k, v in logs.items(): # compile new history from logs
                new_history[k] = [v] # convert values into lists
            current_history = loadHist(history_filename) # load history from current training
            current_history = appendHist(current_history, new_history) # append the logs
            saveHist(history_filename, current_history) # save history from current training
    
    model_checkpoint = ModelCheckpoint(model_filename, verbose = 0, period = 1)
    history_checkpoint = LossHistory()
    callbacks_list = [model_checkpoint, history_checkpoint]
    

    Second, here are some 'helper' functions to do exactly the things that they say they do. These are all called from the LossHistory() callback.

    # https://stackoverflow.com/a/54092401/852795
    import json, codecs
    
    def saveHist(path, history):
        with codecs.open(path, 'w', encoding='utf-8') as f:
            json.dump(history, f, separators=(',', ':'), sort_keys=True, indent=4) 
    
    def loadHist(path):
        n = {} # set history to empty
        if os.path.exists(path): # reload history if it exists
            with codecs.open(path, 'r', encoding='utf-8') as f:
                n = json.loads(f.read())
        return n
    
    def appendHist(h1, h2):
        if h1 == {}:
            return h2
        else:
            dest = {}
            for key, value in h1.items():
                dest[key] = value + h2[key]
            return dest
    

    After that, all you need is to set history_filename to something like data/model-history.json, as well as set model_filesname to something like data/model.h5. One final tweak to make sure not to mess up your history at the end of training, assuming you stop and start, as well as stick in the callbacks, is to do this:

    new_history = model.fit(X_train, y_train, 
                         batch_size = batch_size, 
                         nb_epoch = nb_epoch,
                         validation_data=(X_test, y_test),
                         callbacks=callbacks_list)
    
    history = appendHist(history, new_history.history)
    

    Whenever you want, history = loadHist(history_filename) gets your history back.

    The funkiness comes from the json and the lists but I wasn't able to get it to work without converting it by iterating. Anyway, I know that this works because I've been cranking on it for days now. The pickled.dump answer at https://stackoverflow.com/a/44674337/852795 might be better, but I don't know what that is. If I missed anything here or you can't get it to work, let me know.

    0 讨论(0)
提交回复
热议问题