... | ... |
@@ -14,6 +14,7 @@ from adaptive.learner import ( |
14 | 14 |
Learner2D, |
15 | 15 |
LearnerND, |
16 | 16 |
make_datasaver, |
17 |
+ SequenceLearner, |
|
17 | 18 |
) |
18 | 19 |
from adaptive.notebook_integration import ( |
19 | 20 |
active_plotting_tasks, |
... | ... |
@@ -36,6 +37,7 @@ __all__ = [ |
36 | 37 |
"Learner2D", |
37 | 38 |
"LearnerND", |
38 | 39 |
"make_datasaver", |
40 |
+ "SequenceLearner", |
|
39 | 41 |
"active_plotting_tasks", |
40 | 42 |
"live_plot", |
41 | 43 |
"notebook_extension", |
... | ... |
@@ -10,6 +10,7 @@ from adaptive.learner.integrator_learner import IntegratorLearner |
10 | 10 |
from adaptive.learner.learner1D import Learner1D |
11 | 11 |
from adaptive.learner.learner2D import Learner2D |
12 | 12 |
from adaptive.learner.learnerND import LearnerND |
13 |
+from adaptive.learner.sequence_learner import SequenceLearner |
|
13 | 14 |
|
14 | 15 |
__all__ = [ |
15 | 16 |
"AverageLearner", |
... | ... |
@@ -21,6 +22,7 @@ __all__ = [ |
21 | 22 |
"Learner1D", |
22 | 23 |
"Learner2D", |
23 | 24 |
"LearnerND", |
25 |
+ "SequenceLearner", |
|
24 | 26 |
] |
25 | 27 |
|
26 | 28 |
with suppress(ImportError): |
27 | 29 |
new file mode 100644 |
... | ... |
@@ -0,0 +1,82 @@ |
1 |
+from copy import copy |
|
2 |
+import sys |
|
3 |
+ |
|
4 |
+from adaptive.learner.base_learner import BaseLearner |
|
5 |
+ |
|
6 |
+inf = sys.float_info.max |
|
7 |
+ |
|
8 |
+ |
|
9 |
+def ensure_hashable(x): |
|
10 |
+ try: |
|
11 |
+ hash(x) |
|
12 |
+ return x |
|
13 |
+ except TypeError: |
|
14 |
+ return tuple(x) |
|
15 |
+ |
|
16 |
+ |
|
17 |
+class SequenceLearner(BaseLearner): |
|
18 |
+ def __init__(self, function, sequence): |
|
19 |
+ self.function = function |
|
20 |
+ self._to_do_seq = {ensure_hashable(x) for x in sequence} |
|
21 |
+ self._npoints = len(sequence) |
|
22 |
+ self.sequence = copy(sequence) |
|
23 |
+ self.data = {} |
|
24 |
+ self.pending_points = set() |
|
25 |
+ |
|
26 |
+ def ask(self, n, tell_pending=True): |
|
27 |
+ points = [] |
|
28 |
+ loss_improvements = [] |
|
29 |
+ i = 0 |
|
30 |
+ for point in self._to_do_seq: |
|
31 |
+ if i > n: |
|
32 |
+ break |
|
33 |
+ points.append(point) |
|
34 |
+ loss_improvements.append(inf / self._npoints) |
|
35 |
+ i += 1 |
|
36 |
+ |
|
37 |
+ if tell_pending: |
|
38 |
+ for p in points: |
|
39 |
+ self.tell_pending(p) |
|
40 |
+ |
|
41 |
+ return points, loss_improvements |
|
42 |
+ |
|
43 |
+ def _get_data(self): |
|
44 |
+ return self.data |
|
45 |
+ |
|
46 |
+ def _set_data(self, data): |
|
47 |
+ if data: |
|
48 |
+ self.tell_many(*zip(*data.items())) |
|
49 |
+ |
|
50 |
+ def loss(self, real=True): |
|
51 |
+ if not (self._to_do_seq or self.pending_points): |
|
52 |
+ return 0 |
|
53 |
+ else: |
|
54 |
+ npoints = self.npoints + (0 if real else len(self.pending_points)) |
|
55 |
+ return inf / npoints |
|
56 |
+ |
|
57 |
+ def remove_unfinished(self): |
|
58 |
+ for p in self.pending_points: |
|
59 |
+ self._to_do_seq.add(p) |
|
60 |
+ self.pending_points = set() |
|
61 |
+ |
|
62 |
+ def tell(self, point, value): |
|
63 |
+ self.data[point] = value |
|
64 |
+ self.pending_points.discard(point) |
|
65 |
+ self._to_do_seq.discard(point) |
|
66 |
+ |
|
67 |
+ def tell_pending(self, point): |
|
68 |
+ self.pending_points.add(point) |
|
69 |
+ self._to_do_seq.discard(point) |
|
70 |
+ |
|
71 |
+ def done(self): |
|
72 |
+ return not self._to_do_seq and not self.pending_points |
|
73 |
+ |
|
74 |
+ def result(self): |
|
75 |
+ """Get back the data in the same order as ``sequence``.""" |
|
76 |
+ if not self.done(): |
|
77 |
+ raise Exception("Learner is not yet complete.") |
|
78 |
+ return [self.data[ensure_hashable(x)] for x in self.sequence] |
|
79 |
+ |
|
80 |
+ @property |
|
81 |
+ def npoints(self): |
|
82 |
+ return len(self.data) |
... | ... |
@@ -24,6 +24,7 @@ from adaptive.learner import ( |
24 | 24 |
Learner1D, |
25 | 25 |
Learner2D, |
26 | 26 |
LearnerND, |
27 |
+ SequenceLearner, |
|
27 | 28 |
) |
28 | 29 |
from adaptive.runner import simple |
29 | 30 |
|
... | ... |
@@ -116,6 +117,7 @@ def quadratic(x, m: uniform(0, 10), b: uniform(0, 1)): |
116 | 117 |
|
117 | 118 |
|
118 | 119 |
@learn_with(Learner1D, bounds=(-1, 1)) |
120 |
+@learn_with(SequenceLearner, sequence=np.linspace(-1, 1, 201)) |
|
119 | 121 |
def linear_with_peak(x, d: uniform(-1, 1)): |
120 | 122 |
a = 0.01 |
121 | 123 |
return x + a ** 2 / (a ** 2 + (x - d) ** 2) |
... | ... |
@@ -123,6 +125,7 @@ def linear_with_peak(x, d: uniform(-1, 1)): |
123 | 125 |
|
124 | 126 |
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1))) |
125 | 127 |
@learn_with(Learner2D, bounds=((-1, 1), (-1, 1))) |
128 |
+@learn_with(SequenceLearner, sequence=np.random.rand(1000, 2)) |
|
126 | 129 |
def ring_of_fire(xy, d: uniform(0.2, 1)): |
127 | 130 |
a = 0.2 |
128 | 131 |
x, y = xy |
... | ... |
@@ -130,12 +133,14 @@ def ring_of_fire(xy, d: uniform(0.2, 1)): |
130 | 133 |
|
131 | 134 |
|
132 | 135 |
@learn_with(LearnerND, bounds=((-1, 1), (-1, 1), (-1, 1))) |
136 |
+@learn_with(SequenceLearner, sequence=np.random.rand(1000, 3)) |
|
133 | 137 |
def sphere_of_fire(xyz, d: uniform(0.2, 1)): |
134 | 138 |
a = 0.2 |
135 | 139 |
x, y, z = xyz |
136 | 140 |
return x + math.exp(-(x ** 2 + y ** 2 + z ** 2 - d ** 2) ** 2 / a ** 4) + z ** 2 |
137 | 141 |
|
138 | 142 |
|
143 |
+@learn_with(SequenceLearner, sequence=range(1000)) |
|
139 | 144 |
@learn_with(AverageLearner, rtol=1) |
140 | 145 |
def gaussian(n): |
141 | 146 |
return random.gauss(0, 1) |
... | ... |
@@ -247,7 +252,7 @@ def test_learner_accepts_lists(learner_type, bounds): |
247 | 252 |
simple(learner, goal=lambda l: l.npoints > 10) |
248 | 253 |
|
249 | 254 |
|
250 |
-@run_with(Learner1D, Learner2D, LearnerND) |
|
255 |
+@run_with(Learner1D, Learner2D, LearnerND, SequenceLearner) |
|
251 | 256 |
def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs): |
252 | 257 |
"""Adding already existing data is an idempotent operation. |
253 | 258 |
|
... | ... |
@@ -283,7 +288,7 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs): |
283 | 288 |
|
284 | 289 |
# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55) |
285 | 290 |
# but we xfail it now, as Learner2D will be deprecated anyway |
286 |
-@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner) |
|
291 |
+@run_with(Learner1D, xfail(Learner2D), LearnerND, AverageLearner, SequenceLearner) |
|
287 | 292 |
def test_adding_non_chosen_data(learner_type, f, learner_kwargs): |
288 | 293 |
"""Adding data for a point that was not returned by 'ask'.""" |
289 | 294 |
# XXX: learner, control and bounds are not defined |
... | ... |
@@ -429,7 +434,12 @@ def test_learner_performance_is_invariant_under_scaling( |
429 | 434 |
|
430 | 435 |
|
431 | 436 |
@run_with( |
432 |
- Learner1D, Learner2D, LearnerND, AverageLearner, with_all_loss_functions=False |
|
437 |
+ Learner1D, |
|
438 |
+ Learner2D, |
|
439 |
+ LearnerND, |
|
440 |
+ AverageLearner, |
|
441 |
+ SequenceLearner, |
|
442 |
+ with_all_loss_functions=False, |
|
433 | 443 |
) |
434 | 444 |
def test_balancing_learner(learner_type, f, learner_kwargs): |
435 | 445 |
"""Test if the BalancingLearner works with the different types of learners.""" |
... | ... |
@@ -474,6 +484,7 @@ def test_balancing_learner(learner_type, f, learner_kwargs): |
474 | 484 |
AverageLearner, |
475 | 485 |
maybe_skip(SKOptLearner), |
476 | 486 |
IntegratorLearner, |
487 |
+ SequenceLearner, |
|
477 | 488 |
with_all_loss_functions=False, |
478 | 489 |
) |
479 | 490 |
def test_saving(learner_type, f, learner_kwargs): |
... | ... |
@@ -504,6 +515,7 @@ def test_saving(learner_type, f, learner_kwargs): |
504 | 515 |
AverageLearner, |
505 | 516 |
maybe_skip(SKOptLearner), |
506 | 517 |
IntegratorLearner, |
518 |
+ SequenceLearner, |
|
507 | 519 |
with_all_loss_functions=False, |
508 | 520 |
) |
509 | 521 |
def test_saving_of_balancing_learner(learner_type, f, learner_kwargs): |
... | ... |
@@ -541,6 +553,7 @@ def test_saving_of_balancing_learner(learner_type, f, learner_kwargs): |
541 | 553 |
AverageLearner, |
542 | 554 |
maybe_skip(SKOptLearner), |
543 | 555 |
IntegratorLearner, |
556 |
+ SequenceLearner, |
|
544 | 557 |
with_all_loss_functions=False, |
545 | 558 |
) |
546 | 559 |
def test_saving_with_datasaver(learner_type, f, learner_kwargs): |