How to do product of matrices in PyTorch

后端 未结 4 627
眼角桃花
眼角桃花 2020-12-12 18:02

In numpy I can do a simple matrix multiplication like this:

a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot         


        
4条回答
  •  挽巷
    挽巷 (楼主)
    2020-12-12 18:19

    Use torch.mm(a, b) or torch.matmul(a, b)
    Both are same.

    >>> torch.mm
    
    >>> torch.matmul
    
    

    There's one more option that may be good to know. That is @ operator. @Simon H.

    >>> a = torch.randn(2, 3)
    >>> b = torch.randn(3, 4)
    >>> a@b
    tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
            [ 0.8699, -0.3445,  1.4122, -0.5826]])
    >>> a.mm(b)
    tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
            [ 0.8699, -0.3445,  1.4122, -0.5826]])
    >>> a.matmul(b)
    tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
            [ 0.8699, -0.3445,  1.4122, -0.5826]])    
    

    The three give the same results.

    Related links:
    Matrix multiplication operator
    PEP 465 -- A dedicated infix operator for matrix multiplication

提交回复
热议问题