import operator import pickle import random import flaky import pytest from adaptive.learner import ( AverageLearner, BalancingLearner, DataSaver, IntegratorLearner, Learner1D, Learner2D, LearnerND, SequenceLearner, ) from adaptive.runner import simple try: import cloudpickle with_cloudpickle = True except ModuleNotFoundError: with_cloudpickle = False try: import dill with_dill = True except ModuleNotFoundError: with_dill = False def goal_1(learner): return learner.npoints == 10 def goal_2(learner): return learner.npoints == 20 learners_pairs = [ (Learner1D, dict(bounds=(-1, 1))), (Learner2D, dict(bounds=[(-1, 1), (-1, 1)])), (LearnerND, dict(bounds=[(-1, 1), (-1, 1), (-1, 1)])), (SequenceLearner, dict(sequence=list(range(100)))), (IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)), (AverageLearner, dict(atol=0.1)), ] serializers = [pickle] if with_cloudpickle: serializers.append(cloudpickle) if with_dill: serializers.append(dill) learners = [ (learner_type, learner_kwargs, serializer) for serializer in serializers for learner_type, learner_kwargs in learners_pairs ] def f_for_pickle(x): return 1 def f_for_pickle_datasaver(x): return dict(x=x, y=x) @flaky.flaky(max_runs=3) @pytest.mark.parametrize( "learner_type, learner_kwargs, serializer", learners, ) def test_serialization_for(learner_type, learner_kwargs, serializer): """Test serializing a learner using different serializers.""" def f(x): return random.random() if serializer is pickle: # f from the local scope cannot be pickled f = f_for_pickle # noqa: F811 learner = learner_type(f, **learner_kwargs) if learner_type is Learner1D: learner._recompute_losses_factor = 1 simple(learner, goal_1) learner_bytes = serializer.dumps(learner) loss = learner.loss() asked = learner.ask(1) if serializer is not pickle: # With pickle the functions are only pickled by reference del f del learner learner_loaded = serializer.loads(learner_bytes) assert learner_loaded.npoints == 10 assert loss == learner_loaded.loss() if learner_type is not Learner2D: # cannot test this for Learner2D because # xfailing test_point_adding_order_is_irrelevant assert asked == learner_loaded.ask(1) # load again to undo the ask learner_loaded = serializer.loads(learner_bytes) simple(learner_loaded, goal_2) assert learner_loaded.npoints == 20 @pytest.mark.parametrize( "serializer", serializers, ) def test_serialization_for_datasaver(serializer): def f(x): return dict(x=1, y=x ** 2) if serializer is pickle: # f from the local scope cannot be pickled f = f_for_pickle_datasaver # noqa: F811 _learner = Learner1D(f, bounds=(-1, 1)) learner = DataSaver(_learner, arg_picker=operator.itemgetter("y")) simple(learner, goal_1) learner_bytes = serializer.dumps(learner) if serializer is not pickle: # With pickle the functions are only pickled by reference del f del _learner del learner learner_loaded = serializer.loads(learner_bytes) assert learner_loaded.npoints >= 10 simple(learner_loaded, goal_2) assert learner_loaded.npoints >= 20 @pytest.mark.parametrize( "serializer", serializers, ) def test_serialization_for_balancing_learner(serializer): def f(x): return x ** 2 if serializer is pickle: # f from the local scope cannot be pickled f = f_for_pickle # noqa: F811 learner_1 = Learner1D(f, bounds=(-1, 1)) learner_2 = Learner1D(f, bounds=(-2, 2)) learner = BalancingLearner([learner_1, learner_2]) simple(learner, goal_1) learner_bytes = serializer.dumps(learner) if serializer is not pickle: # With pickle the functions are only pickled by reference del f del learner_1 del learner_2 del learner learner_loaded = serializer.loads(learner_bytes) assert learner_loaded.npoints >= 10 simple(learner_loaded, goal_2) assert learner_loaded.npoints >= 20