关于np.expand_dims的使用,网上好多举了一些实例,自己在平时也常见,但总是有点迷糊,我知道它的作用是扩展一个张量的维度,但结果是如何变化得到的,想来想去不是太明了,所以去函数源码看了一下,算是明白了,np.expand_dims的源码如下:
def expand_dims(a, axis):
"""
Expand the shape of an array.
Insert a new axis that will appear at the `axis` position in the expanded
array shape.
.. note:: Previous to NumPy 1.13.0, neither ``axis < -a.ndim - 1`` nor
``axis > a.ndim`` raised errors or put the new axis where documented.
Those axis values are now deprecated and will raise an AxisError in the
future.
Parameters
----------
a : array_like
Input array.
axis : int
Position in the expanded axes where the new axis is placed.
Returns
-------
res : ndarray
Output array. The number of dimensions is one greater than that of
the input array.
See Also
--------
squeeze : The inverse operation, removing singleton dimensions
reshape : Insert, remove, and combine dimensions, and resize existing ones
doc.indexing, atleast_1d, atleast_2d, atleast_3d
Examples
--------
>>> x = np.array([1,2])
>>> x.shape
(2,)
The following is equivalent to ``x[np.newaxis,:]`` or ``x[np.newaxis]``:
>>> y = np.expand_dims(x, axis=0)
>>> y
array([[1, 2]])
>>> y.shape
(1, 2)
>>> y = np.expand_dims(x, axis=1) # Equivalent to x[:,np.newaxis]
>>> y
array([[1],
[2]])
>>> y.shape
(2, 1)
Note that some examples may use ``None`` instead of ``np.newaxis``. These
are the same objects:
>>> np.newaxis is None
True
"""
if isinstance(a, matrix):
a = asarray(a)
else:
a = asanyarray(a)
shape = a.shape
if axis > a.ndim or axis < -a.ndim - 1:
# 2017-05-17, 1.13.0
warnings.warn("Both axis > a.ndim and axis < -a.ndim - 1 are "
"deprecated and will raise an AxisError in the future.",
DeprecationWarning, stacklevel=2)
# When the deprecation period expires, delete this if block,
if axis < 0:
axis = axis + a.ndim + 1
# and uncomment the following line.
# axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])
上述是官方原始文件,最重要的地方就最后三行代码,下面对其稍微注解一下:
if axis < 0:
axis = axis + a.ndim + 1#当采用倒数的方式指定维度位置时需要转化为正常顺序的位置
# and uncomment the following line.
# axis = normalize_axis_index(axis, a.ndim + 1)
return a.reshape(shape[:axis] + (1,) + shape[axis:])#expand_dims最重要的地方就这里了,在axis位置把新维度插入原始维度中,然后reshape一下。上面的+(1,)可能不好理解(python中两个tuple类型相加,不是求和,而是拼接),举个例子:
例1:
shape=(2,3,4)
axis=1
newshape=shape[:axis] + (1,) + shape[axis:]
print(newshape)
#输出:(2, 1, 3, 4)
例2:
shape=(2,3,4)
axis=1
newshape=shape[:axis] + (1,11) + shape[axis:]
print(newshape)
#输出:(2, 1, 11, 3, 4)
来源:CSDN
作者:粼粼淇
链接:https://blog.csdn.net/lingyunxianhe/article/details/103836807