Browse code

add support for mpi4py

Bas Nijholt authored on 30/04/2019 20:58:14
Showing 1 changed files
... ...
@@ -26,6 +26,12 @@ try:
26 26
 except ModuleNotFoundError:
27 27
     with_distributed = False
28 28
 
29
+try:
30
+    import mpi4py.futures
31
+    with_mpi4py = True
32
+except ModuleNotFoundError:
33
+    with_mpi4py = False
34
+
29 35
 with suppress(ModuleNotFoundError):
30 36
     import uvloop
31 37
     asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
... ...
@@ -66,7 +72,7 @@ class BaseRunner(metaclass=abc.ABCMeta):
66 72
         the learner as its sole argument, and return True when we should
67 73
         stop requesting more points.
68 74
     executor : `concurrent.futures.Executor`, `distributed.Client`,\
69
-               or `ipyparallel.Client`, optional
75
+               `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional
70 76
         The executor in which to evaluate the function to be learned.
71 77
         If not provided, a new `~concurrent.futures.ProcessPoolExecutor`
72 78
         is used on Unix systems while on Windows a `distributed.Client`
... ...
@@ -281,7 +287,7 @@ class BlockingRunner(BaseRunner):
281 287
         the learner as its sole argument, and return True when we should
282 288
         stop requesting more points.
283 289
     executor : `concurrent.futures.Executor`, `distributed.Client`,\
284
-               or `ipyparallel.Client`, optional
290
+               `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional
285 291
         The executor in which to evaluate the function to be learned.
286 292
         If not provided, a new `~concurrent.futures.ProcessPoolExecutor`
287 293
         is used on Unix systems while on Windows a `distributed.Client`
... ...
@@ -386,7 +392,7 @@ class AsyncRunner(BaseRunner):
386 392
         stop requesting more points. If not provided, the runner will run
387 393
         forever, or until ``self.task.cancel()`` is called.
388 394
     executor : `concurrent.futures.Executor`, `distributed.Client`,\
389
-               or `ipyparallel.Client`, optional
395
+               `mpi4py.futures.MPIPoolExecutor`, or `ipyparallel.Client`, optional
390 396
         The executor in which to evaluate the function to be learned.
391 397
         If not provided, a new `~concurrent.futures.ProcessPoolExecutor`
392 398
         is used on Unix systems while on Windows a `distributed.Client`
... ...
@@ -693,6 +699,9 @@ def _get_ncores(ex):
693 699
         return 1
694 700
     elif with_distributed and isinstance(ex, distributed.cfexecutor.ClientExecutor):
695 701
         return sum(n for n in ex._client.ncores().values())
702
+    elif with_mpi4py and isinstance(ex, mpi4py.futures.MPIPoolExecutor):
703
+        ex.bootup() # wait until all workers are up and running
704
+        return ex._pool.size  # not public API!
696 705
     else:
697 706
         raise TypeError('Cannot get number of cores for {}'
698 707
                         .format(ex.__class__))