Browse code

update TranslationalSymmetry to work with SiteArrays

Now 'which' and 'act' (and, by extension, 'to_fd') now accept
SiteArrays as well as Sites.

Joseph Weston authored on 09/09/2019 12:02:36
Showing 2 changed files
... ...
@@ -139,19 +139,44 @@ class Symmetry(metaclass=abc.ABCMeta):
139 139
     def which(self, site):
140 140
         """Calculate the domain of the site.
141 141
 
142
-        Return the group element whose action on a certain site from the
143
-        fundamental domain will result in the given ``site``.
142
+        Parameters
143
+        ----------
144
+        site : `~kwant.system.Site` or `~kwant.system.SiteArray`
145
+
146
+        Returns
147
+        -------
148
+        group_element : tuple or sequence of tuples
149
+            A single tuple if ``site`` is a Site, or a sequence of tuples if
150
+            ``site`` is a SiteArray.  The group element(s) whose action
151
+            on a certain site(s) from the fundamental domain will result
152
+            in the given ``site``.
144 153
         """
145 154
         pass
146 155
 
147 156
     @abc.abstractmethod
148 157
     def act(self, element, a, b=None):
149
-        """Act with a symmetry group element on a site or hopping."""
158
+        """Act with symmetry group element(s) on site(s) or hopping(s).
159
+
160
+        Parameters
161
+        ----------
162
+        element : tuple or sequence of tuples
163
+            Group element(s) with which to act on the provided site(s)
164
+            or hopping(s)
165
+        a, b : `~kwant.system.Site` or `~kwant.system.SiteArray`
166
+            If Site then ``element`` is a single tuple, if SiteArray then
167
+            ``element`` is a sequence of tuples. If only ``a`` is provided then
168
+            ``element`` acts on the site(s) of ``a``. If ``b`` is also provided
169
+            then ``element`` acts on the hopping(s) ``(a, b)``.
170
+        """
150 171
         pass
151 172
 
152 173
     def to_fd(self, a, b=None):
153 174
         """Map a site or hopping to the fundamental domain.
154 175
 
176
+        Parameters
177
+        ----------
178
+        a, b : `~kwant.system.Site` or `~kwant.system.SiteArray`
179
+
155 180
         If ``b`` is None, return a site equivalent to ``a`` within the
156 181
         fundamental domain.  Otherwise, return a hopping equivalent to ``(a,
157 182
         b)`` but where the first element belongs to the fundamental domain.
... ...
@@ -161,11 +186,30 @@ class Symmetry(metaclass=abc.ABCMeta):
161 186
         return self.act(-self.which(a), a, b)
162 187
 
163 188
     def in_fd(self, site):
164
-        """Tell whether ``site`` lies within the fundamental domain."""
165
-        for d in self.which(site):
166
-            if d != 0:
167
-                return False
168
-        return True
189
+        """Tell whether ``site`` lies within the fundamental domain.
190
+
191
+        Parameters
192
+        ----------
193
+        site : `~kwant.system.Site` or `~kwant.system.SiteArray`
194
+
195
+        Returns
196
+        -------
197
+        in_fd : bool or sequence of bool
198
+            single bool if ``site`` is a Site, or a sequence of
199
+            bool if ``site`` is a SiteArray. In the latter case
200
+            we return whether each site in the SiteArray is in
201
+            the fundamental domain.
202
+        """
203
+        if isinstance(site, Site):
204
+            for d in self.which(site):
205
+                if d != 0:
206
+                    return False
207
+            return True
208
+        elif isinstance(site, SiteArray):
209
+            which = self.which(site)
210
+            return np.logical_and.reduce(which != 0, axis=1)
211
+        else:
212
+            raise TypeError("'site' must be a Site or SiteArray")
169 213
 
170 214
     @abc.abstractmethod
171 215
     def subgroup(self, *generators):
... ...
@@ -698,26 +698,48 @@ class TranslationalSymmetry(builder.Symmetry):
698 698
 
699 699
     def which(self, site):
700 700
         det_x_inv_m_part, det_m = self._get_site_family_data(site.family)[-2:]
701
-        result = ta.dot(det_x_inv_m_part, site.tag) // det_m
701
+        if isinstance(site, system.Site):
702
+            result = ta.dot(det_x_inv_m_part, site.tag) // det_m
703
+        elif isinstance(site, system.SiteArray):
704
+            result = np.dot(det_x_inv_m_part, site.tags.transpose()) // det_m
705
+        else:
706
+            raise TypeError("'site' must be a Site or a SiteArray")
707
+
702 708
         return -result if self.is_reversed else result
703 709
 
704 710
     def act(self, element, a, b=None):
705
-        element = ta.array(element)
706
-        if element.dtype is not int:
711
+        is_site = isinstance(a, system.Site)
712
+        # Tinyarray for small arrays (single site) else numpy
713
+        array_mod = ta if is_site else np
714
+        element = array_mod.array(element)
715
+        if not np.issubdtype(element.dtype, np.integer):
707 716
             raise ValueError("group element must be a tuple of integers")
717
+        if (len(element.shape) == 2 and is_site):
718
+            raise ValueError("must provide a single group element when "
719
+                             "acting on single sites.")
720
+        if (len(element.shape) == 1 and not is_site):
721
+            raise ValueError("must provide a sequence of group elements "
722
+                             "when acting on site arrays.")
708 723
         m_part = self._get_site_family_data(a.family)[0]
709 724
         try:
710
-            delta = ta.dot(m_part, element)
725
+            delta = array_mod.dot(m_part, element)
711 726
         except ValueError:
712 727
             msg = 'Expecting a {0}-tuple group element, but got `{1}` instead.'
713 728
             raise ValueError(msg.format(self.num_directions, element))
714 729
         if self.is_reversed:
715 730
             delta = -delta
716 731
         if b is None:
717
-            return builder.Site(a.family, a.tag + delta, True)
732
+            if is_site:
733
+                return system.Site(a.family, a.tag + delta, True)
734
+            else:
735
+                return system.SiteArray(a.family, a.tags + delta.transpose())
718 736
         elif b.family == a.family:
719
-            return (builder.Site(a.family, a.tag + delta, True),
720
-                    builder.Site(b.family, b.tag + delta, True))
737
+            if is_site:
738
+                return (system.Site(a.family, a.tag + delta, True),
739
+                        system.Site(b.family, b.tag + delta, True))
740
+            else:
741
+                return (system.SiteArray(a.family, a.tags + delta.transpose()),
742
+                        system.SiteArray(b.family, b.tags + delta.transpose()))
721 743
         else:
722 744
             m_part = self._get_site_family_data(b.family)[0]
723 745
             try:
... ...
@@ -728,8 +750,12 @@ class TranslationalSymmetry(builder.Symmetry):
728 750
                 raise ValueError(msg.format(self.num_directions, element))
729 751
             if self.is_reversed:
730 752
                 delta2 = -delta2
731
-            return (builder.Site(a.family, a.tag + delta, True),
732
-                    builder.Site(b.family, b.tag + delta2, True))
753
+            if is_site:
754
+                return (system.Site(a.family, a.tag + delta, True),
755
+                        system.Site(b.family, b.tag + delta2, True))
756
+            else:
757
+                return (system.SiteArray(a.family, a.tags + delta.transpose()),
758
+                        system.SiteArray(b.family, b.tags + delta2.transpose()))
733 759
 
734 760
     def reversed(self):
735 761
         """Return a reversed copy of the symmetry.