Browse code

move SimpleSiteFamily to 'tests.builder'

It is only used in the tests anyway. This way we export less
API from the builder module.

Joseph Weston authored on 26/11/2019 15:08:26
Showing 3 changed files
... ...
@@ -10,7 +10,6 @@ Types
10 10
 
11 11
    Builder
12 12
    HoppingKind
13
-   SimpleSiteFamily
14 13
    BuilderLead
15 14
    SelfEnergyLead
16 15
    ModesLead
... ...
@@ -11,7 +11,7 @@ import warnings
11 11
 import operator
12 12
 import collections
13 13
 import copy
14
-from functools import total_ordering, wraps, update_wrapper
14
+from functools import wraps, update_wrapper
15 15
 from itertools import islice, chain
16 16
 import textwrap
17 17
 import bisect
... ...
@@ -29,42 +29,9 @@ from ._common import (ensure_isinstance, get_parameters, reraise_warnings,
29 29
                       interleave, deprecate_args, memoize)
30 30
 
31 31
 
32
-__all__ = ['Builder', 'SimpleSiteFamily', 'Symmetry', 'HoppingKind', 'Lead',
32
+__all__ = ['Builder', 'Symmetry', 'HoppingKind', 'Lead',
33 33
            'BuilderLead', 'SelfEnergyLead', 'ModesLead', 'add_peierls_phase']
34 34
 
35
-
36
-################ Site families
37
-
38
-@total_ordering
39
-class SimpleSiteFamily(SiteFamily):
40
-    """A site family used as an example and for testing.
41
-
42
-    A family of sites tagged by any python objects where object satisfied
43
-    condition ``object == eval(repr(object))``.
44
-
45
-    It exists to provide a basic site family that can be used for testing the
46
-    builder module without other dependencies.  It can be also used to tag
47
-    sites with non-numeric objects like strings should this every be useful.
48
-
49
-    Due to its low storage efficiency for numbers it is not recommended to use
50
-    `SimpleSiteFamily` when `kwant.lattice.Monatomic` would also work.
51
-    """
52
-
53
-    def __init__(self, name=None, norbs=None):
54
-        canonical_repr = '{0}({1}, {2})'.format(self.__class__, repr(name),
55
-                                                repr(norbs))
56
-        super().__init__(canonical_repr, name, norbs)
57
-
58
-    def normalize_tag(self, tag):
59
-        tag = tuple(tag)
60
-        try:
61
-            if eval(repr(tag)) != tag:
62
-                raise RuntimeError()
63
-        except:
64
-            raise TypeError('It must be possible to recreate the tag from '
65
-                            'its representation.')
66
-        return tag
67
-
68 35
 
69 36
 def validate_hopping(hopping):
70 37
     """Verify that the argument is a valid hopping."""
... ...
@@ -23,12 +23,43 @@ from kwant import builder, system
23 23
 from kwant._common import ensure_rng, KwantDeprecationWarning
24 24
 
25 25
 
26
+@ft.total_ordering
27
+class SimpleSiteFamily(system.SiteFamily):
28
+    """A site family used as an example and for testing.
29
+
30
+    A family of sites tagged by any python objects where object satisfied
31
+    condition ``object == eval(repr(object))``.
32
+
33
+    It exists to provide a basic site family that can be used for testing the
34
+    builder module without other dependencies.  It can be also used to tag
35
+    sites with non-numeric objects like strings should this every be useful.
36
+
37
+    Due to its low storage efficiency for numbers it is not recommended to use
38
+    `SimpleSiteFamily` when `kwant.lattice.Monatomic` would also work.
39
+    """
40
+
41
+    def __init__(self, name=None, norbs=None):
42
+        canonical_repr = '{0}({1}, {2})'.format(self.__class__, repr(name),
43
+                                                repr(norbs))
44
+        super().__init__(canonical_repr, name, norbs)
45
+
46
+    def normalize_tag(self, tag):
47
+        tag = tuple(tag)
48
+        try:
49
+            if eval(repr(tag)) != tag:
50
+                raise RuntimeError()
51
+        except:
52
+            raise TypeError('It must be possible to recreate the tag from '
53
+                            'its representation.')
54
+        return tag
55
+
56
+
26 57
 def test_bad_keys():
27 58
 
28 59
     def setitem(key):
29 60
         syst[key] = None
30 61
 
31
-    fam = builder.SimpleSiteFamily(norbs=1)
62
+    fam = SimpleSiteFamily(norbs=1)
32 63
     syst = builder.Builder()
33 64
 
34 65
     failures = [
... ...
@@ -97,9 +128,9 @@ def test_bad_keys():
97 128
 
98 129
 def test_site_families():
99 130
     syst = builder.Builder()
100
-    fam = builder.SimpleSiteFamily(norbs=1)
101
-    ofam = builder.SimpleSiteFamily(norbs=1)
102
-    yafam = builder.SimpleSiteFamily('another_name', norbs=1)
131
+    fam = SimpleSiteFamily(norbs=1)
132
+    ofam = SimpleSiteFamily(norbs=1)
133
+    yafam = SimpleSiteFamily('another_name', norbs=1)
103 134
 
104 135
     syst[fam(0)] = 7
105 136
     assert syst[fam(0)] == 7
... ...
@@ -117,8 +148,8 @@ def test_site_families():
117 148
     assert fam != 'a'
118 149
 
119 150
     # test site families sorting
120
-    fam1 = builder.SimpleSiteFamily(norbs=1)
121
-    fam2 = builder.SimpleSiteFamily(norbs=2)
151
+    fam1 = SimpleSiteFamily(norbs=1)
152
+    fam2 = SimpleSiteFamily(norbs=2)
122 153
     assert fam1 < fam2  # string '1' is lexicographically less than '2'
123 154
 
124 155
 
... ...
@@ -162,7 +193,7 @@ class VerySimpleSymmetry(builder.Symmetry):
162 193
 # made.
163 194
 def check_construction_and_indexing(sites, sites_fd, hoppings, hoppings_fd,
164 195
                                     unknown_hoppings, sym=None):
165
-    fam = builder.SimpleSiteFamily(norbs=1)
196
+    fam = SimpleSiteFamily(norbs=1)
166 197
     syst = builder.Builder(sym)
167 198
     t, V = 1.0j, 0.0
168 199
     syst[sites] = V
... ...
@@ -212,7 +243,7 @@ def check_construction_and_indexing(sites, sites_fd, hoppings, hoppings_fd,
212 243
 
213 244
 def test_construction_and_indexing():
214 245
     # Without symmetry
215
-    fam = builder.SimpleSiteFamily(norbs=1)
246
+    fam = SimpleSiteFamily(norbs=1)
216 247
     sites = [fam(0, 0), fam(0, 1), fam(1, 0)]
217 248
     hoppings = [(fam(0, 0), fam(0, 1)),
218 249
                 (fam(0, 1), fam(1, 0)),
... ...
@@ -250,7 +281,7 @@ def test_hermitian_conjugation():
250 281
             raise ValueError
251 282
 
252 283
     syst = builder.Builder()
253
-    fam = builder.SimpleSiteFamily(norbs=1)
284
+    fam = SimpleSiteFamily(norbs=1)
254 285
     syst[fam(0)] = syst[fam(1)] = ta.identity(2)
255 286
 
256 287
     syst[fam(0), fam(1)] = f
... ...
@@ -266,7 +297,7 @@ def test_hermitian_conjugation():
266 297
 def test_value_equality_and_identity():
267 298
     m = ta.array([[1, 2], [3j, 4j]])
268 299
     syst = builder.Builder()
269
-    fam = builder.SimpleSiteFamily(norbs=1)
300
+    fam = SimpleSiteFamily(norbs=1)
270 301
 
271 302
     syst[fam(0)] = m
272 303
     syst[fam(1)] = m
... ...
@@ -573,7 +604,7 @@ def test_hamiltonian_evaluation(vectorize):
573 604
     edges = [(0, 1), (0, 2), (0, 3), (1, 2)]
574 605
 
575 606
     syst = builder.Builder(vectorize=vectorize)
576
-    fam = builder.SimpleSiteFamily(norbs=1)
607
+    fam = SimpleSiteFamily(norbs=1)
577 608
     sites = [fam(*tag) for tag in tags]
578 609
     hoppings = [(sites[i], sites[j]) for i, j in edges]
579 610
     if vectorize:
... ...
@@ -664,7 +695,7 @@ def test_vectorized_hamiltonian_evaluation():
664 695
     tags = [(0, 0), (1, 1), (2, 2), (3, 3)]
665 696
     edges = [(0, 1), (0, 2), (0, 3), (1, 2)]
666 697
 
667
-    fam = builder.SimpleSiteFamily(norbs=1)
698
+    fam = SimpleSiteFamily(norbs=1)
668 699
     sites = [fam(*tag) for tag in tags]
669 700
     hops = [(fam(*tags[i]), fam(*tags[j])) for (i, j) in edges]
670 701
 
... ...
@@ -804,7 +835,7 @@ def test_dangling():
804 835
         #       / \
805 836
         #    3-0---2-4-5  6-7  8
806 837
         syst = builder.Builder()
807
-        fam = builder.SimpleSiteFamily(norbs=1)
838
+        fam = SimpleSiteFamily(norbs=1)
808 839
         syst[(fam(i) for i in range(9))] = None
809 840
         syst[[(fam(0), fam(1)), (fam(1), fam(2)), (fam(2), fam(0))]] = None
810 841
         syst[[(fam(0), fam(3)), (fam(2), fam(4)), (fam(4), fam(5))]] = None
... ...
@@ -1077,8 +1108,8 @@ def test_fill_sticky():
1077 1108
 
1078 1109
 
1079 1110
 def test_attach_lead():
1080
-    fam = builder.SimpleSiteFamily(norbs=1)
1081
-    fam_noncommensurate = builder.SimpleSiteFamily(name='other', norbs=1)
1111
+    fam = SimpleSiteFamily(norbs=1)
1112
+    fam_noncommensurate = SimpleSiteFamily(name='other', norbs=1)
1082 1113
 
1083 1114
     syst = builder.Builder()
1084 1115
     syst[fam(1)] = 0
... ...
@@ -1146,7 +1177,7 @@ def test_attach_lead_incomplete_unit_cell():
1146 1177
 def test_neighbors_not_in_single_domain():
1147 1178
     sr = builder.Builder()
1148 1179
     lead = builder.Builder(VerySimpleSymmetry(-1))
1149
-    fam = builder.SimpleSiteFamily(norbs=1)
1180
+    fam = SimpleSiteFamily(norbs=1)
1150 1181
     sr[(fam(x, y) for x in range(3) for y in range(3) if x >= y)] = 0
1151 1182
     sr[builder.HoppingKind((1, 0), fam)] = 1
1152 1183
     sr[builder.HoppingKind((0, 1), fam)] = 1
... ...
@@ -1201,7 +1232,7 @@ def test_closest():
1201 1232
 
1202 1233
 
1203 1234
 def test_update():
1204
-    lat = builder.SimpleSiteFamily(norbs=1)
1235
+    lat = SimpleSiteFamily(norbs=1)
1205 1236
 
1206 1237
     syst = builder.Builder()
1207 1238
     syst[[lat(0,), lat(1,)]] = 1
... ...
@@ -1296,7 +1327,7 @@ def test_invalid_HoppingKind():
1296 1327
 
1297 1328
 
1298 1329
 def test_ModesLead_and_SelfEnergyLead():
1299
-    lat = builder.SimpleSiteFamily(norbs=1)
1330
+    lat = SimpleSiteFamily(norbs=1)
1300 1331
     hoppings = [builder.HoppingKind((1, 0), lat),
1301 1332
                 builder.HoppingKind((0, 1), lat)]
1302 1333
     rng = Random(123)
... ...
@@ -1378,9 +1409,9 @@ def test_site_pickle():
1378 1409
 
1379 1410
 
1380 1411
 def test_discrete_symmetries():
1381
-    lat = builder.SimpleSiteFamily(name='ccc', norbs=2)
1382
-    lat2 = builder.SimpleSiteFamily(name='bla', norbs=1)
1383
-    lat3 = builder.SimpleSiteFamily(name='dd', norbs=4)
1412
+    lat = SimpleSiteFamily(name='ccc', norbs=2)
1413
+    lat2 = SimpleSiteFamily(name='bla', norbs=1)
1414
+    lat3 = SimpleSiteFamily(name='dd', norbs=4)
1384 1415
 
1385 1416
     cons_law = {lat: np.diag([1, 2]), lat2: 0}
1386 1417
     syst = builder.Builder(conservation_law=cons_law,