NumPy version of “Exponential weighted moving average”, equivalent to pandas.ewm().mean()

后端 未结 12 771
一生所求
一生所求 2020-11-27 12:30

How do I get the exponential weighted moving average in NumPy just like the following in pandas?

import pandas as pd
import pandas_datareader as pdr
from dat         


        
12条回答
  •  刺人心
    刺人心 (楼主)
    2020-11-27 12:53

    Here is an implementation using NumPy that is equivalent to using df.ewm(alpha=alpha).mean(). After reading the documentation, it is just a few matrix operations. The trick is constructing the right matrices.

    It is worth noting that because we are creating float matrices, you can quickly eat through your memory if the input array is too large.

    import pandas as pd
    import numpy as np
    
    def ewma(x, alpha):
        '''
        Returns the exponentially weighted moving average of x.
    
        Parameters:
        -----------
        x : array-like
        alpha : float {0 <= alpha <= 1}
    
        Returns:
        --------
        ewma: numpy array
              the exponentially weighted moving average
        '''
        # Coerce x to an array
        x = np.array(x)
        n = x.size
    
        # Create an initial weight matrix of (1-alpha), and a matrix of powers
        # to raise the weights by
        w0 = np.ones(shape=(n,n)) * (1-alpha)
        p = np.vstack([np.arange(i,i-n,-1) for i in range(n)])
    
        # Create the weight matrix
        w = np.tril(w0**p,0)
    
        # Calculate the ewma
        return np.dot(w, x[::np.newaxis]) / w.sum(axis=1)
    

    Let's test its:

    alpha = 0.55
    x = np.random.randint(0,30,15)
    df = pd.DataFrame(x, columns=['A'])
    df.ewm(alpha=alpha).mean()
    
    # returns:
    #             A
    # 0   13.000000
    # 1   22.655172
    # 2   20.443268
    # 3   12.159796
    # 4   14.871955
    # 5   15.497575
    # 6   20.743511
    # 7   20.884818
    # 8   24.250715
    # 9   18.610901
    # 10  17.174686
    # 11  16.528564
    # 12  17.337879
    # 13   7.801912
    # 14  12.310889
    
    ewma(x=x, alpha=alpha)
    
    # returns:
    # array([ 13.        ,  22.65517241,  20.44326778,  12.1597964 ,
    #        14.87195534,  15.4975749 ,  20.74351117,  20.88481763,
    #        24.25071484,  18.61090129,  17.17468551,  16.52856393,
    #        17.33787888,   7.80191235,  12.31088889])
    

提交回复
热议问题