Browse code

2D: suggest new point on the edge of the triangle if the triangle is very narrow

Bas Nijholt authored on 21/11/2017 19:04:57
Showing 1 changed files
... ...
@@ -1,6 +1,7 @@
1 1
 # -*- coding: utf-8 -*-
2 2
 import collections
3 3
 import itertools
4
+import math
4 5
 
5 6
 import holoviews as hv
6 7
 import numpy as np
... ...
@@ -48,6 +49,43 @@ def _default_loss_per_triangle(ip):
48 49
     return losses
49 50
 
50 51
 
52
+def choose_point_in_triangle(triangle, max_badness):
53
+    """Choose a new point in inside a triangle.
54
+
55
+    If the ratio of the longest edge of the triangle squared
56
+    over the area is bigger than the `max_badness` the new point
57
+    is chosen on the middle of the longest edge. Otherwise
58
+    a point in the center of the triangle is chosen. The badness
59
+    is 1 for a equilateral triangle.
60
+
61
+    Parameters
62
+    ----------
63
+    triangle : numpy array
64
+        The coordinates of a triangle with shape (3, 2)
65
+    max_badness : int
66
+        The badness at which the point is either chosen on a edge or
67
+        in the middle.
68
+
69
+    Returns
70
+    -------
71
+    point : numpy array
72
+        The x and y coordinate of the suggested new point.
73
+    """
74
+    a, b, c = triangle
75
+    area = 0.5 * np.cross(b - a, c - a)
76
+    triangle_roll = np.roll(triangle, 1, axis=0)
77
+    edge_lengths = np.linalg.norm(triangle - triangle_roll, axis=1)
78
+    i = edge_lengths.argmax()
79
+
80
+    # We multiply by sqrt(3) / 4 such that a equilateral triangle has badness=1
81
+    badness = (edge_lengths[i]**2 / area) * (math.sqrt(3) / 4)
82
+    if badness > max_badness:
83
+        point = (triangle_roll[i] + triangle[i]) / 2
84
+    else:
85
+        point = triangle.mean(axis=0)
86
+    return point
87
+
88
+
51 89
 class Learner2D(BaseLearner):
52 90
     """Learns and predicts a function 'f: ℝ^2 → ℝ^N'.
53 91
 
... ...
@@ -255,9 +293,9 @@ class Learner2D(BaseLearner):
255 293
 
256 294
         for j, _ in enumerate(losses):
257 295
             jsimplex = np.argmax(losses)
258
-            point_new = ip.tri.points[ip.tri.vertices[jsimplex]]
259
-            point_new = self.unscale(point_new.mean(axis=-2))
260
-            point_new = np.clip(point_new, *zip(*self.bounds))
296
+            triangle = ip.tri.points[ip.tri.vertices[jsimplex]]
297
+            point_new = choose_point_in_triangle(triangle, max_badness=5)
298
+            point_new = np.clip(self.unscale(point_new), *zip(*self.bounds))
261 299
 
262 300
             # Check if it is really new
263 301
             if point_exists(point_new):
... ...
@@ -323,7 +361,7 @@ class Learner2D(BaseLearner):
323 361
             x = np.linspace(-0.5, 0.5, n_x)
324 362
             y = np.linspace(-0.5, 0.5, n_y)
325 363
             ip = self.ip()
326
-            z = ip(x[:, None], y[None, :])
364
+            z = ip(x[:, None], y[None, :]).squeeze()
327 365
             plot = hv.Image(np.rot90(z), bounds=lbrt)
328 366
 
329 367
             if triangles_alpha: