Coverage for eligibility_api/tokens.py: 22%

69 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2024-08-08 23:46 +0000

1import datetime 

2import json 

3import logging 

4import uuid 

5 

6from jwcrypto import jwe, jws, jwt, jwk 

7from typing import Iterable 

8import requests 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13def _create_jwk(pem_data): 

14 if isinstance(pem_data, str): 

15 pem_data = bytes(pem_data, "utf-8") 

16 

17 return jwk.JWK.from_pem(pem_data) 

18 

19 

20class TokenError(Exception): 

21 """Error with API request/response token.""" 

22 

23 pass 

24 

25 

26class RequestToken: 

27 """Eligibility Verification API request token.""" 

28 

29 def __init__( 

30 self, 

31 types: Iterable[str], 

32 agency: str, 

33 jws_signing_alg: str, 

34 client_private_key, 

35 jwe_encryption_alg: str, 

36 jwe_cek_enc: str, 

37 server_public_key, 

38 sub: str, 

39 name: str, 

40 issuer: str, 

41 ): 

42 logger.info("Initialize new request token") 

43 

44 self.client_private_jwk = _create_jwk(client_private_key) 

45 self.server_public_jwk = _create_jwk(server_public_key) 

46 

47 # craft the main token payload 

48 payload = dict( 

49 jti=str(uuid.uuid4()), 

50 iss=issuer, 

51 iat=int(datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).timestamp()), 

52 agency=agency, 

53 eligibility=types, 

54 sub=sub, 

55 name=name, 

56 ) 

57 

58 logger.debug("Sign token payload with agency's private key") 

59 header = {"typ": "JWS", "alg": jws_signing_alg} 

60 signed_token = jwt.JWT(header=header, claims=payload) 

61 signed_token.make_signed_token(self.client_private_jwk) 

62 signed_payload = signed_token.serialize() 

63 

64 logger.debug("Encrypt signed token payload with verifier's public key") 

65 header = { 

66 "typ": "JWE", 

67 "alg": jwe_encryption_alg, 

68 "enc": jwe_cek_enc, 

69 } 

70 encrypted_token = jwt.JWT(header=header, claims=signed_payload) 

71 encrypted_token.make_encrypted_token(self.server_public_jwk) 

72 

73 logger.info("Signed and encrypted request token initialized") 

74 self._jwe = encrypted_token 

75 

76 def __repr__(self): 

77 return str(self) 

78 

79 def __str__(self): 

80 return self._jwe.serialize() 

81 

82 

83class ResponseToken: 

84 """Eligibility Verification API response token.""" 

85 

86 def __init__( 

87 self, 

88 response: requests.models.Response, 

89 jwe_encryption_alg: str, 

90 jwe_cek_enc: str, 

91 client_private_key, 

92 jws_signing_alg: str, 

93 server_public_key, 

94 ): 

95 logger.info("Read encrypted token from response") 

96 

97 self.client_private_jwk = _create_jwk(client_private_key) 

98 self.server_public_jwk = _create_jwk(server_public_key) 

99 

100 try: 

101 encrypted_signed_token = response.text 

102 if not encrypted_signed_token: 

103 raise ValueError() 

104 # strip extra spaces and wrapping quote chars 

105 encrypted_signed_token = encrypted_signed_token.strip("'\n\"") 

106 except ValueError: 

107 raise TokenError("Invalid response format") 

108 

109 logger.debug("Decrypt response token using agency's private key") 

110 allowed_algs = [jwe_encryption_alg, jwe_cek_enc] 

111 decrypted_token = jwe.JWE(algs=allowed_algs) 

112 try: 

113 decrypted_token.deserialize(encrypted_signed_token, key=self.client_private_jwk) 

114 except jwe.InvalidJWEData: 

115 raise TokenError("Invalid JWE token") 

116 except jwe.InvalidJWEOperation: 

117 raise TokenError("JWE token decryption failed") 

118 

119 decrypted_payload = str(decrypted_token.payload, "utf-8") 

120 

121 logger.debug("Verify decrypted response token's signature using verifier's public key") 

122 signed_token = jws.JWS() 

123 try: 

124 signed_token.deserialize(decrypted_payload, key=self.server_public_jwk, alg=jws_signing_alg) 

125 except jws.InvalidJWSObject: 

126 raise TokenError("Invalid JWS token") 

127 except jws.InvalidJWSSignature: 

128 raise TokenError("JWS token signature verification failed") 

129 

130 logger.info("Response token decrypted and signature verified") 

131 

132 payload = json.loads(str(signed_token.payload, "utf-8")) 

133 self.eligibility = list(payload.get("eligibility", [])) 

134 self.error = payload.get("error", None)