...
|
...
|
@@ -10,7 +10,7 @@ import scipy.spatial
|
10
|
10
|
from ..notebook_integration import ensure_holoviews
|
11
|
11
|
from .base_learner import BaseLearner
|
12
|
12
|
|
13
|
|
-from .triangulation import Triangulation
|
|
13
|
+from .triangulation import Triangulation, point_in_simplex, circumsphere
|
14
|
14
|
import random
|
15
|
15
|
from math import factorial
|
16
|
16
|
|
...
|
...
|
@@ -115,12 +115,19 @@ def choose_point_in_simplex(simplex, transform=None):
|
115
|
115
|
if transform is not None:
|
116
|
116
|
simplex = np.dot(simplex, transform)
|
117
|
117
|
|
118
|
|
- distances = scipy.spatial.distance.pdist(simplex)
|
119
|
|
- distance_matrix = scipy.spatial.distance.squareform(distances)
|
120
|
|
- i, j = np.unravel_index(np.argmax(distance_matrix), distance_matrix.shape)
|
|
118
|
+ # choose center iff the shape of the simplex is nice,
|
|
119
|
+ # otherwise the longest edge
|
|
120
|
+ center, radius = circumsphere(simplex)
|
|
121
|
+ if point_in_simplex(center, simplex):
|
|
122
|
+ point = np.average(simplex, axis=0)
|
|
123
|
+ else:
|
|
124
|
+ distances = scipy.spatial.distance.pdist(simplex)
|
|
125
|
+ distance_matrix = scipy.spatial.distance.squareform(distances)
|
|
126
|
+ i, j = np.unravel_index(np.argmax(distance_matrix), distance_matrix.shape)
|
|
127
|
+
|
|
128
|
+ point = (simplex[i, :] + simplex[j, :]) / 2
|
121
|
129
|
|
122
|
|
- point = (simplex[i, :] + simplex[j, :]) / 2
|
123
|
|
- return np.linalg.solve(transform, point)
|
|
130
|
+ return np.linalg.solve(transform, point) # undo the transform
|
124
|
131
|
|
125
|
132
|
|
126
|
133
|
class LearnerND(BaseLearner):
|