Browse code

comply with adaptive structure, fixed 1D test, added 4D test

caenrigen authored on 05/12/2019 14:47:29
Showing 2 changed files
... ...
@@ -26,24 +26,24 @@ class SKOptLearner(Optimizer, BaseLearner):
26 26
 
27 27
     def __init__(self, function, **kwargs):
28 28
         self.function = function
29
-        self.pending_points = list()
29
+        self.pending_points = set()
30 30
         self.data = OrderedDict()
31 31
         super().__init__(**kwargs)
32 32
 
33 33
     def tell(self, x, y, fit=True):
34
-        if x in self.pending_points:
35
-            self.pending_points.remove(x)
36 34
         if hasattr(x, '__iter__'):
35
+            self.pending_points.discard(tuple(x))
37 36
             self.data[tuple(x)] = y
38 37
             super().tell(x, y, fit)
39 38
         else:
39
+            self.pending_points.discard(x)
40 40
             self.data[x] = y
41 41
             super().tell([x], y, fit)
42 42
 
43 43
     def tell_pending(self, x):
44 44
         # 'skopt.Optimizer' takes care of points we
45 45
         # have not got results for.
46
-        self.pending_points.append(x)
46
+        self.pending_points.add(tuple(x))
47 47
 
48 48
     def remove_unfinished(self):
49 49
         pass
... ...
@@ -25,3 +25,23 @@ def test_skopt_learner_runs():
25 25
     for _ in range(11):
26 26
         (x,), _ = learner.ask(1)
27 27
         learner.tell(x, learner.function(x))
28
+
29
+
30
+@pytest.mark.skipif(not with_scikit_optimize, reason="scikit-optimize is not installed")
31
+def test_skopt_learner_4D_runs():
32
+    """The SKOptLearner provides very few guarantees about its
33
+       behaviour, so we only test the most basic usage
34
+       In this case we test also for 2D domain
35
+    """
36
+
37
+    def g(x, noise_level=0.1):
38
+        return np.sin(5 * (x[0] + x[1] + x[2] + x[3])) * (
39
+            1 - np.tanh(x[0] ** 2 + x[1] ** 2 + x[2] ** 2 + x[3] ** 2)
40
+        ) + np.random.randn() * noise_level
41
+
42
+    learner = SKOptLearner(g, dimensions=[(-2.0, 2.0), (-2.0, 2.0),
43
+                                            (-2.0, 2.0), (-2.0, 2.0)])
44
+
45
+    for _ in range(11):
46
+        (x,), _ = learner.ask(1)
47
+        learner.tell(x, learner.function(x))