import random import cloudpickle import pytest from adaptive.learner import ( AverageLearner, BalancingLearner, DataSaver, IntegratorLearner, Learner1D, Learner2D, LearnerND, SequenceLearner, ) from adaptive.runner import simple def goal_1(learner): return learner.npoints >= 10 def goal_2(learner): return learner.npoints >= 20 @pytest.mark.parametrize( "learner_type, learner_kwargs", [ (Learner1D, dict(bounds=(-1, 1))), (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])), (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])), (SequenceLearner, dict(sequence=list(range(100)))), (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)), (AverageLearner, dict(atol=0.1)), ], ) def test_cloudpickle_for(learner_type, learner_kwargs): """Test serializing a learner using cloudpickle. We use cloudpickle because with pickle the functions are only pickled by reference.""" def f(x): return random.random() learner = learner_type(f, **learner_kwargs) simple(learner, goal_1) learner_bytes = cloudpickle.dumps(learner) # Delete references del f del learner learner_loaded = cloudpickle.loads(learner_bytes) assert learner_loaded.npoints >= 10 simple(learner_loaded, goal_2) assert learner_loaded.npoints >= 20 def test_cloudpickle_for_datasaver(): def f(x): return dict(x=1, y=x ** 2) _learner = Learner1D(f, bounds=(-1, 1)) learner = DataSaver(_learner, arg_picker=lambda x: x["y"]) simple(learner, goal_1) learner_bytes = cloudpickle.dumps(learner) # Delete references del f del _learner del learner learner_loaded = cloudpickle.loads(learner_bytes) assert learner_loaded.npoints >= 10 simple(learner_loaded, goal_2) assert learner_loaded.npoints >= 20 def test_cloudpickle_for_balancing_learner(): def f(x): return x ** 2 learner_1 = Learner1D(f, bounds=(-1, 1)) learner_2 = Learner1D(f, bounds=(-2, 2)) learner = BalancingLearner([learner_1, learner_2]) simple(learner, goal_1) learner_bytes = cloudpickle.dumps(learner) # Delete references del f del learner_1 del learner_2 del learner learner_loaded = cloudpickle.loads(learner_bytes) assert learner_loaded.npoints >= 10 simple(learner_loaded, goal_2) assert learner_loaded.npoints >= 20