How to slice a batch and apply an operation on each slice in TensorFlow

|▌冷眼眸甩不掉的悲伤 提交于 2019-12-05 19:58:40

The TensorFlow function tf.map_fn(fn, elems) allows you to apply a function (fn) to each slice of a tensor (elems). For example, you could express your program as follows:

def model(x):
    W_1 = tf.Variable(tf.random_normal([6, 1]), name="W_1")

    def fn(x_slice):
        return tf.reduce_sum(x_slice, W_1)

    return tf.map_fn(fn, x)

It may also be possible to implement your operation more concisely using broadcasting on the tf.mul() operator, which uses NumPy broadcasting semantics, and the axis argument to tf.reduce_sum().

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!