diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 004373599e..783de3fdc4 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -11,6 +11,7 @@ import json from django.core import signing from django.core.management import call_command +from django.core.urlresolvers import get_resolver from django.http import HttpResponse from django.utils.timezone import utc as timezone_utc @@ -27,6 +28,7 @@ from corporate.lib.stripe import catch_stripe_errors, \ get_next_billing_log_entry, run_billing_processor_one_step, \ BillingError, StripeCardError, StripeConnectionError, stripe_get_customer from corporate.models import Customer, Plan, Coupon, BillingProcessor +import corporate.urls CallableT = TypeVar('CallableT', bound=Callable[..., Any]) @@ -561,26 +563,6 @@ class StripeTest(ZulipTestCase): self.assertEqual(ujson.loads(response.content)['error_description'], 'downgrade without subscription') mock_save_customer.assert_not_called() - def test_downgrade_permissions(self) -> None: - self.login(self.example_email('hamlet')) - response = self.client_post("/json/billing/downgrade", {}) - self.assert_json_error_contains(response, "Access denied") - # billing admin but not realm admin - user = self.example_user('hamlet') - user.is_billing_admin = True - user.save(update_fields=['is_billing_admin']) - with patch('corporate.views.process_downgrade') as mocked1: - self.client_post("/json/billing/downgrade", {}) - mocked1.assert_called() - # realm admin but not billing admin - user = self.example_user('hamlet') - user.is_billing_admin = False - user.is_realm_admin = True - user.save(update_fields=['is_billing_admin', 'is_realm_admin']) - with patch('corporate.views.process_downgrade') as mocked2: - self.client_post("/json/billing/downgrade", {}) - mocked2.assert_called() - @patch("stripe.Subscription.delete") @patch("stripe.Customer.retrieve", side_effect=mock_customer_with_account_balance(1234)) def test_downgrade_credits(self, mock_retrieve_customer: Mock, @@ -632,31 +614,6 @@ class StripeTest(ZulipTestCase): self.assertFalse(RealmAuditLog.objects.filter( event_type=RealmAuditLog.STRIPE_CARD_CHANGED).exists()) - def test_update_payment_source_permissions(self) -> None: - # This can be removed / merged with e.g. test_downgrade_permissions - # once we have a decorator that handles billing page permissions - self.login(self.example_email('hamlet')) - response = self.client_post("/json/billing/sources/change", - {'stripe_token': ujson.dumps('token')}) - self.assert_json_error_contains(response, "Access denied") - # billing admin but not realm admin - user = self.example_user('hamlet') - user.is_billing_admin = True - user.save(update_fields=['is_billing_admin']) - with patch('corporate.views.do_replace_payment_source') as mocked1: - self.client_post("/json/billing/sources/change", - {'stripe_token': ujson.dumps('token')}) - mocked1.assert_called() - # realm admin but not billing admin - user = self.example_user('hamlet') - user.is_billing_admin = False - user.is_realm_admin = True - user.save(update_fields=['is_billing_admin', 'is_realm_admin']) - with patch('corporate.views.do_replace_payment_source') as mocked2: - self.client_post("/json/billing/sources/change", - {'stripe_token': ujson.dumps('token')}) - mocked2.assert_called() - @patch("stripe.Customer.create", side_effect=mock_create_customer) @patch("stripe.Subscription.create", side_effect=mock_create_subscription) @patch("stripe.Customer.retrieve", side_effect=mock_customer_with_subscription) @@ -739,6 +696,53 @@ class RequiresBillingUpdateTest(ZulipTestCase): do_activate_user(user2) self.assertEqual(4, RealmAuditLog.objects.filter(requires_billing_update=True).count()) +class RequiresBillingAccessTest(ZulipTestCase): + def setUp(self) -> None: + hamlet = self.example_user("hamlet") + hamlet.is_billing_admin = True + hamlet.save(update_fields=["is_billing_admin"]) + + # mocked_function_name will typically be something imported from + # stripe.py. In theory we could have endpoints that need to mock + # multiple functions, but we'll cross that bridge when we get there. + def _test_endpoint(self, url: str, mocked_function_name: str, + request_data: Optional[Dict[str, Any]]={}) -> None: + # Normal users do not have access + self.login(self.example_email('cordelia')) + response = self.client_post(url, request_data) + self.assert_json_error_contains(response, "Access denied") + + # Billing admins have access + self.login(self.example_email('hamlet')) + with patch("corporate.views.{}".format(mocked_function_name)) as mocked1: + response = self.client_post(url, request_data) + self.assert_json_success(response) + mocked1.assert_called() + + # Realm admins have access, even if they are not billing admins + self.login(self.example_email('iago')) + with patch("corporate.views.{}".format(mocked_function_name)) as mocked2: + response = self.client_post(url, request_data) + self.assert_json_success(response) + mocked2.assert_called() + + def test_json_endpoints(self) -> None: + params = [ + ("/json/billing/sources/change", "do_replace_payment_source", + {'stripe_token': ujson.dumps('token')}), + ("/json/billing/downgrade", "process_downgrade", {}) + ] # type: List[Tuple[str, str, Dict[str, Any]]] + + for (url, mocked_function_name, data) in params: + self._test_endpoint(url, mocked_function_name, data) + + # Make sure that we are testing all the JSON endpoints + # Quite a hack, but probably fine for now + string_with_all_endpoints = str(get_resolver('corporate.urls').reverse_dict) + json_endpoints = set([word.strip("\"'()[],$") for word in string_with_all_endpoints.split() + if 'json' in word]) + self.assertEqual(len(json_endpoints), len(params)) + class BillingProcessorTest(ZulipTestCase): def add_log_entry(self, realm: Realm=get_realm('zulip'), event_type: str=RealmAuditLog.USER_CREATED, diff --git a/corporate/urls.py b/corporate/urls.py index 3f99603b6a..1f7d6f3ab4 100644 --- a/corporate/urls.py +++ b/corporate/urls.py @@ -19,7 +19,7 @@ i18n_urlpatterns = [ v1_api_and_json_patterns = [ url(r'^billing/downgrade$', rest_dispatch, {'POST': 'corporate.views.downgrade'}), - url(r'billing/sources/change', rest_dispatch, + url(r'^billing/sources/change', rest_dispatch, {'POST': 'corporate.views.replace_payment_source'}), ] diff --git a/corporate/views.py b/corporate/views.py index b2ca1692fb..0167bbdec0 100644 --- a/corporate/views.py +++ b/corporate/views.py @@ -9,7 +9,7 @@ from django.shortcuts import redirect, render from django.urls import reverse from django.conf import settings -from zerver.decorator import zulip_login_required +from zerver.decorator import zulip_login_required, require_billing_access from zerver.lib.request import REQ, has_request_variables from zerver.lib.response import json_error, json_success from zerver.lib.validator import check_string @@ -144,20 +144,18 @@ def billing_home(request: HttpRequest) -> HttpResponse: return render(request, 'corporate/billing.html', context=context) +@require_billing_access def downgrade(request: HttpRequest, user: UserProfile) -> HttpResponse: - if not user.is_realm_admin and not user.is_billing_admin: - return json_error(_('Access denied')) try: process_downgrade(user) except BillingError as e: return json_error(e.message, data={'error_description': e.description}) return json_success() +@require_billing_access @has_request_variables def replace_payment_source(request: HttpRequest, user: UserProfile, stripe_token: str=REQ("stripe_token", validator=check_string)) -> HttpResponse: - if not user.is_realm_admin and not user.is_billing_admin: - return json_error(_("Access denied")) try: do_replace_payment_source(user, stripe_token) except BillingError as e: diff --git a/zerver/decorator.py b/zerver/decorator.py index ded7cb3a42..1bf737223c 100644 --- a/zerver/decorator.py +++ b/zerver/decorator.py @@ -135,6 +135,14 @@ def require_realm_admin(func: ViewFuncT) -> ViewFuncT: return func(request, user_profile, *args, **kwargs) return wrapper # type: ignore # https://github.com/python/mypy/issues/1927 +def require_billing_access(func: ViewFuncT) -> ViewFuncT: + @wraps(func) + def wrapper(request: HttpRequest, user_profile: UserProfile, *args: Any, **kwargs: Any) -> HttpResponse: + if not user_profile.is_realm_admin and not user_profile.is_billing_admin: + raise JsonableError(_("Access denied")) + return func(request, user_profile, *args, **kwargs) + return wrapper # type: ignore # https://github.com/python/mypy/issues/1927 + from zerver.lib.user_agent import parse_user_agent def get_client_name(request: HttpRequest, is_browser_view: bool) -> str: