diff --git a/zerver/middleware.py b/zerver/middleware.py index ac49338223..bb21467b26 100644 --- a/zerver/middleware.py +++ b/zerver/middleware.py @@ -40,7 +40,7 @@ from zerver.lib.exceptions import ErrorCode, JsonableError, MissingAuthenticatio from zerver.lib.html_to_text import get_content_description from zerver.lib.markdown import get_markdown_requests, get_markdown_time from zerver.lib.rate_limiter import RateLimitResult -from zerver.lib.request import RequestNotes, set_request, unset_request +from zerver.lib.request import REQ, RequestNotes, has_request_variables, set_request, unset_request from zerver.lib.response import json_response, json_response_from_error, json_unauthorized from zerver.lib.subdomains import get_subdomain from zerver.lib.types import ViewFuncT @@ -306,14 +306,21 @@ class RequestContext(MiddlewareMixin): unset_request() -def parse_client(request: HttpRequest) -> Tuple[str, Optional[str]]: +# We take advantage of `has_request_variables` being called multiple times +# when processing a request in order to process any `client` parameter that +# may have been sent in the request content. +@has_request_variables +def parse_client( + request: HttpRequest, + # As `client` is a common element to all API endpoints, we choose + # not to document on every endpoint's individual parameters. + req_client: Optional[str] = REQ("client", default=None, intentionally_undocumented=True), +) -> Tuple[str, Optional[str]]: # If the API request specified a client in the request content, - # that has priority. Otherwise, extract the client from the - # User-Agent. - if "client" in request.GET: # nocoverage - return request.GET["client"], None - if "client" in request.POST: - return request.POST["client"], None + # that has priority. Otherwise, extract the client from the + # USER_AGENT. + if req_client is not None: + return req_client, None if "HTTP_USER_AGENT" in request.META: user_agent: Optional[Dict[str, str]] = parse_user_agent(request.META["HTTP_USER_AGENT"]) else: @@ -353,7 +360,13 @@ class LogRequests(MiddlewareMixin): # Avoid re-initializing request_notes.log_data if it's already there. return - request_notes.client_name, request_notes.client_version = parse_client(request) + try: + request_notes.client_name, request_notes.client_version = parse_client(request) + except JsonableError as e: + logging.exception(e) + request_notes.client_name = "Unparsable" + request_notes.client_version = None + request_notes.log_data = {} record_request_start_data(request_notes.log_data) diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index 9b284ce9e4..02c3b13ca4 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -85,7 +85,7 @@ from zerver.lib.validator import ( to_non_negative_int, to_wild_value, ) -from zerver.middleware import parse_client +from zerver.middleware import LogRequests, parse_client from zerver.models import Realm, UserProfile, get_realm, get_user if settings.ZILENCER_ENABLED: @@ -129,6 +129,28 @@ class DecoratorTestCase(ZulipTestCase): ] = "Mozilla/5.0 (Linux; Android 8.0.0; SM-G930F) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/80.0.3987.132 Mobile Safari/537.36" self.assertEqual(parse_client(req), ("Mozilla", None)) + post_req_with_client = HostRequestMock() + post_req_with_client.POST["client"] = "test_client_1" + post_req_with_client.META["HTTP_USER_AGENT"] = "ZulipMobile/26.22.145 (iOS 13.3.1)" + self.assertEqual(parse_client(post_req_with_client), ("test_client_1", None)) + + get_req_with_client = HostRequestMock() + get_req_with_client.GET["client"] = "test_client_2" + get_req_with_client.META["HTTP_USER_AGENT"] = "ZulipMobile/26.22.145 (iOS 13.3.1)" + self.assertEqual(parse_client(get_req_with_client), ("test_client_2", None)) + + def test_unparsable_user_agent(self) -> None: + request = HttpRequest() + request.POST["param"] = "test" + request.META["HTTP_USER_AGENT"] = "mocked should fail" + with mock.patch( + "zerver.middleware.parse_client", side_effect=JsonableError("message") + ) as m, self.assertLogs(level="ERROR"): + LogRequests.process_request(self, request) + request_notes = RequestNotes.get_notes(request) + self.assertEqual(request_notes.client_name, "Unparsable") + m.assert_called_once() + def test_REQ_aliases(self) -> None: @has_request_variables def double(