Browse code

Merge branch 'fixup/vectorization' into 'master'

make system terms immutable

See merge request kwant/kwant!348

Joseph Weston authored on 10/12/2019 16:36:20
Showing 3 changed files
... ...
@@ -384,7 +384,7 @@ def _vectorized_make_sparse(subgraphs, hams, long [:] norbs, long [:] orb_offset
384 384
 
385 385
     cdef long i, j, k, m, N, M, P, to_off, from_off,\
386 386
               ta, fa, to_norbs, from_norbs
387
-    cdef long [:] ts_offs, fs_offs
387
+    cdef const long [:] ts_offs, fs_offs
388 388
     cdef complex [:, :, :] h
389 389
 
390 390
     m = 0
... ...
@@ -428,7 +428,7 @@ def _vectorized_make_dense(subgraphs, hams, long [:] norbs, long [:] orb_offsets
428 428
 
429 429
     cdef long i, j, k, N, M, P, to_off, from_off,\
430 430
               ta, fa, to_norbs, from_norbs
431
-    cdef long [:] ts_offs, fs_offs
431
+    cdef const long [:] ts_offs, fs_offs
432 432
     cdef complex [:, :, :] h
433 433
 
434 434
     # This outer loop zip() is pure Python, but that's ok, as it
... ...
@@ -2485,6 +2485,7 @@ def _make_onsite_terms(builder, sites, site_offsets, term_offset):
2485 2485
     tmp = []
2486 2486
     for (_, which), s in zip(onsite_to_term_nr, onsite_subgraphs):
2487 2487
         s = s - site_offsets[which]
2488
+        s.flags.writeable = False
2488 2489
         tmp.append(((which, which), (s, s)))
2489 2490
     onsite_subgraphs = tmp
2490 2491
     # onsite_term_errors[i] contains an exception if the corresponding
... ...
@@ -2593,7 +2594,9 @@ def _make_hopping_terms(builder, graph, sites, site_offsets, cell_size, term_off
2593 2594
         # Transpose to get a pair of arrays rather than array of pairs
2594 2595
         # We use the fact that the underlying array is stored in
2595 2596
         # array-of-pairs order to search through it in 'hamiltonian'.
2596
-        tmp.append(((tail_which, head_which), (h - start).transpose()))
2597
+        pairs = (h - start).transpose()
2598
+        pairs.flags.writeable = False
2599
+        tmp.append(((tail_which, head_which), pairs))
2597 2600
     hopping_subgraphs = tmp
2598 2601
     # hopping_term_errors[i] contains an exception if the corresponding
2599 2602
     # term has a value function with an illegal signature. We only raise
... ...
@@ -471,8 +471,8 @@ def _vectorized_make_onsite_terms(syst, where):
471 471
     ret = []
472 472
     for term_id, which in terms.items():
473 473
         term = syst.terms[term_id]
474
-        ((term_sa, _), (term_sites, _)) = syst.subgraphs[term.subgraph]
475
-        term_sites += site_offsets[term_sa]
474
+        ((term_sa, _), (term_site_offsets, _)) = syst.subgraphs[term.subgraph]
475
+        term_sites = site_offsets[term_sa] + term_site_offsets
476 476
         which = np.asarray(which, dtype=gint_dtype)
477 477
         sites = _select(where, which).reshape(-1)
478 478
         selector = np.searchsorted(term_sites, sites)