Coverage for benefits/enrollment_switchio/api.py: 99%
132 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-10 16:52 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-10-10 16:52 +0000
1from dataclasses import dataclass
2from datetime import datetime, timezone
3from enum import Enum
4import hashlib
5import hmac
6import json
7from tempfile import NamedTemporaryFile
8import requests
11@dataclass
12class Registration:
13 regId: str
14 gtwUrl: str
17class RegistrationMode(Enum):
18 REGISTER = "register"
19 IDENTIFY = "identify"
22class EshopResponseMode(Enum):
23 FRAGMENT = "fragment"
24 QUERY = "query"
25 FORM_POST = "form_post"
26 POST_MESSAGE = "post_message"
29@dataclass
30class RegistrationStatus:
31 regState: str
32 created: datetime
33 mode: str
34 tokens: list[dict]
35 eshopResponseMode: str
36 identType: str = None
37 maskCln: str = None
38 cardExp: str = None
41class Client:
42 def __init__(
43 self,
44 private_key,
45 client_certificate,
46 ca_certificate,
47 ):
48 self.private_key = private_key
49 self.client_certificate = client_certificate
50 self.ca_certificate = ca_certificate
52 # see https://github.com/cal-itp/benefits/issues/2848 for more context about this
53 def _cert_request(self, request_func):
54 """
55 Creates named (on-disk) temp files for client cert auth.
56 * request_func: curried callable from `requests` library (e.g. `requests.get`).
57 """
58 # requests library reads temp files from file path
59 # The "with" context destroys temp files when response comes back
60 with NamedTemporaryFile("w+") as cert, NamedTemporaryFile("w+") as key, NamedTemporaryFile("w+") as ca:
61 # write client cert data to temp files
62 # resetting so they can be read again by requests
63 cert.write(self.client_certificate)
64 cert.seek(0)
66 key.write(self.private_key)
67 key.seek(0)
69 ca.write(self.ca_certificate)
70 ca.seek(0)
72 # request using temp file paths
73 return request_func(verify=ca.name, cert=(cert.name, key.name))
76class TokenizationClient(Client):
78 def __init__(
79 self,
80 api_url,
81 api_key,
82 api_secret,
83 private_key,
84 client_certificate,
85 ca_certificate,
86 ):
87 super().__init__(private_key, client_certificate, ca_certificate)
88 self.api_url = api_url.strip("/")
89 self.api_key = api_key
90 self.api_secret = api_secret
92 def _signature_input_string(self, timestamp: str, method: str, request_path: str, body: str = None):
93 if body is None:
94 body = ""
96 return f"{timestamp}{method}{request_path}{body}"
98 def _stp_signature(self, timestamp: str, method: str, request_path, body: str = None):
99 input_string = self._signature_input_string(timestamp, method, request_path, body)
101 # must encode inputs for hashing, according to https://stackoverflow.com/a/66958131
102 byte_key = self.api_secret.encode("utf-8")
103 message = input_string.encode("utf-8")
104 stp_signature = hmac.new(byte_key, message, hashlib.sha256).hexdigest()
106 return stp_signature
108 def _get_headers(self, method, request_path, request_body: dict = None):
109 timestamp = str(int(datetime.now().timestamp()))
111 return {
112 "STP-APIKEY": self.api_key,
113 "STP-TIMESTAMP": timestamp,
114 "STP-SIGNATURE": self._stp_signature(
115 timestamp=timestamp,
116 method=method,
117 request_path=request_path,
118 body=json.dumps(request_body) if request_body else None,
119 ),
120 }
122 def request_registration(
123 self,
124 eshopRedirectUrl: str,
125 mode: RegistrationMode,
126 eshopResponseMode: EshopResponseMode,
127 timeout=5,
128 ) -> Registration:
129 registration_path = "/api/v1/registration"
130 request_body = {
131 "eshopRedirectUrl": eshopRedirectUrl,
132 "mode": mode.value,
133 "eshopResponseMode": eshopResponseMode.value,
134 }
136 response = self._cert_request(
137 lambda verify, cert: requests.post(
138 self.api_url + registration_path,
139 json=request_body,
140 headers=self._get_headers(method="POST", request_path=registration_path, request_body=request_body),
141 cert=cert,
142 verify=verify,
143 timeout=timeout,
144 )
145 )
147 response.raise_for_status()
149 return Registration(**response.json())
151 def get_registration_status(self, registration_id, timeout=5) -> RegistrationStatus:
152 request_path = f"/api/v1/registration/{registration_id}"
154 response = self._cert_request(
155 lambda verify, cert: requests.get(
156 self.api_url + request_path,
157 headers=self._get_headers(method="GET", request_path=request_path),
158 cert=cert,
159 verify=verify,
160 timeout=timeout,
161 )
162 )
164 response.raise_for_status()
166 return RegistrationStatus(**response.json())
169@dataclass
170class Group:
171 id: int
172 operatorId: int
173 name: str
174 code: str
175 value: int
178@dataclass
179class GroupExpiry:
180 group: str
181 expiresAt: datetime | None = None
183 def __post_init__(self):
184 """Parses any date parameters into aware Python datetime objects.
186 For @dataclasses with a generated __init__ function, this function is called automatically.
188 https://docs.python.org/3.12/library/datetime.html#datetime.datetime.fromisoformat
189 """
190 if self.expiresAt:
191 # per the spec, we expect expiresAt to be in ISO format UTC
192 # make the resulting datetime "aware" by replacing tzinfo explicitly
193 self.expiresAt = datetime.fromisoformat(self.expiresAt).replace(tzinfo=timezone.utc)
194 else:
195 self.expiresAt = None
198class EnrollmentClient(Client):
200 def __init__(self, api_url, authorization_header_value, private_key, client_certificate, ca_certificate):
201 super().__init__(private_key, client_certificate, ca_certificate)
202 self.api_url = api_url.strip("/")
203 self.authorization_header_value = authorization_header_value
205 def _format_expiry(self, expiry: datetime) -> str:
206 """Formats an expiry datetime into a string suitable for using in an API request body."""
207 if not isinstance(expiry, datetime): 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true
208 raise TypeError("expiry must be a Python datetime instance")
209 # determine if expiry is an "aware" or "naive" datetime instance
210 # https://docs.python.org/3/library/datetime.html#determining-if-an-object-is-aware-or-naive
211 if expiry.tzinfo is not None and expiry.tzinfo.utcoffset(expiry) is not None:
212 # expiry is an "aware" datetime instance, meaning it has associated time zone information
213 # ensure this datetime instance is expressed in UTC
214 expiry = expiry.astimezone(timezone.utc)
215 else:
216 # expiry is a "naive" datetime instance, meaning it has no associated time zone information
217 # assume this datetime instance was provided in UTC
218 expiry = expiry.replace(tzinfo=timezone.utc)
219 # now expiry is an "aware" datetime instance in UTC format
220 # datetime.isoformat() adds the UTC offset like +00:00
221 # so keep everything but the last 6 characters and add the Z offset character
222 return f"{expiry.isoformat(timespec='seconds')[:-6]}Z"
224 def _get_headers(self):
225 return {"Authorization": self.authorization_header_value}
227 def healthcheck(self, timeout=5):
228 request_path = "/api/external/discount/echo"
230 response = self._cert_request(
231 lambda verify, cert: requests.get(
232 self.api_url.strip("/") + request_path,
233 headers=self._get_headers(),
234 cert=cert,
235 verify=verify,
236 timeout=timeout,
237 )
238 )
240 response.raise_for_status()
242 return response.text
244 def get_groups(self, pto_id, timeout=5):
245 request_path = f"/api/external/discount/{pto_id}/groups"
247 response = self._cert_request(
248 lambda verify, cert: requests.get(
249 self.api_url + request_path,
250 headers=self._get_headers(),
251 cert=cert,
252 verify=verify,
253 timeout=timeout,
254 )
255 )
257 response.raise_for_status()
259 return [Group(**discount_group) for discount_group in response.json()]
261 def get_groups_for_token(self, pto_id, token, timeout=5):
262 request_path = f"/api/external/discount/{pto_id}/token/{token}"
264 response = self._cert_request(
265 lambda verify, cert: requests.get(
266 self.api_url + request_path,
267 headers=self._get_headers(),
268 cert=cert,
269 verify=verify,
270 timeout=timeout,
271 )
272 )
274 response.raise_for_status()
276 return [GroupExpiry(**group_expiry) for group_expiry in response.json()]
278 def add_group_to_token(self, pto_id, group_id, token, expiry: datetime = None, timeout=5):
279 request_path = f"/api/external/discount/{pto_id}/token/{token}/add"
281 request_body = {"group": group_id}
282 if expiry:
283 request_body["expiresAt"] = self._format_expiry(expiry)
285 response = self._cert_request(
286 lambda verify, cert: requests.post(
287 self.api_url + request_path,
288 json=request_body,
289 headers=self._get_headers(),
290 cert=cert,
291 verify=verify,
292 timeout=timeout,
293 )
294 )
296 response.raise_for_status()
298 return response.text
300 def remove_group_from_token(self, pto_id, group_id, token, timeout=5):
301 request_path = f"/api/external/discount/{pto_id}/token/{token}/remove"
303 request_body = {"group": group_id}
305 response = self._cert_request(
306 lambda verify, cert: requests.post(
307 self.api_url + request_path,
308 json=request_body,
309 headers=self._get_headers(),
310 cert=cert,
311 verify=verify,
312 timeout=timeout,
313 )
314 )
316 response.raise_for_status()
318 return response.text