Keras model.fit() with tf.dataset API + validation_data

前端 未结 2 1773
北荒
北荒 2020-12-14 03:24

So I have got my keras model to work with a tf.Dataset through the following code:

# Initialize batch generators(returns tf.Dataset)
batch_train = build_feat         


        
2条回答
  •  粉色の甜心
    2020-12-14 03:50

    The way to connect a reinitializable iterator to a Keras model is to plug in an Iterator that returns both the x and y values concurrently:

    sess = tf.Session()
    keras.backend.set_session(sess) 
    
    x = np.random.random((5, 2))
    y = np.array([0, 1] * 3 + [1, 0] * 2).reshape(5, 2) # One hot encoded
    input_dataset = tf.data.Dataset.from_tensor_slices((x, y))
    
    # Create your reinitializable_iterator and initializer
    reinitializable_iterator = tf.data.Iterator.from_structure(input_dataset.output_types, input_dataset.output_shapes)
    init_op = reinitializable_iterator.make_initializer(input_dataset)
    
    #run the initializer
    sess.run(init_op) # feed_dict if you're using placeholders as input
    
    # build keras model and plug in the iterator
    model = keras.Model.model(...)
    model.compile(...)
    model.fit(reinitializable_iterator,...)
    

    If you also have a validation dataset, the easiest thing to do is to just create a separate iterator and plug it in the validation_data parameter. Make sure to define your steps_per_epoch and validation_steps since they cannot be inferred.

提交回复
热议问题