Keras model.fit() with tf.dataset API + validation_data

前端 未结 2 1772
北荒
北荒 2020-12-14 03:24

So I have got my keras model to work with a tf.Dataset through the following code:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_feat         


        
相关标签:
2条回答
  • 2020-12-14 03:50

    The way to connect a reinitializable iterator to a Keras model is to plug in an Iterator that returns both the x and y values concurrently:

    sess = tf.Session()
    keras.backend.set_session(sess) 
    
    x = np.random.random((5, 2))
    y = np.array([0, 1] * 3 + [1, 0] * 2).reshape(5, 2) # One hot encoded
    input_dataset = tf.data.Dataset.from_tensor_slices((x, y))
    
    # Create your reinitializable_iterator and initializer
    reinitializable_iterator = tf.data.Iterator.from_structure(input_dataset.output_types, input_dataset.output_shapes)
    init_op = reinitializable_iterator.make_initializer(input_dataset)
    
    #run the initializer
    sess.run(init_op) # feed_dict if you're using placeholders as input
    
    # build keras model and plug in the iterator
    model = keras.Model.model(...)
    model.compile(...)
    model.fit(reinitializable_iterator,...)
    

    If you also have a validation dataset, the easiest thing to do is to just create a separate iterator and plug it in the validation_data parameter. Make sure to define your steps_per_epoch and validation_steps since they cannot be inferred.

    0 讨论(0)
  • 2020-12-14 03:59

    I solved the problem by using fit_genertor. I found the solution here. I applied @Dat-Nguyen's solution.

    You need simply to create two iterators, one for training and one for validation and then create your own generator where you will extract batches from the dataset and provide the data in form of (batch_data, batch_labels) . Finally in model.fit_generator you will pass the train_generator and validation_generator.

    0 讨论(0)
提交回复
热议问题