Conditional assignment of tensor values in TensorFlow

前端 未结 2 1338
栀梦
栀梦 2020-11-29 07:57

I want to replicate the following numpy code in tensorflow. For example, I want to assign a 0 to all tensor indices that previously ha

2条回答
  •  天涯浪人
    2020-11-29 08:05

    I'm also just starting to use tensorflow Maybe some one will fill my approach more intuitive

    import tensorflow as tf
    
    conditionVal = 1
    init_a = tf.constant([1, 2, 3, 1], dtype=tf.int32, name='init_a')
    a = tf.Variable(init_a, dtype=tf.int32, name='a')
    target = tf.fill(a.get_shape(), conditionVal, name='target')
    
    init = tf.initialize_all_variables()
    condition = tf.not_equal(a, target)
    defaultValues = tf.zeros(a.get_shape(), dtype=a.dtype)
    calculate = tf.select(condition, a, defaultValues)
    
    with tf.Session() as session:
        session.run(init)
        session.run(calculate)
        print(calculate.eval())
    

    main trouble is that it is difficult to implement "custom logic". if you could not explain your logic within linear math terms you need to write "custom op" library for tensorflow (more details here)

提交回复
热议问题