import abc import functools import gzip import os import pickle from contextlib import contextmanager from itertools import product from atomicwrites import AtomicWriter def named_product(**items): names = items.keys() vals = items.values() return [dict(zip(names, res)) for res in product(*vals)] @contextmanager def restore(*learners): states = [learner.__getstate__() for learner in learners] try: yield finally: for state, learner in zip(states, learners): learner.__setstate__(state) def cache_latest(f): """Cache the latest return value of the function and add it as 'self._cache[f.__name__]'.""" @functools.wraps(f) def wrapper(*args, **kwargs): self = args[0] if not hasattr(self, "_cache"): self._cache = {} self._cache[f.__name__] = f(*args, **kwargs) return self._cache[f.__name__] return wrapper def save(fname, data, compress=True): fname = os.path.expanduser(fname) dirname = os.path.dirname(fname) if dirname: os.makedirs(dirname, exist_ok=True) blob = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) if compress: blob = gzip.compress(blob) with AtomicWriter(fname, "wb", overwrite=True).open() as f: f.write(blob) def load(fname, compress=True): fname = os.path.expanduser(fname) _open = gzip.open if compress else open with _open(fname, "rb") as f: return pickle.load(f) def copy_docstring_from(other): def decorator(method): return functools.wraps(other)(method) return decorator class _RequireAttrsABCMeta(abc.ABCMeta): def __call__(self, *args, **kwargs): obj = super().__call__(*args, **kwargs) for name, type_ in obj.__annotations__.items(): try: x = getattr(obj, name) except AttributeError: raise AttributeError( f"Required attribute {name} not set in __init__." ) from None else: if not isinstance(x, type_): msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}." raise TypeError(msg) return obj