Is sparse tensor multiplication implemented in TensorFlow?

前端 未结 5 1006
夕颜
夕颜 2020-12-13 14:11

Multiplication of sparse tensors with themselves or with dense tensors does not seem to work in TensorFlow. The following example



        
5条回答
  •  孤街浪徒
    2020-12-13 14:43

    General-purpose multiplication for tf.SparseTensor is not currently implemented in TensorFlow. However, there are three partial solutions, and the right one to choose will depend on the characteristics of your data:

    • If you have a tf.SparseTensor and a tf.Tensor, you can use tf.sparse_tensor_dense_matmul() to multiply them. This is more efficient than the next approach if one of the tensors is too large to fit in memory when densified: the documentation has more guidance about how to decide between these two methods. Note that it accepts a tf.SparseTensor as the first argument, so to solve your exact problem you will need to use the adjoint_a and adjoint_b arguments, and transpose the result.

    • If you have two sparse tensors and need to multiply them, the simplest (if not the most performant) way is to convert them to dense and use tf.matmul:

      a = tf.SparseTensor(...)
      b = tf.SparseTensor(...)
      
      c = tf.matmul(tf.sparse_tensor_to_dense(a, 0.0),
                    tf.sparse_tensor_to_dense(b, 0.0),
                    a_is_sparse=True, b_is_sparse=True)
      

      Note that the optional a_is_sparse and b_is_sparse arguments mean that "a (or b) has a dense representation but a large number of its entries are zero", which triggers the use of a different multiplication algorithm.

    • For the special case of sparse vector by (potentially large and sharded) dense matrix multiplication, and the values in the vector are 0 or 1, the tf.nn.embedding_lookup operator may be more appropriate. This tutorial discusses when you might use embeddings and how to invoke the operator in more detail.

    • For the special case of sparse matrix by (potentially large and sharded) dense matrix, tf.nn.embedding_lookup_sparse() may be appropriate. This function accepts one or two tf.SparseTensor objects, with sp_ids representing the non-zero values, and the optional sp_weights representing their values (which otherwise default to one).

提交回复
热议问题