问题
I was wondering if there is any more efficient alternative for the below code, without using the "for" loop in the 4th line?
import torch
n, d = 37700, 7842
k = 4
sample = torch.cat([torch.randperm(d)[:k] for _ in range(n)]).view(n, k)
mask = torch.zeros(n, d, dtype=torch.bool)
mask.scatter_(dim=1, index=sample, value=True)
Basically, what I am trying to do is to create an n by d mask tensor, such that in each row exactly k random elements are True.
回答1:
Here's a way to do this with no loop. Let's start with a random matrix where all elements are drawn iid, in this case uniformly on [0,1]. Then we take the k'th quantile for each row and set all smaller or equal elements to True and the rest to False on each row:
rand_mat = torch.rand(n, d)
k_th_quant = torch.topk(rand_mat, k, largest = False)[0][:,-1:]
mask = rand_mat <= k_th_quant
No loop needed :) x2.1598 faster than the code you attached on my CPU.
来源:https://stackoverflow.com/questions/64162672/how-to-randomly-set-a-fixed-number-of-elements-in-each-row-of-a-tensor-in-pytorc