adaptive/notebook_integration.py
73fa8d28
 import asyncio
b00659d0
 import datetime
06b9200f
 import importlib
243561d6
 import random
9fc92402
 import warnings
176c745c
 from contextlib import suppress
73fa8d28
 
f1e5e354
 _async_enabled = False
b563e096
 _holoviews_enabled = False
 _ipywidgets_enabled = False
 _plotly_enabled = False
73fa8d28
 
9fc92402
 
6b170238
 def notebook_extension(*, _inline_js=True):
b563e096
     """Enable ipywidgets, holoviews, and asyncio notebook integration."""
f1e5e354
     if not in_ipynb():
716dbce8
         raise RuntimeError(
             '"adaptive.notebook_extension()" may only be run '
             "from a Jupyter notebook."
         )
9fc92402
 
b563e096
     global _async_enabled, _holoviews_enabled, _ipywidgets_enabled
 
     # Load holoviews
     try:
008d47e1
         _holoviews_enabled = False  # After closing a notebook the js is gone
b563e096
         if not _holoviews_enabled:
             import holoviews
716dbce8
 
             holoviews.notebook_extension("bokeh", logo=False, inline=_inline_js)
b563e096
             _holoviews_enabled = True
     except ModuleNotFoundError:
716dbce8
         warnings.warn(
19c1e0e6
             "holoviews is not installed; plotting is disabled.", RuntimeWarning
716dbce8
         )
ca1b4b99
 
b563e096
     # Load ipywidgets
f1e5e354
     try:
b563e096
         if not _ipywidgets_enabled:
716dbce8
             import ipywidgets  # noqa: F401
 
b563e096
             _ipywidgets_enabled = True
f1e5e354
     except ModuleNotFoundError:
716dbce8
         warnings.warn(
19c1e0e6
             "ipywidgets is not installed; live_info is disabled.", RuntimeWarning
716dbce8
         )
9fc92402
 
b563e096
     # Enable asyncio integration
     if not _async_enabled:
716dbce8
         get_ipython().magic("gui asyncio")  # noqa: F821
b563e096
         _async_enabled = True
73fa8d28
 
 
416ee3af
 def ensure_holoviews():
     try:
716dbce8
         return importlib.import_module("holoviews")
6b5cd171
     except ModuleNotFoundError:
716dbce8
         raise RuntimeError("holoviews is not installed; plotting is disabled.")
416ee3af
 
 
b563e096
 def ensure_plotly():
     global _plotly_enabled
     try:
         import plotly
716dbce8
 
b563e096
         if not _plotly_enabled:
             import plotly.graph_objs
             import plotly.figure_factory
             import plotly.offline
716dbce8
 
b563e096
             # This injects javascript and should happen only once
             plotly.offline.init_notebook_mode()
             _plotly_enabled = True
         return plotly
     except ModuleNotFoundError:
716dbce8
         raise RuntimeError("plotly is not installed; plotting is disabled.")
b563e096
 
 
f1e5e354
 def in_ipynb():
9fc92402
     try:
f1e5e354
         # If we are running in IPython, then `get_ipython()` is always a global
716dbce8
         return get_ipython().__class__.__name__ == "ZMQInteractiveShell"
f1e5e354
     except NameError:
         return False
73fa8d28
 
 
f1e5e354
 # Fancy displays in the Jupyter notebook
73fa8d28
 
c1dc9bbe
 active_plotting_tasks = dict()
930b4efb
 
 
716dbce8
 def live_plot(runner, *, plotter=None, update_interval=2, name=None, normalize=True):
ae73ef30
     """Live plotting of the learner's data.
 
     Parameters
     ----------
f268f0dc
     runner : `~adaptive.Runner`
ae73ef30
     plotter : function
         A function that takes the learner as a argument and returns a
eef11f0b
         holoviews object. By default ``learner.plot()`` will be called.
ae73ef30
     update_interval : int
         Number of second between the updates of the plot.
     name : hasable
         Name for the `live_plot` task in `adaptive.active_plotting_tasks`.
eef11f0b
         By default the name is None and if another task with the same name
         already exists that other `live_plot` is canceled.
017a5722
     normalize : bool
         Normalize (scale to fit) the frame upon each update.
ae73ef30
 
     Returns
     -------
eef11f0b
     dm : `holoviews.core.DynamicMap`
         The plot that automatically updates every `update_interval`.
ae73ef30
     """
b563e096
     if not _holoviews_enabled:
017a5722
         raise RuntimeError(
             "Live plotting is not enabled; did you run "
             "'adaptive.notebook_extension()'?"
         )
f1e5e354
 
     import holoviews as hv
4e7ec165
     import ipywidgets
     from IPython.display import display
73fa8d28
 
8f6732f8
     if name in active_plotting_tasks:
         active_plotting_tasks[name].cancel()
 
73fa8d28
     def plot_generator():
         while True:
             if not plotter:
                 yield runner.learner.plot()
             else:
                 yield plotter(runner.learner)
 
112e89d2
     streams = [hv.streams.Stream.define("Next")()]
017a5722
     dm = hv.DynamicMap(plot_generator(), streams=streams)
d79512f8
     dm.cache_size = 1
017a5722
 
     if normalize:
1debc154
         # XXX: change when https://github.com/pyviz/holoviews/issues/3637
         # is fixed.
017a5722
         dm = dm.map(lambda obj: obj.opts(framewise=True), hv.Element)
 
716dbce8
     cancel_button = ipywidgets.Button(
         description="cancel live-plot", layout=ipywidgets.Layout(width="150px")
     )
930b4efb
 
73fa8d28
     # Could have used dm.periodic in the following, but this would either spin
     # off a thread (and learner is not threadsafe) or block the kernel.
 
     async def updater():
716dbce8
         event = lambda: hv.streams.Stream.trigger(  # noqa: E731
             dm.streams
         )  # XXX: used to be dm.event()
3f7c1384
         # see https://github.com/pyviz/holoviews/issues/3564
930b4efb
         try:
             while not runner.task.done():
3f7c1384
                 event()
930b4efb
                 await asyncio.sleep(update_interval)
3f7c1384
             event()  # fire off one last update before we die
930b4efb
         finally:
c1dc9bbe
             if active_plotting_tasks[name] is asyncio.Task.current_task():
                 active_plotting_tasks.pop(name, None)
716dbce8
             cancel_button.layout.display = "none"  # remove cancel button
fd606522
 
4e7ec165
     def cancel(_):
82bde4fb
         with suppress(KeyError):
4e7ec165
             active_plotting_tasks[name].cancel()
 
8f6732f8
     active_plotting_tasks[name] = runner.ioloop.create_task(updater())
4e7ec165
     cancel_button.on_click(cancel)
 
8f6732f8
     display(cancel_button)
73fa8d28
     return dm
b00659d0
 
 
243561d6
 def should_update(status):
     try:
         # Get the length of the write buffer size
         buffer_size = len(status.comm.kernel.iopub_thread._events)
 
         # Make sure to only keep all the messages when the notebook
         # is viewed, this means 'buffer_size == 1'. However, when not
         # viewing the notebook the buffer fills up. When this happens
         # we decide to only add messages to it when a certain probability.
         # i.e. we're offline for 12h, with an update_interval of 0.5s,
         # and without the reduced probability, we have buffer_size=86400.
         # With the correction this is np.log(86400) / np.log(1.1) = 119.2
716dbce8
         return 1.1 ** buffer_size * random.random() < 1
243561d6
     except Exception:
         # We catch any Exception because we are using a private API.
         return True
 
 
b00659d0
 def live_info(runner, *, update_interval=0.5):
     """Display live information about the runner.
 
     Returns an interactive ipywidget that can be
     visualized in a Jupyter notebook.
     """
b563e096
     if not _holoviews_enabled:
716dbce8
         raise RuntimeError(
             "Live plotting is not enabled; did you run "
             "'adaptive.notebook_extension()'?"
         )
f1e5e354
 
4cdd355a
     import ipywidgets
b00659d0
     from IPython.display import display
 
4cdd355a
     status = ipywidgets.HTML(value=_info_html(runner))
b00659d0
 
716dbce8
     cancel = ipywidgets.Button(
         description="cancel runner", layout=ipywidgets.Layout(width="100px")
     )
b00659d0
     cancel.on_click(lambda _: runner.cancel())
 
     async def update():
         while not runner.task.done():
             await asyncio.sleep(update_interval)
243561d6
 
             if should_update(status):
                 status.value = _info_html(runner)
             else:
                 await asyncio.sleep(0.05)
 
b00659d0
         status.value = _info_html(runner)
716dbce8
         cancel.layout.display = "none"
b00659d0
 
     runner.ioloop.create_task(update())
 
95293981
     display(ipywidgets.VBox((status, cancel)))
 
 
 def _table_row(i, key, value):
     """Style the rows of a table. Based on the default Jupyterlab table style."""
dfeb5ef4
     style = "text-align: right; padding: 0.5em 0.5em; line-height: 1.0;"
     if i % 2 == 1:
         style += " background: var(--md-grey-100);"
     return f'<tr><th style="{style}">{key}</th><th style="{style}">{value}</th></tr>'
b00659d0
 
 
 def _info_html(runner):
874d89cb
     status = runner.status()
 
716dbce8
     color = {
         "cancelled": "orange",
         "failed": "red",
         "running": "blue",
         "finished": "green",
     }[status]
874d89cb
 
00154106
     overhead = runner.overhead()
     red_level = max(0, min(int(255 * overhead / 100), 255))
     overhead_color = "#{:02x}{:02x}{:02x}".format(red_level, 255 - red_level, 0)
 
b00659d0
     info = [
716dbce8
         ("status", f'<font color="{color}">{status}</font>'),
         ("elapsed time", datetime.timedelta(seconds=runner.elapsed_time())),
00154106
         ("overhead", f'<font color="{overhead_color}">{overhead:.2f}%</font>'),
b00659d0
     ]
 
82bde4fb
     with suppress(Exception):
716dbce8
         info.append(("# of points", runner.learner.npoints))
b00659d0
 
3e56870a
     with suppress(Exception):
716dbce8
         info.append(("latest loss", f'{runner.learner._cache["loss"]:.3f}'))
3e56870a
 
95293981
     table = "\n".join(_table_row(i, k, v) for i, (k, v) in enumerate(info))
b00659d0
 
716dbce8
     return f"""
95293981
         <table>
b00659d0
         {table}
95293981
         </table>
716dbce8
     """