torch.max()输入两个tensor
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
最近看源代码时候没看懂骚操作:
def find_intersection(set_1, set_2): """ Find the intersection of every box combination between two sets of boxes that are in boundary coordinates. :param set_1: set 1, a tensor of dimensions (n1, 4) :param set_2: set 2, a tensor of dimensions (n2, 4) :return: intersection of each of the boxes in set 1 with respect to each of the boxes in set 2, a tensor of dimensions (n1, n2) """ # PyTorch auto-broadcasts singleton dimensions lower_bounds = torch.max(set_1[:, :2].unsqueeze(1), set_2[:, :2].unsqueeze(0)) # (n1, n2, 2) upper_bounds = torch.min(set_1[:, 2:].unsqueeze(1), set_2[:, 2:].unsqueeze(0)) # (n1, n2, 2) intersection_dims = torch.clamp(upper_bounds - lower_bounds, min=0) # (n1, n2, 2) return intersection_dims[:, :, 0] * intersection_dims[:, :, 1] # (n1, n2)
那里说求交集应该是两个边界X距离--两个框的宽度乘以两个边界Y距离--两个框的宽度即可
原来问题出在torch.max()上,简单的用法这里不再赘述,仅仅看最后一个用法,pytorch官方也是一笔带过
torch.max(input, other, out=None) → Tensor Each element of the tensor input is compared with the corresponding element of the tensor other and an element-wise maximum is taken. The shapes of input and other don’t need to match, but they must be broadcastable. \text{out}_i = \max(\text{tensor}_i, \text{other}_i) out_i=max( tensor_i,other_i ) NOTE When the shapes do not match, the shape of the returned output tensor follows the broadcasting rules. Parameters input (Tensor) – the input tensor. other (Tensor) – the second input tensor out (Tensor, optional) – the output tensor. Example: >>> a = torch.randn(4) >>> a tensor([ 0.2942, -0.7416, 0.2653, -0.1584]) >>> b = torch.randn(4) >>> b tensor([ 0.8722, -1.7421, -0.4141, -0.5055]) >>> torch.max(a, b) tensor([ 0.8722, -0.7416, 0.2653, -0.1584])
正常如果如初两个shape相同的tensor,直接按元素比较即可
如果两个不同的tensor上面官方没有说明:
这里举个例子:输入aaa=2 * 2,bbb=2 * 3
aaa = torch.randn(2,2) bbb = torch.randn(3,2) ccc = torch.max(aaa,bbb) RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
出现以上的错误,这里先进行分析一下:
2 * 2
和 3 * 2
无法直接进行比较,按照pytorch官方的说法逐元素比较,那么输出也就应该是2 * 3 * 2
,我们进一步进行测试:
aaa = torch.randn(1,2) bbb = torch.randn(3,2) ccc = torch.max(aaa,bbb) tensor([[1.0350, 0.2532], [0.2203, 0.2532], [0.2912, 0.2532]])
直接可以输出,不会报错
原来pytorch的原则是这样的:维度不同只能比较一维的数据
那么我们可以进一步测试,将输入的2 * 2
和3 * 2
转换成1 * 2 * 2
和3 * 1 * 2
:
aaa = torch.randn(2,2).unsqueeze(1) bbb = torch.randn(3,2).unsqueeze(0) ccc = torch.max(aaa,bbb) RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1
好了,问题完美解决!有时间去看一下源代码怎么实现的,咋不智能。。。。