Browse code

Merge branch 'fixup/parameter-errors' into 'master'

add a test for invalid value functions

See merge request kwant/kwant!347

Joseph Weston authored on 10/12/2019 17:04:18
Showing 2 changed files
... ...
@@ -2509,7 +2509,7 @@ def _make_onsite_terms(builder, sites, site_offsets, term_offset):
2509 2509
     _onsite_term_by_site_id = np.array(_onsite_term_by_site_id)
2510 2510
 
2511 2511
     return (onsite_subgraphs, onsite_terms, onsite_term_values,
2512
-            onsite_term_parameters, onsite_term_errors, _onsite_term_by_site_id)
2512
+            onsite_term_errors, _onsite_term_by_site_id)
2513 2513
 
2514 2514
 
2515 2515
 def _make_hopping_terms(builder, graph, sites, site_offsets, cell_size, term_offset):
... ...
@@ -2619,8 +2619,7 @@ def _make_hopping_terms(builder, graph, sites, site_offsets, cell_size, term_off
2619 2619
     _hopping_term_by_edge_id = np.array(_hopping_term_by_edge_id)
2620 2620
 
2621 2621
     return (hopping_subgraphs, hopping_terms, hopping_term_values,
2622
-            hopping_term_parameters, hopping_term_errors,
2623
-            _hopping_term_by_edge_id)
2622
+            hopping_term_errors, _hopping_term_by_edge_id)
2624 2623
 
2625 2624
 
2626 2625
 class FiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.FiniteVectorizedSystem):
... ...
@@ -2667,12 +2666,11 @@ class FiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.FiniteVect
2667 2666
         site_offsets = np.cumsum([0] + [len(arr) for arr in site_arrays])
2668 2667
 
2669 2668
         (onsite_subgraphs, onsite_terms, onsite_term_values,
2670
-         onsite_term_parameters, onsite_term_errors, _onsite_term_by_site_id) =\
2669
+         onsite_term_errors, _onsite_term_by_site_id) =\
2671 2670
             _make_onsite_terms(builder, sites, site_offsets, term_offset=0)
2672 2671
 
2673 2672
         (hopping_subgraphs, hopping_terms, hopping_term_values,
2674
-         hopping_term_parameters, hopping_term_errors,
2675
-         _hopping_term_by_edge_id) =\
2673
+         hopping_term_errors, _hopping_term_by_edge_id) =\
2676 2674
             _make_hopping_terms(builder, graph, sites, site_offsets,
2677 2675
                                 len(sites), term_offset=len(onsite_terms))
2678 2676
 
... ...
@@ -2684,9 +2682,9 @@ class FiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.FiniteVect
2684 2682
 
2685 2683
         # Construct system parameters
2686 2684
         parameters = set()
2687
-        for params in chain(onsite_term_parameters, hopping_term_parameters):
2688
-            if params is not None:
2689
-                parameters.update(params)
2685
+        for term in terms:
2686
+            if term.parameters is not None:
2687
+                parameters.update(term.parameters)
2690 2688
         parameters = frozenset(parameters)
2691 2689
 
2692 2690
         self.site_arrays = site_arrays
... ...
@@ -2979,12 +2977,11 @@ class InfiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.Infinite
2979 2977
         site_offsets = np.cumsum([0] + [len(arr) for arr in site_arrays])
2980 2978
 
2981 2979
         (onsite_subgraphs, onsite_terms, onsite_term_values,
2982
-         onsite_term_parameters, onsite_term_errors, _onsite_term_by_site_id) =\
2980
+         onsite_term_errors, _onsite_term_by_site_id) =\
2983 2981
             _make_onsite_terms(builder, sites, site_offsets, term_offset=0)
2984 2982
 
2985 2983
         (hopping_subgraphs, hopping_terms, hopping_term_values,
2986
-         hopping_term_parameters, hopping_term_errors,
2987
-         _hopping_term_by_edge_id) =\
2984
+         hopping_term_errors, _hopping_term_by_edge_id) =\
2988 2985
             _make_hopping_terms(builder, graph, sites, site_offsets,
2989 2986
                                 cell_size, term_offset=len(onsite_terms))
2990 2987
 
... ...
@@ -2996,9 +2993,9 @@ class InfiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.Infinite
2996 2993
 
2997 2994
         # Construct system parameters
2998 2995
         parameters = set()
2999
-        for params in chain(onsite_term_parameters, hopping_term_parameters):
3000
-            if params is not None:
3001
-                parameters.update(params)
2996
+        for term in terms:
2997
+            if term.parameters is not None:
2998
+                parameters.update(term.parameters)
3002 2999
         parameters = frozenset(parameters)
3003 3000
 
3004 3001
         self.site_arrays = site_arrays
... ...
@@ -1547,6 +1547,25 @@ def test_argument_passing(vectorize):
1547 1547
         expected_hamiltonian(**params)
1548 1548
     )
1549 1549
 
1550
+@pytest.mark.parametrize("vectorize", [False, True])
1551
+def test_invalid_value_functions(vectorize):
1552
+
1553
+    invalid_value_functions = [
1554
+        lambda _, *args: 1,  # uses *args
1555
+        lambda _, **kwargs: 1,  # uses **kwargs
1556
+        lambda _, d=10: 1,  # has default arguments
1557
+        lambda _, *, d: 1  # has keyword-only parameters
1558
+    ]
1559
+
1560
+    lat = kwant.lattice.chain(norbs=1)
1561
+
1562
+    for f in invalid_value_functions:
1563
+        syst = builder.Builder(vectorize=vectorize)
1564
+        syst[lat(0)] = f
1565
+        syst = syst.finalized()
1566
+        with pytest.raises(ValueError):
1567
+            syst.hamiltonian_submatrix(params=dict(d=1))
1568
+
1550 1569
 
1551 1570
 def test_parameter_substitution():
1552 1571