Browse code

make separate methods to get JSON or text responses

aiohttp is very keen on context managers, and its easier
not to pass the response object around everywhere, but deal with
it completely in a single place.

Joseph Weston authored on 08/09/2017 18:45:30
Showing 1 changed files
... ...
@@ -79,9 +79,15 @@ class Client:
79 79
     async def __aexit__(self, exc_typ, exc, traceback):
80 80
         self.close()
81 81
 
82
-    async def _get(self, endpoint):
83
-        return self._session.get(''.join(self.api_url, endpoint),
84
-                                 headers=self.headers)
82
+    async def _get_json(self, endpoint):
83
+        url = ''.join((self.api_url, endpoint))
84
+        async with self._session.get(url, headers=self.headers) as resp:
85
+            return await resp.json()
86
+
87
+    async def _get_text(self, endpoint):
88
+        url = ''.join((self.api_url, endpoint))
89
+        async with self._session.get(url, headers=self.headers) as resp:
90
+            return await resp.text()
85 91
 
86 92
     # API methods
87 93
 
... ...
@@ -97,8 +103,8 @@ class Client:
97 103
         protocol: str, 'tcp' or 'udp'
98 104
         """
99 105
         host = normalized_hostname(host)
100
-        resp = await self._get(f'files/download/{_config_filename(host, protocol)}')
101
-        return await resp.text()
106
+        endpoint = f'files/download/{_config_filename(host, protocol)}'
107
+        return await self._get_text(endpoint)
102 108
 
103 109
 
104 110
     async def host_load(self, host=None):
... ...
@@ -120,8 +126,7 @@ class Client:
120 126
         if host:
121 127
             host = normalized_hostname(host)
122 128
         endpoint = f'server/stats/{host}' if host else 'server/stats'
123
-        resp = await self._get(endpoint)
124
-        resp = await resp.json()
129
+        resp = await self._get_json(endpoint)
125 130
         if host:
126 131
             if len(resp) != 1:
127 132
                 # Nord API returns load on all hosts if 'host' does not exist.
... ...
@@ -133,8 +138,7 @@ class Client:
133 138
 
134 139
     async def current_ip(self):
135 140
         """Return our current public IP address, as detected by NordVPN."""
136
-        resp = await self._get('user/address')
137
-        return await resp.text()
141
+        return await self._get_text('user/address')
138 142
 
139 143
 
140 144
     @async_lru_cache()
... ...
@@ -146,16 +150,14 @@ class Client:
146 150
         host_info : (dict: str → dict)
147 151
             A map from hostnames to host info dictionaries.
148 152
         """
149
-        resp = await self._get('server')
150
-        info = await resp.json()
153
+        info = await self._get_json('server')
151 154
         return {h['domain']: h for h in info}
152 155
 
153 156
 
154 157
     @async_lru_cache()
155 158
     async def dns_servers(self):
156 159
         """Return a list of ip addresses of NordVPN DNS servers."""
157
-        resp = await self._get('dns/smart')
158
-        return await resp.json()
160
+        return await self._get_json('dns/smart')
159 161
 
160 162
 
161 163
     async def valid_credentials(self, username, password):
... ...
@@ -170,13 +172,11 @@ class Client:
170 172
         ----------
171 173
         username, password : str
172 174
         """
173
-        resp = await self._get('token/token/{username}')
174
-        resp = await resp.json()
175
+        resp = await self._get_json(f'token/token/{username}')
175 176
         token, salt, key = (resp[k] for k in ['token', 'salt', 'key'])
176 177
 
177 178
         round1 = sha512(salt.encode() + password.encode())
178 179
         round2 = sha512(round1.hexdigest().encode() + key.encode())
179 180
         response = round2.hexdigest()
180 181
 
181
-        resp = await self._get(f'token/verify/{token}/{response}')
182
-        return await resp.json()
182
+        return await self._get_json(f'token/verify/{token}/{response}')