Compare (assert equality of) two complex data structures containing numpy arrays in unittest

后端 未结 7 1764
醉酒成梦
醉酒成梦 2020-12-06 16:25

I use Python\'s unittest module and want to check if two complex data structures are equal. The objects can be lists of dicts with all sorts of values: numbers,

7条回答
  •  猫巷女王i
    2020-12-06 17:03

    The assertEqual function will invoke the __eq__ method of objects, which should recurse for complex data types. The exception is numpy, which doesn't have a sane __eq__ method. Using a numpy subclass from this question, you can restore sanity to the equality behavior:

    import copy
    import numpy
    import unittest
    
    class SaneEqualityArray(numpy.ndarray):
        def __eq__(self, other):
            return (isinstance(other, SaneEqualityArray) and
                    self.shape == other.shape and
                    numpy.ndarray.__eq__(self, other).all())
    
    class TestAsserts(unittest.TestCase):
    
        def testAssert(self):
            tests = [
                [1, 2],
                {'foo': 2},
                [2, 'foo', {'d': 4}],
                SaneEqualityArray([1, 2]),
                {'foo': {'hey': SaneEqualityArray([2, 3])}},
                [{'foo': SaneEqualityArray([3, 4]), 'd': {'doo': 3}},
                 SaneEqualityArray([5, 6]), 34]
            ]
            for t in tests:
                self.assertEqual(t, copy.deepcopy(t))
    
    if __name__ == '__main__':
        unittest.main()
    

    This test passes.

提交回复
热议问题