Extending numpy.digitize to multi-dimensional data

前端 未结 1 1625
Happy的楠姐
Happy的楠姐 2020-12-31 11:15

I have a set of large arrays (about 6 million elements each) that I want to basically perform a np.digitize but over multiple axes. I am looking for some suggestions on bot

相关标签:
1条回答
  • 2020-12-31 12:01

    How about use groupby in Pandas. Fix some problem in your code first:

    import itertools
    import numpy as np
    
    np.random.seed(42)
    
    A = np.random.random_sample(1e4)
    B = (np.random.random_sample(1e4) + 10)*20
    C = (np.random.random_sample(1e4) + 20)*40
    D = (np.random.random_sample(1e4) + 80)*80
    
    # make the edges of the bins
    Bbins = np.linspace(B.min(), B.max(), 10)
    Cbins = np.linspace(C.min(), C.max(), 12) # note different number
    Dbins = np.linspace(D.min(), D.max(), 24) # note different number
    
    B_Bidx = np.digitize(B, Bbins)
    C_Cidx = np.digitize(C, Cbins)
    D_Didx = np.digitize(D, Dbins)
    
    a_bins = []
    for bb, cc, dd in itertools.product(np.unique(B_Bidx), 
                                        np.unique(C_Cidx), 
                                        np.unique(D_Didx)):
        a_bins.append([(bb, cc, dd), A[(B_Bidx==bb) & (C_Cidx==cc) & (D_Didx==dd)]])
    
    a_bins[1000]
    

    output:

    [(4, 6, 17), array([ 0.70723863,  0.907611  ,  0.46214047])]
    

    Here is the code that return the same result by Pandas:

    import pandas as pd
    
    cB = pd.cut(B, 9)
    cC = pd.cut(C, 11)
    cD = pd.cut(D, 23)
    
    sA = pd.Series(A)
    g = sA.groupby([cB.labels, cC.labels, cD.labels])
    g.get_group((3, 5, 16))
    

    output:

    800     0.707239
    2320    0.907611
    9388    0.462140
    dtype: float64
    

    If you want to calculate some statistics of every group, you can call the method of g, for example:

    g.mean()
    

    returns:

    0  0  0     0.343566
          1     0.410979
          2     0.700007
          3     0.189936
          4     0.452566
          5     0.565330
          6     0.539565
          7     0.530867
          8     0.568120
          9     0.587762
          11    0.352453
          12    0.484903
          13    0.477969
          14    0.484328
          15    0.467357
    ...
    8  10  8     0.559859
           9     0.570652
           10    0.656718
           11    0.353938
           12    0.628980
           13    0.372350
           14    0.404543
           15    0.387920
           16    0.742292
           17    0.530866
           18    0.389236
           19    0.628461
           20    0.387384
           21    0.541831
           22    0.573023
    Length: 2250, dtype: float64
    
    0 讨论(0)
提交回复
热议问题