billing: Move checks from process_initial_upgrade into separate function.

This commit is contained in:
Rishi Gupta 2018-08-06 00:47:15 -04:00
parent 5719633992
commit 9f2b8a4a11
2 changed files with 24 additions and 19 deletions

View File

@ -9,7 +9,6 @@ from django.conf import settings
from django.db import transaction
from django.utils.translation import ugettext as _
from django.core.signing import Signer
from django.core import signing
import stripe
from zerver.lib.exceptions import JsonableError
@ -179,23 +178,11 @@ def do_subscribe_customer_to_plan(stripe_customer: stripe.Customer, stripe_plan_
requires_billing_update=True,
extra_data=ujson.dumps({'quantity': current_seat_count}))
def process_initial_upgrade(user: UserProfile, plan: str, signed_seat_count: str,
salt: str, stripe_token: str) -> None:
if plan not in [Plan.CLOUD_ANNUAL, Plan.CLOUD_MONTHLY]:
billing_logger.warning("Tampered plan during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered plan', BillingError.CONTACT_SUPPORT)
try:
seat_count = int(unsign_string(signed_seat_count, salt))
except signing.BadSignature:
billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT)
def process_initial_upgrade(user: UserProfile, plan: Plan, seat_count: int, stripe_token: str) -> None:
stripe_customer = do_create_customer_with_payment_source(user, stripe_token)
do_subscribe_customer_to_plan(
stripe_customer=stripe_customer,
stripe_plan_id=Plan.objects.get(nickname=plan).stripe_plan_id,
stripe_plan_id=plan.stripe_plan_id,
seat_count=seat_count,
# TODO: billing address details are passed to us in the request;
# use that to calculate taxes.

View File

@ -1,6 +1,7 @@
from typing import Any, Dict, Optional, Union, cast
from typing import Any, Dict, Optional, Tuple, Union, cast
import logging
from django.core import signing
from django.core.exceptions import ValidationError
from django.core.validators import validate_email, URLValidator
from django.db import IntegrityError
@ -27,7 +28,7 @@ from zerver.views.push_notifications import validate_token
from zilencer.lib.stripe import STRIPE_PUBLISHABLE_KEY, \
get_stripe_customer, get_upcoming_invoice, get_seat_count, \
extract_current_subscription, process_initial_upgrade, sign_string, \
BillingError
unsign_string, BillingError
from zilencer.models import RemotePushDeviceToken, RemoteZulipServer, \
Customer, Plan
@ -158,6 +159,22 @@ def remote_server_notify_push(request: HttpRequest, entity: Union[UserProfile, R
return json_success()
def unsign_and_check_upgrade_parameters(user: UserProfile, plan_nickname: str,
signed_seat_count: str, salt: str) -> Tuple[Plan, int]:
if plan_nickname not in [Plan.CLOUD_ANNUAL, Plan.CLOUD_MONTHLY]:
billing_logger.warning("Tampered plan during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered plan', BillingError.CONTACT_SUPPORT)
plan = Plan.objects.get(nickname=plan_nickname)
try:
seat_count = int(unsign_string(signed_seat_count, salt))
except signing.BadSignature:
billing_logger.warning("Tampered seat count during realm upgrade. user: %s, realm: %s (%s)."
% (user.id, user.realm.id, user.realm.string_id))
raise BillingError('tampered seat count', BillingError.CONTACT_SUPPORT)
return plan, seat_count
@zulip_login_required
def initial_upgrade(request: HttpRequest) -> HttpResponse:
if not settings.DEVELOPMENT:
@ -172,8 +189,9 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
if request.method == 'POST':
try:
process_initial_upgrade(user, request.POST['plan'], request.POST['signed_seat_count'],
request.POST['salt'], request.POST['stripeToken'])
plan, seat_count = unsign_and_check_upgrade_parameters(
user, request.POST['plan'], request.POST['signed_seat_count'], request.POST['salt'])
process_initial_upgrade(user, plan, seat_count, request.POST['stripeToken'])
except BillingError as e:
error_message = e.message
error_description = e.description