问题
I am learning Python and Pandas and I am doing some exercises to understand how things work. My question is the following: can I use the GroupBy.filter() method to select the DataFrame's rows that have a value (in a specific column) greater than the mean of the respective group?
For this exercise, I am using the "planets" dataset included in Seaborn: 1035 rows x 6 columns (column names: "method", "number", "orbital_period", "mass", "distance", "year").
In python:
import pandas as pd
import seaborn as sns
#Load the "planets" dataset included in Seaborn
data = sns.load_dataset("planets")
#Remove rows with NaN in "orbital_period"
data = data.dropna(how = "all", subset = ["orbital_period"])
#Set display of DataFrames for seeing all the columns:
pd.set_option("display.max_columns", 15)
#Group the DataFrame "data" by "method" ()
group1 = data.groupby("method")
#I obtain a DataFrameGroupBy object (group1) composed of 10 groups.
print(group1)
#Print the composition of the DataFrameGroupBy object "group1".
for lab, datafrm in group1:
print(lab, "\n", datafrm, sep="", end="\n\n")
print()
print()
print()
#Define the filter_function that will be used by the filter method.
#I want a function that returns True whenever the "orbital_period" value for
#a row is greater than the mean of the corresponding group's mean.
#This could have been done also directly with "lambda syntax" as argument
#of filter().
def filter_funct(x):
#print(type(x))
#print(x)
return x["orbital_period"] > x["orbital_period"].mean()
dataFiltered = group1.filter(filter_funct)
print("RESULT OF THE FILTER METHOD:")
print()
print(dataFiltered)
print()
print()
Unluckily, I obtain the following error when I run the script.
TypeError: filter function returned a Series, but expected a scalar bool
It looks like x["orbital_period"] does not behave as a vector, meaning that it does not return the single values of the Series... Weirdly enough the transform() method does not suffer from this problem. Indeed on the same dataset (prepared as above) if I run the following:
#Define the transform_function that will be used by the transform() method.
#I want this function to subtract from each value in "orbital_period" the mean
#of the corresponding group.
def transf_funct(x):
#print(type(x))
#print(x)
return x-x.mean()
print("Transform method runs:")
print()
#I directly assign the transformed values to the "orbital_period" column of the DataFrame.
data["orbital_period"] = group1["orbital_period"].transform(transf_funct)
print("RESULT OF THE TRANSFORM METHOD:")
print()
print(data)
print()
print()
print()
I obtain the expected result...
Do DataFrameGroupBy.filter() and DataFrameGroupBy.transform() have different behavior? I know I can achieve what I want in many other ways but my question is: Is there a way to achieve what I want making use of the DataFrameGroupBy.filter() method?
回答1:
Can I use DataFrameGroupBy.filter to exclude specific rows within a group?
The answer is No. DataFrameGroupBy.filter
uses a single Boolean value to characterize an entire group. The result of the filtering is to remove the entirety of a group if it is characterized as False
.
DataFrameGroupBy.filter
is very slow, so it's often advised to use transform
to broadcast the single truth value to all rows within a group and then to subset the DataFrame1. Here is an example of removing entire groups where the mean is <= 50. The filter
method is 100x slower.
import pandas as pd
import numpy as np
N = 10000
df = pd.DataFrame({'grp': np.arange(0,N,1)//10,
'value': np.arange(0,N,1)%100})
# With Filter
%timeit df.groupby('grp').filter(lambda x: x['value'].mean() > 50)
#327 ms ± 2.42 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
# With Transform
%timeit df[df.groupby('grp')['value'].transform('mean') > 50]
#2.7 ms ± 39.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
# Verify they are equivalent
(df.groupby('grp').filter(lambda x: x['value'].mean() > 50)
== df[df.groupby('grp')['value'].transform('mean') > 50]).all().all()
#True
1The gain in performance comes form the fact that transform
may allow you to use a GroupBy operation which is implemented in cython, which is the case for mean
. If this is not the case filter
may be just as performant, if not slightly better.
Finally, because DataFrameGroupBy.transform
broadcasts a result to the entire group, it is the correct tool to use when needing to exclude specific rows within a group based on an overall group characteristic.
In the above example, if you want to keep rows within a group that are above the group mean it is
df[df['value'] > df.groupby('grp')['value'].transform('mean')]
# Compare to the mean of the group the row
# each row belongs to
来源:https://stackoverflow.com/questions/58731702/pandas-use-dataframegroupby-filter-method-to-select-dataframes-rows-with-a-v