how to calculate a Mobilenet FLOPs in Keras

后端 未结 4 1329
借酒劲吻你
借酒劲吻你 2020-12-16 18:14

run_meta = tf.RunMetadata()
enter codwith tf.Session(graph=tf.Graph()) as sess:
K.set_session(sess)


with tf.dev         


        
4条回答
  •  清歌不尽
    2020-12-16 18:33

    This is working for me in TF-2.1:

    def get_flops(model_h5_path):
        session = tf.compat.v1.Session()
        graph = tf.compat.v1.get_default_graph()
    
    
        with graph.as_default():
            with session.as_default():
                model = tf.keras.models.load_model(model_h5_path)
    
                run_meta = tf.compat.v1.RunMetadata()
                opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
    
                # Optional: save printed results to file
                # flops_log_path = os.path.join(tempfile.gettempdir(), 'tf_flops_log.txt')
                # opts['output'] = 'file:outfile={}'.format(flops_log_path)
    
                # We use the Keras session graph in the call to the profiler.
                flops = tf.compat.v1.profiler.profile(graph=graph,
                                                      run_meta=run_meta, cmd='op', options=opts)
    
                return flops.total_float_ops
    

提交回复
热议问题