SimpleJSON and NumPy array

后端 未结 9 1372
挽巷
挽巷 2020-12-04 10:56

What is the most efficient way of serializing a numpy array using simplejson?

9条回答
  •  时光取名叫无心
    2020-12-04 11:41

    I just discovered tlausch's answer to this Question and realized it gives the almost correct answer for my problem, but at least for me it does not work in Python 3.5, because of several errors: 1 - infinite recursion 2 - the data was saved as None

    since i can not directly comment on the original answer yet, here is my version:

    import base64
    import json
    import numpy as np
    
        class NumpyEncoder(json.JSONEncoder):
            def default(self, obj):
                """If input object is an ndarray it will be converted into a dict
                holding dtype, shape and the data, base64 encoded.
                """
                if isinstance(obj, np.ndarray):
                    if obj.flags['C_CONTIGUOUS']:
                        obj_data = obj.data
                    else:
                        cont_obj = np.ascontiguousarray(obj)
                        assert(cont_obj.flags['C_CONTIGUOUS'])
                        obj_data = cont_obj.data
                    data_b64 = base64.b64encode(obj_data)
                    return dict(__ndarray__= data_b64.decode('utf-8'),
                                dtype=str(obj.dtype),
                                shape=obj.shape)
    
    
        def json_numpy_obj_hook(dct):
            """Decodes a previously encoded numpy ndarray with proper shape and dtype.
    
            :param dct: (dict) json encoded ndarray
            :return: (ndarray) if input was an encoded ndarray
            """
            if isinstance(dct, dict) and '__ndarray__' in dct:
                data = base64.b64decode(dct['__ndarray__'])
                return np.frombuffer(data, dct['dtype']).reshape(dct['shape'])
            return dct
    
    expected = np.arange(100, dtype=np.float)
    dumped = json.dumps(expected, cls=NumpyEncoder)
    result = json.loads(dumped, object_hook=json_numpy_obj_hook)
    
    
    # None of the following assertions will be broken.
    assert result.dtype == expected.dtype, "Wrong Type"
    assert result.shape == expected.shape, "Wrong Shape"
    assert np.allclose(expected, result), "Wrong Values"    
    

提交回复
热议问题