support scaling of the cluster inside the runner
See merge request qt/adaptive!63
... | ... |
@@ -93,7 +93,8 @@ class BaseRunner: |
93 | 93 |
self.executor = _ensure_executor(executor) |
94 | 94 |
self.goal = goal |
95 | 95 |
|
96 |
- self.ntasks = ntasks or _get_ncores(self.executor) |
|
96 |
+ self._max_tasks = ntasks |
|
97 |
+ |
|
97 | 98 |
# if we instantiate our own executor, then we are also responsible |
98 | 99 |
# for calling 'shutdown' |
99 | 100 |
self.shutdown_executor = shutdown_executor or (executor is None) |
... | ... |
@@ -102,6 +103,9 @@ class BaseRunner: |
102 | 103 |
self.log = [] if log else None |
103 | 104 |
self.task = None |
104 | 105 |
|
106 |
+ def max_tasks(self): |
|
107 |
+ return self._max_tasks or _get_ncores(self.executor) |
|
108 |
+ |
|
105 | 109 |
|
106 | 110 |
class BlockingRunner(BaseRunner): |
107 | 111 |
"""Run a learner synchronously in an executor. |
... | ... |
@@ -154,7 +158,6 @@ class BlockingRunner(BaseRunner): |
154 | 158 |
def _run(self): |
155 | 159 |
first_completed = concurrent.FIRST_COMPLETED |
156 | 160 |
xs = dict() |
157 |
- done = [None] * self.ntasks |
|
158 | 161 |
do_log = self.log is not None |
159 | 162 |
|
160 | 163 |
if len(done) == 0: |
... | ... |
@@ -163,11 +166,15 @@ class BlockingRunner(BaseRunner): |
163 | 166 |
try: |
164 | 167 |
while not self.goal(self.learner): |
165 | 168 |
# Launch tasks to replace the ones that completed |
166 |
- # on the last iteration. |
|
169 |
+ # on the last iteration, making sure to fill workers |
|
170 |
+ # that have started since the last iteration. |
|
171 |
+ n_new_tasks = max(0, self.max_tasks() - len(xs)) |
|
172 |
+ |
|
167 | 173 |
if do_log: |
168 |
- self.log.append(('ask', len(done))) |
|
174 |
+ self.log.append(('ask', n_new_tasks)) |
|
175 |
+ |
|
176 |
+ points, _ = self.learner.ask(n_new_tasks) |
|
169 | 177 |
|
170 |
- points, _ = self.learner.ask(len(done)) |
|
171 | 178 |
for x in points: |
172 | 179 |
xs[self._submit(x)] = x |
173 | 180 |
|
... | ... |
@@ -357,21 +364,23 @@ class AsyncRunner(BaseRunner): |
357 | 364 |
|
358 | 365 |
async def _run(self): |
359 | 366 |
first_completed = asyncio.FIRST_COMPLETED |
360 |
- xs = dict() |
|
361 |
- done = [None] * self.ntasks |
|
367 |
+ xs = dict() # The points we are waiting for |
|
362 | 368 |
do_log = self.log is not None |
363 | 369 |
|
364 |
- if len(done) == 0: |
|
370 |
+ if self.max_tasks() < 1: |
|
365 | 371 |
raise RuntimeError('Executor has no workers') |
366 | 372 |
|
367 | 373 |
try: |
368 | 374 |
while not self.goal(self.learner): |
369 | 375 |
# Launch tasks to replace the ones that completed |
370 |
- # on the last iteration. |
|
376 |
+ # on the last iteration, making sure to fill workers |
|
377 |
+ # that have started since the last iteration. |
|
378 |
+ n_new_tasks = max(0, self.max_tasks() - len(xs)) |
|
379 |
+ |
|
371 | 380 |
if do_log: |
372 |
- self.log.append(('ask', len(done))) |
|
381 |
+ self.log.append(('ask', n_new_tasks)) |
|
373 | 382 |
|
374 |
- points, _ = self.learner.ask(len(done)) |
|
383 |
+ points, _ = self.learner.ask(n_new_tasks) |
|
375 | 384 |
for x in points: |
376 | 385 |
xs[self._submit(x)] = x |
377 | 386 |
|