zilencer: Tighten type annotations of views.

`remote_server_path` allows us to get rid of all the `validate_entity`
calls in `zilencer.views` and remove all the `Union` type annotations
in the signatures of the authenticated view functions.

Signed-off-by: Zixuan James Li <p359101898@gmail.com>
This commit is contained in:
Zixuan James Li 2022-08-01 17:51:10 -04:00 committed by Tim Abbott
parent 5c49e4ba06
commit eae3e1c3cc
1 changed files with 10 additions and 30 deletions

View File

@ -1,6 +1,6 @@
import datetime import datetime
import logging import logging
from typing import Any, Dict, List, Optional, Type, TypeVar, Union from typing import Any, Dict, List, Optional, Type, TypeVar
from uuid import UUID from uuid import UUID
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
@ -36,7 +36,6 @@ from zerver.lib.validator import (
check_string, check_string,
check_string_fixed_length, check_string_fixed_length,
) )
from zerver.models import UserProfile
from zerver.views.push_notifications import validate_token from zerver.views.push_notifications import validate_token
from zilencer.auth import InvalidZulipServerKeyError from zilencer.auth import InvalidZulipServerKeyError
from zilencer.models import ( from zilencer.models import (
@ -51,12 +50,6 @@ from zilencer.models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def validate_entity(entity: Union[UserProfile, RemoteZulipServer]) -> RemoteZulipServer:
if not isinstance(entity, RemoteZulipServer):
raise JsonableError(err_("Must validate with valid Zulip server API key"))
return entity
def validate_uuid(uuid: str) -> None: def validate_uuid(uuid: str) -> None:
try: try:
uuid_object = UUID(uuid, version=4) uuid_object = UUID(uuid, version=4)
@ -70,14 +63,10 @@ def validate_uuid(uuid: str) -> None:
raise ValidationError(err_("Invalid UUID")) raise ValidationError(err_("Invalid UUID"))
def validate_bouncer_token_request( def validate_bouncer_token_request(token: str, kind: int) -> None:
entity: Union[UserProfile, RemoteZulipServer], token: str, kind: int
) -> RemoteZulipServer:
if kind not in [RemotePushDeviceToken.APNS, RemotePushDeviceToken.GCM]: if kind not in [RemotePushDeviceToken.APNS, RemotePushDeviceToken.GCM]:
raise JsonableError(err_("Invalid token type")) raise JsonableError(err_("Invalid token type"))
server = validate_entity(entity)
validate_token(token, kind) validate_token(token, kind)
return server
@csrf_exempt @csrf_exempt
@ -156,14 +145,14 @@ def register_remote_server(
@has_request_variables @has_request_variables
def register_remote_push_device( def register_remote_push_device(
request: HttpRequest, request: HttpRequest,
entity: Union[UserProfile, RemoteZulipServer], server: RemoteZulipServer,
user_id: Optional[int] = REQ(json_validator=check_int, default=None), user_id: Optional[int] = REQ(json_validator=check_int, default=None),
user_uuid: Optional[str] = REQ(default=None), user_uuid: Optional[str] = REQ(default=None),
token: str = REQ(), token: str = REQ(),
token_kind: int = REQ(json_validator=check_int), token_kind: int = REQ(json_validator=check_int),
ios_app_id: Optional[str] = None, ios_app_id: Optional[str] = None,
) -> HttpResponse: ) -> HttpResponse:
server = validate_bouncer_token_request(entity, token, token_kind) validate_bouncer_token_request(token, token_kind)
if user_id is None and user_uuid is None: if user_id is None and user_uuid is None:
raise JsonableError(_("Missing user_id or user_uuid")) raise JsonableError(_("Missing user_id or user_uuid"))
@ -195,14 +184,14 @@ def register_remote_push_device(
@has_request_variables @has_request_variables
def unregister_remote_push_device( def unregister_remote_push_device(
request: HttpRequest, request: HttpRequest,
entity: Union[UserProfile, RemoteZulipServer], server: RemoteZulipServer,
token: str = REQ(), token: str = REQ(),
token_kind: int = REQ(json_validator=check_int), token_kind: int = REQ(json_validator=check_int),
user_id: Optional[int] = REQ(json_validator=check_int, default=None), user_id: Optional[int] = REQ(json_validator=check_int, default=None),
user_uuid: Optional[str] = REQ(default=None), user_uuid: Optional[str] = REQ(default=None),
ios_app_id: Optional[str] = None, ios_app_id: Optional[str] = None,
) -> HttpResponse: ) -> HttpResponse:
server = validate_bouncer_token_request(entity, token, token_kind) validate_bouncer_token_request(token, token_kind)
user_identity = UserPushIndentityCompat(user_id=user_id, user_uuid=user_uuid) user_identity = UserPushIndentityCompat(user_id=user_id, user_uuid=user_uuid)
deleted = RemotePushDeviceToken.objects.filter( deleted = RemotePushDeviceToken.objects.filter(
@ -217,11 +206,10 @@ def unregister_remote_push_device(
@has_request_variables @has_request_variables
def unregister_all_remote_push_devices( def unregister_all_remote_push_devices(
request: HttpRequest, request: HttpRequest,
entity: Union[UserProfile, RemoteZulipServer], server: RemoteZulipServer,
user_id: Optional[int] = REQ(json_validator=check_int, default=None), user_id: Optional[int] = REQ(json_validator=check_int, default=None),
user_uuid: Optional[str] = REQ(default=None), user_uuid: Optional[str] = REQ(default=None),
) -> HttpResponse: ) -> HttpResponse:
server = validate_entity(entity)
user_identity = UserPushIndentityCompat(user_id=user_id, user_uuid=user_uuid) user_identity = UserPushIndentityCompat(user_id=user_id, user_uuid=user_uuid)
RemotePushDeviceToken.objects.filter(user_identity.filter_q(), server=server).delete() RemotePushDeviceToken.objects.filter(user_identity.filter_q(), server=server).delete()
@ -231,11 +219,9 @@ def unregister_all_remote_push_devices(
@has_request_variables @has_request_variables
def remote_server_notify_push( def remote_server_notify_push(
request: HttpRequest, request: HttpRequest,
entity: Union[UserProfile, RemoteZulipServer], server: RemoteZulipServer,
payload: Dict[str, Any] = REQ(argument_type="body"), payload: Dict[str, Any] = REQ(argument_type="body"),
) -> HttpResponse: ) -> HttpResponse:
server = validate_entity(entity)
user_identity = UserPushIndentityCompat(payload.get("user_id"), payload.get("user_uuid")) user_identity = UserPushIndentityCompat(payload.get("user_id"), payload.get("user_uuid"))
gcm_payload = payload["gcm_payload"] gcm_payload = payload["gcm_payload"]
@ -343,7 +329,7 @@ def batch_create_table_data(
@has_request_variables @has_request_variables
def remote_server_post_analytics( def remote_server_post_analytics(
request: HttpRequest, request: HttpRequest,
entity: Union[UserProfile, RemoteZulipServer], server: RemoteZulipServer,
realm_counts: List[Dict[str, Any]] = REQ( realm_counts: List[Dict[str, Any]] = REQ(
json_validator=check_list( json_validator=check_list(
check_dict_only( check_dict_only(
@ -387,8 +373,6 @@ def remote_server_post_analytics(
default=None, default=None,
), ),
) -> HttpResponse: ) -> HttpResponse:
server = validate_entity(entity)
validate_incoming_table_data(server, RemoteRealmCount, realm_counts, True) validate_incoming_table_data(server, RemoteRealmCount, realm_counts, True)
validate_incoming_table_data(server, RemoteInstallationCount, installation_counts, True) validate_incoming_table_data(server, RemoteInstallationCount, installation_counts, True)
if realmauditlog_rows is not None: if realmauditlog_rows is not None:
@ -449,11 +433,7 @@ def get_last_id_from_server(server: RemoteZulipServer, model: Any) -> int:
@has_request_variables @has_request_variables
def remote_server_check_analytics( def remote_server_check_analytics(request: HttpRequest, server: RemoteZulipServer) -> HttpResponse:
request: HttpRequest, entity: Union[UserProfile, RemoteZulipServer]
) -> HttpResponse:
server = validate_entity(entity)
result = { result = {
"last_realm_count_id": get_last_id_from_server(server, RemoteRealmCount), "last_realm_count_id": get_last_id_from_server(server, RemoteRealmCount),
"last_installation_count_id": get_last_id_from_server(server, RemoteInstallationCount), "last_installation_count_id": get_last_id_from_server(server, RemoteInstallationCount),