How to *actually* read CSV data in TensorFlow?

前端 未结 5 1224
感动是毒
感动是毒 2020-11-28 22:32

I\'m relatively new to the world of TensorFlow, and pretty perplexed by how you\'d actually read CSV data into a usable example/label tensors in Te

5条回答
  •  忘掉有多难
    2020-11-28 23:20

    If anyone came here searching for a simple way to read absolutely large and sharded CSV files in tf.estimator API then , please see below my code

    CSV_COLUMNS = ['ID','text','class']
    LABEL_COLUMN = 'class'
    DEFAULTS = [['x'],['no'],[0]]  #Default values
    
    def read_dataset(filename, mode, batch_size = 512):
        def _input_fn(v_test=False):
    #         def decode_csv(value_column):
    #             columns = tf.decode_csv(value_column, record_defaults = DEFAULTS)
    #             features = dict(zip(CSV_COLUMNS, columns))
    #             label = features.pop(LABEL_COLUMN)
    #             return add_engineered(features), label
    
            # Create list of files that match pattern
            file_list = tf.gfile.Glob(filename)
    
            # Create dataset from file list
            #dataset = tf.data.TextLineDataset(file_list).map(decode_csv)
            dataset = tf.contrib.data.make_csv_dataset(file_list,
                                                       batch_size=batch_size,
                                                       column_names=CSV_COLUMNS,
                                                       column_defaults=DEFAULTS,
                                                       label_name=LABEL_COLUMN)
    
            if mode == tf.estimator.ModeKeys.TRAIN:
                num_epochs = None # indefinitely
                dataset = dataset.shuffle(buffer_size = 10 * batch_size)
            else:
                num_epochs = 1 # end-of-input after this
    
            batch_features, batch_labels = dataset.make_one_shot_iterator().get_next()
    
            #Begins - Uncomment for testing only -----------------------------------------------------<
            if v_test == True:
                with tf.Session() as sess:
                    print(sess.run(batch_features))
            #End - Uncomment for testing only -----------------------------------------------------<
            return add_engineered(batch_features), batch_labels
        return _input_fn
    

    Example usage in TF.estimator:

    train_spec = tf.estimator.TrainSpec(input_fn = read_dataset(
                                                    filename = train_file,
                                                    mode = tf.estimator.ModeKeys.TRAIN,
                                                    batch_size = 128), 
                                          max_steps = num_train_steps)
    

提交回复
热议问题