06b9200f |
import abc
312d4040 |
import asyncio
890f84c0 |
import concurrent.futures as concurrent
58f7b397 |
import functools
b3873839 |
import inspect
58f7b397 |
import itertools
f3c5ea34 |
import pickle
142bf496 |
import sys
4069d322 |
import time
ccfe27d6 |
import traceback
890f84c0 |
import warnings
176c745c |
from contextlib import suppress
b00659d0 |
176c745c |
from adaptive.notebook_integration import in_ipynb, live_info, live_plot
312d4040 |
d74bacde |
8d97deb4 |
if sys.version_info < (3, 8):
# XXX: remove when ipyparallel 6.2.5 is released
import ipyparallel
716dbce8 |
8d97deb4 |
with_ipyparallel = True
with_ipyparallel = False
d74bacde |
except ModuleNotFoundError:
with_ipyparallel = False
import distributed
716dbce8 |
d74bacde |
with_distributed = True
except ModuleNotFoundError:
with_distributed = False
312d4040 |
d538ee25 |
import mpi4py.futures
716dbce8 |
d538ee25 |
with_mpi4py = True
except ModuleNotFoundError:
with_mpi4py = False
498f56b3 |
import loky
with_loky = True
except ModuleNotFoundError:
with_loky = False
82bde4fb |
with suppress(ModuleNotFoundError):
25ea1f5a |
import uvloop
716dbce8 |
25ea1f5a |
312d4040 |
0f85d3ab |
_default_executor = (
loky.get_reusable_executor if with_loky else concurrent.ProcessPoolExecutor
59877e6d |
class BaseRunner(metaclass=abc.ABCMeta):
94eb632d |
r"""Base class for runners that use `concurrent.futures.Executors`.
312d4040 |
eef11f0b |
learner : `~adaptive.BaseLearner` instance
b659f24f |
goal : callable
The end condition for the calculation. This function must take
the learner as its sole argument, and return True when we should
stop requesting more points.
eef11f0b |
executor : `concurrent.futures.Executor`, `distributed.Client`,\
d538ee25 |
`mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional
312d4040 |
The executor in which to evaluate the function to be learned.
eaab130b |
If not provided, a new `~concurrent.futures.ProcessPoolExecutor`.
3ff7c113 |
ntasks : int, optional
The number of concurrent function evaluations. Defaults to the number
eef11f0b |
of cores available in `executor`.
9ba68ad1 |
log : bool, default: False
b65cd92a |
If True, record the method calls made to the learner by this runner.
eef11f0b |
shutdown_executor : bool, default: False
c77b3d87 |
If True, shutdown the executor when the runner has completed. If
eef11f0b |
`executor` is not provided then the executor created internally
c77b3d87 |
by the runner is shut down, regardless of this parameter.
37b09daf |
retries : int, default: 0
eef11f0b |
Maximum amount of retries of a certain point ``x`` in
``learner.function(x)``. After `retries` is reached for ``x``
the point is present in ``runner.failed``.
37b09daf |
raise_if_retries_exceeded : bool, default: True
eef11f0b |
Raise the error after a point ``x`` failed `retries`.
312d4040 |
eef11f0b |
learner : `~adaptive.BaseLearner` instance
b65cd92a |
The underlying learner. May be queried for its state.
9ba68ad1 |
log : list or None
Record of the method calls made to the learner, in the format
eef11f0b |
``(method_name, *args)``.
878ae791 |
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
eef11f0b |
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
878ae791 |
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
a2ac197e |
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.
b65cd92a |
overhead : callable
3cc16fdb |
The overhead in percent of using Adaptive. Essentially, this is
eef11f0b |
``100 * (1 - total_elapsed_function_time / self.elapsed_time())``.
37b09daf |
312d4040 |
716dbce8 |
def __init__(
27be606a |
b659f24f |
self.executor = _ensure_executor(executor)
self.goal = goal
b3873839 |
ddfcf6b0 |
self._max_tasks = ntasks
82bde4fb |
a2ac197e |
self._pending_points = {}
b65cd92a |
2d931fe9 |
# if we instantiate our own executor, then we are also responsible
# for calling 'shutdown'
c77b3d87 |
self.shutdown_executor = shutdown_executor or (executor is None)
27be606a |
312d4040 |
self.learner = learner
9ba68ad1 |
self.log = [] if log else None
b65cd92a |
3cc16fdb |
# Timing
b65cd92a |
self.start_time = time.time()
self.end_time = None
self._elapsed_function_time = 0
# Error handling attributes
37b09daf |
self.retries = retries
self.raise_if_retries_exceeded = raise_if_retries_exceeded
878ae791 |
self._to_retry = {}
self._tracebacks = {}
b659f24f |
d1fd55a7 |
self._id_to_point = {}
58f7b397 |
self._next_id = functools.partial(
next, itertools.count()
) # some unique id to be associated with each point
7d972f94 |
eef11f0b |
def _get_max_tasks(self):
ddfcf6b0 |
return self._max_tasks or _get_ncores(self.executor)
df697cc1 |
7d972f94 |
def _do_raise(self, e, i):
878ae791 |
tb = self._tracebacks[i]
d1fd55a7 |
x = self._id_to_point[i]
b65cd92a |
raise RuntimeError(
716dbce8 |
"An error occured while evaluating "
b65cd92a |
f'"learner.function({x})". '
716dbce8 |
f"See the traceback for details.:\n\n{tb}"
b65cd92a |
) from e
def do_log(self):
return self.log is not None
37b09daf |
def _ask(self, n):
540b13f2 |
pids = [
for pid in self._to_retry.keys()
a2ac197e |
if pid not in self._pending_points.values()
540b13f2 |
loss_improvements = len(pids) * [float("inf")]
if len(pids) < n:
new_points, new_losses = self.learner.ask(n - len(pids))
716dbce8 |
loss_improvements += new_losses
540b13f2 |
for point in new_points:
pid = self._next_id()
self._id_to_point[pid] = point
return pids, loss_improvements
37b09daf |
b65cd92a |
def overhead(self):
"""Overhead of using Adaptive and the executor in percent.
eef11f0b |
This is measured as ``100 * (1 - t_function / t_elapsed)``.
b65cd92a |
This includes the overhead of the executor that is being used.
The slower your function is, the lower the overhead will be. The
learners take ~5-50 ms to suggest a point and sending that point to
the executor also takes about ~5 ms, so you will benefit from using
Adaptive whenever executing the function takes longer than 100 ms.
This of course depends on the type of executor and the type of learner
but is a rough rule of thumb.
t_function = self._elapsed_function_time
f8ff1d82 |
if t_function == 0:
# When no function is done executing, the overhead cannot
# reliably be determined, so 0 is the best we can do.
return 0
b65cd92a |
t_total = self.elapsed_time()
return (1 - t_function / t_total) * 100
def _process_futures(self, done_futs):
for fut in done_futs:
a2ac197e |
pid = self._pending_points.pop(fut)
b65cd92a |
3cc16fdb |
y = fut.result()
t = time.time() - fut.start_time # total execution time
b65cd92a |
except Exception as e:
540b13f2 |
self._tracebacks[pid] = traceback.format_exc()
self._to_retry[pid] = self._to_retry.get(pid, 0) + 1
if self._to_retry[pid] > self.retries:
37b09daf |
if self.raise_if_retries_exceeded:
540b13f2 |
self._do_raise(e, pid)
37b09daf |
2b8d7492 |
self._elapsed_function_time += t / self._get_max_tasks()
540b13f2 |
self._to_retry.pop(pid, None)
self._tracebacks.pop(pid, None)
x = self._id_to_point.pop(pid)
37b09daf |
if self.do_log:
716dbce8 |
self.log.append(("tell", x, y))
37b09daf |
self.learner.tell(x, y)
b65cd92a |
def _get_futures(self):
# Launch tasks to replace the ones that completed
# on the last iteration, making sure to fill workers
# that have started since the last iteration.
a2ac197e |
n_new_tasks = max(0, self._get_max_tasks() - len(self._pending_points))
b65cd92a |
if self.do_log:
716dbce8 |
self.log.append(("ask", n_new_tasks))
b65cd92a |
540b13f2 |
pids, _ = self._ask(n_new_tasks)
b65cd92a |
540b13f2 |
for pid in pids:
2b8d7492 |
start_time = time.time() # so we can measure execution time
540b13f2 |
point = self._id_to_point[pid]
fut = self._submit(point)
2b8d7492 |
fut.start_time = start_time
a2ac197e |
self._pending_points[fut] = pid
b65cd92a |
# Collect and results and add them to the learner
a2ac197e |
futures = list(self._pending_points.keys())
b65cd92a |
return futures
def _remove_unfinished(self):
# remove points with 'None' values from the learner
# cancel any outstanding tasks
a2ac197e |
remaining = list(self._pending_points.keys())
b65cd92a |
for fut in remaining:
return remaining
def _cleanup(self):
if self.shutdown_executor:
db051be8 |
# XXX: temporary set wait=True because of a bug with Python ≥3.7
# and loky in any Python version.
142bf496 |
# see https://github.com/python-adaptive/adaptive/issues/156
# and https://github.com/python-adaptive/adaptive/pull/164
db051be8 |
# and https://bugs.python.org/issue36281
# and https://github.com/joblib/loky/issues/241
b65cd92a |
self.end_time = time.time()
b659f24f |
37b09daf |
def failed(self):
eef11f0b |
"""Set of points that failed ``runner.retries`` times."""
878ae791 |
return set(self._tracebacks) - set(self._to_retry)
ca1b4b99 |
59877e6d |
def elapsed_time(self):
98b22699 |
"""Return the total time elapsed since the runner
was started.
ca1b4b99 |
98b22699 |
Is called in `overhead`.
59877e6d |
ca1b4b99 |
59877e6d |
def _submit(self, x):
98b22699 |
"""Is called in `_get_futures`."""
59877e6d |
4fe50d03 |
878ae791 |
def tracebacks(self):
a2ac197e |
return [(self._id_to_point[pid], tb) for pid, tb in self._tracebacks.items()]
878ae791 |
def to_retry(self):
a2ac197e |
return [(self._id_to_point[pid], n) for pid, n in self._to_retry.items()]
def pending_points(self):
return [
(fut, self._id_to_point[pid]) for fut, pid in self._pending_points.items()
878ae791 |
37b09daf |
b659f24f |
class BlockingRunner(BaseRunner):
"""Run a learner synchronously in an executor.
eef11f0b |
learner : `~adaptive.BaseLearner` instance
b659f24f |
goal : callable
The end condition for the calculation. This function must take
the learner as its sole argument, and return True when we should
stop requesting more points.
eef11f0b |
executor : `concurrent.futures.Executor`, `distributed.Client`,\
2ad5582b |
`mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
`loky.get_reusable_executor`, optional
b659f24f |
The executor in which to evaluate the function to be learned.
eaab130b |
If not provided, a new `~concurrent.futures.ProcessPoolExecutor`.
b659f24f |
ntasks : int, optional
The number of concurrent function evaluations. Defaults to the number
eef11f0b |
of cores available in `executor`.
b659f24f |
log : bool, default: False
b65cd92a |
If True, record the method calls made to the learner by this runner.
6fca822a |
shutdown_executor : bool, default: False
b659f24f |
If True, shutdown the executor when the runner has completed. If
eef11f0b |
`executor` is not provided then the executor created internally
b659f24f |
by the runner is shut down, regardless of this parameter.
37b09daf |
retries : int, default: 0
eef11f0b |
Maximum amount of retries of a certain point ``x`` in
``learner.function(x)``. After `retries` is reached for ``x``
the point is present in ``runner.failed``.
37b09daf |
raise_if_retries_exceeded : bool, default: True
eef11f0b |
Raise the error after a point ``x`` failed `retries`.
b659f24f |
eef11f0b |
learner : `~adaptive.BaseLearner` instance
b65cd92a |
The underlying learner. May be queried for its state.
b659f24f |
log : list or None
Record of the method calls made to the learner, in the format
eef11f0b |
``(method_name, *args)``.
a2ac197e |
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
eef11f0b |
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
a2ac197e |
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.
b65cd92a |
elapsed_time : callable
A method that returns the time elapsed since the runner
was started.
overhead : callable
The overhead in percent of using Adaptive. This includes the
overhead of the executor. Essentially, this is
eef11f0b |
``100 * (1 - total_elapsed_function_time / self.elapsed_time())``.
37b09daf |
b659f24f |
716dbce8 |
def __init__(
b659f24f |
if inspect.iscoroutinefunction(learner.function):
716dbce8 |
raise ValueError(
"Coroutine functions can only be used " "with 'AsyncRunner'."
b659f24f |
def _submit(self, x):
3cc16fdb |
return self.executor.submit(self.learner.function, x)
b659f24f |
def _run(self):
first_completed = concurrent.FIRST_COMPLETED
eef11f0b |
if self._get_max_tasks() < 1:
716dbce8 |
raise RuntimeError("Executor has no workers")
b659f24f |
while not self.goal(self.learner):
b65cd92a |
futures = self._get_futures()
716dbce8 |
done, _ = concurrent.wait(futures, return_when=first_completed)
b65cd92a |
b659f24f |
b65cd92a |
remaining = self._remove_unfinished()
b659f24f |
if remaining:
b65cd92a |
def elapsed_time(self):
3aba9d47 |
"""Return the total time elapsed since the runner
was started."""
b65cd92a |
if self.end_time is None:
# This shouldn't happen if the BlockingRunner
# correctly finished.
self.end_time = time.time()
return self.end_time - self.start_time
b659f24f |
class AsyncRunner(BaseRunner):
94eb632d |
r"""Run a learner asynchronously in an executor using `asyncio`.
b659f24f |
eef11f0b |
learner : `~adaptive.BaseLearner` instance
b659f24f |
goal : callable, optional
The end condition for the calculation. This function must take
the learner as its sole argument, and return True when we should
stop requesting more points. If not provided, the runner will run
eef11f0b |
forever, or until ``self.task.cancel()`` is called.
executor : `concurrent.futures.Executor`, `distributed.Client`,\
2ad5582b |
`mpi4py.futures.MPIPoolExecutor`, `ipyparallel.Client` or\
`loky.get_reusable_executor`, optional
b659f24f |
The executor in which to evaluate the function to be learned.
eaab130b |
If not provided, a new `~concurrent.futures.ProcessPoolExecutor`.
b659f24f |
ntasks : int, optional
The number of concurrent function evaluations. Defaults to the number
eef11f0b |
of cores available in `executor`.
b659f24f |
log : bool, default: False
b65cd92a |
If True, record the method calls made to the learner by this runner.
6fca822a |
shutdown_executor : bool, default: False
b659f24f |
If True, shutdown the executor when the runner has completed. If
eef11f0b |
`executor` is not provided then the executor created internally
b659f24f |
by the runner is shut down, regardless of this parameter.
eef11f0b |
ioloop : ``asyncio.AbstractEventLoop``, optional
b659f24f |
The ioloop in which to run the learning algorithm. If not provided,
the default event loop is used.
37b09daf |
retries : int, default: 0
eef11f0b |
Maximum amount of retries of a certain point ``x`` in
``learner.function(x)``. After `retries` is reached for ``x``
the point is present in ``runner.failed``.
37b09daf |
raise_if_retries_exceeded : bool, default: True
eef11f0b |
Raise the error after a point ``x`` failed `retries`.
b659f24f |
eef11f0b |
task : `asyncio.Task`
b659f24f |
The underlying task. May be cancelled in order to stop the runner.
eef11f0b |
learner : `~adaptive.BaseLearner` instance
b659f24f |
The underlying learner. May be queried for its state.
log : list or None
Record of the method calls made to the learner, in the format
eef11f0b |
``(method_name, *args)``.
a2ac197e |
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
eef11f0b |
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
a2ac197e |
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.
b65cd92a |
elapsed_time : callable
A method that returns the time elapsed since the runner
was started.
overhead : callable
The overhead in percent of using Adaptive. This includes the
overhead of the executor. Essentially, this is
eef11f0b |
``100 * (1 - total_elapsed_function_time / self.elapsed_time())``.
b659f24f |
This runner can be used when an async function (defined with
eef11f0b |
``async def``) has to be learned. In this case the function will be
b659f24f |
run directly on the event loop (and not in the executor).
716dbce8 |
def __init__(
312d4040 |
if goal is None:
716dbce8 |
312d4040 |
def goal(_):
6750da80 |
return False
2a8ac78e |
0f85d3ab |
if (
executor is None
and _default_executor is concurrent.ProcessPoolExecutor
and not inspect.iscoroutinefunction(learner.function)
f3c5ea34 |
except pickle.PicklingError:
707efb74 |
raise ValueError(
"`learner.function` cannot be pickled (is it a lamdba function?)"
" and therefore does not work with the default executor."
" Either make sure the function is pickleble or use an executor"
" that might work with 'hard to pickle'-functions"
" , e.g. `ipyparallel` with `dill`."
42273300 |
716dbce8 |
b659f24f |
self.ioloop = ioloop or asyncio.get_event_loop()
self.task = None
f8d99d1d |
b659f24f |
# When the learned function is 'async def', we run it
# directly on the event loop, and not in the executor.
4069d322 |
# The *whole point* of allowing learning of async functions is so that
# the user can have more fine-grained control over the parallelism.
b659f24f |
if inspect.iscoroutinefunction(learner.function):
4069d322 |
if executor: # user-provided argument
716dbce8 |
raise RuntimeError(
"Cannot use an executor when learning an " "async function."
b659f24f |
self.executor.shutdown() # Make sure we don't shoot ourselves later
4bffcc51 |
6ab03316 |
self.task = self.ioloop.create_task(self._run())
a17c9212 |
self.saving_task = None
b659f24f |
if in_ipynb() and not self.ioloop.is_running():
716dbce8 |
"The runner has been scheduled, but the asyncio "
"event loop is not running! If you are "
"in a Jupyter notebook, remember to run "
f8d99d1d |
59877e6d |
def _submit(self, x):
98b22699 |
ioloop = self.ioloop
if inspect.iscoroutinefunction(self.learner.function):
3cc16fdb |
return ioloop.create_task(self.learner.function(x))
98b22699 |
3cc16fdb |
return ioloop.run_in_executor(self.executor, self.learner.function, x)
59877e6d |
f0b0854b |
def status(self):
"""Return the runner status as a string.
The possible statuses are: running, cancelled, failed, and finished.
except asyncio.CancelledError:
716dbce8 |
return "cancelled"
f0b0854b |
except asyncio.InvalidStateError:
716dbce8 |
return "running"
f0b0854b |
except Exception:
716dbce8 |
return "failed"
f0b0854b |
716dbce8 |
return "finished"
f0b0854b |
e91322c8 |
def cancel(self):
"""Cancel the runner.
eef11f0b |
This is equivalent to calling ``runner.task.cancel()``.
e91322c8 |
f268f0dc |
def live_plot(self, *, plotter=None, update_interval=2, name=None, normalize=True):
b00659d0 |
"""Live plotting of the learner's data.
f268f0dc |
runner : `~adaptive.Runner`
b00659d0 |
plotter : function
A function that takes the learner as a argument and returns a
eef11f0b |
holoviews object. By default ``learner.plot()`` will be called.
b00659d0 |
update_interval : int
Number of second between the updates of the plot.
name : hasable
Name for the `live_plot` task in `adaptive.active_plotting_tasks`.
eef11f0b |
By default the name is None and if another task with the same name
already exists that other `live_plot` is canceled.
f268f0dc |
normalize : bool
Normalize (scale to fit) the frame upon each update.
b00659d0 |
eef11f0b |
dm : `holoviews.core.DynamicMap`
The plot that automatically updates every `update_interval`.
b00659d0 |
716dbce8 |
return live_plot(
self, plotter=plotter, update_interval=update_interval, name=name
b00659d0 |
56fbdeb2 |
def live_info(self, *, update_interval=0.1):
b00659d0 |
"""Display live information about the runner.
Returns an interactive ipywidget that can be
visualized in a Jupyter notebook.
return live_info(self, update_interval=update_interval)
2a8ac78e |
async def _run(self):
312d4040 |
first_completed = asyncio.FIRST_COMPLETED
eef11f0b |
if self._get_max_tasks() < 1:
716dbce8 |
raise RuntimeError("Executor has no workers")
37299d74 |
312d4040 |
2a8ac78e |
while not self.goal(self.learner):
b65cd92a |
futures = self._get_futures()
716dbce8 |
done, _ = await asyncio.wait(
futures, return_when=first_completed, loop=self.ioloop
b65cd92a |
312d4040 |
b65cd92a |
remaining = self._remove_unfinished()
6bf7843f |
if remaining:
await asyncio.wait(remaining)
b65cd92a |
312d4040 |
b65cd92a |
def elapsed_time(self):
3aba9d47 |
"""Return the total time elapsed since the runner
was started."""
b65cd92a |
if self.task.done():
end_time = self.end_time
if end_time is None:
# task was cancelled before it began
assert self.task.cancelled()
return 0
end_time = time.time()
return end_time - self.start_time
312d4040 |
a17c9212 |
def start_periodic_saving(self, save_kwargs, interval):
"""Periodically save the learner's data.
save_kwargs : dict
5e3edf47 |
Key-word arguments for ``learner.save(**save_kwargs)``.
a17c9212 |
interval : int
Number of seconds between saving the learner.
>>> runner = Runner(learner)
>>> runner.start_periodic_saving(
... save_kwargs=dict(fname='data/test.pickle'),
... interval=600)
716dbce8 |
a17c9212 |
async def _saver(save_kwargs=save_kwargs, interval=interval):
716dbce8 |
while self.status() == "running":
a17c9212 |
await asyncio.sleep(interval)
self.learner.save(**save_kwargs) # one last time
716dbce8 |
a17c9212 |
self.saving_task = self.ioloop.create_task(_saver())
return self.saving_task
4fe50d03 |
b659f24f |
# Default runner
Runner = AsyncRunner
1b71dec7 |
def simple(learner, goal):
"""Run the learner until the goal is reached.
Requests a single point from the learner, evaluates
the function to be learned, and adds the point to the
learner, until the goal is reached, blocking the current
This function is useful for extracting error messages,
as the learner's function is evaluated in the same thread,
meaning that exceptions can simple be caught an inspected.
5e3edf47 |
learner : ~`adaptive.BaseLearner` instance
1b71dec7 |
goal : callable
The end condition for the calculation. This function must take the
learner as its sole argument, and return True if we should stop.
while not goal(learner):
06112c20 |
xs, _ = learner.ask(1)
1b71dec7 |
for x in xs:
y = learner.function(x)
c866b60b |
learner.tell(x, y)
1b71dec7 |
9ba68ad1 |
def replay_log(learner, log):
"""Apply a sequence of method calls to a learner.
This is useful for debugging runners.
5e3edf47 |
learner : `~adaptive.BaseLearner` instance
New learner where the log will be applied.
9ba68ad1 |
log : list
5e3edf47 |
contains tuples: ``(method_name, *args)``.
9ba68ad1 |
for method, *args in log:
getattr(learner, method)(*args)
21762088 |
# --- Useful runner goals
def stop_after(*, seconds=0, minutes=0, hours=0):
"""Stop a runner after a specified time.
For example, to specify a runner that should stop after
5 minutes, one could do the following:
>>> runner = Runner(learner, goal=stop_after(minutes=5))
To stop a runner after 2 hours, 10 minutes and 3 seconds,
one could do the following:
>>> runner = Runner(learner, goal=stop_after(hours=2, minutes=10, seconds=3))
seconds, minutes, hours : float, default: 0
If more than one is specified, then they are added together
goal : callable
Can be used as the ``goal`` parameter when constructing
a `Runner`.
The duration specified is only a *lower bound* on the time that the
runner will run for, because the runner only checks its goal when
it adds points to its learner
stop_time = time.time() + seconds + 60 * minutes + 3600 * hours
return lambda _: time.time() > stop_time
# -- Internal executor-related, things
4b8fe532 |
class SequentialExecutor(concurrent.Executor):
"""A trivial executor that runs functions synchronously.
This executor is mainly for testing.
82bde4fb |
4b8fe532 |
def submit(self, fn, *args, **kwargs):
fut = concurrent.Future()
6a93d573 |
fut.set_result(fn(*args, **kwargs))
4b8fe532 |
except Exception as e:
return fut
6a93d573 |
def map(self, fn, *iterable, timeout=None, chunksize=1):
return map(fn, iterable)
4b8fe532 |
def shutdown(self, wait=True):
b659f24f |
def _ensure_executor(executor):
if executor is None:
33e7cae5 |
executor = _default_executor()
890f84c0 |
if isinstance(executor, concurrent.Executor):
b659f24f |
return executor
elif with_ipyparallel and isinstance(executor, ipyparallel.Client):
return executor.executor()
elif with_distributed and isinstance(executor, distributed.Client):
return executor.get_executor()
716dbce8 |
raise TypeError(
"Only a concurrent.futures.Executor, distributed.Client,"
" or ipyparallel.Client can be used."
b3873839 |
b659f24f |
def _get_ncores(ex):
"""Return the maximum number of cores that an executor can use."""
if with_ipyparallel and isinstance(ex, ipyparallel.client.view.ViewExecutor):
return len(ex.view)
716dbce8 |
elif isinstance(
ex, (concurrent.ProcessPoolExecutor, concurrent.ThreadPoolExecutor)
b659f24f |
return ex._max_workers # not public API!
498f56b3 |
elif with_loky and isinstance(ex, loky.reusable_executor._ReusablePoolExecutor):
return ex._max_workers # not public API!
b659f24f |
elif isinstance(ex, SequentialExecutor):
b3873839 |
return 1
75c6b7bb |
elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor):
90d869da |
return sum(n for n in ex._client.ncores().values())
d538ee25 |
elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor):
716dbce8 |
ex.bootup() # wait until all workers are up and running
d538ee25 |
return ex._pool.size # not public API!
b659f24f |
ea225e6b |
raise TypeError(f"Cannot get number of cores for {ex.__class__}")