Coverage for benefits / enrollment_switchio / api.py: 99%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.4, created at 2026-02-13 19:35 +0000

1import hashlib 

2import hmac 

3import json 

4from dataclasses import dataclass 

5from datetime import datetime, timezone 

6from enum import Enum 

7from tempfile import NamedTemporaryFile 

8 

9import requests 

10 

11 

12@dataclass 

13class Registration: 

14 regId: str 

15 gtwUrl: str 

16 

17 

18class RegistrationMode(Enum): 

19 REGISTER = "register" 

20 IDENTIFY = "identify" 

21 

22 

23class EshopResponseMode(Enum): 

24 FRAGMENT = "fragment" 

25 QUERY = "query" 

26 FORM_POST = "form_post" 

27 POST_MESSAGE = "post_message" 

28 

29 

30@dataclass 

31class RegistrationStatus: 

32 regState: str 

33 created: datetime 

34 mode: str 

35 tokens: list[dict] 

36 eshopResponseMode: str 

37 identType: str = None 

38 maskCln: str = None 

39 cardExp: str = None 

40 

41 

42class Client: 

43 def __init__( 

44 self, 

45 private_key, 

46 client_certificate, 

47 ca_certificate, 

48 ): 

49 self.private_key = private_key 

50 self.client_certificate = client_certificate 

51 self.ca_certificate = ca_certificate 

52 

53 # see https://github.com/cal-itp/benefits/issues/2848 for more context about this 

54 def _cert_request(self, request_func): 

55 """ 

56 Creates named (on-disk) temp files for client cert auth. 

57 * request_func: curried callable from `requests` library (e.g. `requests.get`). 

58 """ 

59 # requests library reads temp files from file path 

60 # The "with" context destroys temp files when response comes back 

61 with NamedTemporaryFile("w+") as cert, NamedTemporaryFile("w+") as key, NamedTemporaryFile("w+") as ca: 

62 # write client cert data to temp files 

63 # resetting so they can be read again by requests 

64 cert.write(self.client_certificate) 

65 cert.seek(0) 

66 

67 key.write(self.private_key) 

68 key.seek(0) 

69 

70 ca.write(self.ca_certificate) 

71 ca.seek(0) 

72 

73 # request using temp file paths 

74 return request_func(verify=ca.name, cert=(cert.name, key.name)) 

75 

76 

77class TokenizationClient(Client): 

78 

79 def __init__( 

80 self, 

81 api_url, 

82 api_key, 

83 api_secret, 

84 private_key, 

85 client_certificate, 

86 ca_certificate, 

87 ): 

88 super().__init__(private_key, client_certificate, ca_certificate) 

89 self.api_url = api_url.strip("/") 

90 self.api_key = api_key 

91 self.api_secret = api_secret 

92 

93 def _signature_input_string(self, timestamp: str, method: str, request_path: str, body: str = None): 

94 if body is None: 

95 body = "" 

96 

97 return f"{timestamp}{method}{request_path}{body}" 

98 

99 def _stp_signature(self, timestamp: str, method: str, request_path, body: str = None): 

100 input_string = self._signature_input_string(timestamp, method, request_path, body) 

101 

102 # must encode inputs for hashing, according to https://stackoverflow.com/a/66958131 

103 byte_key = self.api_secret.encode("utf-8") 

104 message = input_string.encode("utf-8") 

105 stp_signature = hmac.new(byte_key, message, hashlib.sha256).hexdigest() 

106 

107 return stp_signature 

108 

109 def _get_headers(self, method, request_path, request_body: dict = None): 

110 timestamp = str(int(datetime.now().timestamp())) 

111 

112 return { 

113 "STP-APIKEY": self.api_key, 

114 "STP-TIMESTAMP": timestamp, 

115 "STP-SIGNATURE": self._stp_signature( 

116 timestamp=timestamp, 

117 method=method, 

118 request_path=request_path, 

119 body=json.dumps(request_body) if request_body else None, 

120 ), 

121 } 

122 

123 def request_registration( 

124 self, 

125 eshopRedirectUrl: str, 

126 mode: RegistrationMode, 

127 eshopResponseMode: EshopResponseMode, 

128 timeout=5, 

129 ) -> Registration: 

130 registration_path = "/api/v1/registration" 

131 request_body = { 

132 "eshopRedirectUrl": eshopRedirectUrl, 

133 "mode": mode.value, 

134 "eshopResponseMode": eshopResponseMode.value, 

135 } 

136 

137 response = self._cert_request( 

138 lambda verify, cert: requests.post( 

139 self.api_url + registration_path, 

140 json=request_body, 

141 headers=self._get_headers(method="POST", request_path=registration_path, request_body=request_body), 

142 cert=cert, 

143 verify=verify, 

144 timeout=timeout, 

145 ) 

146 ) 

147 

148 response.raise_for_status() 

149 

150 return Registration(**response.json()) 

151 

152 def get_registration_status(self, registration_id, timeout=5) -> RegistrationStatus: 

153 request_path = f"/api/v1/registration/{registration_id}" 

154 

155 response = self._cert_request( 

156 lambda verify, cert: requests.get( 

157 self.api_url + request_path, 

158 headers=self._get_headers(method="GET", request_path=request_path), 

159 cert=cert, 

160 verify=verify, 

161 timeout=timeout, 

162 ) 

163 ) 

164 

165 response.raise_for_status() 

166 

167 return RegistrationStatus(**response.json()) 

168 

169 

170@dataclass 

171class Group: 

172 id: int 

173 operatorId: int 

174 name: str 

175 code: str 

176 value: int 

177 

178 

179@dataclass 

180class GroupExpiry: 

181 group: str 

182 expiresAt: datetime | None = None 

183 

184 def __post_init__(self): 

185 """Parses any date parameters into aware Python datetime objects. 

186 

187 For @dataclasses with a generated __init__ function, this function is called automatically. 

188 

189 https://docs.python.org/3.12/library/datetime.html#datetime.datetime.fromisoformat 

190 """ 

191 if self.expiresAt: 

192 # per the spec, we expect expiresAt to be in ISO format UTC 

193 # make the resulting datetime "aware" by replacing tzinfo explicitly 

194 self.expiresAt = datetime.fromisoformat(self.expiresAt).replace(tzinfo=timezone.utc) 

195 else: 

196 self.expiresAt = None 

197 

198 

199class EnrollmentClient(Client): 

200 

201 def __init__(self, api_url, authorization_header_value, private_key, client_certificate, ca_certificate): 

202 super().__init__(private_key, client_certificate, ca_certificate) 

203 self.api_url = api_url.strip("/") 

204 self.authorization_header_value = authorization_header_value 

205 

206 def _format_expiry(self, expiry: datetime) -> str: 

207 """Formats an expiry datetime into a string suitable for using in an API request body.""" 

208 if not isinstance(expiry, datetime): 208 ↛ 209line 208 didn't jump to line 209 because the condition on line 208 was never true

209 raise TypeError("expiry must be a Python datetime instance") 

210 # determine if expiry is an "aware" or "naive" datetime instance 

211 # https://docs.python.org/3/library/datetime.html#determining-if-an-object-is-aware-or-naive 

212 if expiry.tzinfo is not None and expiry.tzinfo.utcoffset(expiry) is not None: 

213 # expiry is an "aware" datetime instance, meaning it has associated time zone information 

214 # ensure this datetime instance is expressed in UTC 

215 expiry = expiry.astimezone(timezone.utc) 

216 else: 

217 # expiry is a "naive" datetime instance, meaning it has no associated time zone information 

218 # assume this datetime instance was provided in UTC 

219 expiry = expiry.replace(tzinfo=timezone.utc) 

220 # now expiry is an "aware" datetime instance in UTC format 

221 # datetime.isoformat() adds the UTC offset like +00:00 

222 # so keep everything but the last 6 characters and add the Z offset character 

223 return f"{expiry.isoformat(timespec='seconds')[:-6]}Z" 

224 

225 def _get_headers(self): 

226 return {"Authorization": self.authorization_header_value} 

227 

228 def healthcheck(self, timeout=5): 

229 request_path = "/api/external/discount/echo" 

230 

231 response = self._cert_request( 

232 lambda verify, cert: requests.get( 

233 self.api_url.strip("/") + request_path, 

234 headers=self._get_headers(), 

235 cert=cert, 

236 verify=verify, 

237 timeout=timeout, 

238 ) 

239 ) 

240 

241 response.raise_for_status() 

242 

243 return response.text 

244 

245 def get_groups(self, pto_id, timeout=5): 

246 request_path = f"/api/external/discount/{pto_id}/groups" 

247 

248 response = self._cert_request( 

249 lambda verify, cert: requests.get( 

250 self.api_url + request_path, 

251 headers=self._get_headers(), 

252 cert=cert, 

253 verify=verify, 

254 timeout=timeout, 

255 ) 

256 ) 

257 

258 response.raise_for_status() 

259 

260 return [Group(**discount_group) for discount_group in response.json()] 

261 

262 def get_groups_for_token(self, pto_id, token, timeout=5): 

263 request_path = f"/api/external/discount/{pto_id}/token/{token}" 

264 

265 response = self._cert_request( 

266 lambda verify, cert: requests.get( 

267 self.api_url + request_path, 

268 headers=self._get_headers(), 

269 cert=cert, 

270 verify=verify, 

271 timeout=timeout, 

272 ) 

273 ) 

274 

275 response.raise_for_status() 

276 

277 return [GroupExpiry(**group_expiry) for group_expiry in response.json()] 

278 

279 def add_group_to_token(self, pto_id, group_id, token, expiry: datetime = None, timeout=5): 

280 request_path = f"/api/external/discount/{pto_id}/token/{token}/add" 

281 

282 request_body = {"group": group_id} 

283 if expiry: 

284 request_body["expiresAt"] = self._format_expiry(expiry) 

285 

286 response = self._cert_request( 

287 lambda verify, cert: requests.post( 

288 self.api_url + request_path, 

289 json=request_body, 

290 headers=self._get_headers(), 

291 cert=cert, 

292 verify=verify, 

293 timeout=timeout, 

294 ) 

295 ) 

296 

297 response.raise_for_status() 

298 

299 return response.text 

300 

301 def remove_group_from_token(self, pto_id, group_id, token, timeout=5): 

302 request_path = f"/api/external/discount/{pto_id}/token/{token}/remove" 

303 

304 request_body = {"group": group_id} 

305 

306 response = self._cert_request( 

307 lambda verify, cert: requests.post( 

308 self.api_url + request_path, 

309 json=request_body, 

310 headers=self._get_headers(), 

311 cert=cert, 

312 verify=verify, 

313 timeout=timeout, 

314 ) 

315 ) 

316 

317 response.raise_for_status() 

318 

319 return response.text