... | ... |
@@ -1,6 +1,7 @@ |
1 | 1 |
# -*- coding: utf-8 -*- |
2 | 2 |
import collections |
3 | 3 |
import itertools |
4 |
+import math |
|
4 | 5 |
|
5 | 6 |
import holoviews as hv |
6 | 7 |
import numpy as np |
... | ... |
@@ -48,6 +49,43 @@ def _default_loss_per_triangle(ip): |
48 | 49 |
return losses |
49 | 50 |
|
50 | 51 |
|
52 |
+def choose_point_in_triangle(triangle, max_badness): |
|
53 |
+ """Choose a new point in inside a triangle. |
|
54 |
+ |
|
55 |
+ If the ratio of the longest edge of the triangle squared |
|
56 |
+ over the area is bigger than the `max_badness` the new point |
|
57 |
+ is chosen on the middle of the longest edge. Otherwise |
|
58 |
+ a point in the center of the triangle is chosen. The badness |
|
59 |
+ is 1 for a equilateral triangle. |
|
60 |
+ |
|
61 |
+ Parameters |
|
62 |
+ ---------- |
|
63 |
+ triangle : numpy array |
|
64 |
+ The coordinates of a triangle with shape (3, 2) |
|
65 |
+ max_badness : int |
|
66 |
+ The badness at which the point is either chosen on a edge or |
|
67 |
+ in the middle. |
|
68 |
+ |
|
69 |
+ Returns |
|
70 |
+ ------- |
|
71 |
+ point : numpy array |
|
72 |
+ The x and y coordinate of the suggested new point. |
|
73 |
+ """ |
|
74 |
+ a, b, c = triangle |
|
75 |
+ area = 0.5 * np.cross(b - a, c - a) |
|
76 |
+ triangle_roll = np.roll(triangle, 1, axis=0) |
|
77 |
+ edge_lengths = np.linalg.norm(triangle - triangle_roll, axis=1) |
|
78 |
+ i = edge_lengths.argmax() |
|
79 |
+ |
|
80 |
+ # We multiply by sqrt(3) / 4 such that a equilateral triangle has badness=1 |
|
81 |
+ badness = (edge_lengths[i]**2 / area) * (math.sqrt(3) / 4) |
|
82 |
+ if badness > max_badness: |
|
83 |
+ point = (triangle_roll[i] + triangle[i]) / 2 |
|
84 |
+ else: |
|
85 |
+ point = triangle.mean(axis=0) |
|
86 |
+ return point |
|
87 |
+ |
|
88 |
+ |
|
51 | 89 |
class Learner2D(BaseLearner): |
52 | 90 |
"""Learns and predicts a function 'f: ℝ^2 → ℝ^N'. |
53 | 91 |
|
... | ... |
@@ -255,9 +293,9 @@ class Learner2D(BaseLearner): |
255 | 293 |
|
256 | 294 |
for j, _ in enumerate(losses): |
257 | 295 |
jsimplex = np.argmax(losses) |
258 |
- point_new = ip.tri.points[ip.tri.vertices[jsimplex]] |
|
259 |
- point_new = self.unscale(point_new.mean(axis=-2)) |
|
260 |
- point_new = np.clip(point_new, *zip(*self.bounds)) |
|
296 |
+ triangle = ip.tri.points[ip.tri.vertices[jsimplex]] |
|
297 |
+ point_new = choose_point_in_triangle(triangle, max_badness=5) |
|
298 |
+ point_new = np.clip(self.unscale(point_new), *zip(*self.bounds)) |
|
261 | 299 |
|
262 | 300 |
# Check if it is really new |
263 | 301 |
if point_exists(point_new): |
... | ... |
@@ -323,7 +361,7 @@ class Learner2D(BaseLearner): |
323 | 361 |
x = np.linspace(-0.5, 0.5, n_x) |
324 | 362 |
y = np.linspace(-0.5, 0.5, n_y) |
325 | 363 |
ip = self.ip() |
326 |
- z = ip(x[:, None], y[None, :]) |
|
364 |
+ z = ip(x[:, None], y[None, :]).squeeze() |
|
327 | 365 |
plot = hv.Image(np.rot90(z), bounds=lbrt) |
328 | 366 |
|
329 | 367 |
if triangles_alpha: |