More efficient weighted Gini coefficient in Python

后端 未结 2 1054
夕颜
夕颜 2020-12-18 00:27

Per https://stackoverflow.com/a/48981834/1840471, this is an implementation of the weighted Gini coefficient in Python:

import numpy as np
def gini(x, weight         


        
2条回答
  •  南方客
    南方客 (楼主)
    2020-12-18 01:10

    Here is a version which is much faster than the one you provided above, and also uses a simplified formula for the case without weight to get even faster results in that case.

    def gini(x, w=None):
        # The rest of the code requires numpy arrays.
        x = np.asarray(x)
        if w is not None:
            w = np.asarray(w)
            sorted_indices = np.argsort(x)
            sorted_x = x[sorted_indices]
            sorted_w = w[sorted_indices]
            # Force float dtype to avoid overflows
            cumw = np.cumsum(sorted_w, dtype=float)
            cumxw = np.cumsum(sorted_x * sorted_w, dtype=float)
            return (np.sum(cumxw[1:] * cumw[:-1] - cumxw[:-1] * cumw[1:]) / 
                    (cumxw[-1] * cumw[-1]))
        else:
            sorted_x = np.sort(x)
            n = len(x)
            cumx = np.cumsum(sorted_x, dtype=float)
            # The above formula, with all weights equal to 1 simplifies to:
            return (n + 1 - 2 * np.sum(cumx) / cumx[-1]) / n
    

    Here is some test code to check we get (mostly) the same results:

    >>> x = np.random.rand(1000000)
    >>> w = np.random.rand(1000000)
    >>> gini_max_ghenis(x, w)
    0.33376310938610521
    >>> gini(x, w)
    0.33376310938610382
    

    But the speed is very different:

    %timeit gini(x, w)
    203 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    
    %timeit gini_max_ghenis(x, w)
    55.6 s ± 3.35 s per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    If you remove the pandas ops from the function, it is already much faster:

    %timeit gini_max_ghenis_no_pandas_ops(x, w)
    1.62 s ± 75 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
    

    If you want to get the last drop of performance you could use numba or cython but that would only gain a few percent because most of the time is spent in sorting.

    %timeit ind = np.argsort(x); sx = x[ind]; sw = w[ind]
    180 ms ± 4.82 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
    

    edit: gini_max_ghenis is the code used in Max Ghenis' answer

提交回复
热议问题