From 5209de026127dd82d0db1a8d9ac56cc995a1e3cb Mon Sep 17 00:00:00 2001 From: Steve Howell Date: Fri, 17 Jul 2020 07:13:10 +0000 Subject: [PATCH] event_schema: Extract check_update_message_flags. --- zerver/lib/event_schema.py | 25 +++++++++++++++++++++++++ zerver/tests/test_events.py | 21 ++++----------------- 2 files changed, 29 insertions(+), 17 deletions(-) diff --git a/zerver/lib/event_schema.py b/zerver/lib/event_schema.py index b01b285167..ac9a98e2de 100644 --- a/zerver/lib/event_schema.py +++ b/zerver/lib/event_schema.py @@ -82,6 +82,14 @@ def check_events_dict( ) +check_add_or_remove = check_union( + [ + # force vertical + equals("add"), + equals("remove"), + ] +) + check_value = check_union( [ # force vertical formatting @@ -588,3 +596,20 @@ check_update_message_embedded = check_events_dict( ("sender", check_string), ] ) + +_check_update_message_flags = check_events_dict( + required_keys=[ + ("type", equals("update_message_flags")), + ("operation", check_add_or_remove), + ("flag", check_string), + ("messages", check_list(check_int)), + ("all", check_bool), + ] +) + + +def check_update_message_flags( + var_name: str, event: Dict[str, Any], operation: str +) -> None: + _check_update_message_flags(var_name, event) + assert event["operation"] == operation diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index dcf43e22dc..aef623890a 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -108,6 +108,7 @@ from zerver.lib.event_schema import ( check_update_global_notifications, check_update_message, check_update_message_embedded, + check_update_message_flags, ) from zerver.lib.events import apply_events, fetch_initial_state_data, post_process_state from zerver.lib.markdown import MentionData @@ -442,14 +443,6 @@ class NormalActionsTest(BaseAction): def test_update_message_flags(self) -> None: # Test message flag update events - schema_checker = check_events_dict([ - ('all', check_bool), - ('type', equals('update_message_flags')), - ('flag', check_string), - ('messages', check_list(check_int)), - ('operation', equals("add")), - ]) - message = self.send_personal_message( self.example_user("cordelia"), self.example_user("hamlet"), @@ -460,19 +453,13 @@ class NormalActionsTest(BaseAction): lambda: do_update_message_flags(user_profile, get_client("website"), 'add', 'starred', [message]), state_change_expected=True, ) - schema_checker('events[0]', events[0]) - schema_checker = check_events_dict([ - ('all', check_bool), - ('type', equals('update_message_flags')), - ('flag', check_string), - ('messages', check_list(check_int)), - ('operation', equals("remove")), - ]) + check_update_message_flags('events[0]', events[0], 'add') + events = self.verify_action( lambda: do_update_message_flags(user_profile, get_client("website"), 'remove', 'starred', [message]), state_change_expected=True, ) - schema_checker('events[0]', events[0]) + check_update_message_flags('events[0]', events[0], 'remove') def test_update_read_flag_removes_unread_msg_ids(self) -> None: