Coverage for benefits/enrollment_switchio/api.py: 100%

119 statements  

« prev     ^ index     » next       coverage.py v7.10.2, created at 2025-08-08 16:26 +0000

1from dataclasses import dataclass 

2from datetime import datetime 

3from enum import Enum 

4import hashlib 

5import hmac 

6import json 

7from tempfile import NamedTemporaryFile 

8import requests 

9 

10 

11@dataclass 

12class Registration: 

13 regId: str 

14 gtwUrl: str 

15 

16 

17class RegistrationMode(Enum): 

18 REGISTER = "register" 

19 IDENTIFY = "identify" 

20 

21 

22class EshopResponseMode(Enum): 

23 FRAGMENT = "fragment" 

24 QUERY = "query" 

25 FORM_POST = "form_post" 

26 POST_MESSAGE = "post_message" 

27 

28 

29@dataclass 

30class RegistrationStatus: 

31 regState: str 

32 created: datetime 

33 mode: str 

34 tokens: list[dict] 

35 eshopResponseMode: str 

36 identType: str = None 

37 maskCln: str = None 

38 cardExp: str = None 

39 

40 

41class Client: 

42 def __init__( 

43 self, 

44 private_key, 

45 client_certificate, 

46 ca_certificate, 

47 ): 

48 self.private_key = private_key 

49 self.client_certificate = client_certificate 

50 self.ca_certificate = ca_certificate 

51 

52 # see https://github.com/cal-itp/benefits/issues/2848 for more context about this 

53 def _cert_request(self, request_func): 

54 """ 

55 Creates named (on-disk) temp files for client cert auth. 

56 * request_func: curried callable from `requests` library (e.g. `requests.get`). 

57 """ 

58 # requests library reads temp files from file path 

59 # The "with" context destroys temp files when response comes back 

60 with NamedTemporaryFile("w+") as cert, NamedTemporaryFile("w+") as key, NamedTemporaryFile("w+") as ca: 

61 # write client cert data to temp files 

62 # resetting so they can be read again by requests 

63 cert.write(self.client_certificate) 

64 cert.seek(0) 

65 

66 key.write(self.private_key) 

67 key.seek(0) 

68 

69 ca.write(self.ca_certificate) 

70 ca.seek(0) 

71 

72 # request using temp file paths 

73 return request_func(verify=ca.name, cert=(cert.name, key.name)) 

74 

75 

76class TokenizationClient(Client): 

77 

78 def __init__( 

79 self, 

80 api_url, 

81 api_key, 

82 api_secret, 

83 private_key, 

84 client_certificate, 

85 ca_certificate, 

86 ): 

87 super().__init__(private_key, client_certificate, ca_certificate) 

88 self.api_url = api_url.strip("/") 

89 self.api_key = api_key 

90 self.api_secret = api_secret 

91 

92 def _signature_input_string(self, timestamp: str, method: str, request_path: str, body: str = None): 

93 if body is None: 

94 body = "" 

95 

96 return f"{timestamp}{method}{request_path}{body}" 

97 

98 def _stp_signature(self, timestamp: str, method: str, request_path, body: str = None): 

99 input_string = self._signature_input_string(timestamp, method, request_path, body) 

100 

101 # must encode inputs for hashing, according to https://stackoverflow.com/a/66958131 

102 byte_key = self.api_secret.encode("utf-8") 

103 message = input_string.encode("utf-8") 

104 stp_signature = hmac.new(byte_key, message, hashlib.sha256).hexdigest() 

105 

106 return stp_signature 

107 

108 def _get_headers(self, method, request_path, request_body: dict = None): 

109 timestamp = str(int(datetime.now().timestamp())) 

110 

111 return { 

112 "STP-APIKEY": self.api_key, 

113 "STP-TIMESTAMP": timestamp, 

114 "STP-SIGNATURE": self._stp_signature( 

115 timestamp=timestamp, 

116 method=method, 

117 request_path=request_path, 

118 body=json.dumps(request_body) if request_body else None, 

119 ), 

120 } 

121 

122 def request_registration( 

123 self, 

124 eshopRedirectUrl: str, 

125 mode: RegistrationMode, 

126 eshopResponseMode: EshopResponseMode, 

127 timeout=5, 

128 ) -> Registration: 

129 registration_path = "/api/v1/registration" 

130 request_body = { 

131 "eshopRedirectUrl": eshopRedirectUrl, 

132 "mode": mode.value, 

133 "eshopResponseMode": eshopResponseMode.value, 

134 } 

135 

136 response = self._cert_request( 

137 lambda verify, cert: requests.post( 

138 self.api_url + registration_path, 

139 json=request_body, 

140 headers=self._get_headers(method="POST", request_path=registration_path, request_body=request_body), 

141 cert=cert, 

142 verify=verify, 

143 timeout=timeout, 

144 ) 

145 ) 

146 

147 response.raise_for_status() 

148 

149 return Registration(**response.json()) 

150 

151 def get_registration_status(self, registration_id, timeout=5) -> RegistrationStatus: 

152 request_path = f"/api/v1/registration/{registration_id}" 

153 

154 response = self._cert_request( 

155 lambda verify, cert: requests.get( 

156 self.api_url + request_path, 

157 headers=self._get_headers(method="GET", request_path=request_path), 

158 cert=cert, 

159 verify=verify, 

160 timeout=timeout, 

161 ) 

162 ) 

163 

164 response.raise_for_status() 

165 

166 return RegistrationStatus(**response.json()) 

167 

168 

169@dataclass 

170class Group: 

171 id: int 

172 operatorId: int 

173 name: str 

174 code: str 

175 value: int 

176 

177 

178@dataclass 

179class GroupExpiry: 

180 group: str 

181 expiresAt: datetime 

182 

183 

184class EnrollmentClient(Client): 

185 

186 def __init__(self, api_url, authorization_header_value, private_key, client_certificate, ca_certificate): 

187 super().__init__(private_key, client_certificate, ca_certificate) 

188 self.api_url = api_url.strip("/") 

189 self.authorization_header_value = authorization_header_value 

190 

191 def _get_headers(self): 

192 return {"Authorization": self.authorization_header_value} 

193 

194 def healthcheck(self, timeout=5): 

195 request_path = "/api/external/discount/echo" 

196 

197 response = self._cert_request( 

198 lambda verify, cert: requests.get( 

199 self.api_url.strip("/") + request_path, 

200 headers=self._get_headers(), 

201 cert=cert, 

202 verify=verify, 

203 timeout=timeout, 

204 ) 

205 ) 

206 

207 response.raise_for_status() 

208 

209 return response.text 

210 

211 def get_groups(self, pto_id, timeout=5): 

212 request_path = f"/api/external/discount/{pto_id}/groups" 

213 

214 response = self._cert_request( 

215 lambda verify, cert: requests.get( 

216 self.api_url + request_path, 

217 headers=self._get_headers(), 

218 cert=cert, 

219 verify=verify, 

220 timeout=timeout, 

221 ) 

222 ) 

223 

224 response.raise_for_status() 

225 

226 return [Group(**discount_group) for discount_group in response.json()] 

227 

228 def get_groups_for_token(self, pto_id, token, timeout=5): 

229 request_path = f"/api/external/discount/{pto_id}/token/{token}" 

230 

231 response = self._cert_request( 

232 lambda verify, cert: requests.get( 

233 self.api_url + request_path, 

234 headers=self._get_headers(), 

235 cert=cert, 

236 verify=verify, 

237 timeout=timeout, 

238 ) 

239 ) 

240 

241 response.raise_for_status() 

242 

243 return [GroupExpiry(**group_expiry) for group_expiry in response.json()] 

244 

245 def add_group_to_token(self, pto_id, group_id, token, timeout=5): 

246 request_path = f"/api/external/discount/{pto_id}/token/{token}/add" 

247 

248 request_body = {"group": group_id} 

249 

250 response = self._cert_request( 

251 lambda verify, cert: requests.post( 

252 self.api_url + request_path, 

253 json=request_body, 

254 headers=self._get_headers(), 

255 cert=cert, 

256 verify=verify, 

257 timeout=timeout, 

258 ) 

259 ) 

260 

261 response.raise_for_status() 

262 

263 return response.text 

264 

265 def remove_group_from_token(self, pto_id, group_id, token, timeout=5): 

266 request_path = f"/api/external/discount/{pto_id}/token/{token}/remove" 

267 

268 request_body = {"group": group_id} 

269 

270 response = self._cert_request( 

271 lambda verify, cert: requests.post( 

272 self.api_url + request_path, 

273 json=request_body, 

274 headers=self._get_headers(), 

275 cert=cert, 

276 verify=verify, 

277 timeout=timeout, 

278 ) 

279 ) 

280 

281 response.raise_for_status() 

282 

283 return response.text