billing: Don't allow guest users to upgrade.

This commit is contained in:
Vishnu KS 2020-07-16 01:48:32 +05:30 committed by Tim Abbott
parent cb01a7f599
commit 67bacd6e31
8 changed files with 84 additions and 38 deletions

View File

@ -6,7 +6,7 @@ import sys
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from functools import wraps
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, cast
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, cast
from unittest.mock import Mock, patch
import responses
@ -380,8 +380,8 @@ class StripeTest(StripeTestCase):
iago = self.example_user('iago')
with self.settings(BILLING_ENABLED=False):
self.login_user(iago)
response = self.client_get("/upgrade/")
self.assert_in_success_response(["Page not found (404)"], response)
response = self.client_get("/upgrade/", follow=True)
self.assertEqual(response.status_code, 404)
@mock_stripe(tested_timestamp_fields=["created"])
def test_upgrade_by_card(self, *mocks: Mock) -> None:
@ -855,6 +855,11 @@ class StripeTest(StripeTestCase):
@mock_stripe()
def test_billing_page_permissions(self, *mocks: Mock) -> None:
# Guest users can't access /upgrade page
self.login_user(self.example_user('polonius'))
response = self.client_get("/upgrade/", follow=True)
self.assertEqual(response.status_code, 404)
# Check that non-admins can access /upgrade via /billing, when there is no Customer object
self.login_user(self.example_user('hamlet'))
response = self.client_get("/billing/")
@ -1847,34 +1852,35 @@ class RequiresBillingAccessTest(ZulipTestCase):
mocked2.assert_called()
def test_who_cant_access_json_endpoints(self) -> None:
def verify_user_cant_access_endpoint(user: UserProfile, url: str, request_data: Dict[str, Any]={}) -> None:
self.login_user(user)
response = self.client_post(url, request_data)
self.assert_json_error_contains(response, "Must be a billing administrator or an organization owner")
def verify_user_cant_access_endpoint(username: str, endpoint: str, request_data: Dict[str, str], error_message: str) -> None:
self.login_user(self.example_user(username))
response = self.client_post(endpoint, request_data)
self.assert_json_error_contains(response, error_message)
params: List[Tuple[str, Dict[str, Any]]] = [
("/json/billing/sources/change", {'stripe_token': ujson.dumps('token')}),
("/json/billing/plan/change", {'status': ujson.dumps(1)}),
]
verify_user_cant_access_endpoint("polonius", "/json/billing/upgrade",
{'billing_modality': ujson.dumps("charge_automatically"), 'schedule': ujson.dumps("annual"),
'signed_seat_count': ujson.dumps('signed count'), 'salt': ujson.dumps('salt')},
"Must be an organization member")
for (url, data) in params:
# Guests can't access
verify_user_cant_access_endpoint(self.example_user("polonius"), url, data)
# Members (not billing admin) can't access
verify_user_cant_access_endpoint(self.example_user("cordelia"), url, data)
# Realm admins (not billing admin) can't access
verify_user_cant_access_endpoint(self.example_user("iago"), url, data)
verify_user_cant_access_endpoint("polonius", "/json/billing/sponsorship",
{'organization-type': ujson.dumps("event"), 'description': ujson.dumps("event description"),
'website': ujson.dumps("example.com")},
"Must be an organization member")
for username in ["cordelia", "iago"]:
self.login_user(self.example_user(username))
verify_user_cant_access_endpoint(username, "/json/billing/sources/change", {'stripe_token': ujson.dumps('token')},
"Must be a billing administrator or an organization owner")
verify_user_cant_access_endpoint(username, "/json/billing/plan/change", {'status': ujson.dumps(1)},
"Must be a billing administrator or an organization owner")
# Make sure that we are testing all the JSON endpoints
# Quite a hack, but probably fine for now
string_with_all_endpoints = str(get_resolver('corporate.urls').reverse_dict)
json_endpoints = {word.strip("\"'()[],$") for word in string_with_all_endpoints.split()
if 'json/' in word}
# No need to test upgrade and sponsorship endpoints as they only require user to be logged in.
json_endpoints.remove("json/billing/upgrade")
json_endpoints.remove("json/billing/sponsorship")
self.assertEqual(len(json_endpoints), len(params))
self.assertEqual(len(json_endpoints), 4)
class BillingHelpersTest(ZulipTestCase):
def test_next_month(self) -> None:

View File

@ -37,7 +37,11 @@ from corporate.models import (
get_current_plan_by_realm,
get_customer_by_realm,
)
from zerver.decorator import require_billing_access, zulip_login_required
from zerver.decorator import (
require_billing_access,
require_organization_member,
zulip_login_required,
)
from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_error, json_success
from zerver.lib.send_email import FromAddress, send_email
@ -100,6 +104,7 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str:
email=settings.ZULIP_ADMINISTRATOR,
) # nocoverage
@require_organization_member
@has_request_variables
def upgrade(request: HttpRequest, user: UserProfile,
billing_modality: str=REQ(validator=check_string),
@ -144,11 +149,11 @@ def upgrade(request: HttpRequest, user: UserProfile,
@zulip_login_required
def initial_upgrade(request: HttpRequest) -> HttpResponse:
if not settings.BILLING_ENABLED:
return render(request, "404.html")
user = request.user
if not settings.BILLING_ENABLED or user.is_guest:
return render(request, "404.html", status=404)
customer = get_customer_by_realm(user.realm)
if customer is not None and (get_current_plan_by_customer(customer) is not None or customer.sponsorship_pending):
billing_page_url = reverse('corporate.views.billing_home')
@ -184,6 +189,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
response = render(request, 'corporate/upgrade.html', context=context)
return response
@require_organization_member
@has_request_variables
def sponsorship(request: HttpRequest, user: UserProfile,
organization_type: str=REQ("organization-type", validator=check_string),

View File

@ -37,6 +37,7 @@ from zerver.lib.exceptions import (
InvalidJSONError,
JsonableError,
OrganizationAdministratorRequired,
OrganizationMemberRequired,
OrganizationOwnerRequired,
UnexpectedWebhookEventType,
)
@ -130,6 +131,14 @@ def require_realm_admin(func: ViewFuncT) -> ViewFuncT:
return func(request, user_profile, *args, **kwargs)
return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927
def require_organization_member(func: ViewFuncT) -> ViewFuncT:
@wraps(func)
def wrapper(request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object) -> HttpResponse:
if user_profile.role > UserProfile.ROLE_MEMBER:
raise OrganizationMemberRequired()
return func(request, user_profile, *args, **kwargs)
return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927
def require_billing_access(func: ViewFuncT) -> ViewFuncT:
@wraps(func)
def wrapper(request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object) -> HttpResponse:

View File

@ -188,6 +188,18 @@ class InvalidJSONError(JsonableError):
def msg_format() -> str:
return _("Malformed JSON")
class OrganizationMemberRequired(JsonableError):
code: ErrorCode = ErrorCode.UNAUTHORIZED_PRINCIPAL
MEMBER_REQUIRED_MESSAGE = _("Must be an organization member")
def __init__(self) -> None:
super().__init__(self.MEMBER_REQUIRED_MESSAGE)
@staticmethod
def msg_format() -> str:
return OrganizationMemberRequired.MEMBER_REQUIRED_MESSAGE
class OrganizationAdministratorRequired(JsonableError):
code: ErrorCode = ErrorCode.UNAUTHORIZED_PRINCIPAL

View File

@ -356,22 +356,29 @@ class SmtpConfigErrorTest(ZulipTestCase):
class PlansPageTest(ZulipTestCase):
def test_plans_auth(self) -> None:
# Test root domain
result = self.client_get("/plans/", subdomain="")
root_domain = ""
result = self.client_get("/plans/", subdomain=root_domain)
self.assert_in_success_response(["Sign up now"], result)
# Test non-existent domain
result = self.client_get("/plans/", subdomain="moo")
non_existent_domain = "moo"
result = self.client_get("/plans/", subdomain=non_existent_domain)
self.assertEqual(result.status_code, 404)
self.assert_in_response("does not exist", result)
# Test valid domain, no login
realm = get_realm("zulip")
realm.plan_type = Realm.STANDARD_FREE
realm.save(update_fields=["plan_type"])
result = self.client_get("/plans/", subdomain="zulip")
self.assertEqual(result.status_code, 302)
self.assertEqual(result["Location"], "/accounts/login/?next=plans")
# Test valid domain, with login
self.login('hamlet')
guest_user = 'polonius'
self.login(guest_user)
result = self.client_get("/plans/", subdomain="zulip", follow=True)
self.assertEqual(result.status_code, 404)
organization_member = 'hamlet'
self.login(organization_member)
result = self.client_get("/plans/", subdomain="zulip")
self.assert_in_success_response(["Current plan"], result)
# Test root domain, with login on different domain

View File

@ -742,12 +742,17 @@ class HomeTest(ZulipTestCase):
def test_show_plans(self) -> None:
realm = get_realm("zulip")
self.login('hamlet')
# Show plans link to all users if plan_type is LIMITED
# Don't show plans to guest users
self.login('polonius')
realm.plan_type = Realm.LIMITED
realm.save(update_fields=["plan_type"])
result_html = self._get_home_page().content.decode('utf-8')
self.assertNotIn('Plans', result_html)
# Show plans link to all other users if plan_type is LIMITED
self.login('hamlet')
result_html = self._get_home_page().content.decode('utf-8')
self.assertIn('Plans', result_html)
# Show plans link to no one, including admins, if SELF_HOSTED or STANDARD

View File

@ -299,7 +299,7 @@ def home_real(request: HttpRequest) -> HttpResponse:
elif CustomerPlan.objects.filter(customer=customer).exists():
show_billing = True
if user_profile.realm.plan_type == Realm.LIMITED:
if not user_profile.is_guest and user_profile.realm.plan_type == Realm.LIMITED:
show_plans = True
request._log_data['extra'] = "[{}]".format(register_ret["queue_id"])

View File

@ -36,7 +36,8 @@ def plans_view(request: HttpRequest) -> HttpResponse:
return HttpResponseRedirect('https://zulip.com/plans')
if not request.user.is_authenticated:
return redirect_to_login(next="plans")
if request.user.is_guest:
return TemplateResponse(request, "404.html", status=404)
if settings.CORPORATE_ENABLED:
from corporate.models import get_customer_by_realm
customer = get_customer_by_realm(realm)