Pandas: Use DataFrameGroupBy.filter() method to select DataFrame's rows with a value greater than the mean of the respective group

混江龙づ霸主 提交于 2020-07-22 21:34:06

问题


I am learning Python and Pandas and I am doing some exercises to understand how things work. My question is the following: can I use the GroupBy.filter() method to select the DataFrame's rows that have a value (in a specific column) greater than the mean of the respective group?

For this exercise, I am using the "planets" dataset included in Seaborn: 1035 rows x 6 columns (column names: "method", "number", "orbital_period", "mass", "distance", "year").

In python:

import pandas as pd
import seaborn as sns

#Load the "planets" dataset included in Seaborn
data = sns.load_dataset("planets")

#Remove rows with NaN in "orbital_period"
data = data.dropna(how = "all", subset = ["orbital_period"])

#Set display of DataFrames for seeing all the columns:
pd.set_option("display.max_columns", 15)

#Group the DataFrame "data" by "method" ()
group1 = data.groupby("method")
#I obtain a DataFrameGroupBy object (group1) composed of 10 groups.
print(group1)
#Print the composition of the DataFrameGroupBy object "group1".
for lab, datafrm in group1:
    print(lab, "\n", datafrm, sep="", end="\n\n")
print()
print()
print()


#Define the filter_function that will be used by the filter method.
#I want a function that returns True whenever the "orbital_period" value for 
#a row is greater than the mean of the corresponding group's mean.
#This could have been done also directly with "lambda syntax" as argument
#of filter().
def filter_funct(x):
    #print(type(x))
    #print(x)
    return x["orbital_period"] > x["orbital_period"].mean()


dataFiltered = group1.filter(filter_funct)
print("RESULT OF THE FILTER METHOD:")
print()
print(dataFiltered)
print()
print()

Unluckily, I obtain the following error when I run the script.

TypeError: filter function returned a Series, but expected a scalar bool

It looks like x["orbital_period"] does not behave as a vector, meaning that it does not return the single values of the Series... Weirdly enough the transform() method does not suffer from this problem. Indeed on the same dataset (prepared as above) if I run the following:

#Define the transform_function that will be used by the transform() method.
#I want this function to subtract from each value in "orbital_period" the mean
#of the corresponding group.
def transf_funct(x):
    #print(type(x))
    #print(x)
    return x-x.mean()

print("Transform method runs:")
print()
#I directly assign the transformed values to the "orbital_period" column of the DataFrame.
data["orbital_period"] = group1["orbital_period"].transform(transf_funct)
print("RESULT OF THE TRANSFORM METHOD:")
print()
print(data)
print()
print()
print()

I obtain the expected result...

Do DataFrameGroupBy.filter() and DataFrameGroupBy.transform() have different behavior? I know I can achieve what I want in many other ways but my question is: Is there a way to achieve what I want making use of the DataFrameGroupBy.filter() method?


回答1:


Can I use DataFrameGroupBy.filter to exclude specific rows within a group?

The answer is No. DataFrameGroupBy.filter uses a single Boolean value to characterize an entire group. The result of the filtering is to remove the entirety of a group if it is characterized as False.

DataFrameGroupBy.filter is very slow, so it's often advised to use transform to broadcast the single truth value to all rows within a group and then to subset the DataFrame1. Here is an example of removing entire groups where the mean is <= 50. The filter method is 100x slower.

import pandas as pd
import numpy as np

N = 10000
df = pd.DataFrame({'grp': np.arange(0,N,1)//10,
                   'value': np.arange(0,N,1)%100})

# With Filter
%timeit df.groupby('grp').filter(lambda x: x['value'].mean() > 50)
#327 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# With Transform
%timeit df[df.groupby('grp')['value'].transform('mean') > 50]
#2.7 ms ± 39.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# Verify they are equivalent
(df.groupby('grp').filter(lambda x: x['value'].mean() > 50) 
  == df[df.groupby('grp')['value'].transform('mean') > 50]).all().all()
#True

1The gain in performance comes form the fact that transform may allow you to use a GroupBy operation which is implemented in cython, which is the case for mean. If this is not the case filter may be just as performant, if not slightly better.


Finally, because DataFrameGroupBy.transform broadcasts a result to the entire group, it is the correct tool to use when needing to exclude specific rows within a group based on an overall group characteristic.

In the above example, if you want to keep rows within a group that are above the group mean it is

df[df['value'] > df.groupby('grp')['value'].transform('mean')]
   # Compare          to the mean of the group the row 
   # each row                   belongs to 


来源:https://stackoverflow.com/questions/58731702/pandas-use-dataframegroupby-filter-method-to-select-dataframes-rows-with-a-v

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!