Fast punctuation removal with pandas

前端 未结 3 1038
予麋鹿
予麋鹿 2020-11-22 06:21

This is a self-answered post. Below I outline a common problem in the NLP domain and propose a few performant methods to solve it.

Oftentimes the need arises to remo

3条回答
  •  小蘑菇
    小蘑菇 (楼主)
    2020-11-22 06:40

    Using numpy we can gain a healthy speedup over the best methods posted so far. The basic strategy is similar---make one big super string. But the processing seems much faster in numpy, presumably because we fully exploit the simplicity of the nothing-for-something replacement op.

    For smaller (less than 0x110000 characters total) problems we automatically find a separator, for larger problems we use a slower method that does not rely on str.split.

    Note that I have moved all precomputables out of the functions. Also note, that translate and pd_translate get to know the only possible separator for the three largest problems for free whereas np_multi_strat has to compute it or to fall back to the separator-less strategy. And finally, note that for the last three data points I switch to a more "interesting" problem; pd_replace and re_sub because they are not equivalent to the other methods had to be excluded for that.

    On the algorithm:

    The basic strategy is actually quite simple. There are only 0x110000 different unicode characters. As OP frames the challenge in terms of huge data sets, it is perfectly worthwhile making a lookup table that has True at the character id's that we want to keep and False at the ones that have to go --- the punctuation in our example.

    Such a lookup table can be used for bulk loookup using numpy's advanced indexing. As lookup is fully vectorized and essentially amounts to dereferencing an array of pointers it is much faster than for example dictionary lookup. Here we make use of numpy view casting which allows to reinterpret unicode characters as integers essentially for free.

    Using the data array which contains just one monster string reinterpreted as a sequence of numbers to index into the lookup table results in a boolean mask. This mask can then be used to filter out the unwanted characters. Using boolean indexing this, too, is a single line of code.

    So far so simple. The tricky bit is chopping up the monster string back into its parts. If we have a separator, i.e. one character that does not occur in the data or the punctuation list, then it still is easy. Use this character to join and resplit. However, automatically finding a separator is challenging and indeed accounts for half the loc in the implementation below.

    Alternatively, we can keep the split points in a separate data structure, track how they move as a consequence of deleting unwanted characters and then use them to slice the processed monster string. As chopping up into parts of uneven length is not numpy's strongest suit, this method is slower than str.split and only used as a fallback when a separator would be too expensive to calculate if it existed in the first place.

    Code (timing/plotting heavily based on @COLDSPEED's post):

    import numpy as np
    import pandas as pd
    import string
    import re
    
    
    spct = np.array([string.punctuation]).view(np.int32)
    lookup = np.zeros((0x110000,), dtype=bool)
    lookup[spct] = True
    invlookup = ~lookup
    OSEP = spct[0]
    SEP = chr(OSEP)
    while SEP in string.punctuation:
        OSEP = np.random.randint(0, 0x110000)
        SEP = chr(OSEP)
    
    
    def find_sep_2(letters):
        letters = np.array([letters]).view(np.int32)
        msk = invlookup.copy()
        msk[letters] = False
        sep = msk.argmax()
        if not msk[sep]:
            return None
        return sep
    
    def find_sep(letters, sep=0x88000):
        letters = np.array([letters]).view(np.int32)
        cmp = np.sign(sep-letters)
        cmpf = np.sign(sep-spct)
        if cmp.sum() + cmpf.sum() >= 1:
            left, right, gs = sep+1, 0x110000, -1
        else:
            left, right, gs = 0, sep, 1
        idx, = np.where(cmp == gs)
        idxf, = np.where(cmpf == gs)
        sep = (left + right) // 2
        while True:
            cmp = np.sign(sep-letters[idx])
            cmpf = np.sign(sep-spct[idxf])
            if cmp.all() and cmpf.all():
                return sep
            if cmp.sum() + cmpf.sum() >= (left & 1 == right & 1):
                left, sep, gs = sep+1, (right + sep) // 2, -1
            else:
                right, sep, gs = sep, (left + sep) // 2, 1
            idx = idx[cmp == gs]
            idxf = idxf[cmpf == gs]
    
    def np_multi_strat(df):
        L = df['text'].tolist()
        all_ = ''.join(L)
        sep = 0x088000
        if chr(sep) in all_: # very unlikely ...
            if len(all_) >= 0x110000: # fall back to separator-less method
                                      # (finding separator too expensive)
                LL = np.array((0, *map(len, L)))
                LLL = LL.cumsum()
                all_ = np.array([all_]).view(np.int32)
                pnct = invlookup[all_]
                NL = np.add.reduceat(pnct, LLL[:-1])
                NLL = np.concatenate([[0], NL.cumsum()]).tolist()
                all_ = all_[pnct]
                all_ = all_.view(f'U{all_.size}').item(0)
                return df.assign(text=[all_[NLL[i]:NLL[i+1]]
                                       for i in range(len(NLL)-1)])
            elif len(all_) >= 0x22000: # use mask
                sep = find_sep_2(all_)
            else: # use bisection
                sep = find_sep(all_)
        all_ = np.array([chr(sep).join(L)]).view(np.int32)
        pnct = invlookup[all_]
        all_ = all_[pnct]
        all_ = all_.view(f'U{all_.size}').item(0)
        return df.assign(text=all_.split(chr(sep)))
    
    def pd_replace(df):
        return df.assign(text=df['text'].str.replace(r'[^\w\s]+', ''))
    
    
    p = re.compile(r'[^\w\s]+')
    
    def re_sub(df):
        return df.assign(text=[p.sub('', x) for x in df['text'].tolist()])
    
    punct = string.punctuation.replace(SEP, '')
    transtab = str.maketrans(dict.fromkeys(punct, ''))
    
    def translate(df):
        return df.assign(
            text=SEP.join(df['text'].tolist()).translate(transtab).split(SEP)
        )
    
    # MaxU's version (https://stackoverflow.com/a/50444659/4909087)
    def pd_translate(df):
        return df.assign(text=df['text'].str.translate(transtab))
    
    from timeit import timeit
    
    import pandas as pd
    import matplotlib.pyplot as plt
    
    res = pd.DataFrame(
           index=['translate', 'pd_replace', 're_sub', 'pd_translate', 'np_multi_strat'],
           columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000, 100000, 500000,
                    1000000],
           dtype=float
    )
    
    for c in res.columns:
        if c >= 100000: # stress test the separator finder
            all_ = np.r_[:OSEP, OSEP+1:0x110000].repeat(c//10000)
            np.random.shuffle(all_)
            split = np.arange(c-1) + \
                    np.sort(np.random.randint(0, len(all_) - c + 2, (c-1,))) 
            l = [x.view(f'U{x.size}').item(0) for x in np.split(all_, split)]
        else:
            l = ['a..b?!??', '%hgh&12','abc123!!!', '$$$1234'] * c
        df = pd.DataFrame({'text' : l})
        for f in res.index: 
            if f == res.index[0]:
                ref = globals()[f](df).text
            elif not (ref == globals()[f](df).text).all():
                res.at[f, c] = np.nan
                print(f, 'disagrees at', c)
                continue
            stmt = '{}(df)'.format(f)
            setp = 'from __main__ import df, {}'.format(f)
            res.at[f, c] = timeit(stmt, setp, number=16)
    
    ax = res.div(res.min()).T.plot(loglog=True) 
    ax.set_xlabel("N"); 
    ax.set_ylabel("time (relative)");
    
    plt.show()
    

提交回复
热议问题