Finding non-intersection of two pytorch tensors

穿精又带淫゛_ 提交于 2021-01-15 18:09:26

问题


Thanks everyone in advance for your help! What I'm trying to do in PyTorch is something like numpy's setdiff1d. For example given the below two tensors:

t1 = torch.tensor([1, 9, 12, 5, 24]).to('cuda:0')
t2 = torch.tensor([1, 24]).to('cuda:0')

The expected output should be (sorted or unsorted):

torch.tensor([9, 12, 5])

Ideally the operations are done on GPU and no back and forth between GPU and CPU. Much appreciated!


回答1:


if you don't want to leave cuda, a workaround could be:

t1 = torch.tensor([1, 9, 12, 5, 24], device = 'cuda')
t2 = torch.tensor([1, 24], device = 'cuda')
indices = torch.ones_like(t1, dtype = torch.uint8, device = 'cuda')
for elem in t2:
    indices = indices & (t1 != elem)  
intersection = t1[indices]  



回答2:


I came across the same problem but the proposed solutions were far too slow when using larger arrays. The following simple solution works on CPU and GPU and is significantly faster than the other proposed solutions:

combined = torch.cat((t1, t2))
uniques, counts = combined.unique(return_counts=True)
difference = uniques[counts == 1]
intersection = uniques[counts > 1]



回答3:


If you don't want a for loop this can compare all values in one go.

Also you can get the non intersection easily too

t1 = torch.tensor([1, 9, 12, 5, 24])
t2 = torch.tensor([1, 24])

# Create a tensor to compare all values at once
compareview = t2.repeat(t1.shape[0],1).T

# Intersection
print(t1[(compareview == t1).T.sum(1)==1])
# Non Intersection
print(t1[(compareview != t1).T.prod(1)==1])
tensor([ 1, 24])
tensor([ 9, 12,  5])


来源:https://stackoverflow.com/questions/55110047/finding-non-intersection-of-two-pytorch-tensors

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!