73fa8d28 |
import asyncio
|
b00659d0 |
import datetime
|
06b9200f |
import importlib
|
243561d6 |
import random
|
9fc92402 |
import warnings
|
176c745c |
from contextlib import suppress
|
73fa8d28 |
|
f1e5e354 |
_async_enabled = False
|
b563e096 |
_holoviews_enabled = False
_ipywidgets_enabled = False
_plotly_enabled = False
|
73fa8d28 |
|
9fc92402 |
|
6b170238 |
def notebook_extension(*, _inline_js=True):
|
b563e096 |
"""Enable ipywidgets, holoviews, and asyncio notebook integration."""
|
f1e5e354 |
if not in_ipynb():
|
716dbce8 |
raise RuntimeError(
'"adaptive.notebook_extension()" may only be run '
"from a Jupyter notebook."
)
|
9fc92402 |
|
b563e096 |
global _async_enabled, _holoviews_enabled, _ipywidgets_enabled
# Load holoviews
try:
|
008d47e1 |
_holoviews_enabled = False # After closing a notebook the js is gone
|
b563e096 |
if not _holoviews_enabled:
import holoviews
|
716dbce8 |
holoviews.notebook_extension("bokeh", logo=False, inline=_inline_js)
|
b563e096 |
_holoviews_enabled = True
except ModuleNotFoundError:
|
716dbce8 |
warnings.warn(
|
19c1e0e6 |
"holoviews is not installed; plotting is disabled.", RuntimeWarning
|
716dbce8 |
)
|
ca1b4b99 |
|
b563e096 |
# Load ipywidgets
|
f1e5e354 |
try:
|
b563e096 |
if not _ipywidgets_enabled:
|
716dbce8 |
import ipywidgets # noqa: F401
|
b563e096 |
_ipywidgets_enabled = True
|
f1e5e354 |
except ModuleNotFoundError:
|
716dbce8 |
warnings.warn(
|
19c1e0e6 |
"ipywidgets is not installed; live_info is disabled.", RuntimeWarning
|
716dbce8 |
)
|
9fc92402 |
|
b563e096 |
# Enable asyncio integration
if not _async_enabled:
|
716dbce8 |
get_ipython().magic("gui asyncio") # noqa: F821
|
b563e096 |
_async_enabled = True
|
73fa8d28 |
|
416ee3af |
def ensure_holoviews():
try:
|
716dbce8 |
return importlib.import_module("holoviews")
|
6b5cd171 |
except ModuleNotFoundError:
|
716dbce8 |
raise RuntimeError("holoviews is not installed; plotting is disabled.")
|
416ee3af |
|
b563e096 |
def ensure_plotly():
global _plotly_enabled
try:
import plotly
|
716dbce8 |
|
b563e096 |
if not _plotly_enabled:
import plotly.graph_objs
import plotly.figure_factory
import plotly.offline
|
716dbce8 |
|
b563e096 |
# This injects javascript and should happen only once
plotly.offline.init_notebook_mode()
_plotly_enabled = True
return plotly
except ModuleNotFoundError:
|
716dbce8 |
raise RuntimeError("plotly is not installed; plotting is disabled.")
|
b563e096 |
|
f1e5e354 |
def in_ipynb():
|
9fc92402 |
try:
|
f1e5e354 |
# If we are running in IPython, then `get_ipython()` is always a global
|
716dbce8 |
return get_ipython().__class__.__name__ == "ZMQInteractiveShell"
|
f1e5e354 |
except NameError:
return False
|
73fa8d28 |
|
f1e5e354 |
# Fancy displays in the Jupyter notebook
|
73fa8d28 |
|
c1dc9bbe |
active_plotting_tasks = dict()
|
930b4efb |
|
716dbce8 |
def live_plot(runner, *, plotter=None, update_interval=2, name=None, normalize=True):
|
ae73ef30 |
"""Live plotting of the learner's data.
Parameters
----------
|
f268f0dc |
runner : `~adaptive.Runner`
|
ae73ef30 |
plotter : function
A function that takes the learner as a argument and returns a
|
eef11f0b |
holoviews object. By default ``learner.plot()`` will be called.
|
ae73ef30 |
update_interval : int
Number of second between the updates of the plot.
name : hasable
Name for the `live_plot` task in `adaptive.active_plotting_tasks`.
|
eef11f0b |
By default the name is None and if another task with the same name
already exists that other `live_plot` is canceled.
|
017a5722 |
normalize : bool
Normalize (scale to fit) the frame upon each update.
|
ae73ef30 |
Returns
-------
|
eef11f0b |
dm : `holoviews.core.DynamicMap`
The plot that automatically updates every `update_interval`.
|
ae73ef30 |
"""
|
b563e096 |
if not _holoviews_enabled:
|
017a5722 |
raise RuntimeError(
"Live plotting is not enabled; did you run "
"'adaptive.notebook_extension()'?"
)
|
f1e5e354 |
import holoviews as hv
|
4e7ec165 |
import ipywidgets
from IPython.display import display
|
73fa8d28 |
|
8f6732f8 |
if name in active_plotting_tasks:
active_plotting_tasks[name].cancel()
|
73fa8d28 |
def plot_generator():
while True:
if not plotter:
yield runner.learner.plot()
else:
yield plotter(runner.learner)
|
112e89d2 |
streams = [hv.streams.Stream.define("Next")()]
|
017a5722 |
dm = hv.DynamicMap(plot_generator(), streams=streams)
|
d79512f8 |
dm.cache_size = 1
|
017a5722 |
if normalize:
|
1debc154 |
# XXX: change when https://github.com/pyviz/holoviews/issues/3637
# is fixed.
|
017a5722 |
dm = dm.map(lambda obj: obj.opts(framewise=True), hv.Element)
|
716dbce8 |
cancel_button = ipywidgets.Button(
description="cancel live-plot", layout=ipywidgets.Layout(width="150px")
)
|
930b4efb |
|
73fa8d28 |
# Could have used dm.periodic in the following, but this would either spin
# off a thread (and learner is not threadsafe) or block the kernel.
async def updater():
|
716dbce8 |
event = lambda: hv.streams.Stream.trigger( # noqa: E731
dm.streams
) # XXX: used to be dm.event()
|
3f7c1384 |
# see https://github.com/pyviz/holoviews/issues/3564
|
930b4efb |
try:
while not runner.task.done():
|
3f7c1384 |
event()
|
930b4efb |
await asyncio.sleep(update_interval)
|
3f7c1384 |
event() # fire off one last update before we die
|
930b4efb |
finally:
|
c1dc9bbe |
if active_plotting_tasks[name] is asyncio.Task.current_task():
active_plotting_tasks.pop(name, None)
|
716dbce8 |
cancel_button.layout.display = "none" # remove cancel button
|
fd606522 |
|
4e7ec165 |
def cancel(_):
|
82bde4fb |
with suppress(KeyError):
|
4e7ec165 |
active_plotting_tasks[name].cancel()
|
8f6732f8 |
active_plotting_tasks[name] = runner.ioloop.create_task(updater())
|
4e7ec165 |
cancel_button.on_click(cancel)
|
8f6732f8 |
display(cancel_button)
|
73fa8d28 |
return dm
|
b00659d0 |
|
243561d6 |
def should_update(status):
try:
# Get the length of the write buffer size
buffer_size = len(status.comm.kernel.iopub_thread._events)
# Make sure to only keep all the messages when the notebook
# is viewed, this means 'buffer_size == 1'. However, when not
# viewing the notebook the buffer fills up. When this happens
# we decide to only add messages to it when a certain probability.
# i.e. we're offline for 12h, with an update_interval of 0.5s,
# and without the reduced probability, we have buffer_size=86400.
# With the correction this is np.log(86400) / np.log(1.1) = 119.2
|
716dbce8 |
return 1.1 ** buffer_size * random.random() < 1
|
243561d6 |
except Exception:
# We catch any Exception because we are using a private API.
return True
|
b00659d0 |
def live_info(runner, *, update_interval=0.5):
"""Display live information about the runner.
Returns an interactive ipywidget that can be
visualized in a Jupyter notebook.
"""
|
b563e096 |
if not _holoviews_enabled:
|
716dbce8 |
raise RuntimeError(
"Live plotting is not enabled; did you run "
"'adaptive.notebook_extension()'?"
)
|
f1e5e354 |
|
4cdd355a |
import ipywidgets
|
b00659d0 |
from IPython.display import display
|
4cdd355a |
status = ipywidgets.HTML(value=_info_html(runner))
|
b00659d0 |
|
716dbce8 |
cancel = ipywidgets.Button(
description="cancel runner", layout=ipywidgets.Layout(width="100px")
)
|
b00659d0 |
cancel.on_click(lambda _: runner.cancel())
async def update():
while not runner.task.done():
await asyncio.sleep(update_interval)
|
243561d6 |
if should_update(status):
status.value = _info_html(runner)
else:
await asyncio.sleep(0.05)
|
b00659d0 |
status.value = _info_html(runner)
|
716dbce8 |
cancel.layout.display = "none"
|
b00659d0 |
runner.ioloop.create_task(update())
|
95293981 |
display(ipywidgets.VBox((status, cancel)))
def _table_row(i, key, value):
"""Style the rows of a table. Based on the default Jupyterlab table style."""
|
dfeb5ef4 |
style = "text-align: right; padding: 0.5em 0.5em; line-height: 1.0;"
if i % 2 == 1:
style += " background: var(--md-grey-100);"
return f'<tr><th style="{style}">{key}</th><th style="{style}">{value}</th></tr>'
|
b00659d0 |
def _info_html(runner):
|
874d89cb |
status = runner.status()
|
716dbce8 |
color = {
"cancelled": "orange",
"failed": "red",
"running": "blue",
"finished": "green",
}[status]
|
874d89cb |
|
00154106 |
overhead = runner.overhead()
red_level = max(0, min(int(255 * overhead / 100), 255))
overhead_color = "#{:02x}{:02x}{:02x}".format(red_level, 255 - red_level, 0)
|
b00659d0 |
info = [
|
716dbce8 |
("status", f'<font color="{color}">{status}</font>'),
("elapsed time", datetime.timedelta(seconds=runner.elapsed_time())),
|
00154106 |
("overhead", f'<font color="{overhead_color}">{overhead:.2f}%</font>'),
|
b00659d0 |
]
|
82bde4fb |
with suppress(Exception):
|
716dbce8 |
info.append(("# of points", runner.learner.npoints))
|
b00659d0 |
|
3e56870a |
with suppress(Exception):
|
716dbce8 |
info.append(("latest loss", f'{runner.learner._cache["loss"]:.3f}'))
|
3e56870a |
|
95293981 |
table = "\n".join(_table_row(i, k, v) for i, (k, v) in enumerate(info))
|
b00659d0 |
|
716dbce8 |
return f"""
|
95293981 |
<table>
|
b00659d0 |
{table}
|
95293981 |
</table>
|
716dbce8 |
"""
|