get_messages_backend: Pass required parameters.

Earlier `num_before` and `num_after` wasn't being passed while
testing.
Moving to typed_endpoint requires all the "required" parameters
to be passed to the function explicitly in tests.
This commit is contained in:
Kenneth Rodrigues 2024-06-09 16:40:17 +05:30 committed by Tim Abbott
parent d38d82edc3
commit a865977bd5
2 changed files with 83 additions and 18 deletions

View File

@ -3655,13 +3655,16 @@ class GetOldMessagesTest(ZulipTestCase):
) )
self.assertEqual(final_dict["content"], "<p>test content</p>") self.assertEqual(final_dict["content"], "<p>test content</p>")
def common_check_get_messages_query( def common_check_get_messages_query(self, query_params: Dict[str, Any], expected: str) -> None:
self, query_params: Dict[str, object], expected: str
) -> None:
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as queries: with queries_captured() as queries:
get_messages_backend(request, user_profile) get_messages_backend(
request,
user_profile,
num_before=query_params["num_before"],
num_after=query_params["num_after"],
)
for query in queries: for query in queries:
sql = str(query.sql) sql = str(query.sql)
@ -3721,7 +3724,12 @@ class GetOldMessagesTest(ZulipTestCase):
) )
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
self.assertEqual(result["anchor"], first_message_id) self.assertEqual(result["anchor"], first_message_id)
self.assertEqual(result["found_newest"], True) self.assertEqual(result["found_newest"], True)
@ -3758,7 +3766,12 @@ class GetOldMessagesTest(ZulipTestCase):
) )
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
self.assertEqual(result["anchor"], first_message_id) self.assertEqual(result["anchor"], first_message_id)
@ -3771,7 +3784,12 @@ class GetOldMessagesTest(ZulipTestCase):
) )
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
self.assertEqual(result["anchor"], 0) self.assertEqual(result["anchor"], 0)
@ -3785,7 +3803,12 @@ class GetOldMessagesTest(ZulipTestCase):
) )
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
self.assertEqual(result["anchor"], LARGER_THAN_MAX_MESSAGE_ID) self.assertEqual(result["anchor"], LARGER_THAN_MAX_MESSAGE_ID)
@ -3799,7 +3822,12 @@ class GetOldMessagesTest(ZulipTestCase):
) )
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
self.assertEqual(result["anchor"], 0) self.assertEqual(result["anchor"], 0)
@ -3813,7 +3841,12 @@ class GetOldMessagesTest(ZulipTestCase):
) )
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
payload = get_messages_backend(request, user_profile) payload = get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
result = orjson.loads(payload.content) result = orjson.loads(payload.content)
self.assertEqual(result["anchor"], LARGER_THAN_MAX_MESSAGE_ID) self.assertEqual(result["anchor"], LARGER_THAN_MAX_MESSAGE_ID)
@ -3842,7 +3875,12 @@ class GetOldMessagesTest(ZulipTestCase):
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
# Verify the query for old messages looks correct. # Verify the query for old messages looks correct.
queries = [q for q in all_queries if "/* get_messages */" in q.sql] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
@ -3889,7 +3927,12 @@ class GetOldMessagesTest(ZulipTestCase):
first_visible_message_id = first_unread_message_id + 2 first_visible_message_id = first_unread_message_id + 2
with first_visible_id_as(first_visible_message_id): with first_visible_id_as(first_visible_message_id):
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
queries = [q for q in all_queries if "/* get_messages */" in q.sql] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1) self.assert_length(queries, 1)
@ -3913,7 +3956,12 @@ class GetOldMessagesTest(ZulipTestCase):
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
queries = [q for q in all_queries if "/* get_messages */" in q.sql] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1) self.assert_length(queries, 1)
@ -3927,7 +3975,12 @@ class GetOldMessagesTest(ZulipTestCase):
first_visible_message_id = 5 first_visible_message_id = 5
with first_visible_id_as(first_visible_message_id): with first_visible_id_as(first_visible_message_id):
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(
request,
user_profile,
num_before=10,
num_after=10,
)
queries = [q for q in all_queries if "/* get_messages */" in q.sql] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
sql = queries[0].sql sql = queries[0].sql
self.assertNotIn("AND message_id <=", sql) self.assertNotIn("AND message_id <=", sql)
@ -3966,7 +4019,12 @@ class GetOldMessagesTest(ZulipTestCase):
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
with queries_captured() as all_queries: with queries_captured() as all_queries:
get_messages_backend(request, user_profile) get_messages_backend(
request,
user_profile,
num_before=0,
num_after=0,
)
# Do some tests on the main query, to verify the muting logic # Do some tests on the main query, to verify the muting logic
# runs on this code path. # runs on this code path.

View File

@ -21,9 +21,16 @@ class MockSession(SessionBase):
self.modified = False self.modified = False
def profile_request(request: HttpRequest) -> HttpResponseBase: def profile_request(request: HttpRequest, num_before: int, num_after: int) -> HttpResponseBase:
def get_response(request: HttpRequest) -> HttpResponseBase: def get_response(request: HttpRequest) -> HttpResponseBase:
return prof.runcall(get_messages_backend, request, request.user, apply_markdown=True) return prof.runcall(
get_messages_backend,
request,
request.user,
num_before=num_before,
num_after=num_after,
apply_markdown=True,
)
prof = cProfile.Profile() prof = cProfile.Profile()
with tempfile.NamedTemporaryFile(prefix="profile.data.", delete=False) as stats_file: with tempfile.NamedTemporaryFile(prefix="profile.data.", delete=False) as stats_file:
@ -58,4 +65,4 @@ class Command(ZulipBaseCommand):
mock_request.session = MockSession() mock_request.session = MockSession()
RequestNotes.get_notes(mock_request).log_data = None RequestNotes.get_notes(mock_request).log_data = None
profile_request(mock_request) profile_request(mock_request, num_before=1200, num_after=200)