Can't pickle static method - Multiprocessing - Python

后端 未结 3 734
野性不改
野性不改 2020-12-10 04:36

I\'m applying some parallelization to my code, in which I use classes. I knew that is not possible to pick a class method without any other approach different of what Python

相关标签:
3条回答
  • 2020-12-10 04:52

    If you use a fork of multiprocessing called pathos.multiprocesssing, you can directly use classes and class methods in multiprocessing's map functions. This is because dill is used instead of pickle or cPickle, and dill can serialize almost anything in python.

    pathos.multiprocessing also provides an asynchronous map function… and it can map functions with multiple arguments (e.g. map(math.pow, [1,2,3], [4,5,6]))

    See: What can multiprocessing and dill do together?

    and: http://matthewrocklin.com/blog/work/2013/12/05/Parallelism-and-Serialization/

    >>> from pathos.multiprocessing import ProcessingPool as Pool
    >>> 
    >>> p = Pool(4)
    >>> 
    >>> def add(x,y):
    ...   return x+y
    ... 
    >>> x = [0,1,2,3]
    >>> y = [4,5,6,7]
    >>> 
    >>> p.map(add, x, y)
    [4, 6, 8, 10]
    >>> 
    >>> class Test(object):
    ...   def plus(self, x, y): 
    ...     return x+y
    ... 
    >>> t = Test()
    >>> 
    >>> p.map(Test.plus, [t]*4, x, y)
    [4, 6, 8, 10]
    >>> 
    >>> p.map(t.plus, x, y)
    [4, 6, 8, 10]
    

    Get the code here: https://github.com/uqfoundation/pathos

    pathos also has an asynchronous map (amap), as well as imap.

    0 讨论(0)
  • 2020-12-10 04:56

    You could define a plain function at the module level and a staticmethod as well. This preserves the calling syntax, introspection and inheritability features of a staticmethod, while avoiding the pickling problem:

    def aux():
        return "VoG - Sucess" 
    
    class VariabilityOfGradients(object):
        aux = staticmethod(aux)
    

    For example,

    import copy_reg
    import types
    from itertools import product
    import multiprocessing as mp
    
    def _pickle_method(method):
        """
        Author: Steven Bethard (author of argparse)
        http://bytes.com/topic/python/answers/552476-why-cant-you-pickle-instancemethods
        """
        func_name = method.im_func.__name__
        obj = method.im_self
        cls = method.im_class
        cls_name = ''
        if func_name.startswith('__') and not func_name.endswith('__'):
            cls_name = cls.__name__.lstrip('_')
        if cls_name:
            func_name = '_' + cls_name + func_name
        return _unpickle_method, (func_name, obj, cls)
    
    
    def _unpickle_method(func_name, obj, cls):
        """
        Author: Steven Bethard
        http://bytes.com/topic/python/answers/552476-why-cant-you-pickle-instancemethods
        """
        for cls in cls.mro():
            try:
                func = cls.__dict__[func_name]
            except KeyError:
                pass
            else:
                break
        return func.__get__(obj, cls)
    
    copy_reg.pickle(types.MethodType, _pickle_method, _unpickle_method)
    
    class ImageData(object):
    
        def __init__(self, width=60, height=60):
            self.width = width
            self.height = height
            self.data = []
            for i in range(width):
                self.data.append([0] * height)
    
        def shepard_interpolation(self, seeds=20):
            print "ImD - Success"       
    
    def aux():
        return "VoG - Sucess" 
    
    class VariabilityOfGradients(object):
        aux = staticmethod(aux)
    
        @staticmethod
        def calculate_orientation_uncertainty():
            pool = mp.Pool()
            results = []
            for x, y in product(range(1, 5), range(1, 5)):
                # result = pool.apply_async(aux) # this works too
                result = pool.apply_async(VariabilityOfGradients.aux, callback=results.append)
            pool.close()
            pool.join()
            print(results)
    
    
    if __name__ == '__main__':  
        results = []
        pool = mp.Pool()
        for _ in range(3):
            result = pool.apply_async(ImageData.shepard_interpolation, args=[ImageData()])
            results.append(result.get())
        pool.close()
        pool.join()
    
        VariabilityOfGradients.calculate_orientation_uncertainty()   
    

    yields

    ImD - Success
    ImD - Success
    ImD - Success
    ['VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess', 'VoG - Sucess']
    

    By the way, result.get() blocks the calling process until the function called by pool.apply_async (e.g. ImageData.shepard_interpolation) is completed. So

    for _ in range(3):
        result = pool.apply_async(ImageData.shepard_interpolation, args=[ImageData()])
        results.append(result.get())
    

    is really calling ImageData.shepard_interpolation sequentially, defeating the purpose of the pool.

    Instead you could use

    for _ in range(3):
        pool.apply_async(ImageData.shepard_interpolation, args=[ImageData()],
                         callback=results.append)
    

    The callback function (e.g. results.append) is called in a thread of the calling process when the function is completed. It is sent one argument -- the return value of the function. Thus nothing blocks the three pool.apply_async calls from being made quickly, and the work done by the three calls to ImageData.shepard_interpolation will be performed concurrently.

    Alternatively, it might be simpler to just use pool.map here.

    results = pool.map(ImageData.shepard_interpolation, [ImageData()]*3)
    
    0 讨论(0)
  • 2020-12-10 04:58

    I'm not sure if this is what you are looking for but my use was slightly different. I wanted to use a method from a class within the same class running on multiple threads.

    This is how I implemented it:

    from multiprocessing import Pool
    
    class Product(object):
    
            def __init__(self):
                    self.logger = "test"
    
            def f(self, x):
                    print(self.logger)
                    return x*x
    
            def multi(self):
                    p = Pool(5)
                    print(p.starmap(Product.f, [(Product(), 1), (Product(), 2), (Product(), 3)]))
    
    
    if __name__ == '__main__':
            obj = Product()
            obj.multi()
    
    0 讨论(0)
提交回复
热议问题