Coverage for benefits / enrollment_switchio / api.py: 99%
132 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 19:35 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 19:35 +0000
1import hashlib
2import hmac
3import json
4from dataclasses import dataclass
5from datetime import datetime, timezone
6from enum import Enum
7from tempfile import NamedTemporaryFile
9import requests
12@dataclass
13class Registration:
14 regId: str
15 gtwUrl: str
18class RegistrationMode(Enum):
19 REGISTER = "register"
20 IDENTIFY = "identify"
23class EshopResponseMode(Enum):
24 FRAGMENT = "fragment"
25 QUERY = "query"
26 FORM_POST = "form_post"
27 POST_MESSAGE = "post_message"
30@dataclass
31class RegistrationStatus:
32 regState: str
33 created: datetime
34 mode: str
35 tokens: list[dict]
36 eshopResponseMode: str
37 identType: str = None
38 maskCln: str = None
39 cardExp: str = None
42class Client:
43 def __init__(
44 self,
45 private_key,
46 client_certificate,
47 ca_certificate,
48 ):
49 self.private_key = private_key
50 self.client_certificate = client_certificate
51 self.ca_certificate = ca_certificate
53 # see https://github.com/cal-itp/benefits/issues/2848 for more context about this
54 def _cert_request(self, request_func):
55 """
56 Creates named (on-disk) temp files for client cert auth.
57 * request_func: curried callable from `requests` library (e.g. `requests.get`).
58 """
59 # requests library reads temp files from file path
60 # The "with" context destroys temp files when response comes back
61 with NamedTemporaryFile("w+") as cert, NamedTemporaryFile("w+") as key, NamedTemporaryFile("w+") as ca:
62 # write client cert data to temp files
63 # resetting so they can be read again by requests
64 cert.write(self.client_certificate)
65 cert.seek(0)
67 key.write(self.private_key)
68 key.seek(0)
70 ca.write(self.ca_certificate)
71 ca.seek(0)
73 # request using temp file paths
74 return request_func(verify=ca.name, cert=(cert.name, key.name))
77class TokenizationClient(Client):
79 def __init__(
80 self,
81 api_url,
82 api_key,
83 api_secret,
84 private_key,
85 client_certificate,
86 ca_certificate,
87 ):
88 super().__init__(private_key, client_certificate, ca_certificate)
89 self.api_url = api_url.strip("/")
90 self.api_key = api_key
91 self.api_secret = api_secret
93 def _signature_input_string(self, timestamp: str, method: str, request_path: str, body: str = None):
94 if body is None:
95 body = ""
97 return f"{timestamp}{method}{request_path}{body}"
99 def _stp_signature(self, timestamp: str, method: str, request_path, body: str = None):
100 input_string = self._signature_input_string(timestamp, method, request_path, body)
102 # must encode inputs for hashing, according to https://stackoverflow.com/a/66958131
103 byte_key = self.api_secret.encode("utf-8")
104 message = input_string.encode("utf-8")
105 stp_signature = hmac.new(byte_key, message, hashlib.sha256).hexdigest()
107 return stp_signature
109 def _get_headers(self, method, request_path, request_body: dict = None):
110 timestamp = str(int(datetime.now().timestamp()))
112 return {
113 "STP-APIKEY": self.api_key,
114 "STP-TIMESTAMP": timestamp,
115 "STP-SIGNATURE": self._stp_signature(
116 timestamp=timestamp,
117 method=method,
118 request_path=request_path,
119 body=json.dumps(request_body) if request_body else None,
120 ),
121 }
123 def request_registration(
124 self,
125 eshopRedirectUrl: str,
126 mode: RegistrationMode,
127 eshopResponseMode: EshopResponseMode,
128 timeout=5,
129 ) -> Registration:
130 registration_path = "/api/v1/registration"
131 request_body = {
132 "eshopRedirectUrl": eshopRedirectUrl,
133 "mode": mode.value,
134 "eshopResponseMode": eshopResponseMode.value,
135 }
137 response = self._cert_request(
138 lambda verify, cert: requests.post(
139 self.api_url + registration_path,
140 json=request_body,
141 headers=self._get_headers(method="POST", request_path=registration_path, request_body=request_body),
142 cert=cert,
143 verify=verify,
144 timeout=timeout,
145 )
146 )
148 response.raise_for_status()
150 return Registration(**response.json())
152 def get_registration_status(self, registration_id, timeout=5) -> RegistrationStatus:
153 request_path = f"/api/v1/registration/{registration_id}"
155 response = self._cert_request(
156 lambda verify, cert: requests.get(
157 self.api_url + request_path,
158 headers=self._get_headers(method="GET", request_path=request_path),
159 cert=cert,
160 verify=verify,
161 timeout=timeout,
162 )
163 )
165 response.raise_for_status()
167 return RegistrationStatus(**response.json())
170@dataclass
171class Group:
172 id: int
173 operatorId: int
174 name: str
175 code: str
176 value: int
179@dataclass
180class GroupExpiry:
181 group: str
182 expiresAt: datetime | None = None
184 def __post_init__(self):
185 """Parses any date parameters into aware Python datetime objects.
187 For @dataclasses with a generated __init__ function, this function is called automatically.
189 https://docs.python.org/3.12/library/datetime.html#datetime.datetime.fromisoformat
190 """
191 if self.expiresAt:
192 # per the spec, we expect expiresAt to be in ISO format UTC
193 # make the resulting datetime "aware" by replacing tzinfo explicitly
194 self.expiresAt = datetime.fromisoformat(self.expiresAt).replace(tzinfo=timezone.utc)
195 else:
196 self.expiresAt = None
199class EnrollmentClient(Client):
201 def __init__(self, api_url, authorization_header_value, private_key, client_certificate, ca_certificate):
202 super().__init__(private_key, client_certificate, ca_certificate)
203 self.api_url = api_url.strip("/")
204 self.authorization_header_value = authorization_header_value
206 def _format_expiry(self, expiry: datetime) -> str:
207 """Formats an expiry datetime into a string suitable for using in an API request body."""
208 if not isinstance(expiry, datetime): 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true
209 raise TypeError("expiry must be a Python datetime instance")
210 # determine if expiry is an "aware" or "naive" datetime instance
211 # https://docs.python.org/3/library/datetime.html#determining-if-an-object-is-aware-or-naive
212 if expiry.tzinfo is not None and expiry.tzinfo.utcoffset(expiry) is not None:
213 # expiry is an "aware" datetime instance, meaning it has associated time zone information
214 # ensure this datetime instance is expressed in UTC
215 expiry = expiry.astimezone(timezone.utc)
216 else:
217 # expiry is a "naive" datetime instance, meaning it has no associated time zone information
218 # assume this datetime instance was provided in UTC
219 expiry = expiry.replace(tzinfo=timezone.utc)
220 # now expiry is an "aware" datetime instance in UTC format
221 # datetime.isoformat() adds the UTC offset like +00:00
222 # so keep everything but the last 6 characters and add the Z offset character
223 return f"{expiry.isoformat(timespec='seconds')[:-6]}Z"
225 def _get_headers(self):
226 return {"Authorization": self.authorization_header_value}
228 def healthcheck(self, timeout=5):
229 request_path = "/api/external/discount/echo"
231 response = self._cert_request(
232 lambda verify, cert: requests.get(
233 self.api_url.strip("/") + request_path,
234 headers=self._get_headers(),
235 cert=cert,
236 verify=verify,
237 timeout=timeout,
238 )
239 )
241 response.raise_for_status()
243 return response.text
245 def get_groups(self, pto_id, timeout=5):
246 request_path = f"/api/external/discount/{pto_id}/groups"
248 response = self._cert_request(
249 lambda verify, cert: requests.get(
250 self.api_url + request_path,
251 headers=self._get_headers(),
252 cert=cert,
253 verify=verify,
254 timeout=timeout,
255 )
256 )
258 response.raise_for_status()
260 return [Group(**discount_group) for discount_group in response.json()]
262 def get_groups_for_token(self, pto_id, token, timeout=5):
263 request_path = f"/api/external/discount/{pto_id}/token/{token}"
265 response = self._cert_request(
266 lambda verify, cert: requests.get(
267 self.api_url + request_path,
268 headers=self._get_headers(),
269 cert=cert,
270 verify=verify,
271 timeout=timeout,
272 )
273 )
275 response.raise_for_status()
277 return [GroupExpiry(**group_expiry) for group_expiry in response.json()]
279 def add_group_to_token(self, pto_id, group_id, token, expiry: datetime = None, timeout=5):
280 request_path = f"/api/external/discount/{pto_id}/token/{token}/add"
282 request_body = {"group": group_id}
283 if expiry:
284 request_body["expiresAt"] = self._format_expiry(expiry)
286 response = self._cert_request(
287 lambda verify, cert: requests.post(
288 self.api_url + request_path,
289 json=request_body,
290 headers=self._get_headers(),
291 cert=cert,
292 verify=verify,
293 timeout=timeout,
294 )
295 )
297 response.raise_for_status()
299 return response.text
301 def remove_group_from_token(self, pto_id, group_id, token, timeout=5):
302 request_path = f"/api/external/discount/{pto_id}/token/{token}/remove"
304 request_body = {"group": group_id}
306 response = self._cert_request(
307 lambda verify, cert: requests.post(
308 self.api_url + request_path,
309 json=request_body,
310 headers=self._get_headers(),
311 cert=cert,
312 verify=verify,
313 timeout=timeout,
314 )
315 )
317 response.raise_for_status()
319 return response.text