TensorFlow 2 How to use *args in tf.function?

笑着哭i 提交于 2021-02-11 13:40:56

问题


Update:

Did a bit more testing and I can't reproduce the behaviour with:

import tensorflow as tf
import numpy as np

@tf.function
def tf_being_unpythonic(an_input, another_input):
    return an_input + another_input

@tf.function
def example(*inputs, other_args = True):
    return tf_being_unpythonic(*inputs)

class TestClass(tf.keras.Model):
    def __init__(self, a, b):
        super().__init__()
        self.a= a
        self.b = b

    @tf.function
    def call(self, *inps, some_kwarg=False):
        if some_kwarg:
            return self.a(*inps)
        return self.b(*inps)

class Model(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.inps = tf.keras.layers.Flatten()
        self.hl1 = tf.keras.layers.Dense(5)
        self.hl2 = tf.keras.layers.Dense(4)
        self.out = tf.keras.layers.Dense(1)

    @tf.function
    def call(self,observation):
        x = self.inps(observation)
        x = self.hl1(x)
        x = self.hl2(x)
        return self.out(x)


class Model2(Model):
    def __init__(self):
        super().__init__()
        self.prein = tf.keras.layers.Concatenate()

    @tf.function
    def call(self,b,c):
        x = self.prein([b,c])
        return super().call(x)   

am = Model()
pm = Model2()
test = TestClass(am,pm)

a = np.random.normal(size=(1,2,3))
b = np.random.normal(size=(1,2,4))

test(a,some_kwarg=True)
test(a,b) 

So it's probably a bug somewhere else.

@tf.function
def call(self, *inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    if target:
        return self._target_network(*inp, training)
    return self._network(*inp, training)

I get:

ValueError: Input 0 of layer flatten is incompatible with the layer: : expected min_ndim=1, found ndim=0. Full shape received: []

But print(inp) gives:

(<tf.Tensor 'inp_0:0' shape=(1, 3) dtype=float32>,) 

I've since edited and was just uncommited toy code so can't investigate further. Will leave the question here so that everyone who doesn't get this issue won't have something to read.


回答1:


I don't think that using a *args construct is a good practice for a tf.function. As you can see, most of the TF functions accepting a variable number of inputs use a tuple.

So, you can rewrite your function signature as:

def call(self, inputs, target=False, training=False)

and calling it with:

instance.call((i1, i2, i3), [...])
# instead of instance.call(i1, i2, i3, [...])

Edit

By the way, I don't see any error while using tf.function with a *args construct:

import tensorflow as tf

@tf.function
def call(*inp, target=False, training=False):
    if not len(inp):
        raise ValueError("Call requires some input")
    return inp[0]

def main():
    print(call(1))
    print(call(2, 2))
    print(call(3, 3, 3))


if __name__ == '__main__':
    main()
tf.Tensor(1, shape=(), dtype=int32)
tf.Tensor(2, shape=(), dtype=int32)
tf.Tensor(3, shape=(), dtype=int32)

So you should provide us more informations about what you try to do and where the error is.




回答2:


This may have been a bug that was resolved recently. *args and **kwargs should work fine.



来源:https://stackoverflow.com/questions/59167107/tensorflow-2-how-to-use-args-in-tf-function

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!