Tensorflow 2.0 model using tf.function very slow and is recompiling every time the train count changes. Eager runs about 4x faster

僤鯓⒐⒋嵵緔 提交于 2019-12-03 07:37:47

问题


I have models built from uncompiled keras code and am trying to run them through a custom training loop.

The TF 2.0 eager (by default) code runs about 30s on a CPU (laptop). When I create a keras model with wrapped tf.function call methods, it is running much, much slower and appears to take a very long time to start, particularly the "first" time.

For example, in the tf.function code the initial train on 10 samples takes 40s, and the follow up one on 10 samples takes 2s.

On 20 samples, the initial takes 50s and the follow up takes 4s.

The first train on 1 sample takes 2s and follow up takes 200 ms.

So it looks like each call of train is creating a new graph where the complexity scales with the train count!?

I am just doing something like this:

@tf.function
def train(n=10):
    step = 0
    loss = 0.0
    accuracy = 0.0
    for i in range(n):
        step += 1
        d, dd, l = train_one_step(model, opt, data)
        tf.print(dd)
        with tf.name_scope('train'):
            for k in dd:
                tf.summary.scalar(k, dd[k], step=step)
        if tf.equal(step % 10, 0):
            tf.print(dd)
    d.update(dd)
    return d

Where the model is keras.model.Model with a @tf.function decorate call method as per the examples.


回答1:


I analyzed this behavior of @tf.function here Using a Python native type.

In short: the design of tf.function does not automatically do the boxing of Python native types to tf.Tensor objects with a well-defined dtype.

If your function accepts a tf.Tensor object, on the first call the function is analyzed, the graph is built and associated with that function. In every non-first call, if the dtype of the tf.Tensor object matches, the graph is reused.

But in case of using a Python native type, the graphg is being built every time the function is invoked with a different value.

In short: design your code to use tf.Tensor everywhere instead of the Python variables if you plan to use @tf.function.

tf.function is not a wrapper that magically accelerates a function that works well in eager mode; is a wrapper that requires to design the eager function (body, input parameters, dytpes) understanding what will happen once the graph is created, in order to get real speed ups.



来源:https://stackoverflow.com/questions/55711115/tensorflow-2-0-model-using-tf-function-very-slow-and-is-recompiling-every-time-t

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!