Coverage for benefits/enrollment_switchio/api.py: 99%

132 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-10-10 16:52 +0000

1from dataclasses import dataclass 

2from datetime import datetime, timezone 

3from enum import Enum 

4import hashlib 

5import hmac 

6import json 

7from tempfile import NamedTemporaryFile 

8import requests 

9 

10 

11@dataclass 

12class Registration: 

13 regId: str 

14 gtwUrl: str 

15 

16 

17class RegistrationMode(Enum): 

18 REGISTER = "register" 

19 IDENTIFY = "identify" 

20 

21 

22class EshopResponseMode(Enum): 

23 FRAGMENT = "fragment" 

24 QUERY = "query" 

25 FORM_POST = "form_post" 

26 POST_MESSAGE = "post_message" 

27 

28 

29@dataclass 

30class RegistrationStatus: 

31 regState: str 

32 created: datetime 

33 mode: str 

34 tokens: list[dict] 

35 eshopResponseMode: str 

36 identType: str = None 

37 maskCln: str = None 

38 cardExp: str = None 

39 

40 

41class Client: 

42 def __init__( 

43 self, 

44 private_key, 

45 client_certificate, 

46 ca_certificate, 

47 ): 

48 self.private_key = private_key 

49 self.client_certificate = client_certificate 

50 self.ca_certificate = ca_certificate 

51 

52 # see https://github.com/cal-itp/benefits/issues/2848 for more context about this 

53 def _cert_request(self, request_func): 

54 """ 

55 Creates named (on-disk) temp files for client cert auth. 

56 * request_func: curried callable from `requests` library (e.g. `requests.get`). 

57 """ 

58 # requests library reads temp files from file path 

59 # The "with" context destroys temp files when response comes back 

60 with NamedTemporaryFile("w+") as cert, NamedTemporaryFile("w+") as key, NamedTemporaryFile("w+") as ca: 

61 # write client cert data to temp files 

62 # resetting so they can be read again by requests 

63 cert.write(self.client_certificate) 

64 cert.seek(0) 

65 

66 key.write(self.private_key) 

67 key.seek(0) 

68 

69 ca.write(self.ca_certificate) 

70 ca.seek(0) 

71 

72 # request using temp file paths 

73 return request_func(verify=ca.name, cert=(cert.name, key.name)) 

74 

75 

76class TokenizationClient(Client): 

77 

78 def __init__( 

79 self, 

80 api_url, 

81 api_key, 

82 api_secret, 

83 private_key, 

84 client_certificate, 

85 ca_certificate, 

86 ): 

87 super().__init__(private_key, client_certificate, ca_certificate) 

88 self.api_url = api_url.strip("/") 

89 self.api_key = api_key 

90 self.api_secret = api_secret 

91 

92 def _signature_input_string(self, timestamp: str, method: str, request_path: str, body: str = None): 

93 if body is None: 

94 body = "" 

95 

96 return f"{timestamp}{method}{request_path}{body}" 

97 

98 def _stp_signature(self, timestamp: str, method: str, request_path, body: str = None): 

99 input_string = self._signature_input_string(timestamp, method, request_path, body) 

100 

101 # must encode inputs for hashing, according to https://stackoverflow.com/a/66958131 

102 byte_key = self.api_secret.encode("utf-8") 

103 message = input_string.encode("utf-8") 

104 stp_signature = hmac.new(byte_key, message, hashlib.sha256).hexdigest() 

105 

106 return stp_signature 

107 

108 def _get_headers(self, method, request_path, request_body: dict = None): 

109 timestamp = str(int(datetime.now().timestamp())) 

110 

111 return { 

112 "STP-APIKEY": self.api_key, 

113 "STP-TIMESTAMP": timestamp, 

114 "STP-SIGNATURE": self._stp_signature( 

115 timestamp=timestamp, 

116 method=method, 

117 request_path=request_path, 

118 body=json.dumps(request_body) if request_body else None, 

119 ), 

120 } 

121 

122 def request_registration( 

123 self, 

124 eshopRedirectUrl: str, 

125 mode: RegistrationMode, 

126 eshopResponseMode: EshopResponseMode, 

127 timeout=5, 

128 ) -> Registration: 

129 registration_path = "/api/v1/registration" 

130 request_body = { 

131 "eshopRedirectUrl": eshopRedirectUrl, 

132 "mode": mode.value, 

133 "eshopResponseMode": eshopResponseMode.value, 

134 } 

135 

136 response = self._cert_request( 

137 lambda verify, cert: requests.post( 

138 self.api_url + registration_path, 

139 json=request_body, 

140 headers=self._get_headers(method="POST", request_path=registration_path, request_body=request_body), 

141 cert=cert, 

142 verify=verify, 

143 timeout=timeout, 

144 ) 

145 ) 

146 

147 response.raise_for_status() 

148 

149 return Registration(**response.json()) 

150 

151 def get_registration_status(self, registration_id, timeout=5) -> RegistrationStatus: 

152 request_path = f"/api/v1/registration/{registration_id}" 

153 

154 response = self._cert_request( 

155 lambda verify, cert: requests.get( 

156 self.api_url + request_path, 

157 headers=self._get_headers(method="GET", request_path=request_path), 

158 cert=cert, 

159 verify=verify, 

160 timeout=timeout, 

161 ) 

162 ) 

163 

164 response.raise_for_status() 

165 

166 return RegistrationStatus(**response.json()) 

167 

168 

169@dataclass 

170class Group: 

171 id: int 

172 operatorId: int 

173 name: str 

174 code: str 

175 value: int 

176 

177 

178@dataclass 

179class GroupExpiry: 

180 group: str 

181 expiresAt: datetime | None = None 

182 

183 def __post_init__(self): 

184 """Parses any date parameters into aware Python datetime objects. 

185 

186 For @dataclasses with a generated __init__ function, this function is called automatically. 

187 

188 https://docs.python.org/3.12/library/datetime.html#datetime.datetime.fromisoformat 

189 """ 

190 if self.expiresAt: 

191 # per the spec, we expect expiresAt to be in ISO format UTC 

192 # make the resulting datetime "aware" by replacing tzinfo explicitly 

193 self.expiresAt = datetime.fromisoformat(self.expiresAt).replace(tzinfo=timezone.utc) 

194 else: 

195 self.expiresAt = None 

196 

197 

198class EnrollmentClient(Client): 

199 

200 def __init__(self, api_url, authorization_header_value, private_key, client_certificate, ca_certificate): 

201 super().__init__(private_key, client_certificate, ca_certificate) 

202 self.api_url = api_url.strip("/") 

203 self.authorization_header_value = authorization_header_value 

204 

205 def _format_expiry(self, expiry: datetime) -> str: 

206 """Formats an expiry datetime into a string suitable for using in an API request body.""" 

207 if not isinstance(expiry, datetime): 207 ↛ 208line 207 didn't jump to line 208 because the condition on line 207 was never true

208 raise TypeError("expiry must be a Python datetime instance") 

209 # determine if expiry is an "aware" or "naive" datetime instance 

210 # https://docs.python.org/3/library/datetime.html#determining-if-an-object-is-aware-or-naive 

211 if expiry.tzinfo is not None and expiry.tzinfo.utcoffset(expiry) is not None: 

212 # expiry is an "aware" datetime instance, meaning it has associated time zone information 

213 # ensure this datetime instance is expressed in UTC 

214 expiry = expiry.astimezone(timezone.utc) 

215 else: 

216 # expiry is a "naive" datetime instance, meaning it has no associated time zone information 

217 # assume this datetime instance was provided in UTC 

218 expiry = expiry.replace(tzinfo=timezone.utc) 

219 # now expiry is an "aware" datetime instance in UTC format 

220 # datetime.isoformat() adds the UTC offset like +00:00 

221 # so keep everything but the last 6 characters and add the Z offset character 

222 return f"{expiry.isoformat(timespec='seconds')[:-6]}Z" 

223 

224 def _get_headers(self): 

225 return {"Authorization": self.authorization_header_value} 

226 

227 def healthcheck(self, timeout=5): 

228 request_path = "/api/external/discount/echo" 

229 

230 response = self._cert_request( 

231 lambda verify, cert: requests.get( 

232 self.api_url.strip("/") + request_path, 

233 headers=self._get_headers(), 

234 cert=cert, 

235 verify=verify, 

236 timeout=timeout, 

237 ) 

238 ) 

239 

240 response.raise_for_status() 

241 

242 return response.text 

243 

244 def get_groups(self, pto_id, timeout=5): 

245 request_path = f"/api/external/discount/{pto_id}/groups" 

246 

247 response = self._cert_request( 

248 lambda verify, cert: requests.get( 

249 self.api_url + request_path, 

250 headers=self._get_headers(), 

251 cert=cert, 

252 verify=verify, 

253 timeout=timeout, 

254 ) 

255 ) 

256 

257 response.raise_for_status() 

258 

259 return [Group(**discount_group) for discount_group in response.json()] 

260 

261 def get_groups_for_token(self, pto_id, token, timeout=5): 

262 request_path = f"/api/external/discount/{pto_id}/token/{token}" 

263 

264 response = self._cert_request( 

265 lambda verify, cert: requests.get( 

266 self.api_url + request_path, 

267 headers=self._get_headers(), 

268 cert=cert, 

269 verify=verify, 

270 timeout=timeout, 

271 ) 

272 ) 

273 

274 response.raise_for_status() 

275 

276 return [GroupExpiry(**group_expiry) for group_expiry in response.json()] 

277 

278 def add_group_to_token(self, pto_id, group_id, token, expiry: datetime = None, timeout=5): 

279 request_path = f"/api/external/discount/{pto_id}/token/{token}/add" 

280 

281 request_body = {"group": group_id} 

282 if expiry: 

283 request_body["expiresAt"] = self._format_expiry(expiry) 

284 

285 response = self._cert_request( 

286 lambda verify, cert: requests.post( 

287 self.api_url + request_path, 

288 json=request_body, 

289 headers=self._get_headers(), 

290 cert=cert, 

291 verify=verify, 

292 timeout=timeout, 

293 ) 

294 ) 

295 

296 response.raise_for_status() 

297 

298 return response.text 

299 

300 def remove_group_from_token(self, pto_id, group_id, token, timeout=5): 

301 request_path = f"/api/external/discount/{pto_id}/token/{token}/remove" 

302 

303 request_body = {"group": group_id} 

304 

305 response = self._cert_request( 

306 lambda verify, cert: requests.post( 

307 self.api_url + request_path, 

308 json=request_body, 

309 headers=self._get_headers(), 

310 cert=cert, 

311 verify=verify, 

312 timeout=timeout, 

313 ) 

314 ) 

315 

316 response.raise_for_status() 

317 

318 return response.text