From 5633049292fbc450b21d5f7fde1e1f20a87b4966 Mon Sep 17 00:00:00 2001 From: Rishi Gupta Date: Fri, 21 Dec 2018 20:29:25 -0800 Subject: [PATCH] billing: Restructure validation of upgrade parameters. --- corporate/lib/stripe.py | 2 +- corporate/tests/test_stripe.py | 97 ++++++++++++++++++++++++---------- corporate/views.py | 70 +++++++++++++----------- 3 files changed, 109 insertions(+), 60 deletions(-) diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 836bead926..a574cf5433 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -56,7 +56,7 @@ class BillingError(Exception): TRY_RELOADING = _("Something went wrong. Please reload the page.") # description is used only for tests - def __init__(self, description: str, message: str) -> None: + def __init__(self, description: str, message: str=CONTACT_SUPPORT) -> None: self.description = description self.message = message diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 65b5e9c878..d64e129af7 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -239,19 +239,20 @@ class StripeTest(ZulipTestCase): return match.group(1) if match else None def upgrade(self, invoice: bool=False, talk_to_stripe: bool=True, - realm: Optional[Realm]=None, **kwargs: Any) -> HttpResponse: + realm: Optional[Realm]=None, del_args: List[str]=[], + **kwargs: Any) -> HttpResponse: host_args = {} if realm is not None: host_args['HTTP_HOST'] = realm.host response = self.client_get("/upgrade/", **host_args) params = { + 'schedule': 'annual', 'signed_seat_count': self.get_signed_seat_count_from_response(response), - 'salt': self.get_salt_from_response(response), - 'schedule': 'annual'} # type: Dict[str, Any] + 'salt': self.get_salt_from_response(response)} # type: Dict[str, Any] if invoice: # send_invoice params.update({ - 'licenses': 123, - 'billing_modality': 'send_invoice'}) + 'billing_modality': 'send_invoice', + 'licenses': 123}) else: # charge_automatically stripe_token = None if not talk_to_stripe: @@ -260,11 +261,15 @@ class StripeTest(ZulipTestCase): if stripe_token is None: stripe_token = stripe_create_token().id params.update({ - 'stripe_token': stripe_token, 'billing_modality': 'charge_automatically', + 'license_management': 'automatic', + 'stripe_token': stripe_token, }) params.update(kwargs) + for key in del_args: + if key in params: + del params[key] for key, value in params.items(): params[key] = ujson.dumps(value) return self.client_post("/json/billing/upgrade", params, **host_args) @@ -469,32 +474,66 @@ class StripeTest(ZulipTestCase): self.assert_json_error_contains(response, "Something went wrong. Please contact") self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered seat count') - def test_upgrade_with_tampered_schedule(self) -> None: - # Test with an unknown plan - self.login(self.example_email("hamlet")) - response = self.upgrade(talk_to_stripe=False, schedule='biweekly') - self.assert_json_error_contains(response, "Something went wrong. Please contact") - self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule') - # Test with a plan that's valid, but not if you're paying by invoice - response = self.upgrade(invoice=True, talk_to_stripe=False, schedule='monthly') - self.assert_json_error_contains(response, "Something went wrong. Please contact") - self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule') + def test_check_upgrade_parameters(self) -> None: + # Tests all the error paths except 'not enough licenses' + def check_error(error_description: str, upgrade_params: Dict[str, Any], + del_args: List[str]=[]) -> None: + response = self.upgrade(talk_to_stripe=False, del_args=del_args, **upgrade_params) + self.assert_json_error_contains(response, "Something went wrong. Please contact") + self.assertEqual(ujson.loads(response.content)['error_description'], error_description) - def test_upgrade_with_insufficient_invoiced_seat_count(self) -> None: self.login(self.example_email("hamlet")) - # Test invoicing for less than MIN_INVOICED_LICENSES - response = self.upgrade(invoice=True, talk_to_stripe=False, - licenses=MIN_INVOICED_LICENSES - 1) - self.assert_json_error_contains(response, "at least {} users.".format(MIN_INVOICED_LICENSES)) - self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses') - # Test invoicing for less than your user count + check_error('unknown billing_modality', {'billing_modality': 'invalid'}) + check_error('unknown schedule', {'schedule': 'invalid'}) + check_error('unknown license_management', {'license_management': 'invalid'}) + check_error('autopay with no card', {}, del_args=['stripe_token']) + + def test_upgrade_license_counts(self) -> None: + def check_error(invoice: bool, licenses: Optional[int], min_licenses_in_response: int, + upgrade_params: Dict[str, Any]={}) -> None: + if licenses is None: + del_args = ['licenses'] + else: + del_args = [] + upgrade_params['licenses'] = licenses + response = self.upgrade(invoice=invoice, talk_to_stripe=False, + del_args=del_args, **upgrade_params) + self.assert_json_error_contains(response, "at least {} users".format(min_licenses_in_response)) + self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses') + + def check_success(invoice: bool, licenses: Optional[int], upgrade_params: Dict[str, Any]={}) -> None: + if licenses is None: + del_args = ['licenses'] + else: + del_args = [] + upgrade_params['licenses'] = licenses + with patch('corporate.views.process_initial_upgrade'): + response = self.upgrade(invoice=invoice, talk_to_stripe=False, + del_args=del_args, **upgrade_params) + self.assert_json_success(response) + + self.login(self.example_email("hamlet")) + # Autopay with licenses < seat count + check_error(False, self.seat_count - 1, self.seat_count, {'license_management': 'manual'}) + # Autopay with not setting licenses + check_error(False, None, self.seat_count, {'license_management': 'manual'}) + # Invoice with licenses < MIN_INVOICED_LICENSES + check_error(True, MIN_INVOICED_LICENSES - 1, MIN_INVOICED_LICENSES) + # Invoice with licenses < seat count with patch("corporate.views.MIN_INVOICED_LICENSES", 3): - response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=4) - self.assert_json_error_contains(response, "at least {} users.".format(self.seat_count)) - self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses') - # Test not setting licenses - response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=None) - self.assert_json_error_contains(response, "licenses is not an integer") + check_error(True, 4, self.seat_count) + # Invoice with not setting licenses + check_error(True, None, MIN_INVOICED_LICENSES) + + # Autopay with automatic license_management + check_success(False, None) + # Autopay with automatic license_management, should just ignore the licenses entry + check_success(False, self.seat_count) + # Autopay + check_success(False, self.seat_count, {'license_management': 'manual'}) + check_success(False, self.seat_count + 10, {'license_management': 'mix'}) + # Invoice + check_success(True, self.seat_count + MIN_INVOICED_LICENSES) @patch("corporate.lib.stripe.billing_logger.error") def test_upgrade_with_uncaught_exception(self, mock_: Mock) -> None: diff --git a/corporate/views.py b/corporate/views.py index 852f5e289a..f30c268b04 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -26,25 +26,32 @@ from corporate.models import Customer, CustomerPlan, Plan billing_logger = logging.getLogger('corporate.stripe') -def unsign_and_check_upgrade_parameters(user: UserProfile, schedule: str, - signed_seat_count: str, salt: str, - billing_modality: str) -> Tuple[int, int]: - provided_schedules = { - 'charge_automatically': ['annual', 'monthly'], - 'send_invoice': ['annual'], - } - if schedule not in provided_schedules[billing_modality]: - billing_logger.warning("Tampered schedule during realm upgrade. user: %s, realm: %s (%s)." - % (user.id, user.realm.id, user.realm.string_id)) - raise BillingError('tampered schedule', BillingError.CONTACT_SUPPORT) - billing_schedule = {'annual': CustomerPlan.ANNUAL, 'monthly': CustomerPlan.MONTHLY}[schedule] +def unsign_seat_count(signed_seat_count: str, salt: str) -> int: try: - seat_count = int(unsign_string(signed_seat_count, salt)) + return int(unsign_string(signed_seat_count, salt)) except signing.BadSignature: - billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)." - % (user.id, user.realm.id, user.realm.string_id)) - raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT) - return seat_count, billing_schedule + raise BillingError('tampered seat count') + +def check_upgrade_parameters( + billing_modality: str, schedule: str, license_management: str, licenses: int, + has_stripe_token: bool, seat_count: int) -> None: + if billing_modality not in ['send_invoice', 'charge_automatically']: + raise BillingError('unknown billing_modality') + if schedule not in ['annual', 'monthly']: + raise BillingError('unknown schedule') + if license_management not in ['automatic', 'manual', 'mix']: + raise BillingError('unknown license_management') + + if billing_modality == 'charge_automatically': + if not has_stripe_token: + raise BillingError('autopay with no card') + + min_licenses = seat_count + if billing_modality == 'send_invoice': + min_licenses = max(seat_count, MIN_INVOICED_LICENSES) + if licenses is None or licenses < min_licenses: + raise BillingError('not enough licenses', + _("You must invoice for at least {} users.".format(min_licenses))) def payment_method_string(stripe_customer: stripe.Customer) -> str: subscription = extract_current_subscription(stripe_customer) @@ -67,26 +74,29 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str: @has_request_variables def upgrade(request: HttpRequest, user: UserProfile, + billing_modality: str=REQ(validator=check_string), schedule: str=REQ(validator=check_string), license_management: str=REQ(validator=check_string, default=None), - signed_seat_count: str=REQ(validator=check_string), - salt: str=REQ(validator=check_string), - billing_modality: str=REQ(validator=check_string), licenses: int=REQ(validator=check_int, default=None), - stripe_token: str=REQ(validator=check_string, default=None)) -> HttpResponse: + stripe_token: str=REQ(validator=check_string, default=None), + signed_seat_count: str=REQ(validator=check_string), + salt: str=REQ(validator=check_string)) -> HttpResponse: try: - seat_count, billing_schedule = unsign_and_check_upgrade_parameters( - user, schedule, signed_seat_count, salt, billing_modality) - if billing_modality == 'send_invoice': - min_required_licenses = max(seat_count, MIN_INVOICED_LICENSES) - if licenses < min_required_licenses: - raise BillingError( - 'not enough licenses', - "You must invoice for at least %d users." % (min_required_licenses,)) - else: + seat_count = unsign_seat_count(signed_seat_count, salt) + if billing_modality == 'charge_automatically' and license_management == 'automatic': licenses = seat_count + if billing_modality == 'send_invoice': + schedule = 'annual' + license_management = 'manual' + check_upgrade_parameters( + billing_modality, schedule, license_management, licenses, + stripe_token is not None, seat_count) + + billing_schedule = {'annual': CustomerPlan.ANNUAL, + 'monthly': CustomerPlan.MONTHLY}[schedule] process_initial_upgrade(user, licenses, billing_schedule, stripe_token) except BillingError as e: + # TODO add a billing_logger.warning with all the upgrade parameters return json_error(e.message, data={'error_description': e.description}) except Exception as e: billing_logger.exception("Uncaught exception in billing: %s" % (e,))