Fast punctuation removal with pandas

前端 未结 3 1071
予麋鹿
予麋鹿 2020-11-22 06:21

This is a self-answered post. Below I outline a common problem in the NLP domain and propose a few performant methods to solve it.

Oftentimes the need arises to remo

3条回答
  •  面向向阳花
    2020-11-22 06:45

    Setup

    For the purpose of demonstration, let's consider this DataFrame.

    df = pd.DataFrame({'text':['a..b?!??', '%hgh&12','abc123!!!', '$$$1234']})
    df
            text
    0   a..b?!??
    1    %hgh&12
    2  abc123!!!
    3    $$$1234
    

    Below, I list the alternatives, one by one, in increasing order of performance

    str.replace

    This option is included to establish the default method as a benchmark for comparing other, more performant solutions.

    This uses pandas in-built str.replace function which performs regex-based replacement.

    df['text'] = df['text'].str.replace(r'[^\w\s]+', '')
    

    df
         text
    0      ab
    1   hgh12
    2  abc123
    3    1234
    

    This is very easy to code, and is quite readable, but slow.


    regex.sub

    This involves using the sub function from the re library. Pre-compile a regex pattern for performance, and call regex.sub inside a list comprehension. Convert df['text'] to a list beforehand if you can spare some memory, you'll get a nice little performance boost out of this.

    import re
    p = re.compile(r'[^\w\s]+')
    df['text'] = [p.sub('', x) for x in df['text'].tolist()]
    

    df
         text
    0      ab
    1   hgh12
    2  abc123
    3    1234
    

    Note: If your data has NaN values, this (as well as the next method below) will not work as is. See the section on "Other Considerations".


    str.translate

    python's str.translate function is implemented in C, and is therefore very fast.

    How this works is:

    1. First, join all your strings together to form one huge string using a single (or more) character separator that you choose. You must use a character/substring that you can guarantee will not belong inside your data.
    2. Perform str.translate on the large string, removing punctuation (the separator from step 1 excluded).
    3. Split the string on the separator that was used to join in step 1. The resultant list must have the same length as your initial column.

    Here, in this example, we consider the pipe separator |. If your data contains the pipe, then you must choose another separator.

    import string
    
    punct = '!"#$%&\'()*+,-./:;<=>?@[\\]^_`{}~'   # `|` is not present here
    transtab = str.maketrans(dict.fromkeys(punct, ''))
    
    df['text'] = '|'.join(df['text'].tolist()).translate(transtab).split('|')
    

    df
         text
    0      ab
    1   hgh12
    2  abc123
    3    1234
    

    Performance

    str.translate performs the best, by far. Note that the graph below includes another variant Series.str.translate from MaxU's answer.

    (Interestingly, I reran this a second time, and the results are slightly different from before. During the second run, it seems re.sub was winning out over str.translate for really small amounts of data.)

    There is an inherent risk involved with using translate (particularly, the problem of automating the process of deciding which separator to use is non-trivial), but the trade-offs are worth the risk.


    Other Considerations

    Handling NaNs with list comprehension methods; Note that this method (and the next) will only work as long as your data does not have NaNs. When handling NaNs, you will have to determine the indices of non-null values and replace those only. Try something like this:

    df = pd.DataFrame({'text': [
        'a..b?!??', np.nan, '%hgh&12','abc123!!!', '$$$1234', np.nan]})
    
    idx = np.flatnonzero(df['text'].notna())
    col_idx = df.columns.get_loc('text')
    df.iloc[idx,col_idx] = [
        p.sub('', x) for x in df.iloc[idx,col_idx].tolist()]
    
    df
         text
    0      ab
    1     NaN
    2   hgh12
    3  abc123
    4    1234
    5     NaN
    

    Dealing with DataFrames; If you are dealing with DataFrames, where every column requires replacement, the procedure is simple:

    v = pd.Series(df.values.ravel())
    df[:] = translate(v).values.reshape(df.shape)
    

    Or,

    v = df.stack()
    v[:] = translate(v)
    df = v.unstack()
    

    Note that the translate function is defined below in with the benchmarking code.

    Every solution has tradeoffs, so deciding what solution best fits your needs will depend on what you're willing to sacrifice. Two very common considerations are performance (which we've already seen), and memory usage. str.translate is a memory-hungry solution, so use with caution.

    Another consideration is the complexity of your regex. Sometimes, you may want to remove anything that is not alphanumeric or whitespace. Othertimes, you will need to retain certain characters, such as hyphens, colons, and sentence terminators [.!?]. Specifying these explicitly add complexity to your regex, which may in turn impact the performance of these solutions. Make sure you test these solutions on your data before deciding what to use.

    Lastly, unicode characters will be removed with this solution. You may want to tweak your regex (if using a regex-based solution), or just go with str.translate otherwise.

    For even more performance (for larger N), take a look at this answer by Paul Panzer.


    Appendix

    Functions

    def pd_replace(df):
        return df.assign(text=df['text'].str.replace(r'[^\w\s]+', ''))
    
    
    def re_sub(df):
        p = re.compile(r'[^\w\s]+')
        return df.assign(text=[p.sub('', x) for x in df['text'].tolist()])
    
    def translate(df):
        punct = string.punctuation.replace('|', '')
        transtab = str.maketrans(dict.fromkeys(punct, ''))
    
        return df.assign(
            text='|'.join(df['text'].tolist()).translate(transtab).split('|')
        )
    
    # MaxU's version (https://stackoverflow.com/a/50444659/4909087)
    def pd_translate(df):
        punct = string.punctuation.replace('|', '')
        transtab = str.maketrans(dict.fromkeys(punct, ''))
    
        return df.assign(text=df['text'].str.translate(transtab))
    

    Performance Benchmarking Code

    from timeit import timeit
    
    import pandas as pd
    import matplotlib.pyplot as plt
    
    res = pd.DataFrame(
           index=['pd_replace', 're_sub', 'translate', 'pd_translate'],
           columns=[10, 50, 100, 500, 1000, 5000, 10000, 50000],
           dtype=float
    )
    
    for f in res.index: 
        for c in res.columns:
            l = ['a..b?!??', '%hgh&12','abc123!!!', '$$$1234'] * c
            df = pd.DataFrame({'text' : l})
            stmt = '{}(df)'.format(f)
            setp = 'from __main__ import df, {}'.format(f)
            res.at[f, c] = timeit(stmt, setp, number=30)
    
    ax = res.div(res.min()).T.plot(loglog=True) 
    ax.set_xlabel("N"); 
    ax.set_ylabel("time (relative)");
    
    plt.show()
    

提交回复
热议问题