diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 266e335b43..5f2fdb2070 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -6,7 +6,7 @@ import sys from datetime import datetime, timedelta, timezone from decimal import Decimal from functools import wraps -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, cast +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, TypeVar, cast from unittest.mock import Mock, patch import responses @@ -380,8 +380,8 @@ class StripeTest(StripeTestCase): iago = self.example_user('iago') with self.settings(BILLING_ENABLED=False): self.login_user(iago) - response = self.client_get("/upgrade/") - self.assert_in_success_response(["Page not found (404)"], response) + response = self.client_get("/upgrade/", follow=True) + self.assertEqual(response.status_code, 404) @mock_stripe(tested_timestamp_fields=["created"]) def test_upgrade_by_card(self, *mocks: Mock) -> None: @@ -855,6 +855,11 @@ class StripeTest(StripeTestCase): @mock_stripe() def test_billing_page_permissions(self, *mocks: Mock) -> None: + # Guest users can't access /upgrade page + self.login_user(self.example_user('polonius')) + response = self.client_get("/upgrade/", follow=True) + self.assertEqual(response.status_code, 404) + # Check that non-admins can access /upgrade via /billing, when there is no Customer object self.login_user(self.example_user('hamlet')) response = self.client_get("/billing/") @@ -1847,34 +1852,35 @@ class RequiresBillingAccessTest(ZulipTestCase): mocked2.assert_called() def test_who_cant_access_json_endpoints(self) -> None: - def verify_user_cant_access_endpoint(user: UserProfile, url: str, request_data: Dict[str, Any]={}) -> None: - self.login_user(user) - response = self.client_post(url, request_data) - self.assert_json_error_contains(response, "Must be a billing administrator or an organization owner") + def verify_user_cant_access_endpoint(username: str, endpoint: str, request_data: Dict[str, str], error_message: str) -> None: + self.login_user(self.example_user(username)) + response = self.client_post(endpoint, request_data) + self.assert_json_error_contains(response, error_message) - params: List[Tuple[str, Dict[str, Any]]] = [ - ("/json/billing/sources/change", {'stripe_token': ujson.dumps('token')}), - ("/json/billing/plan/change", {'status': ujson.dumps(1)}), - ] + verify_user_cant_access_endpoint("polonius", "/json/billing/upgrade", + {'billing_modality': ujson.dumps("charge_automatically"), 'schedule': ujson.dumps("annual"), + 'signed_seat_count': ujson.dumps('signed count'), 'salt': ujson.dumps('salt')}, + "Must be an organization member") - for (url, data) in params: - # Guests can't access - verify_user_cant_access_endpoint(self.example_user("polonius"), url, data) - # Members (not billing admin) can't access - verify_user_cant_access_endpoint(self.example_user("cordelia"), url, data) - # Realm admins (not billing admin) can't access - verify_user_cant_access_endpoint(self.example_user("iago"), url, data) + verify_user_cant_access_endpoint("polonius", "/json/billing/sponsorship", + {'organization-type': ujson.dumps("event"), 'description': ujson.dumps("event description"), + 'website': ujson.dumps("example.com")}, + "Must be an organization member") + + for username in ["cordelia", "iago"]: + self.login_user(self.example_user(username)) + verify_user_cant_access_endpoint(username, "/json/billing/sources/change", {'stripe_token': ujson.dumps('token')}, + "Must be a billing administrator or an organization owner") + + verify_user_cant_access_endpoint(username, "/json/billing/plan/change", {'status': ujson.dumps(1)}, + "Must be a billing administrator or an organization owner") # Make sure that we are testing all the JSON endpoints # Quite a hack, but probably fine for now string_with_all_endpoints = str(get_resolver('corporate.urls').reverse_dict) json_endpoints = {word.strip("\"'()[],$") for word in string_with_all_endpoints.split() if 'json/' in word} - # No need to test upgrade and sponsorship endpoints as they only require user to be logged in. - json_endpoints.remove("json/billing/upgrade") - json_endpoints.remove("json/billing/sponsorship") - - self.assertEqual(len(json_endpoints), len(params)) + self.assertEqual(len(json_endpoints), 4) class BillingHelpersTest(ZulipTestCase): def test_next_month(self) -> None: diff --git a/corporate/views.py b/corporate/views.py index 42c7e6b53e..5972344b69 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -37,7 +37,11 @@ from corporate.models import ( get_current_plan_by_realm, get_customer_by_realm, ) -from zerver.decorator import require_billing_access, zulip_login_required +from zerver.decorator import ( + require_billing_access, + require_organization_member, + zulip_login_required, +) from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_error, json_success from zerver.lib.send_email import FromAddress, send_email @@ -100,6 +104,7 @@ def payment_method_string(stripe_customer: stripe.Customer) -> str: email=settings.ZULIP_ADMINISTRATOR, ) # nocoverage +@require_organization_member @has_request_variables def upgrade(request: HttpRequest, user: UserProfile, billing_modality: str=REQ(validator=check_string), @@ -144,11 +149,11 @@ def upgrade(request: HttpRequest, user: UserProfile, @zulip_login_required def initial_upgrade(request: HttpRequest) -> HttpResponse: - if not settings.BILLING_ENABLED: - return render(request, "404.html") - user = request.user + if not settings.BILLING_ENABLED or user.is_guest: + return render(request, "404.html", status=404) + customer = get_customer_by_realm(user.realm) if customer is not None and (get_current_plan_by_customer(customer) is not None or customer.sponsorship_pending): billing_page_url = reverse('corporate.views.billing_home') @@ -184,6 +189,7 @@ def initial_upgrade(request: HttpRequest) -> HttpResponse: response = render(request, 'corporate/upgrade.html', context=context) return response +@require_organization_member @has_request_variables def sponsorship(request: HttpRequest, user: UserProfile, organization_type: str=REQ("organization-type", validator=check_string), diff --git a/zerver/decorator.py b/zerver/decorator.py index 4ce888e143..a57a491575 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -37,6 +37,7 @@ from zerver.lib.exceptions import ( InvalidJSONError, JsonableError, OrganizationAdministratorRequired, + OrganizationMemberRequired, OrganizationOwnerRequired, UnexpectedWebhookEventType, ) @@ -130,6 +131,14 @@ def require_realm_admin(func: ViewFuncT) -> ViewFuncT: return func(request, user_profile, *args, **kwargs) return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927 +def require_organization_member(func: ViewFuncT) -> ViewFuncT: + @wraps(func) + def wrapper(request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object) -> HttpResponse: + if user_profile.role > UserProfile.ROLE_MEMBER: + raise OrganizationMemberRequired() + return func(request, user_profile, *args, **kwargs) + return cast(ViewFuncT, wrapper) # https://github.com/python/mypy/issues/1927 + def require_billing_access(func: ViewFuncT) -> ViewFuncT: @wraps(func) def wrapper(request: HttpRequest, user_profile: UserProfile, *args: object, **kwargs: object) -> HttpResponse: diff --git a/zerver/lib/exceptions.py b/zerver/lib/exceptions.py index e488bb4a87..2a6df96b08 100644 --- a/zerver/lib/exceptions.py +++ b/zerver/lib/exceptions.py @@ -188,6 +188,18 @@ class InvalidJSONError(JsonableError): def msg_format() -> str: return _("Malformed JSON") +class OrganizationMemberRequired(JsonableError): + code: ErrorCode = ErrorCode.UNAUTHORIZED_PRINCIPAL + + MEMBER_REQUIRED_MESSAGE = _("Must be an organization member") + + def __init__(self) -> None: + super().__init__(self.MEMBER_REQUIRED_MESSAGE) + + @staticmethod + def msg_format() -> str: + return OrganizationMemberRequired.MEMBER_REQUIRED_MESSAGE + class OrganizationAdministratorRequired(JsonableError): code: ErrorCode = ErrorCode.UNAUTHORIZED_PRINCIPAL diff --git a/zerver/tests/test_docs.py b/zerver/tests/test_docs.py index 33edd6b298..f4897362ef 100644 --- a/zerver/tests/test_docs.py +++ b/zerver/tests/test_docs.py @@ -356,22 +356,29 @@ class SmtpConfigErrorTest(ZulipTestCase): class PlansPageTest(ZulipTestCase): def test_plans_auth(self) -> None: - # Test root domain - result = self.client_get("/plans/", subdomain="") + root_domain = "" + result = self.client_get("/plans/", subdomain=root_domain) self.assert_in_success_response(["Sign up now"], result) - # Test non-existent domain - result = self.client_get("/plans/", subdomain="moo") + + non_existent_domain = "moo" + result = self.client_get("/plans/", subdomain=non_existent_domain) self.assertEqual(result.status_code, 404) self.assert_in_response("does not exist", result) - # Test valid domain, no login + realm = get_realm("zulip") realm.plan_type = Realm.STANDARD_FREE realm.save(update_fields=["plan_type"]) result = self.client_get("/plans/", subdomain="zulip") self.assertEqual(result.status_code, 302) self.assertEqual(result["Location"], "/accounts/login/?next=plans") - # Test valid domain, with login - self.login('hamlet') + + guest_user = 'polonius' + self.login(guest_user) + result = self.client_get("/plans/", subdomain="zulip", follow=True) + self.assertEqual(result.status_code, 404) + + organization_member = 'hamlet' + self.login(organization_member) result = self.client_get("/plans/", subdomain="zulip") self.assert_in_success_response(["Current plan"], result) # Test root domain, with login on different domain diff --git a/zerver/tests/test_home.py b/zerver/tests/test_home.py index 5b75ba0c9b..2e23062430 100644 --- a/zerver/tests/test_home.py +++ b/zerver/tests/test_home.py @@ -742,12 +742,17 @@ class HomeTest(ZulipTestCase): def test_show_plans(self) -> None: realm = get_realm("zulip") - self.login('hamlet') - # Show plans link to all users if plan_type is LIMITED + # Don't show plans to guest users + self.login('polonius') realm.plan_type = Realm.LIMITED realm.save(update_fields=["plan_type"]) result_html = self._get_home_page().content.decode('utf-8') + self.assertNotIn('Plans', result_html) + + # Show plans link to all other users if plan_type is LIMITED + self.login('hamlet') + result_html = self._get_home_page().content.decode('utf-8') self.assertIn('Plans', result_html) # Show plans link to no one, including admins, if SELF_HOSTED or STANDARD diff --git a/zerver/views/home.py b/zerver/views/home.py index 296b5f01cd..ce9fd6e0d7 100644 --- a/zerver/views/home.py +++ b/zerver/views/home.py @@ -299,7 +299,7 @@ def home_real(request: HttpRequest) -> HttpResponse: elif CustomerPlan.objects.filter(customer=customer).exists(): show_billing = True - if user_profile.realm.plan_type == Realm.LIMITED: + if not user_profile.is_guest and user_profile.realm.plan_type == Realm.LIMITED: show_plans = True request._log_data['extra'] = "[{}]".format(register_ret["queue_id"]) diff --git a/zerver/views/portico.py b/zerver/views/portico.py index bb3d6cb803..4b17bff9ee 100644 --- a/zerver/views/portico.py +++ b/zerver/views/portico.py @@ -36,7 +36,8 @@ def plans_view(request: HttpRequest) -> HttpResponse: return HttpResponseRedirect('https://zulip.com/plans') if not request.user.is_authenticated: return redirect_to_login(next="plans") - + if request.user.is_guest: + return TemplateResponse(request, "404.html", status=404) if settings.CORPORATE_ENABLED: from corporate.models import get_customer_by_realm customer = get_customer_by_realm(realm)