1D: fix the rare case where the right boundary point exists before the left bound
Closes #94
See merge request qt/adaptive!95
... | ... |
@@ -261,17 +261,15 @@ class Learner1D(BaseLearner): |
261 | 261 |
return [], [] |
262 | 262 |
|
263 | 263 |
# If the bounds have not been chosen yet, we choose them first. |
264 |
- points = [b for b in self.bounds if b not in self.data |
|
265 |
- and b not in self.pending_points] |
|
264 |
+ missing_bounds = [b for b in self.bounds if b not in self.data |
|
265 |
+ and b not in self.pending_points] |
|
266 | 266 |
|
267 |
- if len(points) == 2: |
|
268 |
- # First time |
|
267 |
+ if missing_bounds: |
|
269 | 268 |
loss_improvements = [np.inf] * n |
270 |
- points = np.linspace(*self.bounds, n).tolist() |
|
271 |
- elif len(points) == 1: |
|
272 |
- # Second time, if we previously returned just self.bounds[0] |
|
273 |
- loss_improvements = [np.inf] * n |
|
274 |
- points = np.linspace(*self.bounds, n + 1)[1:].tolist() |
|
269 |
+ # XXX: should check if points are present in self.data or self.pending_points |
|
270 |
+ points = np.linspace(*self.bounds, n + 2 - len(missing_bounds)).tolist() |
|
271 |
+ if len(missing_bounds) == 1: |
|
272 |
+ points = points[1:] if missing_bounds[0] == self.bounds[1] else points[:-1] |
|
275 | 273 |
else: |
276 | 274 |
def xs(x_left, x_right, n): |
277 | 275 |
if n == 1: |
... | ... |
@@ -363,7 +363,7 @@ def test_learner1d_first_iteration(): |
363 | 363 |
"""Edge cases where we ask for a few points at the start.""" |
364 | 364 |
learner = Learner1D(lambda x: None, (-1, 1)) |
365 | 365 |
points, loss_improvements = learner.ask(2) |
366 |
- assert set(points) == set([-1, 1]) |
|
366 |
+ assert set(points) == set(learner.bounds) |
|
367 | 367 |
|
368 | 368 |
learner = Learner1D(lambda x: None, (-1, 1)) |
369 | 369 |
points, loss_improvements = learner.ask(3) |
... | ... |
@@ -371,17 +371,27 @@ def test_learner1d_first_iteration(): |
371 | 371 |
|
372 | 372 |
learner = Learner1D(lambda x: None, (-1, 1)) |
373 | 373 |
points, loss_improvements = learner.ask(1) |
374 |
- assert len(points) == 1 and points[0] in [-1, 1] |
|
374 |
+ assert len(points) == 1 and points[0] in learner.bounds |
|
375 | 375 |
rest = set([-1, 0, 1]) - set(points) |
376 | 376 |
points, loss_improvements = learner.ask(2) |
377 | 377 |
assert set(points) == set(rest) |
378 | 378 |
|
379 | 379 |
learner = Learner1D(lambda x: None, (-1, 1)) |
380 | 380 |
points, loss_improvements = learner.ask(1) |
381 |
- to_see = set([-1, 1]) - set(points) |
|
381 |
+ to_see = set(learner.bounds) - set(points) |
|
382 | 382 |
points, loss_improvements = learner.ask(1) |
383 | 383 |
assert set(points) == set(to_see) |
384 | 384 |
|
385 |
+ learner = Learner1D(lambda x: None, (-1, 1)) |
|
386 |
+ learner.tell(1, 0) |
|
387 |
+ points, loss_improvements = learner.ask(1) |
|
388 |
+ assert points == [-1] |
|
389 |
+ |
|
390 |
+ learner = Learner1D(lambda x: None, (-1, 1)) |
|
391 |
+ learner.tell(-1, 0) |
|
392 |
+ points, loss_improvements = learner.ask(1) |
|
393 |
+ assert points == [1] |
|
394 |
+ |
|
385 | 395 |
|
386 | 396 |
def _run_on_discontinuity(x_0, bounds): |
387 | 397 |
|