Update values of a matrix variable in tensorflow, advanced indexing

后端 未结 1 1467
温柔的废话
温柔的废话 2021-01-26 11:45

I would like to create a function that for every line of a given data X, is applying the softmax function only for some sampled classes, lets say 2, out of K total classes. In s

1条回答
  •  没有蜡笔的小新
    2021-01-26 12:27

    An answer to my problem was found in the comment of a solution of this problem. Suggests to reshape to 1d vector my matrix S. In that way, the code is working and it looks like:

    S = tf.Variable(tf.zeros(shape=(N*K)))
    W = tf.Variable(tf.random_uniform((K,D)))
    tfx = tf.placeholder(tf.float32,shape=(None,D))
    sampled_ind = tf.random_uniform(dtype=tf.int32, minval=0, maxval=K-1, shape=[num_samps])
    ar_to_sof = tf.matmul(tfx,tf.gather(W,sampled_ind),transpose_b=True)
    updates = tf.reshape(tf.nn.softmax(ar_to_sof),shape=(num_samps,))
    init = tf.initialize_all_variables()
    sess = tf.Session()
    sess.run(init)
    for line in range(N):
        inds_new = sampled_ind + line*K
        sess.run(tf.scatter_update(S,inds_new,updates), feed_dict={tfx: X[line:line+1]})
    
    S = tf.reshape(S,shape=(N,K))
    

    That returns the result that i was expecting. The problem now is that this implementation is too slow. Much slower than the numpy version. Maybe is the for loop. Any suggestions?

    0 讨论(0)
提交回复
热议问题