the asterisk in tf.gather_nd in python2.7 rise syntax error

半城伤御伤魂 提交于 2020-11-29 08:46:19

问题


I am using Python2.7, and I can't update it, and I have this line of code, which raise an error at the asterisk, and I don't know why? And how to fix!

inp = tf.random.uniform(shape=[4, 6, 2], maxval=20, dtype=tf.int32)

out = tf.math.reduce_max(inp, axis=2)
am = tf.math.argmax(out, axis=1)
o = tf.gather_nd(inp, [*enumerate(am)])

This code is about getting a 2D max Tensor from a 3D Tensor based on the maximum one value using TensorFlow 1.14. Like the image below illustrate:


回答1:


The syntax error in your question has been explained by BoarGules. With respect to the problem that you are trying to solve, you can get the result you want with something like this:

import tensorflow as tf

with tf.Graph().as_default(), tf.Session() as sess:
    # In TF 2.x: tf.random.set_seed
    tf.random.set_random_seed(0)
    # Input data
    inp = tf.random.uniform(shape=[4, 6, 2], maxval=100, dtype=tf.int32)

    # Find index of greatest value in last two dimensions
    s = tf.shape(inp)
    inp_res = tf.reshape(inp, [s[0], -1])
    max_idx = tf.math.argmax(inp_res, axis=1, output_type=s.dtype)
    # Get row index dividing by number of columns
    max_row_idx = max_idx // s[2]
    # Get rows with max values
    res = tf.gather_nd(inp, tf.expand_dims(max_row_idx, axis=1), batch_dims=1)
    # Print input and result
    print(*sess.run((inp, res)), sep='\n')

Output:

[[[22 78]
  [75 70]
  [31 10]
  [67  9]
  [70 45]
  [ 5 33]]

 [[82 83]
  [82 81]
  [73 58]
  [18 18]
  [57 11]
  [50 71]]

 [[84 55]
  [80 72]
  [93  1]
  [98 27]
  [36  6]
  [10 95]]

 [[83 24]
  [19  9]
  [46 48]
  [90 87]
  [50 26]
  [55 62]]]
[[22 78]
 [82 83]
 [98 27]
 [90 87]]



回答2:


That asterisk syntax is not available in Python 2. It was added in Python 3.5 (PEP 448) which was 7 years ago.

The Python 2 equivalent was

o = tf.gather_nd(inp, [(i,j) for (i,j) in enumerate(am)])

But you really should not be using Python 2 or investing time in learning it. You don't have to "update" your existing Python 2 installation, if you need it to run legacy code. You can have Python 3.8 running side-by-side with Python 2 if you want. For work reasons I have 3.8, 3.7, 3.6 and 2.7 side-by-side on my machine without problems.



来源:https://stackoverflow.com/questions/63537430/the-asterisk-in-tf-gather-nd-in-python2-7-rise-syntax-error

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!