How do you send arguments to a generator function using tf.data.Dataset.from_generator()?

后端 未结 2 1821
忘掉有多难
忘掉有多难 2021-02-20 04:40

I would like to create a number of tf.data.Dataset using the from_generator() function. I would like to send an argument to the generator function (

相关标签:
2条回答
  • 2021-02-20 05:26

    You need to define a new function based on raw_data_gen that doesn't take any arguments. You can use the lambda keyword to do this.

    training_dataset = tf.data.Dataset.from_generator(lambda: raw_data_gen(train_val_or_test=1), (tf.float32, tf.uint8), ([None, 1], [None]))
    ...
    

    Now, we are passing a function to from_generator that doesn't take any arguments, but that will simply act as raw_data_gen with the argument set to 1. You can use the same scheme for the validation and test sets, passing 2 and 3 respectively.

    0 讨论(0)
  • 2021-02-20 05:26

    For Tensorflow 2.4:

    training_dataset = tf.data.Dataset.from_generator(
         raw_data_gen, 
         args=(1), 
         output_types=(tf.float32, tf.uint8), 
         output_shapes=([None, 1], [None]))
    
    0 讨论(0)
提交回复
热议问题