Browse code

add SequenceLearner

Bas Nijholt authored on 09/05/2019 00:59:12
Showing 4 changed files
... ...
@@ -14,6 +14,7 @@ from adaptive.learner import (
14 14
     Learner2D,
15 15
     LearnerND,
16 16
     make_datasaver,
17
+    SequenceLearner,
17 18
 )
18 19
 from adaptive.notebook_integration import (
19 20
     active_plotting_tasks,
... ...
@@ -36,6 +37,7 @@ __all__ = [
36 37
     "Learner2D",
37 38
     "LearnerND",
38 39
     "make_datasaver",
40
+    "SequenceLearner",
39 41
     "active_plotting_tasks",
40 42
     "live_plot",
41 43
     "notebook_extension",
... ...
@@ -10,6 +10,7 @@ from adaptive.learner.integrator_learner import IntegratorLearner
10 10
 from adaptive.learner.learner1D import Learner1D
11 11
 from adaptive.learner.learner2D import Learner2D
12 12
 from adaptive.learner.learnerND import LearnerND
13
+from adaptive.learner.sequence_learner import SequenceLearner
13 14
 
14 15
 __all__ = [
15 16
     "AverageLearner",
... ...
@@ -21,6 +22,7 @@ __all__ = [
21 22
     "Learner1D",
22 23
     "Learner2D",
23 24
     "LearnerND",
25
+    "SequenceLearner",
24 26
 ]
25 27
 
26 28
 with suppress(ImportError):
27 29
new file mode 100644
... ...
@@ -0,0 +1,82 @@
1
+from copy import copy
2
+import sys
3
+
4
+from adaptive.learner.base_learner import BaseLearner
5
+
6
+inf = sys.float_info.max
7
+
8
+
9
+def ensure_hashable(x):
10
+    try:
11
+        hash(x)
12
+        return x
13
+    except TypeError:
14
+        return tuple(x)
15
+
16
+
17
+class SequenceLearner(BaseLearner):
18
+    def __init__(self, function, sequence):
19
+        self.function = function
20
+        self._to_do_seq = {ensure_hashable(x) for x in sequence}
21
+        self._npoints = len(sequence)
22
+        self.sequence = copy(sequence)
23
+        self.data = {}
24
+        self.pending_points = set()
25
+
26
+    def ask(self, n, tell_pending=True):
27
+        points = []
28
+        loss_improvements = []
29
+        i = 0
30
+        for point in self._to_do_seq:
31
+            if i > n:
32
+                break
33
+            points.append(point)
34
+            loss_improvements.append(inf / self._npoints)
35
+            i += 1
36
+
37
+        if tell_pending:
38
+            for p in points:
39
+                self.tell_pending(p)
40
+
41
+        return points, loss_improvements
42
+
43
+    def _get_data(self):
44
+        return self.data
45
+
46
+    def _set_data(self, data):
47
+        if data:
48
+            self.tell_many(*zip(*data.items()))
49
+
50
+    def loss(self, real=True):
51
+        if not (self._to_do_seq or self.pending_points):
52
+            return 0
53
+        else:
54
+            npoints = self.npoints + (0 if real else len(self.pending_points))
55
+            return inf / npoints
56
+
57
+    def remove_unfinished(self):
58
+        for p in self.pending_points:
59
+            self._to_do_seq.add(p)
60
+        self.pending_points = set()
61
+
62
+    def tell(self, point, value):
63
+        self.data[point] = value
64
+        self.pending_points.discard(point)
65
+        self._to_do_seq.discard(point)
66
+
67
+    def tell_pending(self, point):
68
+        self.pending_points.add(point)
69
+        self._to_do_seq.discard(point)
70
+
71
+    def done(self):
72
+        return not self._to_do_seq and not self.pending_points
73
+
74
+    def result(self):
75
+        """Get back the data in the same order as ``sequence``."""
76
+        if not self.done():
77
+            raise Exception("Learner is not yet complete.")
78
+        return [self.data[ensure_hashable(x)] for x in self.sequence]
79
+
80
+    @property
81
+    def npoints(self):
82
+        return len(self.data)
... ...
@@ -24,6 +24,7 @@ from adaptive.learner import (
24 24
     Learner1D,
25 25
     Learner2D,
26 26
     LearnerND,
27
+    SequenceLearner,
27 28
 )
28 29
 from adaptive.runner import simple
29 30
 
... ...
@@ -116,6 +117,7 @@ def quadratic(x, m: uniform(0, 10), b: uniform(0, 1)):
116 117
 
117 118
 
118 119
 @learn_with(Learner1D, bounds=(-1, 1))
120
+@learn_with(SequenceLearner, sequence=np.linspace(-1, 1, 201))
119 121
 def linear_with_peak(x, d: uniform(-1, 1)):
120 122
     a = 0.01
121 123
     return x + a ** 2 / (a ** 2 + (x - d) ** 2)
... ...
@@ -123,6 +125,7 @@ def linear_with_peak(x, d: uniform(-1, 1)):
123 125
 
124 126
 @learn_with(LearnerND, bounds=((-1, 1), (-1, 1)))
125 127
 @learn_with(Learner2D, bounds=((-1, 1), (-1, 1)))
128
+@learn_with(SequenceLearner, sequence=np.random.rand(1000, 2))
126 129
 def ring_of_fire(xy, d: uniform(0.2, 1)):
127 130
     a = 0.2
128 131
     x, y = xy
... ...
@@ -130,12 +133,14 @@ def ring_of_fire(xy, d: uniform(0.2, 1)):
130 133
 
131 134
 
132 135
 @learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1)))
136
+@learn_with(SequenceLearner, sequence=np.random.rand(1000, 3))
133 137
 def sphere_of_fire(xyz, d: uniform(0.2, 1)):
134 138
     a = 0.2
135 139
     x, y, z = xyz
136 140
     return x + math.exp(-(x ** 2 + y ** 2 + z ** 2 - d ** 2) ** 2 / a ** 4) + z ** 2
137 141
 
138 142
 
143
+@learn_with(SequenceLearner, sequence=range(1000))
139 144
 @learn_with(AverageLearner, rtol=1)
140 145
 def gaussian(n):
141 146
     return random.gauss(0, 1)
... ...
@@ -247,7 +252,7 @@ def test_learner_accepts_lists(learner_type, bounds):
247 252
     simple(learner, goal=lambda l: l.npoints > 10)
248 253
 
249 254
 
250
-@run_with(Learner1D, Learner2D, LearnerND)
255
+@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner)
251 256
 def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
252 257
     """Adding already existing data is an idempotent operation.
253 258
 
... ...
@@ -283,7 +288,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
283 288
 
284 289
 # XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
285 290
 #      but we xfail it now, as Learner2D will be deprecated anyway
286
-@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner)
291
+@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner, SequenceLearner)
287 292
 def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
288 293
     """Adding data for a point that was not returned by 'ask'."""
289 294
     # XXX: learner, control and bounds are not defined
... ...
@@ -429,7 +434,12 @@ def test_learner_performance_is_invariant_under_scaling(
429 434
 
430 435
 
431 436
 @run_with(
432
-    Learner1D, Learner2D, LearnerND, AverageLearner, with_all_loss_functions=False
437
+    Learner1D,
438
+    Learner2D,
439
+    LearnerND,
440
+    AverageLearner,
441
+    SequenceLearner,
442
+    with_all_loss_functions=False,
433 443
 )
434 444
 def test_balancing_learner(learner_type, f, learner_kwargs):
435 445
     """Test if the BalancingLearner works with the different types of learners."""
... ...
@@ -474,6 +484,7 @@ def test_balancing_learner(learner_type, f, learner_kwargs):
474 484
     AverageLearner,
475 485
     maybe_skip(SKOptLearner),
476 486
     IntegratorLearner,
487
+    SequenceLearner,
477 488
     with_all_loss_functions=False,
478 489
 )
479 490
 def test_saving(learner_type, f, learner_kwargs):
... ...
@@ -504,6 +515,7 @@ def test_saving(learner_type, f, learner_kwargs):
504 515
     AverageLearner,
505 516
     maybe_skip(SKOptLearner),
506 517
     IntegratorLearner,
518
+    SequenceLearner,
507 519
     with_all_loss_functions=False,
508 520
 )
509 521
 def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
... ...
@@ -541,6 +553,7 @@ def test_saving_of_balancing_learner(learner_type, f, learner_kwargs):
541 553
     AverageLearner,
542 554
     maybe_skip(SKOptLearner),
543 555
     IntegratorLearner,
556
+    SequenceLearner,
544 557
     with_all_loss_functions=False,
545 558
 )
546 559
 def test_saving_with_datasaver(learner_type, f, learner_kwargs):