add a test for invalid value functions
See merge request kwant/kwant!347
... | ... |
@@ -2509,7 +2509,7 @@ def _make_onsite_terms(builder, sites, site_offsets, term_offset): |
2509 | 2509 |
_onsite_term_by_site_id = np.array(_onsite_term_by_site_id) |
2510 | 2510 |
|
2511 | 2511 |
return (onsite_subgraphs, onsite_terms, onsite_term_values, |
2512 |
- onsite_term_parameters, onsite_term_errors, _onsite_term_by_site_id) |
|
2512 |
+ onsite_term_errors, _onsite_term_by_site_id) |
|
2513 | 2513 |
|
2514 | 2514 |
|
2515 | 2515 |
def _make_hopping_terms(builder, graph, sites, site_offsets, cell_size, term_offset): |
... | ... |
@@ -2619,8 +2619,7 @@ def _make_hopping_terms(builder, graph, sites, site_offsets, cell_size, term_off |
2619 | 2619 |
_hopping_term_by_edge_id = np.array(_hopping_term_by_edge_id) |
2620 | 2620 |
|
2621 | 2621 |
return (hopping_subgraphs, hopping_terms, hopping_term_values, |
2622 |
- hopping_term_parameters, hopping_term_errors, |
|
2623 |
- _hopping_term_by_edge_id) |
|
2622 |
+ hopping_term_errors, _hopping_term_by_edge_id) |
|
2624 | 2623 |
|
2625 | 2624 |
|
2626 | 2625 |
class FiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.FiniteVectorizedSystem): |
... | ... |
@@ -2667,12 +2666,11 @@ class FiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.FiniteVect |
2667 | 2666 |
site_offsets = np.cumsum([0] + [len(arr) for arr in site_arrays]) |
2668 | 2667 |
|
2669 | 2668 |
(onsite_subgraphs, onsite_terms, onsite_term_values, |
2670 |
- onsite_term_parameters, onsite_term_errors, _onsite_term_by_site_id) =\ |
|
2669 |
+ onsite_term_errors, _onsite_term_by_site_id) =\ |
|
2671 | 2670 |
_make_onsite_terms(builder, sites, site_offsets, term_offset=0) |
2672 | 2671 |
|
2673 | 2672 |
(hopping_subgraphs, hopping_terms, hopping_term_values, |
2674 |
- hopping_term_parameters, hopping_term_errors, |
|
2675 |
- _hopping_term_by_edge_id) =\ |
|
2673 |
+ hopping_term_errors, _hopping_term_by_edge_id) =\ |
|
2676 | 2674 |
_make_hopping_terms(builder, graph, sites, site_offsets, |
2677 | 2675 |
len(sites), term_offset=len(onsite_terms)) |
2678 | 2676 |
|
... | ... |
@@ -2684,9 +2682,9 @@ class FiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.FiniteVect |
2684 | 2682 |
|
2685 | 2683 |
# Construct system parameters |
2686 | 2684 |
parameters = set() |
2687 |
- for params in chain(onsite_term_parameters, hopping_term_parameters): |
|
2688 |
- if params is not None: |
|
2689 |
- parameters.update(params) |
|
2685 |
+ for term in terms: |
|
2686 |
+ if term.parameters is not None: |
|
2687 |
+ parameters.update(term.parameters) |
|
2690 | 2688 |
parameters = frozenset(parameters) |
2691 | 2689 |
|
2692 | 2690 |
self.site_arrays = site_arrays |
... | ... |
@@ -2979,12 +2977,11 @@ class InfiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.Infinite |
2979 | 2977 |
site_offsets = np.cumsum([0] + [len(arr) for arr in site_arrays]) |
2980 | 2978 |
|
2981 | 2979 |
(onsite_subgraphs, onsite_terms, onsite_term_values, |
2982 |
- onsite_term_parameters, onsite_term_errors, _onsite_term_by_site_id) =\ |
|
2980 |
+ onsite_term_errors, _onsite_term_by_site_id) =\ |
|
2983 | 2981 |
_make_onsite_terms(builder, sites, site_offsets, term_offset=0) |
2984 | 2982 |
|
2985 | 2983 |
(hopping_subgraphs, hopping_terms, hopping_term_values, |
2986 |
- hopping_term_parameters, hopping_term_errors, |
|
2987 |
- _hopping_term_by_edge_id) =\ |
|
2984 |
+ hopping_term_errors, _hopping_term_by_edge_id) =\ |
|
2988 | 2985 |
_make_hopping_terms(builder, graph, sites, site_offsets, |
2989 | 2986 |
cell_size, term_offset=len(onsite_terms)) |
2990 | 2987 |
|
... | ... |
@@ -2996,9 +2993,9 @@ class InfiniteVectorizedSystem(_VectorizedFinalizedBuilderMixin, system.Infinite |
2996 | 2993 |
|
2997 | 2994 |
# Construct system parameters |
2998 | 2995 |
parameters = set() |
2999 |
- for params in chain(onsite_term_parameters, hopping_term_parameters): |
|
3000 |
- if params is not None: |
|
3001 |
- parameters.update(params) |
|
2996 |
+ for term in terms: |
|
2997 |
+ if term.parameters is not None: |
|
2998 |
+ parameters.update(term.parameters) |
|
3002 | 2999 |
parameters = frozenset(parameters) |
3003 | 3000 |
|
3004 | 3001 |
self.site_arrays = site_arrays |
... | ... |
@@ -1547,6 +1547,25 @@ def test_argument_passing(vectorize): |
1547 | 1547 |
expected_hamiltonian(**params) |
1548 | 1548 |
) |
1549 | 1549 |
|
1550 |
+@pytest.mark.parametrize("vectorize", [False, True]) |
|
1551 |
+def test_invalid_value_functions(vectorize): |
|
1552 |
+ |
|
1553 |
+ invalid_value_functions = [ |
|
1554 |
+ lambda _, *args: 1, # uses *args |
|
1555 |
+ lambda _, **kwargs: 1, # uses **kwargs |
|
1556 |
+ lambda _, d=10: 1, # has default arguments |
|
1557 |
+ lambda _, *, d: 1 # has keyword-only parameters |
|
1558 |
+ ] |
|
1559 |
+ |
|
1560 |
+ lat = kwant.lattice.chain(norbs=1) |
|
1561 |
+ |
|
1562 |
+ for f in invalid_value_functions: |
|
1563 |
+ syst = builder.Builder(vectorize=vectorize) |
|
1564 |
+ syst[lat(0)] = f |
|
1565 |
+ syst = syst.finalized() |
|
1566 |
+ with pytest.raises(ValueError): |
|
1567 |
+ syst.hamiltonian_submatrix(params=dict(d=1)) |
|
1568 |
+ |
|
1550 | 1569 |
|
1551 | 1570 |
def test_parameter_substitution(): |
1552 | 1571 |
|