问题
For clarity I will extract an excerpt from my code and use general names. I have a class Foo()
that stores a DataFrame to an attribute.
import pandas as pd
import pandas.util.testing as pdt
class Foo():
def __init__(self, bar):
self.bar = bar # dict of dicts
self.df = pd.DataFrame(bar) # pandas object
def __eq__(self, other):
if isinstance(other, self.__class__):
return self.__dict__ == other.__dict__
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result
However, when I try to compare two instances of Foo
, I get an excepetion related to the ambiguity of comparing two DataFrames (the comparison should work fine without the 'df' key in Foo.__dict__
).
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()
foo1 = Foo(d1)
foo2 = Foo(d2)
foo1.bar # dict
foo1.df # pandas DataFrame
foo1 == foo2 # ValueError
[Out] ValueError: The truth value of a DataFrame is ambiguous. Use a.empty, a.bool(), a.item(), a.any() or a.all().
Fortunately, pandas has utility functions for asserting whether two DataFrames or Series are true. I'd like to use this function's comparison operation if possible.
pdt.assert_frame_equal(pd.DataFrame(d1), pd.DataFrame(d2)) # no raises
There are a few options to resolve the comparison of two Foo
instances:
- compare a copy of
__dict__
, wherenew_dict
lacks the df key - delete the df key from
__dict__
(not ideal) - don't compare
__dict__
, but only parts of it contained in a tuple - overload the
__eq__
to facilitate pandas DataFrame comparisons
The last option seems the most robust in the long-run, but I am not sure of the best approach. In the end, I would like to refactor __eq__
to compare all items from Foo.__dict__
, including DataFrames (and Series). Any ideas on how to accomplish this?
回答1:
Solution from these threads
Comparing two pandas dataframes for differences
Pandas DataFrames with NaNs equality comparison
def df_equal(self):
try:
assert_frame_equal(csvdata, csvdata_old)
return True
except:
return False
For a dictionary of dataframes:
def df_equal(df1, df2):
try:
assert_frame_equal(df1, df2)
return True
except:
return False
def __eq__(self, other):
if self.df.keys() != other.keys():
return False
for k in self.df.keys():
if not df_equal(self.df[k], other[k]):
return False
return True
回答2:
The following code seems to satisfy my original question completely. It handles both pandas DataFrames
and Series
. Simplifications are welcome.
The trick here is that __eq__
has been implemented to compare __dict__
and pandas objects separately. The truthiness of each is finally compared and the result returned. Something interesting and exploited here, and
returns the second value if the the first value is True
.
The idea for using error handling and an external comparison function was inspired by an answer submitted by @ate50eggs. Many thanks.
import pandas as pd
import pandas.util.testing as pdt
def ndframe_equal(ndf1, ndf2):
try:
if isinstance(ndf1, pd.DataFrame) and isinstance(ndf2, pd.DataFrame):
pdt.assert_frame_equal(ndf1, ndf2)
#print('DataFrame check:', type(ndf1), type(ndf2))
elif isinstance(ndf1, pd.Series) and isinstance(ndf2, pd.Series):
pdt.assert_series_equal(ndf1, ndf2)
#print('Series check:', type(ndf1), type(ndf2))
return True
except (ValueError, AssertionError, AttributeError):
return False
class Foo(object):
def __init__(self, bar):
self.bar = bar
try:
self.ndf = pd.DataFrame(bar)
except(ValueError):
self.ndf = pd.Series(bar)
def __eq__(self, other):
if isinstance(other, self.__class__):
# Auto check attrs if assigned to DataFrames/Series, then add to list
blacklisted = [attr for attr in self.__dict__ if
isinstance(getattr(self, attr), pd.DataFrame)
or isinstance(getattr(self, attr), pd.Series)]
# Check DataFrames and Series
for attr in blacklisted:
ndf_eq = ndframe_equal(getattr(self, attr),
getattr(other, attr))
# Ignore pandas objects; check rest of __dict__ and build new dicts
self._dict = {
key: value
for key, value in self.__dict__.items()
if key not in blacklisted}
other._dict = {
key: value
for key, value in other.__dict__.items()
if key not in blacklisted}
return ndf_eq and self._dict == other._dict # order is important
return NotImplemented
def __ne__(self, other):
result = self.__eq__(other)
if result is NotImplemented:
return result
return not result
Testing the latter code on DataFrames
.
# Data for DataFrames
d1 = {'A' : pd.Series([1, 2], index=['a', 'b']),
'B' : pd.Series([1, 2], index=['a', 'b'])}
d2 = d1.copy()
d3 = {'A' : pd.Series([1, 2], index=['abc', 'b']),
'B' : pd.Series([9, 0], index=['abc', 'b'])}
# Test DataFrames
foo1 = Foo(d1)
foo2 = Foo(d2)
foo1.bar # dict of Series
foo1.ndf # pandas DataFrame
foo1 == foo2 # triggers _dict
#foo1.__dict__['_dict']
#foo1._dict
foo1 == foo2 # True
foo1 != foo2 # False
not foo1 == foo2 # False
not foo1 != foo2 # True
foo2 = Foo(d3)
foo1 == foo2 # False
foo1 != foo2 # True
not foo1 == foo2 # True
not foo1 != foo2 # False
Finally testing on another common pandas object, the Series
.
# Data for Series
s1 = {'a' : 0., 'b' : 1., 'c' : 2.}
s2 = s1.copy()
s3 = {'a' : 0., 'b' : 4, 'c' : 5}
# Test Series
foo3 = Foo(s1)
foo4 = Foo(s2)
foo3.bar # dict
foo4.ndf # pandas Series
foo3 == foo4 # True
foo3 != foo4 # False
not foo3 == foo4 # False
not foo3 != foo4 # True
foo4 = Foo(s3)
foo3 == foo4 # False
foo3 != foo4 # True
not foo3 == foo4 # True
not foo3 != foo4 # False
来源:https://stackoverflow.com/questions/32770797/how-do-i-overload-eq-to-compare-pandas-dataframes-and-series