N-D version of itertools.combinations in numpy

前端 未结 4 1259
臣服心动
臣服心动 2020-12-01 12:18

I would like to implement itertools.combinations for numpy. Based on this discussion, I have a function that works for 1D input:

def combs(a, r):
    \"\"\"         


        
4条回答
  •  青春惊慌失措
    2020-12-01 12:53

    Not sure how it will work out performance-wise, but you can do the combinations on an index array, then extract the actual array slices with np.take:

    def combs_nd(a, r, axis=0):
        a = np.asarray(a)
        if axis < 0:
            axis += a.ndim
        indices = np.arange(a.shape[axis])
        dt = np.dtype([('', np.intp)]*r)
        indices = np.fromiter(combinations(indices, r), dt)
        indices = indices.view(np.intp).reshape(-1, r)
        return np.take(a, indices, axis=axis)
    
    >>> combs_nd([1,2,3], 2)
    array([[1, 2],
           [1, 3],
           [2, 3]])
    >>> combs_nd([[1,2,3],[4,5,6]], 2, axis=1)
    array([[[1, 2],
            [1, 3],
            [2, 3]],
    
           [[4, 5],
            [4, 6],
            [5, 6]]])
    

提交回复
热议问题