How to create ensemble in tensorflow?

前端 未结 2 1020
一整个雨季
一整个雨季 2020-12-14 02:58

I am trying to create an ensemble of many trained models. All models have the same graph and just differ by its weights. I am creating the model graph using tf.get_var

2条回答
  •  独厮守ぢ
    2020-12-14 03:50

    This requires a few hacks. Let us save a few simple models

    #! /usr/bin/env python
    # -*- coding: utf-8 -*-
    
    import argparse
    import tensorflow as tf
    
    
    def build_graph(init_val=0.0):
        x = tf.placeholder(tf.float32)
        w = tf.get_variable('w', initializer=init_val)
        y = x + w
        return x, y
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--init', help='dummy string', type=float)
        parser.add_argument('--path', help='dummy string', type=str)
        args = parser.parse_args()
    
        x1, y1 = build_graph(args.init)
    
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            print(sess.run(y1, {x1: 10}))  # outputs: 10 + i
    
            save_path = saver.save(sess, args.path)
            print("Model saved in path: %s" % save_path)
    
    # python ensemble.py --init 1 --path ./models/model1.chpt
    # python ensemble.py --init 2 --path ./models/model2.chpt
    # python ensemble.py --init 3 --path ./models/model3.chpt
    

    These models produce outputs of "10 + i" where i=1, 2, 3. Note this script creates, runs and saves multiple times the same graph-structure. Loading these values and restoring each graph individually is folklore and can be done by

    #! /usr/bin/env python
    # -*- coding: utf-8 -*-
    
    import argparse
    import tensorflow as tf
    
    
    def build_graph(init_val=0.0):
        x = tf.placeholder(tf.float32)
        w = tf.get_variable('w', initializer=init_val)
        y = x + w
        return x, y
    
    
    if __name__ == '__main__':
        parser = argparse.ArgumentParser()
        parser.add_argument('--path', help='dummy string', type=str)
        args = parser.parse_args()
    
        x1, y1 = build_graph(-5.)
    
        saver = tf.train.Saver()
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
    
            saver.restore(sess, args.path)
            print("Model loaded from path: %s" % args.path)
    
            print(sess.run(y1, {x1: 10}))
    
    # python ensemble_load.py --path ./models/model1.chpt  # gives 11
    # python ensemble_load.py --path ./models/model2.chpt  # gives 12
    # python ensemble_load.py --path ./models/model3.chpt  # gives 13
    

    These produce again the outputs 11,12,13 like expected. Now the trick is to create for each model from the ensemble its own scope like

    def build_graph(x, init_val=0.0):
        w = tf.get_variable('w', initializer=init_val)
        y = x + w
        return x, y
    
    
    if __name__ == '__main__':
        models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
        x = tf.placeholder(tf.float32)
        outputs = []
        for k, path in enumerate(models):
            # THE VARIABLE SCOPE IS IMPORTANT
            with tf.variable_scope('model_%03i' % (k + 1)):
                outputs.append(build_graph(x, -100 * np.random.rand())[1])
    

    Hence each model lives under a different variable-scope, ie. we have variables 'model_001/w:0, model_002/w:0, model_003/w:0' although they have a similar (not the same) sub-graph, these variables are indeed different objects. Now, the trick is to manage two sets of variables (those of the graph under the current scope and those from the checkpoint):

    def restore_collection(path, scopename, sess):
        # retrieve all variables under scope
        variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
        # retrieves all variables in checkpoint
        for var_name, _ in tf.contrib.framework.list_variables(path):
            # get the value of the variable
            var_value = tf.contrib.framework.load_variable(path, var_name)
            # construct expected variablename under new scope
            target_var_name = '%s/%s:0' % (scopename, var_name)
            # reference to variable-tensor
            target_variable = variables[target_var_name]
            # assign old value from checkpoint to new variable
            sess.run(target_variable.assign(var_value))
    

    The full solution would be

    #! /usr/bin/env python
    # -*- coding: utf-8 -*-
    
    import numpy as np
    import tensorflow as tf
    
    
    def restore_collection(path, scopename, sess):
        # retrieve all variables under scope
        variables = {v.name: v for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scopename)}
        # retrieves all variables in checkpoint
        for var_name, _ in tf.contrib.framework.list_variables(path):
            # get the value of the variable
            var_value = tf.contrib.framework.load_variable(path, var_name)
            # construct expected variablename under new scope
            target_var_name = '%s/%s:0' % (scopename, var_name)
            # reference to variable-tensor
            target_variable = variables[target_var_name]
            # assign old value from checkpoint to new variable
            sess.run(target_variable.assign(var_value))
    
    
    def build_graph(x, init_val=0.0):
        w = tf.get_variable('w', initializer=init_val)
        y = x + w
        return x, y
    
    
    if __name__ == '__main__':
        models = ['./models/model1.chpt', './models/model2.chpt', './models/model3.chpt']
        x = tf.placeholder(tf.float32)
        outputs = []
        for k, path in enumerate(models):
            with tf.variable_scope('model_%03i' % (k + 1)):
                outputs.append(build_graph(x, -100 * np.random.rand())[1])
    
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
    
            print(sess.run(outputs[0], {x: 10}))  # random output -82.4929
            print(sess.run(outputs[1], {x: 10}))  # random output -63.65792
            print(sess.run(outputs[2], {x: 10}))  # random output -19.888203
    
            print(sess.run(W[0]))  # randomly initialize value -92.4929
            print(sess.run(W[1]))  # randomly initialize value -73.65792
            print(sess.run(W[2]))  # randomly initialize value -29.888203
    
            restore_collection(models[0], 'model_001', sess)  # restore all variables from different checkpoints
            restore_collection(models[1], 'model_002', sess)  # restore all variables from different checkpoints
            restore_collection(models[2], 'model_003', sess)  # restore all variables from different checkpoints
    
            print(sess.run(W[0]))  # old values from different checkpoints: 1.0
            print(sess.run(W[1]))  # old values from different checkpoints: 2.0
            print(sess.run(W[2]))  # old values from different checkpoints: 3.0
    
            print(sess.run(outputs[0], {x: 10}))  # what we expect: 11.0
            print(sess.run(outputs[1], {x: 10}))  # what we expect: 12.0
            print(sess.run(outputs[2], {x: 10}))  # what we expect: 13.0
    
    # python ensemble_load_all.py
    

    Now having a list of outputs, you can average these values within TensorFlow or do some other ensemble predictions.

    edit:

    • It is way easier to store the model as a numpy dictionary using NumPy (npz) and load these values, like in my answer here: https://stackoverflow.com/a/50181741/7443104
    • The code above just illustrates a solution. It does not feature sanity checks (like does the variable really exists). A try-catch might help.

提交回复
热议问题