TensorFlow: Is there a way to measure FLOPS for a model?

后端 未结 3 1336
温柔的废话
温柔的废话 2020-12-07 21:05

The closest example I can get is found in this issue: https://github.com/tensorflow/tensorflow/issues/899

With this minimum reproducible code:

import         


        
3条回答
  •  慢半拍i
    慢半拍i (楼主)
    2020-12-07 22:07

    A little bit late but maybe it helps some visitors in future. For your example I successfully tested the following snippet:

    g = tf.Graph()
    run_meta = tf.RunMetadata()
    with g.as_default():
        A = tf.Variable(tf.random_normal( [25,16] ))
        B = tf.Variable(tf.random_normal( [16,9] ))
        C = tf.matmul(A,B) # shape=[25,9]
    
        opts = tf.profiler.ProfileOptionBuilder.float_operation()    
        flops = tf.profiler.profile(g, run_meta=run_meta, cmd='op', options=opts)
        if flops is not None:
            print('Flops should be ~',2*25*16*9)
            print('25 x 25 x 9 would be',2*25*25*9) # ignores internal dim, repeats first
            print('TF stats gives',flops.total_float_ops)
    

    It's also possible to use the profiler in combination with Keras like the following snippet:

    import tensorflow as tf
    import keras.backend as K
    from keras.applications.mobilenet import MobileNet
    
    run_meta = tf.RunMetadata()
    with tf.Session(graph=tf.Graph()) as sess:
        K.set_session(sess)
        net = MobileNet(alpha=.75, input_tensor=tf.placeholder('float32', shape=(1,32,32,3)))
    
        opts = tf.profiler.ProfileOptionBuilder.float_operation()    
        flops = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts)
    
        opts = tf.profiler.ProfileOptionBuilder.trainable_variables_parameter()    
        params = tf.profiler.profile(sess.graph, run_meta=run_meta, cmd='op', options=opts)
    
        print("{:,} --- {:,}".format(flops.total_float_ops, params.total_parameters))
    

    I hope I could help!

提交回复
热议问题