diff --git a/zerver/tests.py b/zerver/tests.py index 2da857e143..ce5607dac5 100644 --- a/zerver/tests.py +++ b/zerver/tests.py @@ -1460,6 +1460,39 @@ class SubscriptionRestApiTest(AuthedTestCase): streams = self.get_streams(email) self.assertTrue('my_test_stream_1' not in streams) + def test_bad_add_parameters(self): + email = 'hamlet@zulip.com' + self.login(email) + + def check_for_error(val, expected_message): + request = { + 'add': ujson.dumps(val) + } + result = self.client_patch( + "/api/v1/users/me/subscriptions", + request, + **self.api_auth(email) + ) + self.assert_json_error(result, expected_message) + + check_for_error(['foo'], 'add[0] is not a dict') + check_for_error([{'bogus': 'foo'}], 'name key is missing from add[0]') + check_for_error([{'name': {}}], 'add[0]["name"] is not a string') + + def test_bad_delete_parameters(self): + email = 'hamlet@zulip.com' + self.login(email) + + request = { + 'delete': ujson.dumps([{'name': 'my_test_stream_1'}]) + } + result = self.client_patch( + "/api/v1/users/me/subscriptions", + request, + **self.api_auth(email) + ) + self.assert_json_error(result, "delete[0] is not a string") + class SubscriptionAPITest(AuthedTestCase): def setUp(self): diff --git a/zerver/views/__init__.py b/zerver/views/__init__.py index 9c5638ca1c..8cf506d608 100644 --- a/zerver/views/__init__.py +++ b/zerver/views/__init__.py @@ -52,6 +52,7 @@ from openid.consumer.consumer import SUCCESS as openid_SUCCESS from openid.extensions import ax from zerver.lib import bugdown from zerver.lib.alert_words import user_alert_words +from zerver.lib.validator import check_string, check_list, check_dict from zerver.decorator import require_post, \ authenticated_api_view, authenticated_json_post_view, \ @@ -1561,6 +1562,16 @@ def update_subscriptions_backend(request, user_profile, if not add and not delete: return json_error('Nothing to do. Specify at least one of "add" or "delete".') + # validate 'add' is a list of one-item dicts with key "name" and a string value + error = check_list(check_dict([['name', check_string]]))('add', add) + if error: + raise JsonableError(error) + + # validate 'delete' is a list of strings + error = check_list(check_string)('delete', delete) + if error: + raise JsonableError(error) + json_dict = {} for method, items in ((add_subscriptions_backend, add), (remove_subscriptions_backend, delete)): response = method(request, user_profile, streams_raw=items)