import asyncio import datetime import importlib import random import warnings from contextlib import suppress _async_enabled = False _holoviews_enabled = False _ipywidgets_enabled = False _plotly_enabled = False def notebook_extension(*, _inline_js=True): """Enable ipywidgets, holoviews, and asyncio notebook integration.""" if not in_ipynb(): raise RuntimeError( '"adaptive.notebook_extension()" may only be run ' "from a Jupyter notebook." ) global _async_enabled, _holoviews_enabled, _ipywidgets_enabled # Load holoviews try: _holoviews_enabled = False # After closing a notebook the js is gone if not _holoviews_enabled: import holoviews holoviews.notebook_extension("bokeh", logo=False, inline=_inline_js) _holoviews_enabled = True except ModuleNotFoundError: warnings.warn( "holoviews is not installed; plotting is disabled.", RuntimeWarning ) # Load ipywidgets try: if not _ipywidgets_enabled: import ipywidgets # noqa: F401 _ipywidgets_enabled = True except ModuleNotFoundError: warnings.warn( "ipywidgets is not installed; live_info is disabled.", RuntimeWarning ) # Enable asyncio integration if not _async_enabled: get_ipython().magic("gui asyncio") # noqa: F821 _async_enabled = True def ensure_holoviews(): try: return importlib.import_module("holoviews") except ModuleNotFoundError: raise RuntimeError("holoviews is not installed; plotting is disabled.") def ensure_plotly(): global _plotly_enabled try: import plotly if not _plotly_enabled: import plotly.graph_objs import plotly.figure_factory import plotly.offline # This injects javascript and should happen only once plotly.offline.init_notebook_mode() _plotly_enabled = True return plotly except ModuleNotFoundError: raise RuntimeError("plotly is not installed; plotting is disabled.") def in_ipynb(): try: # If we are running in IPython, then `get_ipython()` is always a global return get_ipython().__class__.__name__ == "ZMQInteractiveShell" except NameError: return False # Fancy displays in the Jupyter notebook active_plotting_tasks = dict() def live_plot(runner, *, plotter=None, update_interval=2, name=None, normalize=True): """Live plotting of the learner's data. Parameters ---------- runner : `~adaptive.Runner` plotter : function A function that takes the learner as a argument and returns a holoviews object. By default ``learner.plot()`` will be called. update_interval : int Number of second between the updates of the plot. name : hasable Name for the `live_plot` task in `adaptive.active_plotting_tasks`. By default the name is None and if another task with the same name already exists that other `live_plot` is canceled. normalize : bool Normalize (scale to fit) the frame upon each update. Returns ------- dm : `holoviews.core.DynamicMap` The plot that automatically updates every `update_interval`. """ if not _holoviews_enabled: raise RuntimeError( "Live plotting is not enabled; did you run " "'adaptive.notebook_extension()'?" ) import holoviews as hv import ipywidgets from IPython.display import display if name in active_plotting_tasks: active_plotting_tasks[name].cancel() def plot_generator(): while True: if not plotter: yield runner.learner.plot() else: yield plotter(runner.learner) streams = [hv.streams.Stream.define("Next")()] dm = hv.DynamicMap(plot_generator(), streams=streams) dm.cache_size = 1 if normalize: # XXX: change when https://github.com/pyviz/holoviews/issues/3637 # is fixed. dm = dm.map(lambda obj: obj.opts(framewise=True), hv.Element) cancel_button = ipywidgets.Button( description="cancel live-plot", layout=ipywidgets.Layout(width="150px") ) # Could have used dm.periodic in the following, but this would either spin # off a thread (and learner is not threadsafe) or block the kernel. async def updater(): event = lambda: hv.streams.Stream.trigger( # noqa: E731 dm.streams ) # XXX: used to be dm.event() # see https://github.com/pyviz/holoviews/issues/3564 try: while not runner.task.done(): event() await asyncio.sleep(update_interval) event() # fire off one last update before we die finally: if active_plotting_tasks[name] is asyncio.Task.current_task(): active_plotting_tasks.pop(name, None) cancel_button.layout.display = "none" # remove cancel button def cancel(_): with suppress(KeyError): active_plotting_tasks[name].cancel() active_plotting_tasks[name] = runner.ioloop.create_task(updater()) cancel_button.on_click(cancel) display(cancel_button) return dm def should_update(status): try: # Get the length of the write buffer size buffer_size = len(status.comm.kernel.iopub_thread._events) # Make sure to only keep all the messages when the notebook # is viewed, this means 'buffer_size == 1'. However, when not # viewing the notebook the buffer fills up. When this happens # we decide to only add messages to it when a certain probability. # i.e. we're offline for 12h, with an update_interval of 0.5s, # and without the reduced probability, we have buffer_size=86400. # With the correction this is np.log(86400) / np.log(1.1) = 119.2 return 1.1 ** buffer_size * random.random() < 1 except Exception: # We catch any Exception because we are using a private API. return True def live_info(runner, *, update_interval=0.5): """Display live information about the runner. Returns an interactive ipywidget that can be visualized in a Jupyter notebook. """ if not _holoviews_enabled: raise RuntimeError( "Live plotting is not enabled; did you run " "'adaptive.notebook_extension()'?" ) import ipywidgets from IPython.display import display status = ipywidgets.HTML(value=_info_html(runner)) cancel = ipywidgets.Button( description="cancel runner", layout=ipywidgets.Layout(width="100px") ) cancel.on_click(lambda _: runner.cancel()) async def update(): while not runner.task.done(): await asyncio.sleep(update_interval) if should_update(status): status.value = _info_html(runner) else: await asyncio.sleep(0.05) status.value = _info_html(runner) cancel.layout.display = "none" runner.ioloop.create_task(update()) display(ipywidgets.VBox((status, cancel))) def _table_row(i, key, value): """Style the rows of a table. Based on the default Jupyterlab table style.""" style = "text-align: right; padding: 0.5em 0.5em; line-height: 1.0;" if i % 2 == 1: style += " background: var(--md-grey-100);" return f'<tr><th style="{style}">{key}</th><th style="{style}">{value}</th></tr>' def _info_html(runner): status = runner.status() color = { "cancelled": "orange", "failed": "red", "running": "blue", "finished": "green", }[status] overhead = runner.overhead() red_level = max(0, min(int(255 * overhead / 100), 255)) overhead_color = "#{:02x}{:02x}{:02x}".format(red_level, 255 - red_level, 0) info = [ ("status", f'<font color="{color}">{status}</font>'), ("elapsed time", datetime.timedelta(seconds=runner.elapsed_time())), ("overhead", f'<font color="{overhead_color}">{overhead:.2f}%</font>'), ] with suppress(Exception): info.append(("# of points", runner.learner.npoints)) with suppress(Exception): info.append(("latest loss", f'{runner.learner._cache["loss"]:.3f}')) table = "\n".join(_table_row(i, k, v) for i, (k, v) in enumerate(info)) return f""" <table> {table} </table> """