Compare (assert equality of) two complex data structures containing numpy arrays in unittest

后端 未结 7 1742
醉酒成梦
醉酒成梦 2020-12-06 16:25

I use Python\'s unittest module and want to check if two complex data structures are equal. The objects can be lists of dicts with all sorts of values: numbers,

相关标签:
7条回答
  • 2020-12-06 16:44

    I've run into the same issue, and developed a function to compare equality based on creating a fixed hash for the object. This has the added advantage that you can test that an object is as expected by comparing it's hash against a fixed has shored in code.

    The code (a stand-alone python file, is here). There are two functions: fixed_hash_eq, which solves your problem, and compute_fixed_hash, which makes a hash from the structure. Tests are here

    Here's a test:

    obj1 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
    obj2 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
    obj3 = [1, 'd', {'a': 4, 'b': np.arange(10)}, (7, [1, 2, 3, 4, 5])]
    obj3[2]['b'][4] = 0
    assert fixed_hash_eq(obj1, obj2)
    assert not fixed_hash_eq(obj1, obj3)
    
    0 讨论(0)
  • 2020-12-06 16:47

    Would have commented, but it gets too long...

    Fun fact, you cannot use == to test if arrays are the same I would suggest you use np.testing.assert_array_equal instead.

    1. that checks dtype, shape, etc.,
    2. that doesn't fail for the neat little math of (float('nan') == float('nan')) == False (normal python sequence == has an even more fun way of ignoring this sometimes, because it uses PyObject_RichCompareBool which does a (for NaNs incorrect) is quick check (for testing of course that is perfect)...
    3. There is also assert_allclose because floating point equality can get very tricky if you do actual calculations and you usually want almost the same values, since the values can become hardware depended or possibly random depending what you do with them.

    I would almost suggest trying serializing it with pickle if you want something this insanely nested, but that is overly strict (and point 3 is of course fully broken then), for example the memory layout of your array does not matter, but matters to its serialization.

    0 讨论(0)
  • 2020-12-06 16:47

    So the idea illustrated by jterrace seems to work for me with a slight modification:

    class SaneEqualityArray(np.ndarray):
        def __eq__(self, other):
            return (isinstance(other, np.ndarray) and self.shape == other.shape and 
                np.allclose(self, other))
    

    Like I said, the container with these objects should be on the left side of the equality check. I create SaneEqualityArray objects from existing numpy.ndarrays like this:

    SaneEqualityArray(my_array.shape, my_array.dtype, my_array)
    

    in accordance with ndarray constructor signature:

    ndarray(shape, dtype=float, buffer=None, offset=0,
            strides=None, order=None)
    

    This class is defined within the test suite and serves for testing purposes only. The RHS of the equality check is an actual object returned by the tested function and contains real numpy.ndarray objects.

    P.S. Thanks to the authors of both answers posted so far, they were both very helpful. If anyone sees any problems with this approach, I'd appreciate your feedback.

    0 讨论(0)
  • 2020-12-06 17:02

    check numpy.testing.assert_almost_equal which "raises an AssertionError if two items are not equal up to desired precision", e.g.:

     import numpy.testing as npt
     npt.assert_almost_equal(np.array([1.0,2.3333333333333]),
                             np.array([1.0,2.33333334]), decimal=9)
    
    0 讨论(0)
  • 2020-12-06 17:03

    The assertEqual function will invoke the __eq__ method of objects, which should recurse for complex data types. The exception is numpy, which doesn't have a sane __eq__ method. Using a numpy subclass from this question, you can restore sanity to the equality behavior:

    import copy
    import numpy
    import unittest
    
    class SaneEqualityArray(numpy.ndarray):
        def __eq__(self, other):
            return (isinstance(other, SaneEqualityArray) and
                    self.shape == other.shape and
                    numpy.ndarray.__eq__(self, other).all())
    
    class TestAsserts(unittest.TestCase):
    
        def testAssert(self):
            tests = [
                [1, 2],
                {'foo': 2},
                [2, 'foo', {'d': 4}],
                SaneEqualityArray([1, 2]),
                {'foo': {'hey': SaneEqualityArray([2, 3])}},
                [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
                 SaneEqualityArray([5, 6]), 34]
            ]
            for t in tests:
                self.assertEqual(t, copy.deepcopy(t))
    
    if __name__ == '__main__':
        unittest.main()
    

    This test passes.

    0 讨论(0)
  • 2020-12-06 17:05

    Building on @dbw (with thanks), the following method inserted within the test-case subclass worked well for me:

     def assertNumpyArraysEqual(self,this,that,msg=''):
        '''
        modified from http://stackoverflow.com/a/15399475/5459638
        '''
        if this.shape != that.shape:
            raise AssertionError("Shapes don't match")
        if not np.allclose(this,that):
            raise AssertionError("Elements don't match!")
    

    I had it called as self.assertNumpyArraysEqual(this,that) inside my test case methods and worked like a charm.

    0 讨论(0)
提交回复
热议问题