mirror of https://github.com/zulip/zulip.git
billing: Restructure validation of upgrade parameters.
This commit is contained in:
parent
b4a28f3147
commit
5633049292
|
@ -56,7 +56,7 @@ class BillingError(Exception):
|
|||
TRY_RELOADING = _("Something went wrong. Please reload the page.")
|
||||
|
||||
# description is used only for tests
|
||||
def __init__(self, description: str, message: str) -> None:
|
||||
def __init__(self, description: str, message: str=CONTACT_SUPPORT) -> None:
|
||||
self.description = description
|
||||
self.message = message
|
||||
|
||||
|
|
|
@ -239,19 +239,20 @@ class StripeTest(ZulipTestCase):
|
|||
return match.group(1) if match else None
|
||||
|
||||
def upgrade(self, invoice: bool=False, talk_to_stripe: bool=True,
|
||||
realm: Optional[Realm]=None, **kwargs: Any) -> HttpResponse:
|
||||
realm: Optional[Realm]=None, del_args: List[str]=[],
|
||||
**kwargs: Any) -> HttpResponse:
|
||||
host_args = {}
|
||||
if realm is not None:
|
||||
host_args['HTTP_HOST'] = realm.host
|
||||
response = self.client_get("/upgrade/", **host_args)
|
||||
params = {
|
||||
'schedule': 'annual',
|
||||
'signed_seat_count': self.get_signed_seat_count_from_response(response),
|
||||
'salt': self.get_salt_from_response(response),
|
||||
'schedule': 'annual'} # type: Dict[str, Any]
|
||||
'salt': self.get_salt_from_response(response)} # type: Dict[str, Any]
|
||||
if invoice: # send_invoice
|
||||
params.update({
|
||||
'licenses': 123,
|
||||
'billing_modality': 'send_invoice'})
|
||||
'billing_modality': 'send_invoice',
|
||||
'licenses': 123})
|
||||
else: # charge_automatically
|
||||
stripe_token = None
|
||||
if not talk_to_stripe:
|
||||
|
@ -260,11 +261,15 @@ class StripeTest(ZulipTestCase):
|
|||
if stripe_token is None:
|
||||
stripe_token = stripe_create_token().id
|
||||
params.update({
|
||||
'stripe_token': stripe_token,
|
||||
'billing_modality': 'charge_automatically',
|
||||
'license_management': 'automatic',
|
||||
'stripe_token': stripe_token,
|
||||
})
|
||||
|
||||
params.update(kwargs)
|
||||
for key in del_args:
|
||||
if key in params:
|
||||
del params[key]
|
||||
for key, value in params.items():
|
||||
params[key] = ujson.dumps(value)
|
||||
return self.client_post("/json/billing/upgrade", params, **host_args)
|
||||
|
@ -469,32 +474,66 @@ class StripeTest(ZulipTestCase):
|
|||
self.assert_json_error_contains(response, "Something went wrong. Please contact")
|
||||
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered seat count')
|
||||
|
||||
def test_upgrade_with_tampered_schedule(self) -> None:
|
||||
# Test with an unknown plan
|
||||
self.login(self.example_email("hamlet"))
|
||||
response = self.upgrade(talk_to_stripe=False, schedule='biweekly')
|
||||
self.assert_json_error_contains(response, "Something went wrong. Please contact")
|
||||
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule')
|
||||
# Test with a plan that's valid, but not if you're paying by invoice
|
||||
response = self.upgrade(invoice=True, talk_to_stripe=False, schedule='monthly')
|
||||
self.assert_json_error_contains(response, "Something went wrong. Please contact")
|
||||
self.assertEqual(ujson.loads(response.content)['error_description'], 'tampered schedule')
|
||||
def test_check_upgrade_parameters(self) -> None:
|
||||
# Tests all the error paths except 'not enough licenses'
|
||||
def check_error(error_description: str, upgrade_params: Dict[str, Any],
|
||||
del_args: List[str]=[]) -> None:
|
||||
response = self.upgrade(talk_to_stripe=False, del_args=del_args, **upgrade_params)
|
||||
self.assert_json_error_contains(response, "Something went wrong. Please contact")
|
||||
self.assertEqual(ujson.loads(response.content)['error_description'], error_description)
|
||||
|
||||
def test_upgrade_with_insufficient_invoiced_seat_count(self) -> None:
|
||||
self.login(self.example_email("hamlet"))
|
||||
# Test invoicing for less than MIN_INVOICED_LICENSES
|
||||
response = self.upgrade(invoice=True, talk_to_stripe=False,
|
||||
licenses=MIN_INVOICED_LICENSES - 1)
|
||||
self.assert_json_error_contains(response, "at least {} users.".format(MIN_INVOICED_LICENSES))
|
||||
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
|
||||
# Test invoicing for less than your user count
|
||||
check_error('unknown billing_modality', {'billing_modality': 'invalid'})
|
||||
check_error('unknown schedule', {'schedule': 'invalid'})
|
||||
check_error('unknown license_management', {'license_management': 'invalid'})
|
||||
check_error('autopay with no card', {}, del_args=['stripe_token'])
|
||||
|
||||
def test_upgrade_license_counts(self) -> None:
|
||||
def check_error(invoice: bool, licenses: Optional[int], min_licenses_in_response: int,
|
||||
upgrade_params: Dict[str, Any]={}) -> None:
|
||||
if licenses is None:
|
||||
del_args = ['licenses']
|
||||
else:
|
||||
del_args = []
|
||||
upgrade_params['licenses'] = licenses
|
||||
response = self.upgrade(invoice=invoice, talk_to_stripe=False,
|
||||
del_args=del_args, **upgrade_params)
|
||||
self.assert_json_error_contains(response, "at least {} users".format(min_licenses_in_response))
|
||||
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
|
||||
|
||||
def check_success(invoice: bool, licenses: Optional[int], upgrade_params: Dict[str, Any]={}) -> None:
|
||||
if licenses is None:
|
||||
del_args = ['licenses']
|
||||
else:
|
||||
del_args = []
|
||||
upgrade_params['licenses'] = licenses
|
||||
with patch('corporate.views.process_initial_upgrade'):
|
||||
response = self.upgrade(invoice=invoice, talk_to_stripe=False,
|
||||
del_args=del_args, **upgrade_params)
|
||||
self.assert_json_success(response)
|
||||
|
||||
self.login(self.example_email("hamlet"))
|
||||
# Autopay with licenses < seat count
|
||||
check_error(False, self.seat_count - 1, self.seat_count, {'license_management': 'manual'})
|
||||
# Autopay with not setting licenses
|
||||
check_error(False, None, self.seat_count, {'license_management': 'manual'})
|
||||
# Invoice with licenses < MIN_INVOICED_LICENSES
|
||||
check_error(True, MIN_INVOICED_LICENSES - 1, MIN_INVOICED_LICENSES)
|
||||
# Invoice with licenses < seat count
|
||||
with patch("corporate.views.MIN_INVOICED_LICENSES", 3):
|
||||
response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=4)
|
||||
self.assert_json_error_contains(response, "at least {} users.".format(self.seat_count))
|
||||
self.assertEqual(ujson.loads(response.content)['error_description'], 'not enough licenses')
|
||||
# Test not setting licenses
|
||||
response = self.upgrade(invoice=True, talk_to_stripe=False, licenses=None)
|
||||
self.assert_json_error_contains(response, "licenses is not an integer")
|
||||
check_error(True, 4, self.seat_count)
|
||||
# Invoice with not setting licenses
|
||||
check_error(True, None, MIN_INVOICED_LICENSES)
|
||||
|
||||
# Autopay with automatic license_management
|
||||
check_success(False, None)
|
||||
# Autopay with automatic license_management, should just ignore the licenses entry
|
||||
check_success(False, self.seat_count)
|
||||
# Autopay
|
||||
check_success(False, self.seat_count, {'license_management': 'manual'})
|
||||
check_success(False, self.seat_count + 10, {'license_management': 'mix'})
|
||||
# Invoice
|
||||
check_success(True, self.seat_count + MIN_INVOICED_LICENSES)
|
||||
|
||||
@patch("corporate.lib.stripe.billing_logger.error")
|
||||
def test_upgrade_with_uncaught_exception(self, mock_: Mock) -> None:
|
||||
|
|
|
@ -26,25 +26,32 @@ from corporate.models import Customer, CustomerPlan, Plan
|
|||
|
||||
billing_logger = logging.getLogger('corporate.stripe')
|
||||
|
||||
def unsign_and_check_upgrade_parameters(user: UserProfile, schedule: str,
|
||||
signed_seat_count: str, salt: str,
|
||||
billing_modality: str) -> Tuple[int, int]:
|
||||
provided_schedules = {
|
||||
'charge_automatically': ['annual', 'monthly'],
|
||||
'send_invoice': ['annual'],
|
||||
}
|
||||
if schedule not in provided_schedules[billing_modality]:
|
||||
billing_logger.warning("Tampered schedule during realm upgrade. user: %s, realm: %s (%s)."
|
||||
% (user.id, user.realm.id, user.realm.string_id))
|
||||
raise BillingError('tampered schedule', BillingError.CONTACT_SUPPORT)
|
||||
billing_schedule = {'annual': CustomerPlan.ANNUAL, 'monthly': CustomerPlan.MONTHLY}[schedule]
|
||||
def unsign_seat_count(signed_seat_count: str, salt: str) -> int:
|
||||
try:
|
||||
seat_count = int(unsign_string(signed_seat_count, salt))
|
||||
return int(unsign_string(signed_seat_count, salt))
|
||||
except signing.BadSignature:
|
||||
billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)."
|
||||
% (user.id, user.realm.id, user.realm.string_id))
|
||||
raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT)
|
||||
return seat_count, billing_schedule
|
||||
raise BillingError('tampered seat count')
|
||||
|
||||
def check_upgrade_parameters(
|
||||
billing_modality: str, schedule: str, license_management: str, licenses: int,
|
||||
has_stripe_token: bool, seat_count: int) -> None:
|
||||
if billing_modality not in ['send_invoice', 'charge_automatically']:
|
||||
raise BillingError('unknown billing_modality')
|
||||
if schedule not in ['annual', 'monthly']:
|
||||
raise BillingError('unknown schedule')
|
||||
if license_management not in ['automatic', 'manual', 'mix']:
|
||||
raise BillingError('unknown license_management')
|
||||
|
||||
if billing_modality == 'charge_automatically':
|
||||
if not has_stripe_token:
|
||||
raise BillingError('autopay with no card')
|
||||
|
||||
min_licenses = seat_count
|
||||
if billing_modality == 'send_invoice':
|
||||
min_licenses = max(seat_count, MIN_INVOICED_LICENSES)
|
||||
if licenses is None or licenses < min_licenses:
|
||||
raise BillingError('not enough licenses',
|
||||
_("You must invoice for at least {} users.".format(min_licenses)))
|
||||
|
||||
def payment_method_string(stripe_customer: stripe.Customer) -> str:
|
||||
subscription = extract_current_subscription(stripe_customer)
|
||||
|
@ -67,26 +74,29 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str:
|
|||
|
||||
@has_request_variables
|
||||
def upgrade(request: HttpRequest, user: UserProfile,
|
||||
billing_modality: str=REQ(validator=check_string),
|
||||
schedule: str=REQ(validator=check_string),
|
||||
license_management: str=REQ(validator=check_string, default=None),
|
||||
signed_seat_count: str=REQ(validator=check_string),
|
||||
salt: str=REQ(validator=check_string),
|
||||
billing_modality: str=REQ(validator=check_string),
|
||||
licenses: int=REQ(validator=check_int, default=None),
|
||||
stripe_token: str=REQ(validator=check_string, default=None)) -> HttpResponse:
|
||||
stripe_token: str=REQ(validator=check_string, default=None),
|
||||
signed_seat_count: str=REQ(validator=check_string),
|
||||
salt: str=REQ(validator=check_string)) -> HttpResponse:
|
||||
try:
|
||||
seat_count, billing_schedule = unsign_and_check_upgrade_parameters(
|
||||
user, schedule, signed_seat_count, salt, billing_modality)
|
||||
if billing_modality == 'send_invoice':
|
||||
min_required_licenses = max(seat_count, MIN_INVOICED_LICENSES)
|
||||
if licenses < min_required_licenses:
|
||||
raise BillingError(
|
||||
'not enough licenses',
|
||||
"You must invoice for at least %d users." % (min_required_licenses,))
|
||||
else:
|
||||
seat_count = unsign_seat_count(signed_seat_count, salt)
|
||||
if billing_modality == 'charge_automatically' and license_management == 'automatic':
|
||||
licenses = seat_count
|
||||
if billing_modality == 'send_invoice':
|
||||
schedule = 'annual'
|
||||
license_management = 'manual'
|
||||
check_upgrade_parameters(
|
||||
billing_modality, schedule, license_management, licenses,
|
||||
stripe_token is not None, seat_count)
|
||||
|
||||
billing_schedule = {'annual': CustomerPlan.ANNUAL,
|
||||
'monthly': CustomerPlan.MONTHLY}[schedule]
|
||||
process_initial_upgrade(user, licenses, billing_schedule, stripe_token)
|
||||
except BillingError as e:
|
||||
# TODO add a billing_logger.warning with all the upgrade parameters
|
||||
return json_error(e.message, data={'error_description': e.description})
|
||||
except Exception as e:
|
||||
billing_logger.exception("Uncaught exception in billing: %s" % (e,))
|
||||
|
|
Loading…
Reference in New Issue