Use a generator for Keras model.fit_generator

后端 未结 4 2198
情话喂你
情话喂你 2020-12-01 11:08

I originally tried to use generator syntax when writing a custom generator for training a Keras model. So I yielded from __next__. How

4条回答
  •  余生分开走
    2020-12-01 11:42

    I would like to upgrade Vaasha's code with TensorFlow 2.x to achieve training efficiencies as well as ease of data processing. This is particularly useful for image processing.

    Process the data using Generator function as Vaasha had generated in the above example or using tf.data.dataset API. The latter approach is very useful when processing any datasets with metadata. For example, MNIST data can be loaded and processed with a few statements.

    import tensorflow as tf # Ensure that TensorFlow 2.x is used
    tf.compat.v1.enable_eager_execution()
    import tensorflow_datasets as tfds # Needed if you are using any of the tf datasets such as MNIST, CIFAR10
    mnist_train = tfds.load(name="mnist", split="train")
    

    Use tfds.load the datasets. Once data is loaded and processed (for example, converting categorical variables, resizing, etc.).

    Now upgrading keras model using TensorFlow 2.x

     model = tf.keras.Sequential() # Tensorflow 2.0 upgrade
     model.add(tf.keras.layers.Dense(12, activation='relu', input_dim=dataFrame.shape[1]))
     model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
    
    
     model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])
    
     #Train the model using generator vs using the full batch
     batch_size = 8
    
     model.fit_generator(generator(dataFrameTrain,expectedFrameTrain,batch_size), epochs=3,steps_per_epoch = dataFrame.shape[0]/batch_size, validation_data=generator(dataFrameTest,expectedFrameTest,batch_size*2),validation_steps=dataFrame.shape[0]/batch_size*2)
    

    This will upgrade the model to run in TensorFlow 2.x

提交回复
热议问题