billing: Restructure validation of upgrade parameters.

This commit is contained in:
Rishi Gupta 2018-12-21 20:29:25 -08:00
parent b4a28f3147
commit 5633049292
3 changed files with 109 additions and 60 deletions

View File

@ -56,7 +56,7 @@ class BillingError(Exception):
TRY_RELOADING = _("Something went wrong. Please reload the page.")
# description is used only for tests
def __init__(self, description: str, message: str) -> None:
def __init__(self, description: str, message: str=CONTACT_SUPPORT) -> None:
self.description = description
self.message = message

View File

@ -239,19 +239,20 @@ class StripeTest(ZulipTestCase):
return match.group(1) if match else None
def upgrade(self, invoice: bool=False, talk_to_stripe: bool=True,
realm: Optional[Realm]=None, **kwargs: Any) -> HttpResponse:
realm: Optional[Realm]=None, del_args: List[str]=[],
**kwargs: Any) -> HttpResponse:
host_args = {}
if realm is not None:
host_args['HTTP_HOST'] = realm.host
response = self.client_get("/upgrade/", **host_args)
params = {
'schedule': 'annual',
'signed_seat_count': self.get_signed_seat_count_from_response(response),
'salt': self.get_salt_from_response(response),
'schedule': 'annual'} # type: Dict[str, Any]
'salt': self.get_salt_from_response(response)} # type: Dict[str, Any]
if invoice: # send_invoice
params.update({
'licenses': 123,
'billing_modality': 'send_invoice'})
'billing_modality': 'send_invoice',
'licenses': 123})
else: # charge_automatically
stripe_token = None
if not talk_to_stripe:
@ -260,11 +261,15 @@ class StripeTest(ZulipTestCase):
if stripe_token is None:
stripe_token = stripe_create_token().id
params.update({
'stripe_token': stripe_token,
'billing_modality': 'charge_automatically',
'license_management': 'automatic',
'stripe_token': stripe_token,
})
params.update(kwargs)
for key in del_args:
if key in params:
del params[key]
for key, value in params.items():
params[key] = ujson.dumps(value)
return self.client_post("/json/billing/upgrade", params, **host_args)
@ -469,32 +474,66 @@ class StripeTest(ZulipTestCase):
self.assert_json_error_contains(response, "Something went wrong. Please contact")
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered seat count')
def test_upgrade_with_tampered_schedule(self) -> None:
# Test with an unknown plan
self.login(self.example_email("hamlet"))
response = self.upgrade(talk_to_stripe=False, schedule='biweekly')
self.assert_json_error_contains(response, "Something went wrong. Please contact")
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule')
# Test with a plan that's valid, but not if you're paying by invoice
response = self.upgrade(invoice=True, talk_to_stripe=False, schedule='monthly')
self.assert_json_error_contains(response, "Something went wrong. Please contact")
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule')
def test_check_upgrade_parameters(self) -> None:
# Tests all the error paths except 'not enough licenses'
def check_error(error_description: str, upgrade_params: Dict[str, Any],
del_args: List[str]=[]) -> None:
response = self.upgrade(talk_to_stripe=False, del_args=del_args, **upgrade_params)
self.assert_json_error_contains(response, "Something went wrong. Please contact")
self.assertEqual(ujson.loads(response.content)['error_description'], error_description)
def test_upgrade_with_insufficient_invoiced_seat_count(self) -> None:
self.login(self.example_email("hamlet"))
# Test invoicing for less than MIN_INVOICED_LICENSES
response = self.upgrade(invoice=True, talk_to_stripe=False,
licenses=MIN_INVOICED_LICENSES - 1)
self.assert_json_error_contains(response, "at least {} users.".format(MIN_INVOICED_LICENSES))
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
# Test invoicing for less than your user count
check_error('unknown billing_modality', {'billing_modality': 'invalid'})
check_error('unknown schedule', {'schedule': 'invalid'})
check_error('unknown license_management', {'license_management': 'invalid'})
check_error('autopay with no card', {}, del_args=['stripe_token'])
def test_upgrade_license_counts(self) -> None:
def check_error(invoice: bool, licenses: Optional[int], min_licenses_in_response: int,
upgrade_params: Dict[str, Any]={}) -> None:
if licenses is None:
del_args = ['licenses']
else:
del_args = []
upgrade_params['licenses'] = licenses
response = self.upgrade(invoice=invoice, talk_to_stripe=False,
del_args=del_args, **upgrade_params)
self.assert_json_error_contains(response, "at least {} users".format(min_licenses_in_response))
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
def check_success(invoice: bool, licenses: Optional[int], upgrade_params: Dict[str, Any]={}) -> None:
if licenses is None:
del_args = ['licenses']
else:
del_args = []
upgrade_params['licenses'] = licenses
with patch('corporate.views.process_initial_upgrade'):
response = self.upgrade(invoice=invoice, talk_to_stripe=False,
del_args=del_args, **upgrade_params)
self.assert_json_success(response)
self.login(self.example_email("hamlet"))
# Autopay with licenses < seat count
check_error(False, self.seat_count - 1, self.seat_count, {'license_management': 'manual'})
# Autopay with not setting licenses
check_error(False, None, self.seat_count, {'license_management': 'manual'})
# Invoice with licenses < MIN_INVOICED_LICENSES
check_error(True, MIN_INVOICED_LICENSES - 1, MIN_INVOICED_LICENSES)
# Invoice with licenses < seat count
with patch("corporate.views.MIN_INVOICED_LICENSES", 3):
response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=4)
self.assert_json_error_contains(response, "at least {} users.".format(self.seat_count))
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
# Test not setting licenses
response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=None)
self.assert_json_error_contains(response, "licenses is not an integer")
check_error(True, 4, self.seat_count)
# Invoice with not setting licenses
check_error(True, None, MIN_INVOICED_LICENSES)
# Autopay with automatic license_management
check_success(False, None)
# Autopay with automatic license_management, should just ignore the licenses entry
check_success(False, self.seat_count)
# Autopay
check_success(False, self.seat_count, {'license_management': 'manual'})
check_success(False, self.seat_count + 10, {'license_management': 'mix'})
# Invoice
check_success(True, self.seat_count + MIN_INVOICED_LICENSES)
@patch("corporate.lib.stripe.billing_logger.error")
def test_upgrade_with_uncaught_exception(self, mock_: Mock) -> None:

View File

@ -26,25 +26,32 @@ from corporate.models import Customer, CustomerPlan, Plan
billing_logger = logging.getLogger('corporate.stripe')
def unsign_and_check_upgrade_parameters(user: UserProfile, schedule: str,
signed_seat_count: str, salt: str,
billing_modality: str) -> Tuple[int, int]:
provided_schedules = {
'charge_automatically': ['annual', 'monthly'],
'send_invoice': ['annual'],
}
if schedule not in provided_schedules[billing_modality]:
billing_logger.warning("Tampered schedule during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered schedule', BillingError.CONTACT_SUPPORT)
billing_schedule = {'annual': CustomerPlan.ANNUAL, 'monthly': CustomerPlan.MONTHLY}[schedule]
def unsign_seat_count(signed_seat_count: str, salt: str) -> int:
try:
seat_count = int(unsign_string(signed_seat_count, salt))
return int(unsign_string(signed_seat_count, salt))
except signing.BadSignature:
billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT)
return seat_count, billing_schedule
raise BillingError('tampered seat count')
def check_upgrade_parameters(
billing_modality: str, schedule: str, license_management: str, licenses: int,
has_stripe_token: bool, seat_count: int) -> None:
if billing_modality not in ['send_invoice', 'charge_automatically']:
raise BillingError('unknown billing_modality')
if schedule not in ['annual', 'monthly']:
raise BillingError('unknown schedule')
if license_management not in ['automatic', 'manual', 'mix']:
raise BillingError('unknown license_management')
if billing_modality == 'charge_automatically':
if not has_stripe_token:
raise BillingError('autopay with no card')
min_licenses = seat_count
if billing_modality == 'send_invoice':
min_licenses = max(seat_count, MIN_INVOICED_LICENSES)
if licenses is None or licenses < min_licenses:
raise BillingError('not enough licenses',
_("You must invoice for at least {} users.".format(min_licenses)))
def payment_method_string(stripe_customer: stripe.Customer) -> str:
subscription = extract_current_subscription(stripe_customer)
@ -67,26 +74,29 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str:
@has_request_variables
def upgrade(request: HttpRequest, user: UserProfile,
billing_modality: str=REQ(validator=check_string),
schedule: str=REQ(validator=check_string),
license_management: str=REQ(validator=check_string, default=None),
signed_seat_count: str=REQ(validator=check_string),
salt: str=REQ(validator=check_string),
billing_modality: str=REQ(validator=check_string),
licenses: int=REQ(validator=check_int, default=None),
stripe_token: str=REQ(validator=check_string, default=None)) -> HttpResponse:
stripe_token: str=REQ(validator=check_string, default=None),
signed_seat_count: str=REQ(validator=check_string),
salt: str=REQ(validator=check_string)) -> HttpResponse:
try:
seat_count, billing_schedule = unsign_and_check_upgrade_parameters(
user, schedule, signed_seat_count, salt, billing_modality)
if billing_modality == 'send_invoice':
min_required_licenses = max(seat_count, MIN_INVOICED_LICENSES)
if licenses < min_required_licenses:
raise BillingError(
'not enough licenses',
"You must invoice for at least %d users." % (min_required_licenses,))
else:
seat_count = unsign_seat_count(signed_seat_count, salt)
if billing_modality == 'charge_automatically' and license_management == 'automatic':
licenses = seat_count
if billing_modality == 'send_invoice':
schedule = 'annual'
license_management = 'manual'
check_upgrade_parameters(
billing_modality, schedule, license_management, licenses,
stripe_token is not None, seat_count)
billing_schedule = {'annual': CustomerPlan.ANNUAL,
'monthly': CustomerPlan.MONTHLY}[schedule]
process_initial_upgrade(user, licenses, billing_schedule, stripe_token)
except BillingError as e:
# TODO add a billing_logger.warning with all the upgrade parameters
return json_error(e.message, data={'error_description': e.description})
except Exception as e:
billing_logger.exception("Uncaught exception in billing: %s" % (e,))