"""Engines with multi-process parallelization."""
import logging
import multiprocessing
import os
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Any
import cloudpickle as pickle
from ..util import tqdm
from .base import Engine
from .task import Task
logger = logging.getLogger(__name__)
def work(pickled_task) -> Any:
"""Unpickle and execute task.
Parameters
----------
pickled_task:
A pickled Task object to execute.
Returns
-------
The result of executing the task.
"""
task = pickle.loads(pickled_task)
return task.execute()
[docs]
class MultiProcessEngine(Engine):
"""
Parallelize the task execution using multiprocessing.
Parameters
----------
n_procs:
The maximum number of processes to use in parallel.
Defaults to the number of CPUs available on the system according to
`os.cpu_count()`.
The effectively used number of processes will be the minimum of
`n_procs` and the number of tasks submitted. Defaults to ``None``.
method:
Start method, any of "fork", "spawn", "forkserver", or None,
giving the system specific default context. Defaults to ``None``.
"""
[docs]
def __init__(
self,
n_procs: int | None = None,
method: str | None = None,
):
super().__init__()
if n_procs is None:
n_procs = os.cpu_count()
logger.info(
f"Engine will use up to {n_procs} processes (= CPU count)."
)
self.n_procs: int = n_procs
self.method: str = method
[docs]
def execute(
self, tasks: list[Task], progress_bar: bool = None
) -> list[Any]:
"""Pickle tasks and distribute work over parallel processes.
Tasks are pickled on-demand as workers become available.
Parameters
----------
tasks:
List of :class:`pypesto.engine.Task` to execute.
progress_bar:
Whether to display a progress bar.
Returns
-------
A list of results in the same order as the input tasks.
"""
n_tasks = len(tasks)
n_procs = min(self.n_procs, n_tasks)
logger.debug(f"Parallelizing on {n_procs} processes.")
ctx = multiprocessing.get_context(method=self.method)
# Use ProcessPoolExecutor for on-demand pickling
with ProcessPoolExecutor(
max_workers=n_procs, mp_context=ctx
) as executor:
# Submit tasks and track futures
future_to_index = {
executor.submit(work, pickle.dumps(task)): i
for i, task in enumerate(tasks)
}
# Collect results in original order
results = [None] * n_tasks
for future in tqdm(
as_completed(future_to_index),
total=n_tasks,
enable=progress_bar,
):
index = future_to_index[future]
results[index] = future.result()
return results