Passing a numpy array to a tensorflow Queue

前端 未结 1 1647
广开言路
广开言路 2020-12-17 04:21

I have a NumPy array and would like to read it in TensorFlow\'s code using a Queue. I would like the queue to return the whole data shuffled, some specified number of epochs

相关标签:
1条回答
  • 2020-12-17 05:14

    You could create another queue, enqueue your data onto it num_epoch times, close it, and then hook it up to your batch. To save memory, you can make this queue small, and enqueue items onto it in parallel. There will be a bit of mixing between epochs. To fully prevent mixing, you could take code below with num_epochs=1 and call it num_epochs times.

    tf.reset_default_graph()
    data = np.array([1, 2, 3, 4])
    num_epochs = 5
    queue1_input = tf.placeholder(tf.int32)
    queue1 = tf.FIFOQueue(capacity=10, dtypes=[tf.int32], shapes=[()])
    
    def create_session():
        config = tf.ConfigProto()
        config.operation_timeout_in_ms=20000
        return tf.InteractiveSession(config=config)
    
    enqueue_op = queue1.enqueue_many(queue1_input)
    close_op = queue1.close()
    dequeue_op = queue1.dequeue()
    batch = tf.train.shuffle_batch([dequeue_op], batch_size=4, capacity=5, min_after_dequeue=4)
    
    sess = create_session()
    
    def fill_queue():
        for i in range(num_epochs):
            sess.run(enqueue_op, feed_dict={queue1_input: data})
        sess.run(close_op)
    
    fill_thread = threading.Thread(target=fill_queue, args=())
    fill_thread.start()
    
    # read the data from queue shuffled
    tf.train.start_queue_runners()
    try:
        while True:
            print batch.eval()
    except tf.errors.OutOfRangeError:
        print "Done"
    

    BTW, enqueue_many pattern above will hang when the queue is not large enough to load the entire numpy dataset into it. You could give yourself flexibility to have a smaller queue by loading the data in chunks as below.

    tf.reset_default_graph()
    data = np.array([1, 2, 3, 4])
    queue1_capacity = 2
    num_epochs = 2
    queue1_input = tf.placeholder(tf.int32)
    queue1 = tf.FIFOQueue(capacity=queue1_capacity, dtypes=[tf.int32], shapes=[()])
    
    enqueue_op = queue1.enqueue_many(queue1_input)
    close_op = queue1.close()
    dequeue_op = queue1.dequeue()
    
    def dequeue():
        try:
            while True:
                print sess.run(dequeue_op)
        except:
            return 
    
    def enqueue():
        for i in range(num_epochs):
            start_pos = 0
            while start_pos < len(data):
                end_pos = start_pos+queue1_capacity
                data_chunk = data[start_pos: end_pos]
                sess.run(enqueue_op, feed_dict={queue1_input: data_chunk})
                start_pos += queue1_capacity
        sess.run(close_op)
    
    sess = create_session()
    
    enqueue_thread = threading.Thread(target=enqueue, args=())
    enqueue_thread.start()
    
    dequeue_thread = threading.Thread(target=dequeue, args=())
    dequeue_thread.start()
    
    0 讨论(0)
提交回复
热议问题