From a5324b6ea79c86df8e572662f6b625b553d16172 Mon Sep 17 00:00:00 2001 From: Rishi Gupta Date: Tue, 29 Jan 2019 07:01:31 -0800 Subject: [PATCH] billing: Add a test for a race condition in process_initial_upgrade. --- corporate/lib/stripe.py | 3 +-- corporate/tests/test_stripe.py | 9 +++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 7474269ca3..d46589dfc4 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -258,8 +258,7 @@ def process_initial_upgrade(user: UserProfile, licenses: int, automanage_license billing_schedule: int, stripe_token: Optional[str]) -> None: realm = user.realm customer = update_or_create_stripe_customer(user, stripe_token=stripe_token) - # TODO write a test for this - if CustomerPlan.objects.filter(customer=customer, status=CustomerPlan.ACTIVE).exists(): # nocoverage + if CustomerPlan.objects.filter(customer=customer, status=CustomerPlan.ACTIVE).exists(): # Unlikely race condition from two people upgrading (clicking "Make payment") # at exactly the same time. Doesn't fully resolve the race condition, but having # a check here reduces the likelihood. diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index cdc67db9d4..9f7b01c53d 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -628,6 +628,15 @@ class StripeTest(StripeTestCase): self.assert_json_error_contains(response, "Something went wrong. Please contact") self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered seat count') + def test_upgrade_race_condition(self) -> None: + self.login(self.example_email("hamlet")) + self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, 'token') + with patch("corporate.lib.stripe.billing_logger.warning") as mock_billing_logger: + with self.assertRaises(BillingError) as context: + self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL, 'token') + self.assertEqual('subscribing with existing subscription', context.exception.description) + mock_billing_logger.assert_called() + def test_check_upgrade_parameters(self) -> None: # Tests all the error paths except 'not enough licenses' def check_error(error_description: str, upgrade_params: Dict[str, Any],