refactor: Integrate POSTRequestMock into HostRequestMock.

Minimized code duplication by integrating POSTRequestMock into
HostRequestMock and then updating the required files with
HostRequestMock.

Fixes part of #1211.
This commit is contained in:
Rex Ferrer 2021-02-07 15:34:01 -05:00 committed by Tim Abbott
parent 4bbfac8aa9
commit d4c0578560
4 changed files with 36 additions and 43 deletions

View File

@ -315,41 +315,34 @@ class DummyHandler(AsyncDjangoHandler):
allocate_handler_id(self) allocate_handler_id(self)
class POSTRequestMock:
method = "POST"
def __init__(self, post_data: Dict[str, Any], user_profile: Optional[UserProfile]) -> None:
self.GET: Dict[str, Any] = {}
# Convert any integer parameters passed into strings, even
# though of course the HTTP API would do so. Ideally, we'd
# get rid of this abstraction entirely and just use the HTTP
# API directly, but while it exists, we need this code.
self.POST: Dict[str, str] = {}
for key in post_data:
self.POST[key] = str(post_data[key])
self.user = user_profile
self._tornado_handler = DummyHandler()
self._log_data: Dict[str, Any] = {}
self.META = {"PATH_INFO": "test"}
self.path = ""
class HostRequestMock: class HostRequestMock:
"""A mock request object where get_host() works. Useful for testing """A mock request object where get_host() works. Useful for testing
routes that use Zulip's subdomains feature""" routes that use Zulip's subdomains feature"""
def __init__( def __init__(
self, user_profile: Optional[UserProfile] = None, host: str = settings.EXTERNAL_HOST self,
post_data: Dict[str, Any] = {},
user_profile: Optional[UserProfile] = None,
host: str = settings.EXTERNAL_HOST,
) -> None: ) -> None:
self.host = host self.host = host
self.GET: Dict[str, Any] = {} self.GET: Dict[str, Any] = {}
self.method = ""
# Convert any integer parameters passed into strings, even
# though of course the HTTP API would do so. Ideally, we'd
# get rid of this abstraction entirely and just use the HTTP
# API directly, but while it exists, we need this code
self.POST: Dict[str, Any] = {} self.POST: Dict[str, Any] = {}
for key in post_data:
self.POST[key] = str(post_data[key])
self.method = "POST"
self._tornado_handler = DummyHandler()
self._log_data: Dict[str, Any] = {}
self.META = {"PATH_INFO": "test"} self.META = {"PATH_INFO": "test"}
self.path = "" self.path = ""
self.user = user_profile self.user = user_profile
self.method = ""
self.body = "" self.body = ""
self.content_type = "" self.content_type = ""
@ -403,8 +396,8 @@ def instrument_url(f: UrlFuncT) -> UrlFuncT:
else: else:
extra_info = "" extra_info = ""
if isinstance(info, POSTRequestMock): if isinstance(info, HostRequestMock):
info = "<POSTRequestMock>" info = "<HostRequestMock>"
elif isinstance(info, bytes): elif isinstance(info, bytes):
info = "<bytes>" info = "<bytes>"
elif isinstance(info, dict): elif isinstance(info, dict):

View File

@ -7,7 +7,7 @@ from django.http import HttpRequest, HttpResponse
from zerver.lib.actions import do_change_subscription_property, do_mute_topic from zerver.lib.actions import do_change_subscription_property, do_mute_topic
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import POSTRequestMock, mock_queue_publish from zerver.lib.test_helpers import HostRequestMock, mock_queue_publish
from zerver.models import Recipient, Stream, Subscription, UserProfile, get_stream from zerver.models import Recipient, Stream, Subscription, UserProfile, get_stream
from zerver.tornado.event_queue import ( from zerver.tornado.event_queue import (
ClientDescriptor, ClientDescriptor,
@ -288,7 +288,7 @@ class MissedMessageNotificationsTest(ZulipTestCase):
user_profile: UserProfile, user_profile: UserProfile,
post_data: Dict[str, Any], post_data: Dict[str, Any],
) -> HttpResponse: ) -> HttpResponse:
request = POSTRequestMock(post_data, user_profile) request = HostRequestMock(post_data, user_profile)
return view_func(request, user_profile) return view_func(request, user_profile)
def test_stream_watchers(self) -> None: def test_stream_watchers(self) -> None:

View File

@ -9,7 +9,7 @@ from django.http import HttpRequest, HttpResponse
from zerver.lib.actions import check_send_message, do_change_user_role, do_set_realm_property from zerver.lib.actions import check_send_message, do_change_user_role, do_set_realm_property
from zerver.lib.events import fetch_initial_state_data, get_raw_user_data from zerver.lib.events import fetch_initial_state_data, get_raw_user_data
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import POSTRequestMock, queries_captured, stub_event_queue_user_events from zerver.lib.test_helpers import HostRequestMock, queries_captured, stub_event_queue_user_events
from zerver.lib.users import get_api_key from zerver.lib.users import get_api_key
from zerver.models import ( from zerver.models import (
Realm, Realm,
@ -145,13 +145,13 @@ class EventsEndpointTest(ZulipTestCase):
), ),
).decode(), ).decode(),
) )
req = POSTRequestMock(post_data, user_profile=None) req = HostRequestMock(post_data, user_profile=None)
req.META["REMOTE_ADDR"] = "127.0.0.1" req.META["REMOTE_ADDR"] = "127.0.0.1"
result = self.client_post_request("/notify_tornado", req) result = self.client_post_request("/notify_tornado", req)
self.assert_json_error(result, "Access denied", status_code=403) self.assert_json_error(result, "Access denied", status_code=403)
post_data["secret"] = settings.SHARED_SECRET post_data["secret"] = settings.SHARED_SECRET
req = POSTRequestMock(post_data, user_profile=None) req = HostRequestMock(post_data, user_profile=None)
req.META["REMOTE_ADDR"] = "127.0.0.1" req.META["REMOTE_ADDR"] = "127.0.0.1"
result = self.client_post_request("/notify_tornado", req) result = self.client_post_request("/notify_tornado", req)
self.assert_json_success(result) self.assert_json_success(result)
@ -164,7 +164,7 @@ class GetEventsTest(ZulipTestCase):
user_profile: UserProfile, user_profile: UserProfile,
post_data: Dict[str, Any], post_data: Dict[str, Any],
) -> HttpResponse: ) -> HttpResponse:
request = POSTRequestMock(post_data, user_profile) request = HostRequestMock(post_data, user_profile)
return view_func(request, user_profile) return view_func(request, user_profile)
def test_get_events(self) -> None: def test_get_events(self) -> None:

View File

@ -33,7 +33,7 @@ from zerver.lib.request import JsonableError
from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection
from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_streams_queryset from zerver.lib.streams import StreamDict, create_streams_if_needed, get_public_streams_queryset
from zerver.lib.test_classes import ZulipTestCase from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.test_helpers import POSTRequestMock, get_user_messages, queries_captured from zerver.lib.test_helpers import HostRequestMock, get_user_messages, queries_captured
from zerver.lib.topic import MATCH_TOPIC, TOPIC_NAME from zerver.lib.topic import MATCH_TOPIC, TOPIC_NAME
from zerver.lib.topic_mutes import set_topic_mutes from zerver.lib.topic_mutes import set_topic_mutes
from zerver.lib.types import DisplayRecipientT from zerver.lib.types import DisplayRecipientT
@ -2927,7 +2927,7 @@ class GetOldMessagesTest(ZulipTestCase):
self, query_params: Dict[str, object], expected: str self, query_params: Dict[str, object], expected: str
) -> None: ) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as queries: with queries_captured() as queries:
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
@ -2986,7 +2986,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow='[["stream", "England"]]', narrow='[["stream", "England"]]',
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(request, user_profile)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
@ -3023,7 +3023,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(request, user_profile)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
@ -3036,7 +3036,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(request, user_profile)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
@ -3050,7 +3050,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(request, user_profile)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
@ -3064,7 +3064,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(request, user_profile)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
@ -3078,7 +3078,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(request, user_profile)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
@ -3106,7 +3106,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
@ -3151,7 +3151,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
first_visible_message_id = first_unread_message_id + 2 first_visible_message_id = first_unread_message_id + 2
with first_visible_id_as(first_visible_message_id): with first_visible_id_as(first_visible_message_id):
@ -3177,7 +3177,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=10, num_after=10,
narrow="[]", narrow="[]",
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)
@ -3229,7 +3229,7 @@ class GetOldMessagesTest(ZulipTestCase):
num_after=0, num_after=0,
narrow='[["stream", "Scotland"]]', narrow='[["stream", "Scotland"]]',
) )
request = POSTRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(request, user_profile)