tf.contrib.data.Dataset seems does not support SparseTensor

生来就可爱ヽ(ⅴ<●) 提交于 2019-11-30 15:13:07

EDIT (2018/01/25): tf.SparseTensor support was added to tf.data in TensorFlow 1.5. The code in the question should work in TensorFlow 1.5 or later.


Up to TF 1.4, the tf.contrib.data API did not support tf.SparseTensor objects in dataset elements. There are a couple of workarounds:

  1. (Harder, but probably faster.) If a tf.SparseTensor object st represents a variable-length list of features, you may be able to return st.values instead of st from the map() function. Note that you would probably then need to pad the results using Dataset.padded_batch() instead of Dataset.batch().

  2. (Easier, but probably slower.) In the _parse_function_train() function, iterate over tensor_dict and produce a new version where any tf.SparseTensor objects have been converted to a tf.Tensor using tf.serialize_sparse(). When you

    # NOTE: You could probably infer these from `keys`.
    sparse_keys = set()
    
    def _parse_function_train(example):
      serialized_example = tf.reshape(example, shape=[])
      tensors = decoder.decode(serialized_example, items=keys)
      tensor_dict = dict(zip(keys, tensors))
      tensor_dict['image'].set_shape([None, None, 3])
    
      rewritten_tensor_dict = {}
      for key, value in tensor_dict.items():
        if isinstance(value, tf.SparseTensor):
          rewritten_tensor_dict[key] = tf.serialize_sparse(value)
          sparse_keys.add(key)
        else:
          rewritten_tensor_dict[key] = value
      return rewritten_tensor_dict
    

    Then, after you get the next_element dictionary from iterator.get_next(), you can reverse this conversion using tf.deserialize_many_sparse():

    next_element = iterator.get_next()
    
    for key in sparse_keys:
      next_element[key] = tf.deserialize_many_sparse(key)
    

In addition to tf.SparseTensor, there is tf.scatter_nd that seems to achieve the same result except you may have to recover the indices later.

These two code blocks achieves same result

indices = tf.concat([xy_indices, z_indices], axis=-1)
values  = tf.concat([bboxes, objectness, one_hot], axis=1)

SparseTensor = tf.SparseTensor(indices= indices,
                               values = values, 
                               dense_shape =[output_size, output_size,
                                            len(anchors), 4 + 1 + num_classes])

DenseTensor = tf.scatter_nd(indices = indices,
                            updates = values
                            shape   = [output_size, output_size,
                                       len(anchors),4 + 1 + num_classes])
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!