... | ... |
@@ -2,6 +2,7 @@ |
2 | 2 |
|
3 | 3 |
import numpy as np |
4 | 4 |
from skopt import Optimizer |
5 |
+from collections import OrderedDict |
|
5 | 6 |
|
6 | 7 |
from adaptive.learner.base_learner import BaseLearner |
7 | 8 |
from adaptive.notebook_integration import ensure_holoviews |
... | ... |
@@ -25,19 +26,24 @@ class SKOptLearner(Optimizer, BaseLearner): |
25 | 26 |
|
26 | 27 |
def __init__(self, function, **kwargs): |
27 | 28 |
self.function = function |
28 |
- self.pending_points = set() |
|
29 |
- self.data = {} |
|
29 |
+ self.pending_points = list() |
|
30 |
+ self.data = OrderedDict() |
|
30 | 31 |
super().__init__(**kwargs) |
31 | 32 |
|
32 | 33 |
def tell(self, x, y, fit=True): |
33 |
- self.pending_points.discard(x) |
|
34 |
- self.data[x] = y |
|
35 |
- super().tell([x], y, fit) |
|
34 |
+ if x in self.pending_points: |
|
35 |
+ self.pending_points.remove(x) |
|
36 |
+ if hasattr(x, '__iter__'): |
|
37 |
+ self.data[tuple(x)] = y |
|
38 |
+ super().tell(x, y, fit) |
|
39 |
+ else: |
|
40 |
+ self.data[x] = y |
|
41 |
+ super().tell([x], y, fit) |
|
36 | 42 |
|
37 | 43 |
def tell_pending(self, x): |
38 | 44 |
# 'skopt.Optimizer' takes care of points we |
39 | 45 |
# have not got results for. |
40 |
- self.pending_points.add(x) |
|
46 |
+ self.pending_points.append(x) |
|
41 | 47 |
|
42 | 48 |
def remove_unfinished(self): |
43 | 49 |
pass |