remote_realm: Plumb RemoteRealmBillingUser into BillingSession.

Adds the RemoteRealmBillingUser object to the BillingSession in the
views decorated by authenticated_remote_realm_management_endpoint.
This commit is contained in:
Mateusz Mandera 2023-12-10 20:05:43 +01:00 committed by Tim Abbott
parent a0ea14bdb1
commit 7d62471d0b
3 changed files with 32 additions and 8 deletions

View File

@ -10,7 +10,7 @@ from typing_extensions import Concatenate, ParamSpec
from corporate.lib.remote_billing_util import ( from corporate.lib.remote_billing_util import (
RemoteBillingIdentityExpiredError, RemoteBillingIdentityExpiredError,
get_remote_realm_from_session, get_remote_realm_and_user_from_session,
get_remote_server_and_user_from_session, get_remote_server_and_user_from_session,
) )
from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession from corporate.lib.stripe import RemoteRealmBillingSession, RemoteServerBillingSession
@ -59,7 +59,9 @@ def authenticated_remote_realm_management_endpoint(
raise TypeError("realm_uuid must be a string or None") raise TypeError("realm_uuid must be a string or None")
try: try:
remote_realm = get_remote_realm_from_session(request, realm_uuid) remote_realm, remote_billing_user = get_remote_realm_and_user_from_session(
request, realm_uuid
)
except RemoteBillingIdentityExpiredError as e: except RemoteBillingIdentityExpiredError as e:
# The user had an authenticated session with an identity_dict, # The user had an authenticated session with an identity_dict,
# but it expired. # but it expired.
@ -78,7 +80,7 @@ def authenticated_remote_realm_management_endpoint(
server_uuid = e.server_uuid server_uuid = e.server_uuid
uri_scheme = e.uri_scheme uri_scheme = e.uri_scheme
if realm_uuid is None: if realm_uuid is None:
# This doesn't make sense - if get_remote_realm_from_session # This doesn't make sense - if get_remote_realm_and_user_from_session
# found an expired identity dict, it should have had a realm_uuid. # found an expired identity dict, it should have had a realm_uuid.
raise AssertionError raise AssertionError
@ -104,7 +106,9 @@ def authenticated_remote_realm_management_endpoint(
return HttpResponseRedirect(url) return HttpResponseRedirect(url)
billing_session = RemoteRealmBillingSession(remote_realm) billing_session = RemoteRealmBillingSession(
remote_realm, remote_billing_user=remote_billing_user
)
return view_func(request, billing_session) return view_func(request, billing_session)
return _wrapped_view_func return _wrapped_view_func

View File

@ -7,7 +7,12 @@ from django.utils.translation import gettext as _
from zerver.lib.exceptions import JsonableError, RemoteBillingAuthenticationError from zerver.lib.exceptions import JsonableError, RemoteBillingAuthenticationError
from zerver.lib.timestamp import datetime_to_timestamp from zerver.lib.timestamp import datetime_to_timestamp
from zilencer.models import RemoteRealm, RemoteServerBillingUser, RemoteZulipServer from zilencer.models import (
RemoteRealm,
RemoteRealmBillingUser,
RemoteServerBillingUser,
RemoteZulipServer,
)
billing_logger = logging.getLogger("corporate.stripe") billing_logger = logging.getLogger("corporate.stripe")
@ -94,10 +99,10 @@ def get_identity_dict_from_session(
return result return result
def get_remote_realm_from_session( def get_remote_realm_and_user_from_session(
request: HttpRequest, request: HttpRequest,
realm_uuid: Optional[str], realm_uuid: Optional[str],
) -> RemoteRealm: ) -> Tuple[RemoteRealm, RemoteRealmBillingUser]:
# Cannot use isinstance with TypeDicts, to make mypy know # Cannot use isinstance with TypeDicts, to make mypy know
# which of the TypedDicts in the Union this is - so just cast it. # which of the TypedDicts in the Union this is - so just cast it.
identity_dict = cast( identity_dict = cast(
@ -127,7 +132,20 @@ def get_remote_realm_from_session(
): ):
raise JsonableError(_("Registration is deactivated")) raise JsonableError(_("Registration is deactivated"))
return remote_realm remote_billing_user_id = identity_dict["remote_billing_user_id"]
# We only put IdentityDicts with remote_billing_user_id in the session in this flow,
# because the RemoteRealmBillingUser already exists when this is inserted into the session
# at the end of authentication.
assert remote_billing_user_id is not None
try:
remote_billing_user = RemoteRealmBillingUser.objects.get(
id=remote_billing_user_id, remote_realm=remote_realm
)
except RemoteRealmBillingUser.DoesNotExist:
raise AssertionError
return remote_realm, remote_billing_user
def get_remote_server_and_user_from_session( def get_remote_server_and_user_from_session(

View File

@ -2821,9 +2821,11 @@ class RemoteRealmBillingSession(BillingSession): # nocoverage
def __init__( def __init__(
self, self,
remote_realm: RemoteRealm, remote_realm: RemoteRealm,
remote_billing_user: Optional[RemoteRealmBillingUser] = None,
support_staff: Optional[UserProfile] = None, support_staff: Optional[UserProfile] = None,
) -> None: ) -> None:
self.remote_realm = remote_realm self.remote_realm = remote_realm
self.remote_billing_user = remote_billing_user
if support_staff is not None: if support_staff is not None:
assert support_staff.is_staff assert support_staff.is_staff
self.support_session = True self.support_session = True