I want to run a function in parallel, and wait until all parallel nodes are done, using joblib. Like in the example:
from math import sqrt
from joblib import Par
As noted above, solutions that simply wrap the iterable passed to joblib.Parallel()
do not truly monitor the progress of execution. Instead, I suggest subclassing Parallel
and overriding the print_progress()
method, as follows:
import joblib
from tqdm.auto import tqdm
class ProgressParallel(joblib.Parallel):
def __call__(self, *args, **kwargs):
with tqdm() as self._pbar:
return joblib.Parallel.__call__(self, *args, **kwargs)
def print_progress(self):
self._pbar.total = self.n_dispatched_tasks
self._pbar.n = self.n_completed_tasks
self._pbar.refresh()