Replacing Queue-based input pipelines with tf.data

旧时模样 提交于 2019-12-06 05:26:12

I ended up finding my answer through someone else's code, which was inquiring about the poor performance of TextLineDataset and decode_csv.

Here's my code that uses tf.data to do something similar to the code on Ganegedara‘s book:

import tensorflow as tf
import numpy as np
import os

graph = tf.Graph()
session = tf.InteractiveSession(graph=graph)
filenames = ['test%d.txt'%i for i in range(1,4)]

record_defaults = [[-1.0]] * 10

features = tf.data.TextLineDataset(filenames=filenames)

def parse_csv(line):
        cols_types = [[-1.0]] * 10  # all required
        columns = tf.decode_csv(line, record_defaults=cols_types)
        return tf.stack(columns)

features = features.map(parse_csv).batch(batch_size=3).shuffle(buffer_size=5)

x = features.make_one_shot_iterator().get_next()
x = tf.convert_to_tensor(x)
W = tf.Variable(tf.random_uniform(shape=[10,5], minval=-0.1,maxval=0.1, dtype=tf.float32),name='W') 
b = tf.Variable(tf.zeros(shape=[5],dtype=tf.float32),name='b')
h = tf.nn.sigmoid(tf.matmul(x,W) + b) # Operation to be performed

tf.global_variables_initializer().run() # Initialize the variables

# Calculate h with x and print the results for 5 steps
for step in range(5):
    x_eval, h_eval = session.run([x,h]) 
    print('========== Step %d =========='%step)
    print('Evaluated data (x)')
    print(x_eval)
    print('Evaluated data (h)')
    print(h_eval)
    print('')
session.close()
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!