Tracking progress of joblib.Parallel execution

后端 未结 8 839
隐瞒了意图╮
隐瞒了意图╮ 2020-12-24 12:08

Is there a simple way to track the overall progress of a joblib.Parallel execution?

I have a long-running execution composed of thousands of jobs, which I want to tr

8条回答
  •  醉话见心
    2020-12-24 12:28

    TLDR solution:

    Works with joblib 0.14.0 and tqdm 4.46.0 using python 3.5. Credits to frenzykryger for contextlib suggestions, dano and Connor for monkey patching idea.

    import contextlib
    import joblib
    from tqdm import tqdm
    from joblib import Parallel, delayed
    
    @contextlib.contextmanager
    def tqdm_joblib(tqdm_object):
        """Context manager to patch joblib to report into tqdm progress bar given as argument"""
    
        def tqdm_print_progress(self):
            if self.n_completed_tasks > tqdm_object.n:
                n_completed = self.n_completed_tasks - tqdm_object.n
                tqdm_object.update(n=n_completed)
    
        original_print_progress = joblib.parallel.Parallel.print_progress
        joblib.parallel.Parallel.print_progress = tqdm_print_progress
    
        try:
            yield tqdm_object
        finally:
            joblib.parallel.Parallel.print_progress = original_print_progress
            tqdm_object.close()
    

    You can use this the same way as described by frenzykryger

    import time
    def some_method(wait_time):
        time.sleep(wait_time)
    
    with tqdm_joblib(tqdm(desc="My method", total=10)) as progress_bar:
        Parallel(n_jobs=2)(delayed(some_method)(0.2) for i in range(10))
    

    Longer explanation:

    The solution by Jon is simple to implement, but it only measures the dispatched task. If the task takes a long time, the bar will be stuck at 100% while waiting for the last dispatched task to finish execution.

    The context manager approach by frenzykryger, improved from dano and Connor, is better, but the BatchCompletionCallBack can also be called with ImmediateResult before the task completes (See Intermediate results from joblib). This is going to get us a count that is over 100%.

    Instead of monkey patching the BatchCompletionCallBack, we can just patch the print_progress function in Parallel. The BatchCompletionCallBack already calls this print_progress anyway. If the verbose is set (i.e. Parallel(n_jobs=2, verbose=100)), the print_progress will be printing out completed tasks, though not as nice as tqdm. Looking at the code, the print_progress is a class method, so it already has self.n_completed_tasks that logs the number we want. All we have to do is just to compare this with the current state of joblib's progress and update only if there is a difference.

    This was tested in joblib 0.14.0 and tqdm 4.46.0 using python 3.5.

提交回复
热议问题