Browse code

add LearnerND to adaptive in standard import add plot_slice function to plot a slice of the evaluated function adapt the ask function to return a point

Jorn Hoofwijk authored on 27/05/2018 17:54:04 • Bas Nijholt committed on 11/07/2018 05:27:20
Showing 4 changed files
... ...
@@ -8,7 +8,7 @@ from . import learner
8 8
 from . import runner
9 9
 from . import utils
10 10
 
11
-from .learner import (Learner1D, Learner2D, AverageLearner,
11
+from .learner import (Learner1D, Learner2D, LearnerND, AverageLearner,
12 12
                       BalancingLearner, make_datasaver, DataSaver,
13 13
                       IntegratorLearner)
14 14
 
... ...
@@ -6,6 +6,7 @@ from .base_learner import BaseLearner
6 6
 from .balancing_learner import BalancingLearner
7 7
 from .learner1D import Learner1D
8 8
 from .learner2D import Learner2D
9
+from .learnerND import LearnerND
9 10
 from .integrator_learner import IntegratorLearner
10 11
 from .data_saver import DataSaver, make_datasaver
11 12
 
... ...
@@ -14,9 +14,7 @@ from .base_learner import BaseLearner
14 14
 def volumes(ip):
15 15
     p = ip.tri.points[ip.tri.vertices]
16 16
     matrices = p[:, :-1, :] - p[:, -1, None, :]
17
-    # vol = abs(q[:, 0, 0] * q[:, 1, 1] - q[:, 0, 1] * q[:, 1, 0]) / 2
18 17
     n_points, points_per_simplex, dim = np.shape(p)
19
-    # assert(points_per_simplex == dim + 1)
20 18
 
21 19
     # See https://www.jstor.org/stable/2315353
22 20
     vols = np.abs(np.linalg.det(matrices)) / np.math.factorial(dim)
... ...
@@ -42,9 +40,11 @@ def uniform_loss(ip):
42 40
     """
43 41
     return volumes(ip)
44 42
 
43
+
45 44
 def default_loss(ip):
46 45
     return uniform_loss(ip)
47 46
 
47
+
48 48
 def choose_point_in_simplex(simplex):
49 49
     """Choose a new point in inside a simplex.
50 50
 
... ...
@@ -68,7 +68,7 @@ def choose_point_in_simplex(simplex):
68 68
 
69 69
     for i in range(2, N+1):
70 70
         for j in range(i):
71
-            length = np.norm(simplex[i, :] - simplex[j, :])
71
+            length = np.linalg.norm(simplex[i, :] - simplex[j, :])
72 72
             if length > longest:
73 73
                 longest = length
74 74
                 point = (simplex[i, :] + simplex[j, :]) / 2
... ...
@@ -162,7 +162,7 @@ class LearnerND(BaseLearner):
162 162
     def unscale(self, points):
163 163
         # this functions converts the points from equalised coordinates to real coordinates
164 164
         points = np.asarray(points, dtype=float)
165
-        return points * self._mean + self._ptp_scale
165
+        return points * self._ptp_scale + self._mean
166 166
 
167 167
     @property
168 168
     def npoints(self):
... ...
@@ -180,10 +180,10 @@ class LearnerND(BaseLearner):
180 180
 
181 181
     @property
182 182
     def bounds_are_done(self):
183
-        return not any((p in self._pending or p in self._stack)
184
-                       for p in self._bounds_points)
183
+        return all(p in self.data for p in self._bounds_points)
185 184
 
186 185
     def ip(self):
186
+        # raise DeprecationWarning('usage of LinearNDInterpolator should be reduced')
187 187
         # returns a scipy.interpolate.LinearNDInterpolator object with the given data as sources
188 188
         if self._ip is None:
189 189
             points = self.scale(list(self.data.keys()))
... ...
@@ -201,68 +201,42 @@ class LearnerND(BaseLearner):
201 201
             self._pending.discard(point)
202 202
             self._ip = None
203 203
 
204
-        self._stack.pop(point, None)
205
-
206
-    def _fill_stack(self, stack_till=1):
207
-        # Do notice that it is a heap and not a stack
208
-        # TODO modify this function
209
-        if len(self.data) + len(self._interp) < self.ndim + 1:
210
-            raise ValueError("too few points...")
204
+    def ask(self, n=1, tell=True):
205
+        # Complexity: O(N log N + n * N)
206
+        # TODO adapt this function
207
+        # Even if tell is False we add the point such that _fill_stack
208
+        # will return new points, later we remove these points if needed.
209
+        assert(n == 1)
211 210
 
212
-        # Interpolate
213
-        ip = self.ip_combined()
211
+        new_points = []
212
+        new_loss_improvements = []
213
+        if not self.bounds_are_done:
214
+            bounds_to_do = [p for p in self._bounds_points if p not in self.data and p not in self._pending]
215
+            new_points = bounds_to_do[:n]
216
+            new_loss_improvements = [-np.inf] * n
217
+            n = n - len(new_points)
214 218
 
215
-        losses = self.loss_per_triangle(ip)  # compute the losses of all interpolated triangles
219
+        if n > 0:
220
+            # Interpolate
221
+            ip = self.ip()  # O(N log N) for triangulation
222
+            losses = self.loss_per_simplex(ip)  # O(N), compute the losses of all interpolated triangles
216 223
 
217
-        points_new = []
218
-        losses_new = []
219
-        for j, _ in enumerate(losses):
220
-            jsimplex = np.argmax(losses)  # Find the index of the simplex with the highest loss
221
-            triangle = ip.tri.points[ip.tri.vertices[jsimplex]]  # get the corner points the the worst simplex
222
-            point_new = choose_point_in_simplex(triangle, max_badness=5)  # choose a new point in the triangle
223
-            point_new = tuple(self.unscale(point_new))  # relative coordinates to real coordinates
224
-            loss_new = losses[jsimplex]
224
+            for _ in range(n):
225
+                simplex_index = np.argmax(losses)  # O(N), Find the index of the simplex with the highest loss
226
+                simplex = ip.tri.points[ip.tri.vertices[simplex_index]]  # get the corner points the the worst simplex
227
+                point_new = choose_point_in_simplex(simplex)  # choose a new point in the triangle
228
+                point_new = tuple(self.unscale(point_new))  # relative coordinates to real coordinates
229
+                loss_new = losses[simplex_index]
225 230
 
226
-            points_new.append(point_new)
227
-            losses_new.append(loss_new)
231
+                new_points.append(point_new)
232
+                new_loss_improvements.append(loss_new)
228 233
 
229
-            self._stack[point_new] = loss_new
234
+                losses[simplex_index] = -np.inf
230 235
 
231
-            if len(self._stack) >= stack_till:
232
-                break
233
-            else:
234
-                losses[jsimplex] = -np.inf
236
+        if tell:
237
+            self.tell(new_points, itertools.repeat(None))
235 238
 
236
-        return points_new, losses_new
237
-
238
-    def ask(self, n, tell=True):
239
-        # TODO adapt this function
240
-        # Even if tell is False we add the point such that _fill_stack
241
-        # will return new points, later we remove these points if needed.
242
-        points = list(self._stack.keys())
243
-        loss_improvements = list(self._stack.values())
244
-        n_left = n - len(points)
245
-        self.tell(points[:n], itertools.repeat(None))
246
-
247
-        while n_left > 0:
248
-            # The while loop is needed because `stack_till` could be larger
249
-            # than the number of triangles between the points. Therefore
250
-            # it could fill up till a length smaller than `stack_till`.
251
-            new_points, new_loss_improvements = self._fill_stack(
252
-                stack_till=max(n_left, self.stack_size))
253
-            self.tell(new_points[:n_left], itertools.repeat(None))
254
-            n_left -= len(new_points)
255
-
256
-            points += new_points
257
-            loss_improvements += new_loss_improvements
258
-
259
-        if not tell:
260
-            self._stack = OrderedDict(zip(points[:self.stack_size],
261
-                                          loss_improvements))
262
-            for point in points[:n]:
263
-                self._interp.discard(point)
264
-
265
-        return points[:n], loss_improvements[:n]
239
+        return new_points, new_loss_improvements
266 240
 
267 241
     # def loss(self, real=True):
268 242
     #     if not self.bounds_are_done:
... ...
@@ -285,7 +259,6 @@ class LearnerND(BaseLearner):
285 259
             if p not in self.data:
286 260
                 self._stack[p] = np.inf
287 261
 
288
-
289 262
     def plot(self, n=None, tri_alpha=0):
290 263
         hv = ensure_holoviews()
291 264
         if self.vdim > 1:
... ...
@@ -300,7 +273,7 @@ class LearnerND(BaseLearner):
300 273
             if n is None:
301 274
                 # Calculate how many grid points are needed.
302 275
                 # factor from A=√3/4 * a² (equilateral triangle)
303
-                n = int(0.658 / sqrt(areas(ip).min()))
276
+                n = int(0.658 / sqrt(volumes(ip).min()))
304 277
                 n = max(n, 10)
305 278
 
306 279
             x = y = np.linspace(-0.5, 0.5, n)
... ...
@@ -326,3 +299,32 @@ class LearnerND(BaseLearner):
326 299
         no_hover = dict(plot=dict(inspection_policy=None, tools=[]))
327 300
 
328 301
         return im.opts(style=im_opts) * tris.opts(style=tri_opts, **no_hover)
302
+
303
+    def plot_slice(self, values, n=None, tri_alpha=0):
304
+        values = list(values)
305
+        count_none = values.count(None)
306
+        assert(count_none == 1 or count_none == 2)
307
+        if count_none == 2:
308
+            raise NotImplementedError('plot slice does currently not support 2D plotting')
309
+        else:
310
+            hv = ensure_holoviews()
311
+            if not self.data:
312
+                p = hv.Scatter([]) * hv.Path([])
313
+            elif not self.vdim > 1:
314
+                ind = values.index(None)
315
+                x = np.linspace(-0.5, 0.5, 500)
316
+                values[ind] = 0
317
+                values = list(self.scale(values))
318
+                values[ind] = x
319
+                ip = self.ip()
320
+                y = ip(*values)
321
+                x = x * self._ptp_scale[ind] + self._mean[ind]
322
+                p = hv.Path((x, y))
323
+            else:
324
+                raise NotImplementedError('multidimensional output not yet supported by plotSlice')
325
+
326
+            # Plot with 5% empty margins such that the boundary points are visible
327
+            margin = 0.05 * self._ptp_scale[ind]
328
+            plot_bounds = (x[0] - margin, x[-1] + margin)
329
+
330
+            return p.redim(x=dict(range=plot_bounds))
329 331
\ No newline at end of file
... ...
@@ -780,9 +780,7 @@
780 780
   {
781 781
    "cell_type": "code",
782 782
    "execution_count": null,
783
-   "metadata": {
784
-    "collapsed": true
785
-   },
783
+   "metadata": {},
786 784
    "outputs": [],
787 785
    "source": [
788 786
     "def g(x, noise_level=0.1):\n",