How do you make TensorFlow + Keras fast with a TFRecord dataset?

后端 未结 2 402
死守一世寂寞
死守一世寂寞 2020-12-23 12:39

What is an example of how to use a TensorFlow TFRecord with a Keras Model and tf.session.run() while keeping the dataset in tensors w/ queue runners?

<
2条回答
  •  轻奢々
    轻奢々 (楼主)
    2020-12-23 13:08

    I don't use tfrecord dataset format so won't argue on the pros and cons, but I got interested in extending Keras to support the same.

    github.com/indraforyou/keras_tfrecord is the repository. Will briefly explain the main changes.

    Dataset creation and loading

    data_to_tfrecord and read_and_decode here takes care of creating tfrecord dataset and loading the same. Special care must be to implement the read_and_decode otherwise you will face cryptic errors during training.

    Initialization and Keras model

    Now both tf.train.shuffle_batch and Keras Input layer returns tensor. But the one returned by tf.train.shuffle_batch don't have metadata needed by Keras internally. As it turns out, any tensor can be easily turned into a tensor with keras metadata by calling Input layer with tensor param.

    So this takes care of initialization:

    x_train_, y_train_ = ktfr.read_and_decode('train.mnist.tfrecord', one_hot=True, n_class=nb_classes, is_train=True)
    
    x_train_batch, y_train_batch = K.tf.train.shuffle_batch([x_train_, y_train_],
                                                    batch_size=batch_size,
                                                    capacity=2000,
                                                    min_after_dequeue=1000,
                                                    num_threads=32) # set the number of threads here
    
    x_train_inp = Input(tensor=x_train_batch)
    

    Now with x_train_inp any keras model can be developed.

    Training (simple)

    Lets say train_out is the output tensor of your keras model. You can easily write a custom training loop on the lines of:

    loss = tf.reduce_mean(categorical_crossentropy(y_train_batch, train_out))
    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
    
    
    # sess.run(tf.global_variables_initializer())
    sess.run(tf.initialize_all_variables())
    
    with sess.as_default():
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    
        try:
          step = 0
          while not coord.should_stop():
            start_time = time.time()
    
            _, loss_value = sess.run([train_op, loss], feed_dict={K.learning_phase(): 0})
    
            duration = time.time() - start_time
    
            if step % 100 == 0:
              print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
                                                         duration))
            step += 1
        except tf.errors.OutOfRangeError:
          print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
        finally:
          coord.request_stop()
    
        coord.join(threads)
        sess.close()
    

    Training (keras style)

    One of the features of keras that makes it so lucrative is its generalized training mechanism with the callback functions.

    But to support tfrecords type training there are several changes that are need in the fit function

    • running the queue threads
    • no feeding in batch data through feed_dict
    • supporting validation becomes tricky as the validation data will also be coming in through another tensor an different model needs to be internally created with shared upper layers and validation tensor fed in by other tfrecord reader.

    But all this can be easily supported by another flag parameter. What makes things messing are the keras features sample_weight and class_weight they are used to weigh each sample and weigh each class. For this in compile() keras creates placeholders (here) and placeholders are also implicitly created for the targets (here) which is not needed in our case the labels are already fed in by tfrecord readers. These placeholders needs to be fed in during session run which is unnecessary in our cae.

    So taking into account these changes, compile_tfrecord(here) and fit_tfrecord(here) are the extension of compile and fit and shares say 95% of the code.

    They can be used in the following way:

    import keras_tfrecord as ktfr
    
    train_model = Model(input=x_train_inp, output=train_out)
    ktfr.compile_tfrecord(train_model, optimizer='rmsprop', loss='categorical_crossentropy', out_tensor_lst=[y_train_batch], metrics=['accuracy'])
    
    train_model.summary()
    
    ktfr.fit_tfrecord(train_model, X_train.shape[0], batch_size, nb_epoch=3)
    train_model.save_weights('saved_wt.h5')
    

    You are welcome to improve on the code and pull requests.

提交回复
热议问题