adaptive/utils.py
3b2ae110
 import abc
a17c9212
 import functools
 import gzip
 import os
 import pickle
176c745c
 from contextlib import contextmanager
 from itertools import product
e1100a29
 
66404722
 from atomicwrites import AtomicWriter
 
e1100a29
 
d2791d69
 def named_product(**items):
     names = items.keys()
     vals = items.values()
     return [dict(zip(names, res)) for res in product(*vals)]
5261c860
 
71c0fc17
 
5261c860
 @contextmanager
 def restore(*learners):
     states = [learner.__getstate__() for learner in learners]
     try:
         yield
     finally:
         for state, learner in zip(states, learners):
             learner.__setstate__(state)
6d5cc14e
 
 
 def cache_latest(f):
     """Cache the latest return value of the function and add it
     as 'self._cache[f.__name__]'."""
716dbce8
 
a17c9212
     @functools.wraps(f)
6d5cc14e
     def wrapper(*args, **kwargs):
         self = args[0]
716dbce8
         if not hasattr(self, "_cache"):
6d5cc14e
             self._cache = {}
         self._cache[f.__name__] = f(*args, **kwargs)
         return self._cache[f.__name__]
716dbce8
 
6d5cc14e
     return wrapper
a17c9212
 
 
 def save(fname, data, compress=True):
     fname = os.path.expanduser(fname)
     dirname = os.path.dirname(fname)
     if dirname:
         os.makedirs(dirname, exist_ok=True)
66404722
 
     blob = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
     if compress:
         blob = gzip.compress(blob)
 
     with AtomicWriter(fname, "wb", overwrite=True).open() as f:
         f.write(blob)
a17c9212
 
 
 def load(fname, compress=True):
     fname = os.path.expanduser(fname)
     _open = gzip.open if compress else open
716dbce8
     with _open(fname, "rb") as f:
a17c9212
         return pickle.load(f)
 
 
 def copy_docstring_from(other):
     def decorator(method):
         return functools.wraps(other)(method)
716dbce8
 
a17c9212
     return decorator
3b2ae110
 
 
 class _RequireAttrsABCMeta(abc.ABCMeta):
     def __call__(self, *args, **kwargs):
         obj = super().__call__(*args, **kwargs)
c5970fba
         for name, type_ in obj.__annotations__.items():
2b89d7bf
             try:
                 x = getattr(obj, name)
8681b7fe
             except AttributeError:
b431fe3e
                 raise AttributeError(
                     f"Required attribute {name} not set in __init__."
                 ) from None
             else:
                 if not isinstance(x, type_):
                     msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
                     raise TypeError(msg)
3b2ae110
         return obj