billing: Don't allow guest users to upgrade.

This commit is contained in:
Vishnu KS 2020-07-16 01:48:32 +05:30 committed by Tim Abbott
parent cb01a7f599
commit 67bacd6e31
8 changed files with 84 additions and 38 deletions

View File

@ -6,7 +6,7 @@ import sys
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from decimal import Decimal from decimal import Decimal
from functools import wraps from functools import wraps
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, cast from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, cast
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
import responses import responses
@ -380,8 +380,8 @@ class StripeTest(StripeTestCase):
iago = self.example_user('iago') iago = self.example_user('iago')
with self.settings(BILLING_ENABLED=False): with self.settings(BILLING_ENABLED=False):
self.login_user(iago) self.login_user(iago)
response = self.client_get("/upgrade/") response = self.client_get("/upgrade/", follow=True)
self.assert_in_success_response(["Page not found (404)"], response) self.assertEqual(response.status_code, 404)
@mock_stripe(tested_timestamp_fields=["created"]) @mock_stripe(tested_timestamp_fields=["created"])
def test_upgrade_by_card(self, *mocks: Mock) -> None: def test_upgrade_by_card(self, *mocks: Mock) -> None:
@ -855,6 +855,11 @@ class StripeTest(StripeTestCase):
@mock_stripe() @mock_stripe()
def test_billing_page_permissions(self, *mocks: Mock) -> None: def test_billing_page_permissions(self, *mocks: Mock) -> None:
# Guest users can't access /upgrade page
self.login_user(self.example_user('polonius'))
response = self.client_get("/upgrade/", follow=True)
self.assertEqual(response.status_code, 404)
# Check that non-admins can access /upgrade via /billing, when there is no Customer object # Check that non-admins can access /upgrade via /billing, when there is no Customer object
self.login_user(self.example_user('hamlet')) self.login_user(self.example_user('hamlet'))
response = self.client_get("/billing/") response = self.client_get("/billing/")
@ -1847,34 +1852,35 @@ class RequiresBillingAccessTest(ZulipTestCase):
mocked2.assert_called() mocked2.assert_called()
def test_who_cant_access_json_endpoints(self) -> None: def test_who_cant_access_json_endpoints(self) -> None:
def verify_user_cant_access_endpoint(user: UserProfile, url: str, request_data: Dict[str, Any]={}) -> None: def verify_user_cant_access_endpoint(username: str, endpoint: str, request_data: Dict[str, str], error_message: str) -> None:
self.login_user(user) self.login_user(self.example_user(username))
response = self.client_post(url, request_data) response = self.client_post(endpoint, request_data)
self.assert_json_error_contains(response, "Must be a billing administrator or an organization owner") self.assert_json_error_contains(response, error_message)
params: List[Tuple[str, Dict[str, Any]]] = [ verify_user_cant_access_endpoint("polonius", "/json/billing/upgrade",
("/json/billing/sources/change", {'stripe_token': ujson.dumps('token')}), {'billing_modality': ujson.dumps("charge_automatically"), 'schedule': ujson.dumps("annual"),
("/json/billing/plan/change", {'status': ujson.dumps(1)}), 'signed_seat_count': ujson.dumps('signed count'), 'salt': ujson.dumps('salt')},
] "Must be an organization member")
for (url, data) in params: verify_user_cant_access_endpoint("polonius", "/json/billing/sponsorship",
# Guests can't access {'organization-type': ujson.dumps("event"), 'description': ujson.dumps("event description"),
verify_user_cant_access_endpoint(self.example_user("polonius"), url, data) 'website': ujson.dumps("example.com")},
# Members (not billing admin) can't access "Must be an organization member")
verify_user_cant_access_endpoint(self.example_user("cordelia"), url, data)
# Realm admins (not billing admin) can't access for username in ["cordelia", "iago"]:
verify_user_cant_access_endpoint(self.example_user("iago"), url, data) self.login_user(self.example_user(username))
verify_user_cant_access_endpoint(username, "/json/billing/sources/change", {'stripe_token': ujson.dumps('token')},
"Must be a billing administrator or an organization owner")
verify_user_cant_access_endpoint(username, "/json/billing/plan/change", {'status': ujson.dumps(1)},
"Must be a billing administrator or an organization owner")
# Make sure that we are testing all the JSON endpoints # Make sure that we are testing all the JSON endpoints
# Quite a hack, but probably fine for now # Quite a hack, but probably fine for now
string_with_all_endpoints = str(get_resolver('corporate.urls').reverse_dict) string_with_all_endpoints = str(get_resolver('corporate.urls').reverse_dict)
json_endpoints = {word.strip("\"'()[],$") for word in string_with_all_endpoints.split() json_endpoints = {word.strip("\"'()[],$") for word in string_with_all_endpoints.split()
if 'json/' in word} if 'json/' in word}
# No need to test upgrade and sponsorship endpoints as they only require user to be logged in. self.assertEqual(len(json_endpoints), 4)
json_endpoints.remove("json/billing/upgrade")
json_endpoints.remove("json/billing/sponsorship")
self.assertEqual(len(json_endpoints), len(params))
class BillingHelpersTest(ZulipTestCase): class BillingHelpersTest(ZulipTestCase):
def test_next_month(self) -> None: def test_next_month(self) -> None:

View File

@ -37,7 +37,11 @@ from corporate.models import (
get_current_plan_by_realm, get_current_plan_by_realm,
get_customer_by_realm, get_customer_by_realm,
) )
from zerver.decorator import require_billing_access, zulip_login_required from zerver.decorator import (
require_billing_access,
require_organization_member,
zulip_login_required,
)
from zerver.lib.request import REQ, has_request_variables from zerver.lib.request import REQ, has_request_variables
from zerver.lib.response import json_error, json_success from zerver.lib.response import json_error, json_success
from zerver.lib.send_email import FromAddress, send_email from zerver.lib.send_email import FromAddress, send_email
@ -100,6 +104,7 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str:
email=settings.ZULIP_ADMINISTRATOR, email=settings.ZULIP_ADMINISTRATOR,
) # nocoverage ) # nocoverage
@require_organization_member
@has_request_variables @has_request_variables
def upgrade(request: HttpRequest, user: UserProfile, def upgrade(request: HttpRequest, user: UserProfile,
billing_modality: str=REQ(validator=check_string), billing_modality: str=REQ(validator=check_string),
@ -144,11 +149,11 @@ def upgrade(request: HttpRequest, user: UserProfile,
@zulip_login_required @zulip_login_required
def initial_upgrade(request: HttpRequest) -> HttpResponse: def initial_upgrade(request: HttpRequest) -> HttpResponse:
if not settings.BILLING_ENABLED:
return render(request, "404.html")
user = request.user user = request.user
if not settings.BILLING_ENABLED or user.is_guest:
return render(request, "404.html", status=404)
customer = get_customer_by_realm(user.realm) customer = get_customer_by_realm(user.realm)
if customer is not None and (get_current_plan_by_customer(customer) is not None or customer.sponsorship_pending): if customer is not None and (get_current_plan_by_customer(customer) is not None or customer.sponsorship_pending):
billing_page_url = reverse('corporate.views.billing_home') billing_page_url = reverse('corporate.views.billing_home')
@ -184,6 +189,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse:
response = render(request, 'corporate/upgrade.html', context=context) response = render(request, 'corporate/upgrade.html', context=context)
return response return response
@require_organization_member
@has_request_variables @has_request_variables
def sponsorship(request: HttpRequest, user: UserProfile, def sponsorship(request: HttpRequest, user: UserProfile,
organization_type: str=REQ("organization-type", validator=check_string), organization_type: str=REQ("organization-type", validator=check_string),

View File

@ -37,6 +37,7 @@ from zerver.lib.exceptions import (
InvalidJSONError, InvalidJSONError,
JsonableError, JsonableError,
OrganizationAdministratorRequired, OrganizationAdministratorRequired,
OrganizationMemberRequired,
OrganizationOwnerRequired, OrganizationOwnerRequired,
UnexpectedWebhookEventType, UnexpectedWebhookEventType,
) )
@ -130,6 +131,14 @@ def require_realm_admin(func: ViewFuncT) -> ViewFuncT:
return func(request, user_profile, *args, **kwargs) return func(request, user_profile, *args, **kwargs)
return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927 return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927
def require_organization_member(func: ViewFuncT) -> ViewFuncT:
@wraps(func)
def wrapper(request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object) -> HttpResponse:
if user_profile.role > UserProfile.ROLE_MEMBER:
raise OrganizationMemberRequired()
return func(request, user_profile, *args, **kwargs)
return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927
def require_billing_access(func: ViewFuncT) -> ViewFuncT: def require_billing_access(func: ViewFuncT) -> ViewFuncT:
@wraps(func) @wraps(func)
def wrapper(request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object) -> HttpResponse: def wrapper(request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object) -> HttpResponse:

View File

@ -188,6 +188,18 @@ class InvalidJSONError(JsonableError):
def msg_format() -> str: def msg_format() -> str:
return _("Malformed JSON") return _("Malformed JSON")
class OrganizationMemberRequired(JsonableError):
code: ErrorCode = ErrorCode.UNAUTHORIZED_PRINCIPAL
MEMBER_REQUIRED_MESSAGE = _("Must be an organization member")
def __init__(self) -> None:
super().__init__(self.MEMBER_REQUIRED_MESSAGE)
@staticmethod
def msg_format() -> str:
return OrganizationMemberRequired.MEMBER_REQUIRED_MESSAGE
class OrganizationAdministratorRequired(JsonableError): class OrganizationAdministratorRequired(JsonableError):
code: ErrorCode = ErrorCode.UNAUTHORIZED_PRINCIPAL code: ErrorCode = ErrorCode.UNAUTHORIZED_PRINCIPAL

View File

@ -356,22 +356,29 @@ class SmtpConfigErrorTest(ZulipTestCase):
class PlansPageTest(ZulipTestCase): class PlansPageTest(ZulipTestCase):
def test_plans_auth(self) -> None: def test_plans_auth(self) -> None:
# Test root domain root_domain = ""
result = self.client_get("/plans/", subdomain="") result = self.client_get("/plans/", subdomain=root_domain)
self.assert_in_success_response(["Sign up now"], result) self.assert_in_success_response(["Sign up now"], result)
# Test non-existent domain
result = self.client_get("/plans/", subdomain="moo") non_existent_domain = "moo"
result = self.client_get("/plans/", subdomain=non_existent_domain)
self.assertEqual(result.status_code, 404) self.assertEqual(result.status_code, 404)
self.assert_in_response("does not exist", result) self.assert_in_response("does not exist", result)
# Test valid domain, no login
realm = get_realm("zulip") realm = get_realm("zulip")
realm.plan_type = Realm.STANDARD_FREE realm.plan_type = Realm.STANDARD_FREE
realm.save(update_fields=["plan_type"]) realm.save(update_fields=["plan_type"])
result = self.client_get("/plans/", subdomain="zulip") result = self.client_get("/plans/", subdomain="zulip")
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result["Location"], "/accounts/login/?next=plans") self.assertEqual(result["Location"], "/accounts/login/?next=plans")
# Test valid domain, with login
self.login('hamlet') guest_user = 'polonius'
self.login(guest_user)
result = self.client_get("/plans/", subdomain="zulip", follow=True)
self.assertEqual(result.status_code, 404)
organization_member = 'hamlet'
self.login(organization_member)
result = self.client_get("/plans/", subdomain="zulip") result = self.client_get("/plans/", subdomain="zulip")
self.assert_in_success_response(["Current plan"], result) self.assert_in_success_response(["Current plan"], result)
# Test root domain, with login on different domain # Test root domain, with login on different domain

View File

@ -742,12 +742,17 @@ class HomeTest(ZulipTestCase):
def test_show_plans(self) -> None: def test_show_plans(self) -> None:
realm = get_realm("zulip") realm = get_realm("zulip")
self.login('hamlet')
# Show plans link to all users if plan_type is LIMITED # Don't show plans to guest users
self.login('polonius')
realm.plan_type = Realm.LIMITED realm.plan_type = Realm.LIMITED
realm.save(update_fields=["plan_type"]) realm.save(update_fields=["plan_type"])
result_html = self._get_home_page().content.decode('utf-8') result_html = self._get_home_page().content.decode('utf-8')
self.assertNotIn('Plans', result_html)
# Show plans link to all other users if plan_type is LIMITED
self.login('hamlet')
result_html = self._get_home_page().content.decode('utf-8')
self.assertIn('Plans', result_html) self.assertIn('Plans', result_html)
# Show plans link to no one, including admins, if SELF_HOSTED or STANDARD # Show plans link to no one, including admins, if SELF_HOSTED or STANDARD

View File

@ -299,7 +299,7 @@ def home_real(request: HttpRequest) -> HttpResponse:
elif CustomerPlan.objects.filter(customer=customer).exists(): elif CustomerPlan.objects.filter(customer=customer).exists():
show_billing = True show_billing = True
if user_profile.realm.plan_type == Realm.LIMITED: if not user_profile.is_guest and user_profile.realm.plan_type == Realm.LIMITED:
show_plans = True show_plans = True
request._log_data['extra'] = "[{}]".format(register_ret["queue_id"]) request._log_data['extra'] = "[{}]".format(register_ret["queue_id"])

View File

@ -36,7 +36,8 @@ def plans_view(request: HttpRequest) -> HttpResponse:
return HttpResponseRedirect('https://zulip.com/plans') return HttpResponseRedirect('https://zulip.com/plans')
if not request.user.is_authenticated: if not request.user.is_authenticated:
return redirect_to_login(next="plans") return redirect_to_login(next="plans")
if request.user.is_guest:
return TemplateResponse(request, "404.html", status=404)
if settings.CORPORATE_ENABLED: if settings.CORPORATE_ENABLED:
from corporate.models import get_customer_by_realm from corporate.models import get_customer_by_realm
customer = get_customer_by_realm(realm) customer = get_customer_by_realm(realm)