An official tutorial on @tf.function
says:
To get peak performance and to make your model deployable anywhere, use tf.function to make
TLDR: It depends on your function and whether you are in production or development. Don't use tf.function
if you want to be able to debug your function easily, or if it falls under the limitations of AutoGraph or tf.v1 code compatibility.
I would highly recommend watching the Inside TensorFlow talks about AutoGraph and Functions, not Sessions.
In the following I'll break down the reasons, which are all taken from information made available online by Google.
In general, the tf.function
decorator causes a function to be compiled as a callable that executes a TensorFlow graph. This entails:
There is detailed information available on the design ideas behind this.
tf.function
tf.function
decorationIf you want to use AutoGraph, using tf.function
is highly recommended over calling AutoGraph directly.
Reasons for this include: Automatic control dependencies, it is required for some APIs, more caching, and exception helpers (Source).
tf.function
tf.function
decorationDetailed information on AutoGraph limitations is available.
tf.function
, but this is subject to change as tf.v1 code is phased out (Source)It is not allowed to create variables more than once, such as v
in the following example:
@tf.function
def f(x):
v = tf.Variable(1)
return tf.add(x, v)
f(tf.constant(2))
# => ValueError: tf.function-decorated function tried to create variables on non-first call.
In the following code, this is mitigated by making sure that self.v
is only created once:
class C(object):
def __init__(self):
self.v = None
@tf.function
def f(self, x):
if self.v is None:
self.v = tf.Variable(1)
return tf.add(x, self.v)
c = C()
print(c.f(tf.constant(2)))
# => tf.Tensor(3, shape=(), dtype=int32)
Changes such as to self.a
in this example can't be hidden, which leads to an error since cross-function analysis is not done (yet) (Source):
class C(object):
def change_state(self):
self.a += 1
@tf.function
def f(self):
self.a = tf.constant(0)
if tf.constant(True):
self.change_state() # Mutation of self.a is hidden
tf.print(self.a)
x = C()
x.f()
# => InaccessibleTensorError: The tensor 'Tensor("add:0", shape=(), dtype=int32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=cond_true_5, id=5477800528); accessed from: FuncGraph(name=f, id=5476093776).
Changes in plain sight are no problem:
class C(object):
@tf.function
def f(self):
self.a = tf.constant(0)
if tf.constant(True):
self.a += 1 # Mutation of self.a is in plain sight
tf.print(self.a)
x = C()
x.f()
# => 1
This if statement leads to an error because the value for else needs to be defined for TF control flow:
@tf.function
def f(a, b):
if tf.greater(a, b):
return tf.constant(1)
# If a <= b would return None
x = f(tf.constant(3), tf.constant(2))
# => ValueError: A value must also be returned from the else branch. If a value is returned from one branch of a conditional a value must be returned from all branches.