billing: Restructure validation of upgrade parameters.

This commit is contained in:
Rishi Gupta 2018-12-21 20:29:25 -08:00
parent b4a28f3147
commit 5633049292
3 changed files with 109 additions and 60 deletions

View File

@ -56,7 +56,7 @@ class BillingError(Exception):
TRY_RELOADING = _("Something went wrong. Please reload the page.") TRY_RELOADING = _("Something went wrong. Please reload the page.")
# description is used only for tests # description is used only for tests
def __init__(self, description: str, message: str) -> None: def __init__(self, description: str, message: str=CONTACT_SUPPORT) -> None:
self.description = description self.description = description
self.message = message self.message = message

View File

@ -239,19 +239,20 @@ class StripeTest(ZulipTestCase):
return match.group(1) if match else None return match.group(1) if match else None
def upgrade(self, invoice: bool=False, talk_to_stripe: bool=True, def upgrade(self, invoice: bool=False, talk_to_stripe: bool=True,
realm: Optional[Realm]=None, **kwargs: Any) -> HttpResponse: realm: Optional[Realm]=None, del_args: List[str]=[],
**kwargs: Any) -> HttpResponse:
host_args = {} host_args = {}
if realm is not None: if realm is not None:
host_args['HTTP_HOST'] = realm.host host_args['HTTP_HOST'] = realm.host
response = self.client_get("/upgrade/", **host_args) response = self.client_get("/upgrade/", **host_args)
params = { params = {
'schedule': 'annual',
'signed_seat_count': self.get_signed_seat_count_from_response(response), 'signed_seat_count': self.get_signed_seat_count_from_response(response),
'salt': self.get_salt_from_response(response), 'salt': self.get_salt_from_response(response)} # type: Dict[str, Any]
'schedule': 'annual'} # type: Dict[str, Any]
if invoice: # send_invoice if invoice: # send_invoice
params.update({ params.update({
'licenses': 123, 'billing_modality': 'send_invoice',
'billing_modality': 'send_invoice'}) 'licenses': 123})
else: # charge_automatically else: # charge_automatically
stripe_token = None stripe_token = None
if not talk_to_stripe: if not talk_to_stripe:
@ -260,11 +261,15 @@ class StripeTest(ZulipTestCase):
if stripe_token is None: if stripe_token is None:
stripe_token = stripe_create_token().id stripe_token = stripe_create_token().id
params.update({ params.update({
'stripe_token': stripe_token,
'billing_modality': 'charge_automatically', 'billing_modality': 'charge_automatically',
'license_management': 'automatic',
'stripe_token': stripe_token,
}) })
params.update(kwargs) params.update(kwargs)
for key in del_args:
if key in params:
del params[key]
for key, value in params.items(): for key, value in params.items():
params[key] = ujson.dumps(value) params[key] = ujson.dumps(value)
return self.client_post("/json/billing/upgrade", params, **host_args) return self.client_post("/json/billing/upgrade", params, **host_args)
@ -469,32 +474,66 @@ class StripeTest(ZulipTestCase):
self.assert_json_error_contains(response, "Something went wrong. Please contact") self.assert_json_error_contains(response, "Something went wrong. Please contact")
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered seat count') self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered seat count')
def test_upgrade_with_tampered_schedule(self) -> None: def test_check_upgrade_parameters(self) -> None:
# Test with an unknown plan # Tests all the error paths except 'not enough licenses'
self.login(self.example_email("hamlet")) def check_error(error_description: str, upgrade_params: Dict[str, Any],
response = self.upgrade(talk_to_stripe=False, schedule='biweekly') del_args: List[str]=[]) -> None:
self.assert_json_error_contains(response, "Something went wrong. Please contact") response = self.upgrade(talk_to_stripe=False, del_args=del_args, **upgrade_params)
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule') self.assert_json_error_contains(response, "Something went wrong. Please contact")
# Test with a plan that's valid, but not if you're paying by invoice self.assertEqual(ujson.loads(response.content)['error_description'], error_description)
response = self.upgrade(invoice=True, talk_to_stripe=False, schedule='monthly')
self.assert_json_error_contains(response, "Something went wrong. Please contact")
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule')
def test_upgrade_with_insufficient_invoiced_seat_count(self) -> None:
self.login(self.example_email("hamlet")) self.login(self.example_email("hamlet"))
# Test invoicing for less than MIN_INVOICED_LICENSES check_error('unknown billing_modality', {'billing_modality': 'invalid'})
response = self.upgrade(invoice=True, talk_to_stripe=False, check_error('unknown schedule', {'schedule': 'invalid'})
licenses=MIN_INVOICED_LICENSES - 1) check_error('unknown license_management', {'license_management': 'invalid'})
self.assert_json_error_contains(response, "at least {} users.".format(MIN_INVOICED_LICENSES)) check_error('autopay with no card', {}, del_args=['stripe_token'])
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
# Test invoicing for less than your user count def test_upgrade_license_counts(self) -> None:
def check_error(invoice: bool, licenses: Optional[int], min_licenses_in_response: int,
upgrade_params: Dict[str, Any]={}) -> None:
if licenses is None:
del_args = ['licenses']
else:
del_args = []
upgrade_params['licenses'] = licenses
response = self.upgrade(invoice=invoice, talk_to_stripe=False,
del_args=del_args, **upgrade_params)
self.assert_json_error_contains(response, "at least {} users".format(min_licenses_in_response))
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
def check_success(invoice: bool, licenses: Optional[int], upgrade_params: Dict[str, Any]={}) -> None:
if licenses is None:
del_args = ['licenses']
else:
del_args = []
upgrade_params['licenses'] = licenses
with patch('corporate.views.process_initial_upgrade'):
response = self.upgrade(invoice=invoice, talk_to_stripe=False,
del_args=del_args, **upgrade_params)
self.assert_json_success(response)
self.login(self.example_email("hamlet"))
# Autopay with licenses < seat count
check_error(False, self.seat_count - 1, self.seat_count, {'license_management': 'manual'})
# Autopay with not setting licenses
check_error(False, None, self.seat_count, {'license_management': 'manual'})
# Invoice with licenses < MIN_INVOICED_LICENSES
check_error(True, MIN_INVOICED_LICENSES - 1, MIN_INVOICED_LICENSES)
# Invoice with licenses < seat count
with patch("corporate.views.MIN_INVOICED_LICENSES", 3): with patch("corporate.views.MIN_INVOICED_LICENSES", 3):
response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=4) check_error(True, 4, self.seat_count)
self.assert_json_error_contains(response, "at least {} users.".format(self.seat_count)) # Invoice with not setting licenses
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses') check_error(True, None, MIN_INVOICED_LICENSES)
# Test not setting licenses
response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=None) # Autopay with automatic license_management
self.assert_json_error_contains(response, "licenses is not an integer") check_success(False, None)
# Autopay with automatic license_management, should just ignore the licenses entry
check_success(False, self.seat_count)
# Autopay
check_success(False, self.seat_count, {'license_management': 'manual'})
check_success(False, self.seat_count + 10, {'license_management': 'mix'})
# Invoice
check_success(True, self.seat_count + MIN_INVOICED_LICENSES)
@patch("corporate.lib.stripe.billing_logger.error") @patch("corporate.lib.stripe.billing_logger.error")
def test_upgrade_with_uncaught_exception(self, mock_: Mock) -> None: def test_upgrade_with_uncaught_exception(self, mock_: Mock) -> None:

View File

@ -26,25 +26,32 @@ from corporate.models import Customer, CustomerPlan, Plan
billing_logger = logging.getLogger('corporate.stripe') billing_logger = logging.getLogger('corporate.stripe')
def unsign_and_check_upgrade_parameters(user: UserProfile, schedule: str, def unsign_seat_count(signed_seat_count: str, salt: str) -> int:
signed_seat_count: str, salt: str,
billing_modality: str) -> Tuple[int, int]:
provided_schedules = {
'charge_automatically': ['annual', 'monthly'],
'send_invoice': ['annual'],
}
if schedule not in provided_schedules[billing_modality]:
billing_logger.warning("Tampered schedule during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered schedule', BillingError.CONTACT_SUPPORT)
billing_schedule = {'annual': CustomerPlan.ANNUAL, 'monthly': CustomerPlan.MONTHLY}[schedule]
try: try:
seat_count = int(unsign_string(signed_seat_count, salt)) return int(unsign_string(signed_seat_count, salt))
except signing.BadSignature: except signing.BadSignature:
billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)." raise BillingError('tampered seat count')
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT) def check_upgrade_parameters(
return seat_count, billing_schedule billing_modality: str, schedule: str, license_management: str, licenses: int,
has_stripe_token: bool, seat_count: int) -> None:
if billing_modality not in ['send_invoice', 'charge_automatically']:
raise BillingError('unknown billing_modality')
if schedule not in ['annual', 'monthly']:
raise BillingError('unknown schedule')
if license_management not in ['automatic', 'manual', 'mix']:
raise BillingError('unknown license_management')
if billing_modality == 'charge_automatically':
if not has_stripe_token:
raise BillingError('autopay with no card')
min_licenses = seat_count
if billing_modality == 'send_invoice':
min_licenses = max(seat_count, MIN_INVOICED_LICENSES)
if licenses is None or licenses < min_licenses:
raise BillingError('not enough licenses',
_("You must invoice for at least {} users.".format(min_licenses)))
def payment_method_string(stripe_customer: stripe.Customer) -> str: def payment_method_string(stripe_customer: stripe.Customer) -> str:
subscription = extract_current_subscription(stripe_customer) subscription = extract_current_subscription(stripe_customer)
@ -67,26 +74,29 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str:
@has_request_variables @has_request_variables
def upgrade(request: HttpRequest, user: UserProfile, def upgrade(request: HttpRequest, user: UserProfile,
billing_modality: str=REQ(validator=check_string),
schedule: str=REQ(validator=check_string), schedule: str=REQ(validator=check_string),
license_management: str=REQ(validator=check_string, default=None), license_management: str=REQ(validator=check_string, default=None),
signed_seat_count: str=REQ(validator=check_string),
salt: str=REQ(validator=check_string),
billing_modality: str=REQ(validator=check_string),
licenses: int=REQ(validator=check_int, default=None), licenses: int=REQ(validator=check_int, default=None),
stripe_token: str=REQ(validator=check_string, default=None)) -> HttpResponse: stripe_token: str=REQ(validator=check_string, default=None),
signed_seat_count: str=REQ(validator=check_string),
salt: str=REQ(validator=check_string)) -> HttpResponse:
try: try:
seat_count, billing_schedule = unsign_and_check_upgrade_parameters( seat_count = unsign_seat_count(signed_seat_count, salt)
user, schedule, signed_seat_count, salt, billing_modality) if billing_modality == 'charge_automatically' and license_management == 'automatic':
if billing_modality == 'send_invoice':
min_required_licenses = max(seat_count, MIN_INVOICED_LICENSES)
if licenses < min_required_licenses:
raise BillingError(
'not enough licenses',
"You must invoice for at least %d users." % (min_required_licenses,))
else:
licenses = seat_count licenses = seat_count
if billing_modality == 'send_invoice':
schedule = 'annual'
license_management = 'manual'
check_upgrade_parameters(
billing_modality, schedule, license_management, licenses,
stripe_token is not None, seat_count)
billing_schedule = {'annual': CustomerPlan.ANNUAL,
'monthly': CustomerPlan.MONTHLY}[schedule]
process_initial_upgrade(user, licenses, billing_schedule, stripe_token) process_initial_upgrade(user, licenses, billing_schedule, stripe_token)
except BillingError as e: except BillingError as e:
# TODO add a billing_logger.warning with all the upgrade parameters
return json_error(e.message, data={'error_description': e.description}) return json_error(e.message, data={'error_description': e.description})
except Exception as e: except Exception as e:
billing_logger.exception("Uncaught exception in billing: %s" % (e,)) billing_logger.exception("Uncaught exception in billing: %s" % (e,))