Getting free text features into Tensorflow Canned Estimators with Dataset API via feature_columns

元气小坏坏 提交于 2019-12-04 11:51:56

You could use tf.string_split(), then do tf.slice() to slice it, taking care to tf.pad() the strings with zeros first. Look at the title preprocessing operations in: https://towardsdatascience.com/how-to-do-text-classification-using-tensorflow-word-embeddings-and-cnn-edae13b3e575

Once you have the words, then you can create ten feature columns

Adding answer as per approach from the post @Lak did, but adapted a little for dataset api.

# Create an input function reading a file using the Dataset API
# Then provide the results to the Estimator API
def read_dataset(prefix, mode, batch_size):

    def _input_fn():

        def decode_csv(value_column):

            columns = tf.decode_csv(value_column, field_delim='|', record_defaults=DEFAULTS)
            features = dict(zip(CSV_COLUMNS, columns))

            features['comment_words'] = tf.string_split([features['comment']])
            features['comment_words'] = tf.sparse_tensor_to_dense(features['comment_words'], default_value=PADWORD)
            features['comment_padding'] = tf.constant([[0,0],[0,MAX_DOCUMENT_LENGTH]])
            features['comment_padded'] = tf.pad(features['comment_words'], features['comment_padding'])
            features['comment_sliced'] = tf.slice(features['comment_padded'], [0,0], [-1, MAX_DOCUMENT_LENGTH])
            features['comment_words'] = tf.pad(features['comment_sliced'], features['comment_padding'])
            features['comment_words'] = tf.slice(features['comment_words'],[0,0],[-1,MAX_DOCUMENT_LENGTH])

            features.pop('comment_padding')
            features.pop('comment_padded')
            features.pop('comment_sliced')

            label = features.pop(LABEL_COLUMN)

            return features, label

        # Use prefix to create file path
        file_path = '{}/{}*{}*'.format(INPUT_DIR, prefix, PATTERN)

        # Create list of files that match pattern
        file_list = tf.gfile.Glob(file_path)

        # Create dataset from file list
        dataset = (tf.data.TextLineDataset(file_list)  # Read text file
                    .map(decode_csv))  # Transform each elem by applying decode_csv fn

        tf.logging.info("...dataset.output_types={}".format(dataset.output_types))
        tf.logging.info("...dataset.output_shapes={}".format(dataset.output_shapes))

        if mode == tf.estimator.ModeKeys.TRAIN:

            num_epochs = None # indefinitely
            dataset = dataset.shuffle(buffer_size = 10 * batch_size)

        else:

            num_epochs = 1 # end-of-input after this

        dataset = dataset.repeat(num_epochs).batch(batch_size)

        return dataset.make_one_shot_iterator().get_next()

    return _input_fn

Then in below function we can reference the field we made as part of decode_csv() :

# Define feature columns
def get_wide_deep():

    EMBEDDING_SIZE = 10

    # Define column types
    subreddit = tf.feature_column.categorical_column_with_vocabulary_list('subreddit', ['news', 'ireland', 'pics'])

    comment_embeds = tf.feature_column.embedding_column(
        categorical_column = tf.feature_column.categorical_column_with_vocabulary_file(
            key='comment_words',
            vocabulary_file='{}/vocab.csv-00000-of-00001'.format(INPUT_DIR),
            vocabulary_size=100
            ),
        dimension = EMBEDDING_SIZE
        )

    # Sparse columns are wide, have a linear relationship with the output
    wide = [ subreddit ]

    # Continuous columns are deep, have a complex relationship with the output
    deep = [ comment_embeds ]

    return wide, deep
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!