Browse code

derive new tasks to launch from number of cores and outstanding tasks

Joseph Weston authored on 28/05/2018 16:10:04
Showing 1 changed files
... ...
@@ -92,8 +92,7 @@ class BaseRunner:
92 92
         self.executor = _ensure_executor(executor)
93 93
         self.goal = goal
94 94
 
95
-        self._ntasks_is_user_specified = bool(ntasks)
96
-        self.ntasks = ntasks or _get_ncores(self.executor)
95
+        self._max_tasks = ntasks
97 96
             
98 97
         # if we instantiate our own executor, then we are also responsible
99 98
         # for calling 'shutdown'
... ...
@@ -103,17 +102,8 @@ class BaseRunner:
103 102
         self.log = [] if log else None
104 103
         self.task = None
105 104
 
106
-    def _get_n(self, done):
107
-        n = len(done)
108
-        if not self._ntasks_is_user_specified:
109
-            ntasks_new = _get_ncores(self.executor)
110
-            dn = ntasks_new - self.ntasks
111
-            if dn < 0:
112
-                # Cores have died or stopped
113
-                n = 0
114
-            n += dn
115
-            self.ntasks = ntasks_new
116
-        return n
105
+    def max_tasks(self):
106
+        return self._max_tasks or _get_ncores(self.executor)
117 107
 
118 108
 
119 109
 class BlockingRunner(BaseRunner):
... ...
@@ -167,7 +157,6 @@ class BlockingRunner(BaseRunner):
167 157
     def _run(self):
168 158
         first_completed = concurrent.FIRST_COMPLETED
169 159
         xs = dict()
170
-        done = [None] * self.ntasks
171 160
         do_log = self.log is not None
172 161
 
173 162
         if len(done) == 0:
... ...
@@ -176,14 +165,14 @@ class BlockingRunner(BaseRunner):
176 165
         try:
177 166
             while not self.goal(self.learner):
178 167
                 # Launch tasks to replace the ones that completed
179
-                # on the last iteration and check if new workers
180
-                # have started since the last iteration
181
-                n = self._get_n(done)
168
+                # on the last iteration, making sure to fill workers
169
+                # that have started since the last iteration.
170
+                n_new_tasks = max(0, self.max_tasks() - len(xs))
182 171
 
183 172
                 if do_log:
184
-                    self.log.append(('ask', n))
173
+                    self.log.append(('ask', n_new_tasks))
185 174
 
186
-                points, _ = self.learner.ask(n)
175
+                points, _ = self.learner.ask(n_new_tasks)
187 176
 
188 177
                 for x in points:
189 178
                     xs[self._submit(x)] = x
... ...
@@ -373,24 +362,23 @@ class AsyncRunner(BaseRunner):
373 362
 
374 363
     async def _run(self):
375 364
         first_completed = asyncio.FIRST_COMPLETED
376
-        xs = dict()
377
-        done = [None] * self.ntasks
365
+        xs = dict()  # The points we are waiting for
378 366
         do_log = self.log is not None
379 367
 
380
-        if len(done) == 0:
368
+        if self.max_tasks() < 1:
381 369
             raise RuntimeError('Executor has no workers')
382 370
 
383 371
         try:
384 372
             while not self.goal(self.learner):
385 373
                 # Launch tasks to replace the ones that completed
386
-                # on the last iteration and check if new workers
387
-                # have started since the last iteration
388
-                n = self._get_n(done)
374
+                # on the last iteration, making sure to fill workers
375
+                # that have started since the last iteration.
376
+                n_new_tasks = max(0, self.max_tasks() - len(xs))
389 377
 
390 378
                 if do_log:
391
-                    self.log.append(('ask', n))
379
+                    self.log.append(('ask', n_new_tasks))
392 380
 
393
-                points, _ = self.learner.ask(n)
381
+                points, _ = self.learner.ask(n_new_tasks)
394 382
                 for x in points:
395 383
                     xs[self._submit(x)] = x
396 384