Coverage for benefits / core / models / transit.py: 99%
150 statements
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 19:35 +0000
« prev ^ index » next coverage.py v7.13.4, created at 2026-02-13 19:35 +0000
1import logging
2import os
4from django.contrib.auth.models import Group, User
5from django.core.exceptions import NON_FIELD_ERRORS, ValidationError
6from django.db import models
7from django.urls import reverse
8from multiselectfield import MultiSelectField
10from benefits.routes import routes
12from .common import Environment
14logger = logging.getLogger(__name__)
17class CardSchemes:
18 VISA = "visa"
19 MASTERCARD = "mastercard"
20 DISCOVER = "discover"
21 AMEX = "amex"
23 CHOICES = dict(
24 [
25 (VISA, "Visa"),
26 (MASTERCARD, "Mastercard"),
27 (DISCOVER, "Discover"),
28 (AMEX, "American Express"),
29 ]
30 )
33def agency_logo(instance, filename):
34 base, ext = os.path.splitext(filename)
35 return f"agencies/{instance.slug}" + ext
38class TransitProcessorConfig(models.Model):
39 id = models.AutoField(primary_key=True)
40 environment = models.TextField(
41 choices=Environment,
42 help_text="A label to indicate which environment this configuration is for.",
43 )
44 transit_agency = models.OneToOneField(
45 "TransitAgency",
46 on_delete=models.PROTECT,
47 null=True,
48 blank=True,
49 default=None,
50 help_text="The transit agency that uses this configuration.",
51 )
52 portal_url = models.TextField(
53 default="",
54 blank=True,
55 help_text="The absolute base URL for the TransitProcessor's control portal, including https://.",
56 )
58 def __str__(self):
59 environment_label = Environment(self.environment).label if self.environment else "unknown"
60 agency_slug = self.transit_agency.slug if self.transit_agency else "(no agency)"
61 return f"({environment_label}) {agency_slug}"
64class AgencySlug(models.TextChoices):
65 # raw value, display value
66 CST = "cst", "cst"
67 MST = "mst", "mst"
68 EDCTA = "edcta", "edcta"
69 NEVCO = "nevco", "nevco"
70 RABA = "raba", "raba"
71 ROSEVILLE = "roseville", "roseville"
72 SACRT = "sacrt", "sacrt"
73 SBMTD = "sbmtd", "sbmtd"
74 SLORTA = "slorta", "slorta"
75 VCTC = "vctc", "vctc"
78class TransitAgency(models.Model):
79 """An agency offering transit service."""
81 class Meta:
82 verbose_name_plural = "transit agencies"
84 id = models.AutoField(primary_key=True)
85 active = models.BooleanField(default=False, help_text="Determines if this Agency is enabled for users")
86 slug = models.SlugField(
87 choices=AgencySlug,
88 help_text="Used for URL navigation for this agency, e.g. the agency homepage url is /{slug}",
89 )
90 short_name = models.TextField(
91 default="", help_text="The user-facing short name for this agency. Often an uppercase acronym."
92 )
93 long_name = models.TextField(
94 default="",
95 blank=True,
96 help_text="The user-facing long name for this agency. Often the short_name acronym, spelled out.",
97 )
98 info_url = models.URLField(
99 default="",
100 blank=True,
101 help_text="URL of a website/page with more information about the agency's discounts",
102 )
103 phone = models.TextField(default="", blank=True, help_text="Agency customer support phone number")
104 supported_card_schemes = MultiSelectField(
105 choices=CardSchemes.CHOICES,
106 min_choices=1,
107 max_choices=len(CardSchemes.CHOICES),
108 default=[CardSchemes.VISA, CardSchemes.MASTERCARD],
109 help_text="The contactless card schemes this agency supports.",
110 )
111 sso_domain = models.TextField(
112 blank=True,
113 default="",
114 help_text="The email domain of users to automatically add to this agency's staff group upon login.",
115 )
116 customer_service_group = models.OneToOneField(
117 Group,
118 on_delete=models.PROTECT,
119 null=True,
120 blank=True,
121 default=None,
122 help_text="The group of users who are allowed to do in-person eligibility verification and enrollment.",
123 related_name="transit_agency",
124 )
125 logo = models.ImageField(
126 default="",
127 blank=True,
128 upload_to=agency_logo,
129 help_text="The transit agency's logo.",
130 )
132 def __str__(self):
133 return self.long_name
135 @property
136 def index_url(self):
137 """Public-facing URL to the TransitAgency's landing page."""
138 return reverse(routes.AGENCY_INDEX, args=[self.slug])
140 @property
141 def eligibility_index_url(self):
142 """Public facing URL to the TransitAgency's eligibility page."""
143 return reverse(routes.AGENCY_ELIGIBILITY_INDEX, args=[self.slug])
145 @property
146 def littlepay_config(self):
147 if hasattr(self, "transitprocessorconfig") and hasattr(self.transitprocessorconfig, "littlepayconfig"):
148 return self.transitprocessorconfig.littlepayconfig
149 else:
150 return None
152 @property
153 def switchio_config(self):
154 if hasattr(self, "transitprocessorconfig") and hasattr(self.transitprocessorconfig, "switchioconfig"):
155 return self.transitprocessorconfig.switchioconfig
156 else:
157 return None
159 @property
160 def transit_processor(self):
161 if self.littlepay_config:
162 return "littlepay"
163 elif self.switchio_config:
164 return "switchio"
165 else:
166 return None
168 @property
169 def in_person_enrollment_index_route(self):
170 """This Agency's in-person enrollment index route, based on its configured transit processor."""
171 if self.littlepay_config:
172 return routes.IN_PERSON_ENROLLMENT_LITTLEPAY_INDEX
173 elif self.switchio_config:
174 return routes.IN_PERSON_ENROLLMENT_SWITCHIO_INDEX
175 else:
176 raise ValueError(
177 (
178 "TransitAgency must have either a LittlepayConfig or SwitchioConfig "
179 "in order to show in-person enrollment index."
180 )
181 )
183 @property
184 def enrollment_index_route(self):
185 """This Agency's enrollment index route, based on its configured transit processor."""
186 if self.littlepay_config:
187 return routes.ENROLLMENT_LITTLEPAY_INDEX
188 elif self.switchio_config:
189 return routes.ENROLLMENT_SWITCHIO_INDEX
190 else:
191 raise ValueError(
192 "TransitAgency must have either a LittlepayConfig or SwitchioConfig in order to show enrollment index."
193 )
195 @property
196 def enrollment_flows(self):
197 return self.enrollmentflow_set
199 @property
200 def customer_service_group_name(self):
201 """Returns the standardized name for this Agency's customer service group."""
202 return f"{self.short_name} Customer Service"
204 def clean(self):
205 field_errors = {}
206 non_field_errors = []
208 if self.active:
209 message = "This field is required for active transit agencies."
210 needed = dict(
211 long_name=self.long_name,
212 phone=self.phone,
213 info_url=self.info_url,
214 logo=self.logo,
215 )
216 field_errors.update({k: ValidationError(message) for k, v in needed.items() if not v})
218 if self.littlepay_config is None and self.switchio_config is None:
219 non_field_errors.append(ValidationError("Must fill out configuration for either Littlepay or Switchio."))
220 else:
221 if self.littlepay_config:
222 try:
223 self.littlepay_config.clean()
224 except ValidationError as e:
225 message = "Littlepay configuration is missing fields that are required when this agency is active."
226 message += f" Missing fields: {', '.join(e.error_dict.keys())}"
227 non_field_errors.append(ValidationError(message))
229 if self.switchio_config:
230 try:
231 self.switchio_config.clean()
232 except ValidationError as e:
233 message = "Switchio configuration is missing fields that are required when this agency is active."
234 message += f" Missing fields: {', '.join(e.error_dict.keys())}"
235 non_field_errors.append(ValidationError(message))
237 if self.pk: # prohibit updating short_name with blank customer_service_group 237 ↛ 248line 237 didn't jump to line 248 because the condition on line 237 was always true
238 original_obj = TransitAgency.objects.get(pk=self.pk)
239 if self.short_name != original_obj.short_name and not self.customer_service_group:
240 field_errors.update(
241 {
242 "customer_service_group": ValidationError(
243 "Blank not allowed. Set to its original value if changing the Short Name."
244 )
245 }
246 )
248 all_errors = {}
249 if field_errors:
250 all_errors.update(field_errors)
251 if non_field_errors:
252 all_errors.update({NON_FIELD_ERRORS: value for value in non_field_errors})
253 if all_errors:
254 raise ValidationError(all_errors)
256 @staticmethod
257 def by_id(id):
258 """Get a TransitAgency instance by its ID."""
259 logger.debug(f"Get {TransitAgency.__name__} by id: {id}")
260 return TransitAgency.objects.get(id=id)
262 @staticmethod
263 def by_slug(slug):
264 """Get a TransitAgency instance by its slug."""
265 logger.debug(f"Get {TransitAgency.__name__} by slug: {slug}")
266 return TransitAgency.objects.filter(slug=slug).first()
268 @staticmethod
269 def all_active():
270 """Get all TransitAgency instances marked active."""
271 logger.debug(f"Get all active {TransitAgency.__name__}")
272 return TransitAgency.objects.filter(active=True).order_by("long_name")
274 @staticmethod
275 def for_user(user: User):
276 for group in user.groups.all():
277 if hasattr(group, "transit_agency"):
278 return group.transit_agency # this is looking at the TransitAgency's customer_service_group
280 # the loop above returns the first match found. Return None if no match was found.
281 return None