3b2ae110 |
import abc
|
a17c9212 |
import functools
import gzip
import os
import pickle
|
176c745c |
from contextlib import contextmanager
from itertools import product
|
e1100a29 |
|
66404722 |
from atomicwrites import AtomicWriter
|
e1100a29 |
|
d2791d69 |
def named_product(**items):
names = items.keys()
vals = items.values()
return [dict(zip(names, res)) for res in product(*vals)]
|
5261c860 |
|
71c0fc17 |
|
5261c860 |
@contextmanager
def restore(*learners):
states = [learner.__getstate__() for learner in learners]
try:
yield
finally:
for state, learner in zip(states, learners):
learner.__setstate__(state)
|
6d5cc14e |
def cache_latest(f):
"""Cache the latest return value of the function and add it
as 'self._cache[f.__name__]'."""
|
716dbce8 |
|
a17c9212 |
@functools.wraps(f)
|
6d5cc14e |
def wrapper(*args, **kwargs):
self = args[0]
|
716dbce8 |
if not hasattr(self, "_cache"):
|
6d5cc14e |
self._cache = {}
self._cache[f.__name__] = f(*args, **kwargs)
return self._cache[f.__name__]
|
716dbce8 |
|
6d5cc14e |
return wrapper
|
a17c9212 |
def save(fname, data, compress=True):
fname = os.path.expanduser(fname)
dirname = os.path.dirname(fname)
if dirname:
os.makedirs(dirname, exist_ok=True)
|
66404722 |
blob = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
if compress:
blob = gzip.compress(blob)
with AtomicWriter(fname, "wb", overwrite=True).open() as f:
f.write(blob)
|
a17c9212 |
def load(fname, compress=True):
fname = os.path.expanduser(fname)
_open = gzip.open if compress else open
|
716dbce8 |
with _open(fname, "rb") as f:
|
a17c9212 |
return pickle.load(f)
def copy_docstring_from(other):
def decorator(method):
return functools.wraps(other)(method)
|
716dbce8 |
|
a17c9212 |
return decorator
|
3b2ae110 |
class _RequireAttrsABCMeta(abc.ABCMeta):
def __call__(self, *args, **kwargs):
obj = super().__call__(*args, **kwargs)
|
c5970fba |
for name, type_ in obj.__annotations__.items():
|
2b89d7bf |
try:
x = getattr(obj, name)
|
8681b7fe |
except AttributeError:
|
b431fe3e |
raise AttributeError(
f"Required attribute {name} not set in __init__."
) from None
else:
if not isinstance(x, type_):
msg = f"The attribute '{name}' should be of type {type_}, not {type(x)}."
raise TypeError(msg)
|
3b2ae110 |
return obj
|