...
|
...
|
@@ -86,6 +86,8 @@ def test_serialization_for(learner_type, learner_kwargs, serializer):
|
86
|
86
|
|
87
|
87
|
simple(learner, goal_1)
|
88
|
88
|
learner_bytes = serializer.dumps(learner)
|
|
89
|
+ loss = learner.loss()
|
|
90
|
+ asked = learner.ask(1)
|
89
|
91
|
|
90
|
92
|
if serializer is not pickle:
|
91
|
93
|
# With pickle the functions are only pickled by reference
|
...
|
...
|
@@ -94,6 +96,15 @@ def test_serialization_for(learner_type, learner_kwargs, serializer):
|
94
|
96
|
|
95
|
97
|
learner_loaded = serializer.loads(learner_bytes)
|
96
|
98
|
assert learner_loaded.npoints >= 10
|
|
99
|
+ assert loss == learner_loaded.loss()
|
|
100
|
+
|
|
101
|
+ if learner_type is not Learner2D:
|
|
102
|
+ # cannot test this for Learner2D because
|
|
103
|
+ # xfailing test_point_adding_order_is_irrelevant
|
|
104
|
+ assert asked == learner_loaded.ask(1)
|
|
105
|
+ # load again to undo the ask
|
|
106
|
+ learner_loaded = serializer.loads(learner_bytes)
|
|
107
|
+
|
97
|
108
|
simple(learner_loaded, goal_2)
|
98
|
109
|
assert learner_loaded.npoints >= 20
|
99
|
110
|
|