tensorflow学习之tensor操作函数

纵然是瞬间 提交于 2020-03-08 20:20:04

1.定义各种数组

a = np.array([[1,2,3],[4,5,6]])
# 数组转tensor:数组a,  tensor_a=tf.convert_to_tensor(a)
# tensor转数组:tensor b, array_b=b.eval()
b = tf.convert_to_tensor(a) # 将np定义数组转化为tensor

c = tf.constant([[1,2,3],[4,5,6]])

d_list = [[1,2,3],[4,5,6]]

2.tf.shape(tensor)     tensor.tf.get_shape()

# tf.shape()很显然这个是获取张量的大小的, tf.shape()返回的是一个tensor,要想知道是多少,必须通过sess.run()
# x.get_shape(),只有tensor才可以使用这种方法,返回的是一个元组,不能放到sess.run()里面,这个里面只能放operation和tensor
# c.get_shape().as_list()返回一个list
import tensorflow as tf
import numpy as np

a = np.array([[1,2,3],[4,5,6]])
b = tf.convert_to_tensor(a) # 将np定义数组转化为tensor

c = tf.constant([[1,2,3],[4,5,6]])

d_list = [[1,2,3],[4,5,6]]

# x.get_shape(),只有tensor才可以使用这种方法,返回的是一个元组,不能放到sess.run()里面,这个里面只能放operation和tensor
# c.get_shape().as_list()返回一个list
print(c.get_shape())
print(c.get_shape().as_list())

with tf.Session() as sess:
    print(sess.run(tf.shape(a)))  # tf.shape()很显然这个是获取张量的大小的
    print(sess.run(tf.shape(c)))  # tf.shape()返回的是一个tensor,要想知道是多少,必须通过sess.run()
    print(sess.run(tf.shape(d_list)))

3.tf.transpose()   tf.reshape()

x = tf.transpose(x, [0, 3, 1, 2])#交换维度(b,2c,h,w)
x = tf.reshape(x, (-1, int(x_shape[1]), int(x_shape[2]), 2))#(bc,h,w,2)

4.tf.expand_dims(input, dim, name=None)

将tensor增加维度,例如(对图像维度降到二维做特定操作后,要还原成四维[batch, height, width, channels],前后各增加一维)

# 't' is a tensor of shape [2]
shape(expand_dims(t, 0)) ==> [1, 2]
shape(expand_dims(t, 1)) ==> [2, 1]
shape(expand_dims(t, -1)) ==> [2, 1]

# 't2' is a tensor of shape [2, 3, 5]
shape(expand_dims(t2, 0)) ==> [1, 2, 3, 5]
shape(expand_dims(t2, 2)) ==> [2, 3, 1, 5]
shape(expand_dims(t2, 3)) ==> [2, 3, 5, 1]

one_img = tf.expand_dims(one_img, 0)
one_img = tf.expand_dims(one_img, -1) #-1表示最后一维

5. tf.tile()函数

tensorflow中的tile()函数是用来对张量(Tensor)进行扩展的,其特点是对当前张量内的数据进行一定规则的复制。最终的输出张量维度不变

tf.tile(
    input,
    multiples,
    name=None
)

with tf.Graph().as_default():
    a = tf.constant([1,2],name='a') 
    b = tf.tile(a,[3])
    sess = tf.Session()
    print(sess.run(b))
# 对[1,2]的同一维度上复制3次,multiples参数维度与input维度应一致,结果如下:

# [1 2 1 2 1 2]


with tf.Graph().as_default():
    a = tf.constant([[1,2],[3,4]],name='a')   
    b = tf.tile(a,[2,3])
    sess = tf.Session()
    print(sess.run(b))

# 输出:
[[1 2 1 2 1 2]
 [3 4 3 4 3 4]
 [1 2 1 2 1 2]
 [3 4 3 4 3 4]]

6. tf.cast()

"""tf.cast()函数的作用是执行 tensorflow 中张量数据类型转换,比如读入的图片如果是int8类型的,一般在要在训练前把图像的数据格式转换为float32。

cast定义:

cast(x, dtype, name=None)
第一个参数 x:   待转换的数据(张量)
第二个参数 dtype: 目标数据类型
第三个参数 name: 可选参数,定义操作的名称

int32转换为float32:
"""

import tensorflow as tf
 
t1 = tf.Variable([1,2,3,4,5])
t2 = tf.cast(t1,dtype=tf.float32)
 
print 't1: {}'.format(t1)
print 't2: {}'.format(t2)
 
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(t2)
    print t2.eval()
    # print(sess.run(t2))
输出:

t1: <tf.Variable 'Variable:0' shape=(5,) dtype=int32_ref>
t2: Tensor("Cast:0", shape=(5,), dtype=float32)
[ 1.  2.  3.  4.  5.]

7. tf.concat与tf.stack

tf.concat是沿某一维度拼接shape相同的张量,拼接生成的新张量维度不会增加。而tf.stack是在新的维度上拼接,拼接后维度加1

import tensorflow as tf
a = tf.constant([[1,2,3],[4,5,6]])
b = tf.constant([[7,8,9],[10,11,12]])
ab1 = tf.concat([a,b],axis=0)
ab2 = tf.stack([a,b], axis=0)
sess = tf.Session()
print(sess.run(ab1))
print(sess.run(ab2))
print ab1
print ab2
"""
结果:

[[ 1 2 3]
[ 4 5 6]
[ 7 8 9]
[10 11 12]]#ab1的值

[[[ 1 2 3]
[ 4 5 6]]

[[ 7 8 9]
[10 11 12]]]#ab2的值
Tensor(“concat:0”, shape=(4, 3), dtype=int32)#ab1的属性
Tensor(“stack:0”, shape=(2, 2, 3), dtype=int32)#ab2的属性
由上可知:由于axis=0,表示在第一维连接。对tf.concat来说连接后张量的第一维是在外层中括号内,将两个张量原来所有第一维内的元素(如[1 2 3],[7 8 9 ]等)连接。而tf.stack需要先将原二维张量扩展成三维张量,在扩展后张量的第一维上将原两个二维张量连接。

若axis=1,结果为:

[[ 1 2 3 7 8 9]
[ 4 5 6 10 11 12]]

[[[ 1 2 3]
[ 7 8 9]]

[[ 4 5 6]
[10 11 12]]]
Tensor(“concat:0”, shape=(2, 6), dtype=int32)
Tensor(“stack:0”, shape=(2, 2, 3), dtype=int32)
同上:axis=1,表示在第二维连接。对tf.concat来说连接后张量的第二维是在内层中括号内,将两个张量原来所有第二维内的元素(如1 ,2, 3,7,8,9等)连接。tf.stack先将原二维张量扩展成三维张量,在扩展后张量的第二维上将两个原张量的相对应的元素(如[1 2 3]和[ 7 8 9])连接。

若axis=2,tf.concat会报错(因为tf.concat不增加维数)。tf.stack结果如下:

[[[ 1 7]
[ 2 8]
[ 3 9]]

[[ 4 10]
[ 5 11]
[ 6 12]]]
Tensor(“stack:0”, shape=(2, 3, 2), dtype=int32)
同上:axis=2,表示在第三维连接。tf.stack先将原二维张量扩展成三维张量,在扩展后张量的第三维上将两个原张量的相对应的元素(如1和7,2和8,3和9等 )连接。
"""

8.tf.gather_nd()

tf.gather_nd(
    params,
    indices,
    name=None
)

按照indices的格式从params中抽取切片(合并为一个Tensor)
indices是一个K维整数Tensor

import tensorflow as tf

a = tf.Variable([[1, 2, 3, 4, 5],
                 [6, 7, 8, 9, 10],
                 [11, 12, 13, 14, 15]])
index_a1 = tf.Variable([[0, 2], [0, 4], [2, 2]])  # 随便选几个
index_a2 = tf.Variable([0, 1])  # 0行1列的元素——2
index_a3 = tf.Variable([[0], [1]])  # [第0行,第1行]
index_a4 = tf.Variable([0])  # 第0行


with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf.gather_nd(a, index_a1)))
    print(sess.run(tf.gather_nd(a, index_a2)))
    print(sess.run(tf.gather_nd(a, index_a3)))
    print(sess.run(tf.gather_nd(a, index_a4)))

"""
输出:
[ 3  5 13]
2
[[ 1  2  3  4  5]
 [ 6  7  8  9 10]]
[1 2 3 4 5]
"""

9.tf.argmax()

tf.argmax(input,axis)根据axis取值的不同返回每行或者每列最大值的索引

  • axis = 0: 
      axis=0时比较每一列的元素,将每一列最大元素所在的索引记录下来,最后输出每一列最大元素所在的索引数组。

axis = 1: 
  axis=1的时候,将每一行最大元素所在的索引记录下来,最后返回每一行最大元素所在的索引数组

test = np.array([[1, 2, 3], [2, 3, 4], [5, 4, 3], [8, 7, 2]])
np.argmax(test, 0)   #输出:array([3, 3, 1]
np.argmax(test, 1)   #输出:array([2, 2, 0, 0]123


test[0] = array([1, 2, 3])
test[1] = array([2, 3, 4])
test[2] = array([5, 4, 3])
test[3] = array([8, 7, 2])
# output   :    [3, 3, 1]   

test[0] = array([1, 2, 3])  #2
test[1] = array([2, 3, 4])  #2
test[2] = array([5, 4, 3])  #0
test[3] = array([8, 7, 2])  #0

这是里面都是数组长度一致的情况,如果不一致,axis最大值为最小的数组长度-1,超过则报错。 
当不一致的时候,axis=0的比较也就变成了每个数组的和的比较。

10. tf.clip_by_value(V, min, max)

tf.clip_by_value(V, min, max), 截取V使之在min和max之间

import tensorflow as tf

import numpy as np

v = tf.constant([[1.0, 2.0, 4.0],[4.0, 5.0, 6.0]])
result = tf.clip_by_value(v, 2.5, 4.5)


with tf.Session() as sess:
    print(sess.run(result))

'''
输出

[[ 2.5  2.5  4. ]

 [ 4.   4.5  4.5]]
'''

 

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!