Browse code

add scikit-optimize learner

Joseph Weston authored on 23/05/2018 17:31:15
Showing 3 changed files
... ...
@@ -7,6 +7,12 @@ from . import runner
7 7
 
8 8
 from .learner import (Learner1D, Learner2D, AverageLearner,
9 9
                       BalancingLearner, DataSaver, IntegratorLearner)
10
+try:
11
+    # Only available if 'scikit-optimize' is installed
12
+    from .learner import SKOptLearner
13
+except ImportError:
14
+    pass
15
+
10 16
 from .runner import Runner, BlockingRunner
11 17
 from . import version
12 18
 
... ...
@@ -6,3 +6,9 @@ from .learner1D import Learner1D
6 6
 from .learner2D import Learner2D
7 7
 from .integrator_learner import IntegratorLearner
8 8
 from .data_saver import DataSaver
9
+
10
+try:
11
+    # Only available if 'scikit-optimize' is installed
12
+    from .skopt_learner import SKOptLearner
13
+except ImportError:
14
+    pass
9 15
new file mode 100644
... ...
@@ -0,0 +1,92 @@
1
+# -*- coding: utf-8 -*-
2
+from copy import deepcopy
3
+import heapq
4
+import itertools
5
+import math
6
+
7
+import numpy as np
8
+import sortedcontainers
9
+
10
+from ..notebook_integration import ensure_holoviews
11
+from .base_learner import BaseLearner
12
+
13
+from skopt import Optimizer
14
+
15
+
16
+class SKOptLearner(Optimizer, BaseLearner):
17
+    """Learn a function minimum using 'skopt.Optimizer'.
18
+
19
+    This is an 'Optimizer' from 'scikit-optimize',
20
+    with the necessary methods added to make it conform
21
+    to the 'adaptive' learner interface.
22
+
23
+    Parameters
24
+    ----------
25
+    function : callable
26
+        The function to learn.
27
+    **kwargs :
28
+        Arguments to pass to 'skopt.Optimizer'.
29
+    """
30
+
31
+    def __init__(self, function, **kwargs):
32
+        self.function = function
33
+        super().__init__(**kwargs)
34
+
35
+    def add_point(self, x, y):
36
+        if y is not None:
37
+            # 'skopt.Optimizer' takes care of points we
38
+            # have not got results for.
39
+            self.tell([x], y)
40
+
41
+    def remove_unfinished(self):
42
+        pass
43
+
44
+    def loss(self, real=True):
45
+        if not self.models:
46
+            return np.inf
47
+        else:
48
+            model = self.models[-1]
49
+            # Return the in-sample error (i.e. test the model
50
+            # with the training data). This is not the best
51
+            # estimator of loss, but it is the cheapest.
52
+            return 1 / model.score(self.Xi, self.yi)
53
+
54
+    def choose_points(self, n, add_data=True):
55
+        points = self.ask(n)
56
+        if self.space.n_dims > 1:
57
+            return points, [np.inf] * n
58
+        else:
59
+            return [p[0] for p in points], [np.inf] * n
60
+
61
+    @property
62
+    def npoints(self):
63
+        return len(self.Xi)
64
+
65
+    def plot(self):
66
+        hv = ensure_holoviews()
67
+        if self.space.n_dims > 1:
68
+            raise ValueError('Can only plot 1D functions')
69
+        bounds = self.space.bounds[0]
70
+        if not self.Xi:
71
+            p = hv.Scatter([]) * hv.Area([])
72
+        else:
73
+            scatter = hv.Scatter(([p[0] for p in self.Xi], self.yi))
74
+            if self.models:
75
+                # Plot 95% confidence interval as colored area around points
76
+                model = self.models[-1]
77
+                xs = np.linspace(*bounds, 201)
78
+                y_pred, sigma = model.predict(np.atleast_2d(xs).transpose(),
79
+                                              return_std=True)
80
+                area = hv.Area(
81
+                    (xs, y_pred - 1.96 * sigma, y_pred + 1.96 * sigma),
82
+                    vdims=['y', 'y2'],
83
+                ).opts(style=dict(alpha=0.5, line_alpha=0))
84
+            else:
85
+                area = hv.Area([])
86
+            p = scatter * area
87
+
88
+        # Plot with 5% empty margins such that the boundary points are visible
89
+        margin = 0.05 * (bounds[1] - bounds[0])
90
+        plot_bounds = (bounds[0] - margin, bounds[1] + margin)
91
+
92
+        return p.redim(x=dict(range=plot_bounds))