pandas: Keep only top n values and set others to 0

﹥>﹥吖頭↗ 提交于 2021-02-18 11:39:27

问题


In a pandas dataframe, for every row, I want to keep only the top N values and set everything else to 0. I can iterate through the rows and do it but I am sure python/pandas can do it elegantly in a single line.

For e.g.: for N = 2

Input:
A   B   C   D
4   10  10  6
5   20  50  90
6   30  6   4
7   40  12  9

Output:
A   B   C   D
0   10  10  0
0   0   50  90
6   30  6   0
0   40  12  0

回答1:


Using rank with parameters axis=1 and method='min' and ascending=False as:

N = 2
df = df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0)

Or using np.where with pd.DataFrame which is faster than mask method:

df = pd.DataFrame(np.where(df.rank(axis=1,method='min',ascending=False)>N, 0, df),
                  columns=df.columns)

print(df)
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0

Explanation :

Step 1: First we need to find what are the 2 smallest numbers in the row and also if there is a duplicate that need to be taken account. So, using axis=1 ranks across rows and duplicate values will be taken care by method='min' and ascending = False:

print(df.rank(axis=1, method='min', ascending=False))
     A    B    C    D
0  4.0  1.0  1.0  3.0
1  4.0  3.0  2.0  1.0
2  2.0  1.0  2.0  4.0
3  4.0  1.0  2.0  3.0

Step 2: Second we need to filter where the values is greater than (N) as per condition and then change those values using mask:

print(df.rank(axis=1, method='min', ascending=False) > N)
       A      B      C      D
0   True  False  False   True
1   True   True  False  False
2  False  False  False   True
3   True  False  False   True

print(df.mask(df.rank(axis=1, method='min', ascending=False) > N, 0))
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0



回答2:


Use:

N = 2
df = df.where(df.apply(lambda x: x.isin(x.nlargest(N)), axis=1), 0)
print (df)
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0

Or:

import heapq
N = 2
df = df.where(df.apply(lambda x: x.isin(heapq.nlargest(N, x)), axis=1), 0)
print (df)
   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0



回答3:


Use nlargest to get N largest numbers:

df.mask(~df.apply(lambda x: x.isin(x.nlargest(2)), axis=1), 0)

Outpu:

    A   B   C   D
0   0   10  10  0
1   0   0   50  90
2   6   30  6   0
3   0   40  12  0



回答4:


You can use scipy.stats.rankdata via np.apply_along_axis, and feed to pd.DataFrame.where:

from scipy.stats import rankdata

df[:] = df.where(np.apply_along_axis(rankdata, 1, df, method='max') > 2, 0)

print(df)

   A   B   C   D
0  0  10  10   0
1  0   0  50  90
2  6  30   6   0
3  0  40  12   0

Performance benchmarking

pd.DataFrame.rank is most efficient of solutions below; apply + lambda perform worst.

from scipy.stats import rankdata
from heapq import nlargest

df = pd.concat([df]*100, ignore_index=True)

%timeit df.mask(df.rank(axis=1, method='min', ascending=False) > 2, 0)       # 2.23 ms per loop
%timeit df.where(np.apply_along_axis(rankdata, 1, df, method='max') > 2, 0)  # 45 ms per loop
%timeit df.where(df.apply(lambda x: x.isin(nlargest(2, x)), axis=1), 0)      # 92.4 ms per loop
%timeit df.mask(~df.apply(lambda x: x.isin(x.nlargest(2)), axis=1), 0)       # 274 ms per loop


来源:https://stackoverflow.com/questions/53169397/pandas-keep-only-top-n-values-and-set-others-to-0

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!