Browse code

added compatibility with n-dimensional function domain

caenrigen authored on 04/12/2019 20:04:01
Showing 1 changed files
... ...
@@ -2,6 +2,7 @@
2 2
 
3 3
 import numpy as np
4 4
 from skopt import Optimizer
5
+from collections import OrderedDict
5 6
 
6 7
 from adaptive.learner.base_learner import BaseLearner
7 8
 from adaptive.notebook_integration import ensure_holoviews
... ...
@@ -25,19 +26,24 @@ class SKOptLearner(Optimizer, BaseLearner):
25 26
 
26 27
     def __init__(self, function, **kwargs):
27 28
         self.function = function
28
-        self.pending_points = set()
29
-        self.data = {}
29
+        self.pending_points = list()
30
+        self.data = OrderedDict()
30 31
         super().__init__(**kwargs)
31 32
 
32 33
     def tell(self, x, y, fit=True):
33
-        self.pending_points.discard(x)
34
-        self.data[x] = y
35
-        super().tell([x], y, fit)
34
+        if x in self.pending_points:
35
+            self.pending_points.remove(x)
36
+        if hasattr(x, '__iter__'):
37
+            self.data[tuple(x)] = y
38
+            super().tell(x, y, fit)
39
+        else:
40
+            self.data[x] = y
41
+            super().tell([x], y, fit)
36 42
 
37 43
     def tell_pending(self, x):
38 44
         # 'skopt.Optimizer' takes care of points we
39 45
         # have not got results for.
40
-        self.pending_points.add(x)
46
+        self.pending_points.append(x)
41 47
 
42 48
     def remove_unfinished(self):
43 49
         pass