How Tensorflow handles categorical features with multiple inputs within one column?

半世苍凉 提交于 2019-12-01 01:03:45

OK Looks like writing custom feature column worked for me with the same task.

I took HashedCategoricalColumn as a base, and cleaned up to work with strings only. Should add checks for type though.

class _SparseArrayCategoricalColumn(
    _CategoricalColumn,
    collections.namedtuple('_SparseArrayCategoricalColumn',
                           ['key', 'num_buckets', 'category_delimiter'])):

  @property
  def name(self):
    return self.key

  @property
  def _parse_example_spec(self):
    return {self.key: parsing_ops.VarLenFeature(dtypes.string)}

  def _transform_feature(self, inputs):
    input_tensor = inputs.get(self.key)
    flat_input = array_ops.reshape(input_tensor, (-1,))
    input_tensor = tf.string_split(flat_input, self.category_delimiter)

    if not isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
      raise ValueError('SparseColumn input must be a SparseTensor.')

    sparse_values = input_tensor.values
    # tf.summary.text(self.key, flat_input)
    sparse_id_values = string_ops.string_to_hash_bucket_fast(
        sparse_values, self.num_buckets, name='lookup')


    return sparse_tensor_lib.SparseTensor(
        input_tensor.indices, sparse_id_values, input_tensor.dense_shape)


  @property
  def _variable_shape(self):
    if not hasattr(self, '_shape'):
        self._shape = tensor_shape.vector(self.num_buckets)
    return self._shape

  @property
  def _num_buckets(self):
    """Returns number of buckets in this sparse feature."""
    return self.num_buckets

  def _get_sparse_tensors(self, inputs, weight_collections=None,
                          trainable=None):
    return _CategoricalColumn.IdWeightPair(inputs.get(self), None)


def categorical_column_with_array_input(key,
                                        num_buckets, category_delimiter="|"):
  if (num_buckets is None) or (num_buckets < 1):
    raise ValueError('Invalid num_buckets {}.'.format(num_buckets))

  return _SparseArrayCategoricalColumn(key, num_buckets, category_delimiter)

Then it may be wrapped by embedding/indicator column. Seems it is what you need. It was first step for me. I need to handle column with values like "str:float|str:float...".

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!