... | ... |
@@ -26,24 +26,24 @@ class SKOptLearner(Optimizer, BaseLearner): |
26 | 26 |
|
27 | 27 |
def __init__(self, function, **kwargs): |
28 | 28 |
self.function = function |
29 |
- self.pending_points = list() |
|
29 |
+ self.pending_points = set() |
|
30 | 30 |
self.data = OrderedDict() |
31 | 31 |
super().__init__(**kwargs) |
32 | 32 |
|
33 | 33 |
def tell(self, x, y, fit=True): |
34 |
- if x in self.pending_points: |
|
35 |
- self.pending_points.remove(x) |
|
36 | 34 |
if hasattr(x, '__iter__'): |
35 |
+ self.pending_points.discard(tuple(x)) |
|
37 | 36 |
self.data[tuple(x)] = y |
38 | 37 |
super().tell(x, y, fit) |
39 | 38 |
else: |
39 |
+ self.pending_points.discard(x) |
|
40 | 40 |
self.data[x] = y |
41 | 41 |
super().tell([x], y, fit) |
42 | 42 |
|
43 | 43 |
def tell_pending(self, x): |
44 | 44 |
# 'skopt.Optimizer' takes care of points we |
45 | 45 |
# have not got results for. |
46 |
- self.pending_points.append(x) |
|
46 |
+ self.pending_points.add(tuple(x)) |
|
47 | 47 |
|
48 | 48 |
def remove_unfinished(self): |
49 | 49 |
pass |
... | ... |
@@ -25,3 +25,23 @@ def test_skopt_learner_runs(): |
25 | 25 |
for _ in range(11): |
26 | 26 |
(x,), _ = learner.ask(1) |
27 | 27 |
learner.tell(x, learner.function(x)) |
28 |
+ |
|
29 |
+ |
|
30 |
+@pytest.mark.skipif(not with_scikit_optimize, reason="scikit-optimize is not installed") |
|
31 |
+def test_skopt_learner_4D_runs(): |
|
32 |
+ """The SKOptLearner provides very few guarantees about its |
|
33 |
+ behaviour, so we only test the most basic usage |
|
34 |
+ In this case we test also for 2D domain |
|
35 |
+ """ |
|
36 |
+ |
|
37 |
+ def g(x, noise_level=0.1): |
|
38 |
+ return np.sin(5 * (x[0] + x[1] + x[2] + x[3])) * ( |
|
39 |
+ 1 - np.tanh(x[0] ** 2 + x[1] ** 2 + x[2] ** 2 + x[3] ** 2) |
|
40 |
+ ) + np.random.randn() * noise_level |
|
41 |
+ |
|
42 |
+ learner = SKOptLearner(g, dimensions=[(-2.0, 2.0), (-2.0, 2.0), |
|
43 |
+ (-2.0, 2.0), (-2.0, 2.0)]) |
|
44 |
+ |
|
45 |
+ for _ in range(11): |
|
46 |
+ (x,), _ = learner.ask(1) |
|
47 |
+ learner.tell(x, learner.function(x)) |