... | ... |
@@ -278,39 +278,36 @@ class Learner1D(BaseLearner): |
278 | 278 |
loss_improvements = [np.inf] * n |
279 | 279 |
points = np.linspace(*self.bounds, n + 1)[1:].tolist() |
280 | 280 |
else: |
281 |
- def xs(x, n): |
|
281 |
+ def xs(x_left, x_right, n): |
|
282 | 282 |
if n == 1: |
283 |
+ # This is just an optimization |
|
283 | 284 |
return [] |
284 | 285 |
else: |
285 |
- step = (x[1] - x[0]) / n |
|
286 |
- return [x[0] + step * i for i in range(1, n)] |
|
286 |
+ step = (x_right - x_left) / n |
|
287 |
+ return [x_left + step * i for i in range(1, n)] |
|
287 | 288 |
|
288 | 289 |
# Calculate how many points belong to each interval. |
289 | 290 |
x_scale = self._scale[0] |
290 | 291 |
|
291 | 292 |
quals = [] |
292 | 293 |
for ((x_left, x_right), loss) in self.losses_combined.items(): |
293 |
- dx = x_right - x_left |
|
294 |
- if abs(dx) < self._dx_eps: |
|
295 |
- # The interval is too small and should not be subdivided |
|
296 |
- quality = 0 |
|
297 |
- elif not math.isinf(loss): |
|
298 |
- quality = -loss |
|
299 |
- else: |
|
300 |
- quality = -dx / x_scale |
|
294 |
+ quality = -loss if not math.isinf(loss) else -(x_right - x_left) / x_scale |
|
301 | 295 |
quals.append((quality, (x_left, x_right), 1)) |
302 | 296 |
|
303 | 297 |
heapq.heapify(quals) |
304 | 298 |
|
305 | 299 |
for point_number in range(n): |
306 | 300 |
quality, x, n = quals[0] |
301 |
+ if abs(x[1] - x[0]) / (n + 1) <= self._dx_eps: |
|
302 |
+ # The interval is too small and should not be subdivided |
|
303 |
+ quality = np.inf |
|
307 | 304 |
heapq.heapreplace(quals, (quality * n / (n + 1), x, n + 1)) |
308 | 305 |
|
309 |
- points = list(itertools.chain.from_iterable(xs(x, n) |
|
310 |
- for quality, x, n in quals)) |
|
306 |
+ points = list(itertools.chain.from_iterable( |
|
307 |
+ xs(*x, n) for quality, x, n in quals)) |
|
311 | 308 |
|
312 | 309 |
loss_improvements = list(itertools.chain.from_iterable( |
313 |
- itertools.repeat(-quality, n-1) |
|
310 |
+ itertools.repeat(-quality, n - 1) |
|
314 | 311 |
for quality, x, n in quals)) |
315 | 312 |
|
316 | 313 |
if add_data: |