billing: Add function to sign strings.

This commit is contained in:
Vishnu Ks 2018-07-13 21:04:39 +05:30 committed by Rishi Gupta
parent 8de454ad2d
commit d75054fb15
2 changed files with 23 additions and 2 deletions

View File

@ -2,17 +2,19 @@ import datetime
from functools import wraps from functools import wraps
import logging import logging
import os import os
from typing import Any, Callable, Optional, TypeVar from typing import Any, Callable, Optional, TypeVar, Tuple
import ujson import ujson
from django.conf import settings from django.conf import settings
from django.db import transaction from django.db import transaction
from django.utils.translation import ugettext as _ from django.utils.translation import ugettext as _
from django.core.signing import Signer
import stripe import stripe
from zerver.lib.exceptions import JsonableError from zerver.lib.exceptions import JsonableError
from zerver.lib.logging_util import log_to_file from zerver.lib.logging_util import log_to_file
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import generate_random_token
from zerver.models import Realm, UserProfile, RealmAuditLog from zerver.models import Realm, UserProfile, RealmAuditLog
from zilencer.models import Customer, Plan from zilencer.models import Customer, Plan
from zproject.settings import get_secret from zproject.settings import get_secret
@ -52,6 +54,15 @@ CallableT = TypeVar('CallableT', bound=Callable[..., Any])
def get_seat_count(realm: Realm) -> int: def get_seat_count(realm: Realm) -> int:
return UserProfile.objects.filter(realm=realm, is_active=True, is_bot=False).count() return UserProfile.objects.filter(realm=realm, is_active=True, is_bot=False).count()
def sign_string(string: str) -> Tuple[str, str]:
salt = generate_random_token(64)
signer = Signer(salt=salt)
return signer.sign(string), salt
def unsign_string(signed_string: str, salt: str) -> str:
signer = Signer(salt=salt)
return signer.unsign(signed_string)
class StripeError(JsonableError): class StripeError(JsonableError):
pass pass

View File

@ -3,6 +3,8 @@ import os
from typing import Any from typing import Any
import ujson import ujson
from django.core import signing
import stripe import stripe
from stripe.api_resources.list_object import ListObject from stripe.api_resources.list_object import ListObject
@ -13,7 +15,7 @@ from zerver.lib.timestamp import timestamp_to_datetime
from zerver.models import Realm, UserProfile, get_realm, RealmAuditLog from zerver.models import Realm, UserProfile, get_realm, RealmAuditLog
from zilencer.lib.stripe import StripeError, catch_stripe_errors, \ from zilencer.lib.stripe import StripeError, catch_stripe_errors, \
do_create_customer_with_payment_source, do_subscribe_customer_to_plan, \ do_create_customer_with_payment_source, do_subscribe_customer_to_plan, \
get_seat_count, extract_current_subscription get_seat_count, extract_current_subscription, sign_string, unsign_string
from zilencer.models import Customer, Plan from zilencer.models import Customer, Plan
fixture_data_file = open(os.path.join(os.path.dirname(__file__), 'stripe_fixtures.json'), 'r') fixture_data_file = open(os.path.join(os.path.dirname(__file__), 'stripe_fixtures.json'), 'r')
@ -246,6 +248,14 @@ class StripeTest(ZulipTestCase):
subscription = extract_current_subscription(customer_with_subscription) subscription = extract_current_subscription(customer_with_subscription)
self.assertEqual(subscription["id"][:4], "sub_") self.assertEqual(subscription["id"][:4], "sub_")
def test_sign_string(self) -> None:
string = "abc"
signed_string, salt = sign_string(string)
self.assertEqual(string, unsign_string(signed_string, salt))
with self.assertRaises(signing.BadSignature):
unsign_string(signed_string, "randomsalt")
class BillingUpdateTest(ZulipTestCase): class BillingUpdateTest(ZulipTestCase):
def setUp(self) -> None: def setUp(self) -> None:
self.user = self.example_user("hamlet") self.user = self.example_user("hamlet")