diff --git a/zilencer/tests/test_stripe.py b/zilencer/tests/test_stripe.py index 5eba6ac56e..a59810a236 100644 --- a/zilencer/tests/test_stripe.py +++ b/zilencer/tests/test_stripe.py @@ -218,6 +218,18 @@ class StripeTest(ZulipTestCase): }) self.assert_in_success_response(["Something went wrong. Please contact"], result) + @mock.patch("zilencer.lib.stripe.STRIPE_PUBLISHABLE_KEY", "stripe_publishable_key") + @mock.patch("zilencer.views.STRIPE_PUBLISHABLE_KEY", "stripe_publishable_key") + def test_upgrade_with_tampered_plan(self) -> None: + self.login(self.user.email) + result = self.client_post("/upgrade/", { + 'stripeToken': self.token, + 'signed_seat_count': self.signed_seat_count, + 'salt': self.salt, + 'plan': "invalid" + }) + self.assert_in_success_response(["Something went wrong. Please contact"], result) + @mock.patch("zilencer.lib.stripe.STRIPE_PUBLISHABLE_KEY", "stripe_publishable_key") @mock.patch("zilencer.views.STRIPE_PUBLISHABLE_KEY", "stripe_publishable_key") @mock.patch("stripe.Customer.retrieve", side_effect=mock_retrieve_customer) diff --git a/zilencer/views.py b/zilencer/views.py index 394758e763..bf3452a3d3 100644 --- a/zilencer/views.py +++ b/zilencer/views.py @@ -168,6 +168,11 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse: return HttpResponseRedirect(reverse('zilencer.views.billing_home')) if request.method == 'POST': + plan = request.POST['plan'] + if plan not in [Plan.CLOUD_ANNUAL, Plan.CLOUD_MONTHLY]: + billing_logger.warning("Tampered plan during realm upgrade. user: %s, realm: %s (%s)." + % (user.id, user.realm.id, user.realm.string_id)) + error_message = "Something went wrong. Please contact support@zulipchat.com" try: seat_count = int(unsign_string(request.POST['signed_seat_count'], request.POST['salt'])) except signing.BadSignature: @@ -179,7 +184,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse: stripe_customer = do_create_customer_with_payment_source(user, request.POST['stripeToken']) do_subscribe_customer_to_plan( stripe_customer=stripe_customer, - stripe_plan_id=Plan.objects.get(nickname=request.POST['plan']).stripe_plan_id, + stripe_plan_id=Plan.objects.get(nickname=plan).stripe_plan_id, seat_count=seat_count, # TODO: billing address details are passed to us in the request; # use that to calculate taxes.