... | ... |
@@ -683,15 +683,19 @@ class Learner2D(BaseLearner): |
683 | 683 |
self._bounds_points = list(itertools.product(*bounds)) |
684 | 684 |
|
685 | 685 |
# Add the loss improvement to the bounds in the stack |
686 |
- self._stack = [p + (np.inf,) for p in self._bounds_points] |
|
686 |
+ self._bounds_points = [(x / self.x_scale, y / self.y_scale) |
|
687 |
+ for x, y in self._bounds_points] |
|
688 |
+ self._stack = [(*p, np.inf) for p in self._bounds_points] |
|
687 | 689 |
|
688 |
- def f(xy): |
|
690 |
+ self.original_function = function |
|
691 |
+ |
|
692 |
+ def scaled_function(xy): |
|
689 | 693 |
x, y = xy |
690 |
- x /= self.x_scale |
|
691 |
- y /= self.y_scale |
|
692 |
- return function((x, y)) |
|
694 |
+ x *= self.x_scale |
|
695 |
+ y *= self.y_scale |
|
696 |
+ return self.original_function((x, y)) |
|
693 | 697 |
|
694 |
- self.function = f |
|
698 |
+ self.function = scaled_function |
|
695 | 699 |
|
696 | 700 |
@property |
697 | 701 |
def points_combined(self): |