... | ... |
@@ -445,11 +445,11 @@ class LearnerND(BaseLearner): |
445 | 445 |
# n = int(0.658 / sqrt(volumes(ip).min())) # TODO fix this calculation |
446 | 446 |
n = 50 |
447 | 447 |
|
448 |
- xs = ys = np.linspace(-0.5, 0.5, n) |
|
448 |
+ xs = ys = np.linspace(0, 1, n) |
|
449 | 449 |
xs = xs[:, None] |
450 | 450 |
ys = ys[None, :] |
451 | 451 |
i = values.index(None) |
452 |
- j = values.index(None) |
|
452 |
+ j = values.index(None, i+1) |
|
453 | 453 |
|
454 | 454 |
bx, by = self.bounds[i], self.bounds[j] |
455 | 455 |
|
... | ... |
@@ -247,6 +247,94 @@ |
247 | 247 |
" learner2.plot(n, tri_alpha=0.4) + learner.plot(tri_alpha=0.4)).cols(2)" |
248 | 248 |
] |
249 | 249 |
}, |
250 |
+ { |
|
251 |
+ "cell_type": "markdown", |
|
252 |
+ "metadata": {}, |
|
253 |
+ "source": [ |
|
254 |
+ "# N-dimensional function learner\n", |
|
255 |
+ "Appart from the 1d and 2d learner, we can also learn N-d functions: $\\ f: ℝ^N → ℝ, N \\ge 2$\n", |
|
256 |
+ "\n", |
|
257 |
+ "Do keep in mind the speed of the learner drops quickly with increasing number of dimensions." |
|
258 |
+ ] |
|
259 |
+ }, |
|
260 |
+ { |
|
261 |
+ "cell_type": "code", |
|
262 |
+ "execution_count": null, |
|
263 |
+ "metadata": {}, |
|
264 |
+ "outputs": [], |
|
265 |
+ "source": [ |
|
266 |
+ "# this step takes a lot of time, it will finish at about 3300 points, which can take up to 6 minutes\n", |
|
267 |
+ "\n", |
|
268 |
+ "def sphere(xyz):\n", |
|
269 |
+ " x, y, z = xyz\n", |
|
270 |
+ " a = 0.4\n", |
|
271 |
+ " return x + z**2 + np.exp(-(x**2 + y**2 + z**2 - 0.75**2)**2/a**4)\n", |
|
272 |
+ "\n", |
|
273 |
+ "learner = adaptive.LearnerND(sphere, bounds=[(-1, 1), (-1, 1), (-1, 1)])\n", |
|
274 |
+ "runner = adaptive.Runner(learner, goal=lambda l: l.loss() < 0.02)\n", |
|
275 |
+ "runner.live_info()" |
|
276 |
+ ] |
|
277 |
+ }, |
|
278 |
+ { |
|
279 |
+ "cell_type": "code", |
|
280 |
+ "execution_count": null, |
|
281 |
+ "metadata": {}, |
|
282 |
+ "outputs": [], |
|
283 |
+ "source": [ |
|
284 |
+ "learner.loss()" |
|
285 |
+ ] |
|
286 |
+ }, |
|
287 |
+ { |
|
288 |
+ "cell_type": "code", |
|
289 |
+ "execution_count": null, |
|
290 |
+ "metadata": {}, |
|
291 |
+ "outputs": [], |
|
292 |
+ "source": [ |
|
293 |
+ "# We plot 2d slices of the 3d method at different values for x\n", |
|
294 |
+ "(learner.plot_slice((0.0, None, None), n=100).relabel(\"x=0\") + \n", |
|
295 |
+ " learner.plot_slice((0.2, None, None), n=100).relabel(\"x=0.2\") +\n", |
|
296 |
+ " learner.plot_slice((0.4, None, None), n=100).relabel(\"x=0.4\") + \n", |
|
297 |
+ " learner.plot_slice((0.6, None, None), n=100).relabel(\"x=0.6\") + \n", |
|
298 |
+ " learner.plot_slice((0.8, None, None), n=100).relabel(\"x=0.8\")\n", |
|
299 |
+ ").cols(2)" |
|
300 |
+ ] |
|
301 |
+ }, |
|
302 |
+ { |
|
303 |
+ "cell_type": "code", |
|
304 |
+ "execution_count": null, |
|
305 |
+ "metadata": {}, |
|
306 |
+ "outputs": [], |
|
307 |
+ "source": [ |
|
308 |
+ "# We can also plot slices in a different direction\n", |
|
309 |
+ "(learner.plot_slice((None, None, 0.0), n=100).relabel(\"z=0\") + \n", |
|
310 |
+ " learner.plot_slice((None, None, 0.2), n=100).relabel(\"z=0.2\") +\n", |
|
311 |
+ " learner.plot_slice((None, None, 0.4), n=100).relabel(\"z=0.4\") + \n", |
|
312 |
+ " learner.plot_slice((None, None, 0.6), n=100).relabel(\"z=0.6\") + \n", |
|
313 |
+ " learner.plot_slice((None, None, 0.8), n=100).relabel(\"z=0.8\")\n", |
|
314 |
+ ").cols(2)" |
|
315 |
+ ] |
|
316 |
+ }, |
|
317 |
+ { |
|
318 |
+ "cell_type": "code", |
|
319 |
+ "execution_count": null, |
|
320 |
+ "metadata": {}, |
|
321 |
+ "outputs": [], |
|
322 |
+ "source": [ |
|
323 |
+ "# Or we can plot 1d slices\n", |
|
324 |
+ "(learner.plot_slice((None, 0.0, 0.0), n=100).relabel(\"y=0 z=0\") + \n", |
|
325 |
+ " learner.plot_slice((None, 0.0, 0.5), n=100).relabel(\"y=0 z=0.5\") +\n", |
|
326 |
+ " learner.plot_slice((0.0, None, 0.0), n=100).relabel(\"x=0 z=0\") + \n", |
|
327 |
+ " learner.plot_slice((0.0, 0.5, None), n=100).relabel(\"x=0.5 y=0\")\n", |
|
328 |
+ ").cols(2)" |
|
329 |
+ ] |
|
330 |
+ }, |
|
331 |
+ { |
|
332 |
+ "cell_type": "markdown", |
|
333 |
+ "metadata": {}, |
|
334 |
+ "source": [ |
|
335 |
+ "The plots show some wobbles while the original function was smooth, this is a result of the fact that the learner chooses points in 3 dimensions and the simplices are not in the same face as we try to interpolate our lines. However, as always, when you sample more points the graph will become gradually smoother." |
|
336 |
+ ] |
|
337 |
+ }, |
|
250 | 338 |
{ |
251 | 339 |
"cell_type": "markdown", |
252 | 340 |
"metadata": {}, |