mask-rcnn的解读(三):batch_slice()

折月煮酒 提交于 2019-12-06 05:53:40
我已用随机生产函数取模拟5张图片各有8个box的坐标值,而后验证batch_slice()函数的意义。由于inputs_slice = [x[i] for x in inputs]    output_slice = graph_fn(*inputs_slice)代码一时蒙蔽,故而对其深入理解,如下:代码如下:
import tensorflow as tfimport randomimport numpy as npsess=tf.Session()input=np.array([random.randint(0,150) for i in range(5*8*4)]).reshape((5,8,4))# print('show input=',input)ax=np.array([random.randint(0,7) for i in range(5*6)]).reshape((5,6))inputs=[input,ax]print('true_inputs=',inputs)def batch_slice(inputs, graph_fn, batch_size, names=None):    """Splits inputs into slices and feeds each slice to a copy of the given    computation graph and then combines the results. It allows you to run a    graph on a batch of inputs even if the graph is written to support one    instance only.    inputs: list of tensors. All must have the same first dimension length    graph_fn: A function that returns a TF tensor that's part of a graph.    batch_size: number of slices to divide the data into.    names: If provided, assigns names to the resulting tensors.    """    if not isinstance(inputs, list):  # 判断inputs是否为list类型        inputs = [inputs]    outputs = []    for i in range(batch_size):        inputs_slice = [x[i] for x in inputs]  # 是一个二维矩阵(去掉了图片张数的维度)# 表示切batch_size,即原来有5个图片,现在截取batch_size=3个图片        output_slice = graph_fn(*inputs_slice)  # 根据ax值取值        if not isinstance(output_slice, (tuple, list)):            output_slice = [output_slice]        outputs.append(output_slice)    # Change outputs from a list of slices where each is    # a list of outputs to a list of outputs and each has    # a list of slices    outputs = list(zip(*outputs))    if names is None:        names = [None] * len(outputs)    result = [tf.stack(o, axis=0, name=n) for o, n in zip(outputs, names)]    if len(result) == 1:        result = result[0]    return resultd=pre_nms_anchors = batch_slice(inputs, lambda a, x: tf.gather(a, x),   3,   names=["pre_nms_anchors"])d=sess.run(d)print('result',d) # 最终结果print('show value=',[x for x in inputs]) # 与下面代码比较,理解inputs_slice = [x[i] for x in inputs]的意义for i in range(2):    inputs_slice = [x[i] for x in inputs]     print('%id='%(i),inputs_slice)print('show inputs_slice=',inputs_slice)结果如下:

true_inputs= [array([[[102, 7, 45, 34],
[ 19, 105, 82, 83],
[ 84, 89, 70, 8],
[ 57, 81, 138, 122],
[ 69, 54, 61, 116],
[108, 120, 46, 122],
[102, 29, 39, 97],
[ 49, 92, 117, 52]],

[[ 52, 124, 86, 86],
[ 54, 9, 70, 104],
[102, 27, 29, 119],
[124, 82, 17, 4],
[ 53, 87, 69, 98],
[127, 106, 80, 40],
[ 78, 121, 84, 28],
[ 86, 111, 129, 149]],

[[112, 98, 89, 142],
[ 20, 134, 40, 50],
[139, 101, 99, 99],
[140, 60, 148, 49],
[ 49, 113, 26, 58],
[143, 85, 96, 142],
[ 42, 70, 16, 123],
[ 12, 92, 77, 143]],

[[136, 137, 31, 31],
[ 78, 28, 32, 87],
[ 39, 12, 124, 47],
[100, 96, 131, 12],
[111, 27, 28, 118],
[ 14, 130, 16, 43],
[ 77, 127, 69, 60],
[ 62, 53, 85, 95]],

[[ 17, 112, 122, 149],
[ 5, 89, 40, 105],
[ 49, 128, 128, 121],
[ 25, 1, 31, 52],
[127, 149, 9, 115],
[ 37, 103, 114, 119],
[130, 23, 29, 86],
[ 46, 111, 101, 69]]]), array([[3, 2, 6, 7, 2, 6],
[1, 1, 0, 6, 1, 7],
[1, 7, 0, 6, 6, 6],
[6, 3, 7, 7, 6, 0],
[0, 7, 4, 6, 3, 0]])]
result [[[ 57 81 138 122]
[ 84 89 70 8]
[102 29 39 97]
[ 49 92 117 52]
[ 84 89 70 8]
[102 29 39 97]]

[[ 54 9 70 104]
[ 54 9 70 104]
[ 52 124 86 86]
[ 78 121 84 28]
[ 54 9 70 104]
[ 86 111 129 149]]

[[ 20 134 40 50]
[ 12 92 77 143]
[112 98 89 142]
[ 42 70 16 123]
[ 42 70 16 123]
[ 42 70 16 123]]]
show value= [array([[[102, 7, 45, 34],
[ 19, 105, 82, 83],
[ 84, 89, 70, 8],
[ 57, 81, 138, 122],
[ 69, 54, 61, 116],
[108, 120, 46, 122],
[102, 29, 39, 97],
[ 49, 92, 117, 52]],

[[ 52, 124, 86, 86],
[ 54, 9, 70, 104],
[102, 27, 29, 119],
[124, 82, 17, 4],
[ 53, 87, 69, 98],
[127, 106, 80, 40],
[ 78, 121, 84, 28],
[ 86, 111, 129, 149]],

[[112, 98, 89, 142],
[ 20, 134, 40, 50],
[139, 101, 99, 99],
[140, 60, 148, 49],
[ 49, 113, 26, 58],
[143, 85, 96, 142],
[ 42, 70, 16, 123],
[ 12, 92, 77, 143]],

[[136, 137, 31, 31],
[ 78, 28, 32, 87],
[ 39, 12, 124, 47],
[100, 96, 131, 12],
[111, 27, 28, 118],
[ 14, 130, 16, 43],
[ 77, 127, 69, 60],
[ 62, 53, 85, 95]],

[[ 17, 112, 122, 149],
[ 5, 89, 40, 105],
[ 49, 128, 128, 121],
[ 25, 1, 31, 52],
[127, 149, 9, 115],
[ 37, 103, 114, 119],
[130, 23, 29, 86],
[ 46, 111, 101, 69]]]), array([[3, 2, 6, 7, 2, 6],
[1, 1, 0, 6, 1, 7],
[1, 7, 0, 6, 6, 6],
[6, 3, 7, 7, 6, 0],
[0, 7, 4, 6, 3, 0]])]
0d= [array([[102, 7, 45, 34],
[ 19, 105, 82, 83],
[ 84, 89, 70, 8],
[ 57, 81, 138, 122],
[ 69, 54, 61, 116],
[108, 120, 46, 122],
[102, 29, 39, 97],
[ 49, 92, 117, 52]]), array([3, 2, 6, 7, 2, 6])]
1d= [array([[ 52, 124, 86, 86],
[ 54, 9, 70, 104],
[102, 27, 29, 119],
[124, 82, 17, 4],
[ 53, 87, 69, 98],
[127, 106, 80, 40],
[ 78, 121, 84, 28],
[ 86, 111, 129, 149]]), array([1, 1, 0, 6, 1, 7])]
show inputs_slice= [array([[ 52, 124, 86, 86],
[ 54, 9, 70, 104],
[102, 27, 29, 119],
[124, 82, 17, 4],
[ 53, 87, 69, 98],
[127, 106, 80, 40],
[ 78, 121, 84, 28],
[ 86, 111, 129, 149]]), array([1, 1, 0, 6, 1, 7])]

Process finished with exit code 0


标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!