N-D version of itertools.combinations in numpy

前端 未结 4 1261
臣服心动
臣服心动 2020-12-01 12:18

I would like to implement itertools.combinations for numpy. Based on this discussion, I have a function that works for 1D input:

def combs(a, r):
    \"\"\"         


        
4条回答
  •  我在风中等你
    2020-12-01 13:18

    You can use itertools.combinations() to create the index array, and then use NumPy's fancy indexing:

    import numpy as np
    from itertools import combinations, chain
    from scipy.special import comb
    
    def comb_index(n, k):
        count = comb(n, k, exact=True)
        index = np.fromiter(chain.from_iterable(combinations(range(n), k)), 
                            int, count=count*k)
        return index.reshape(-1, k)
    
    data = np.array([[1,2,3,4,5],[10,11,12,13,14]])
    
    idx = comb_index(5, 3)
    print(data[:, idx])
    

    output:

    [[[ 1  2  3]
      [ 1  2  4]
      [ 1  2  5]
      [ 1  3  4]
      [ 1  3  5]
      [ 1  4  5]
      [ 2  3  4]
      [ 2  3  5]
      [ 2  4  5]
      [ 3  4  5]]
    
     [[10 11 12]
      [10 11 13]
      [10 11 14]
      [10 12 13]
      [10 12 14]
      [10 13 14]
      [11 12 13]
      [11 12 14]
      [11 13 14]
      [12 13 14]]]
    

提交回复
热议问题