diff --git a/zerver/lib/bot_lib.py b/zerver/lib/bot_lib.py index b25b96f521..413b449666 100644 --- a/zerver/lib/bot_lib.py +++ b/zerver/lib/bot_lib.py @@ -1,7 +1,7 @@ import importlib import json import os -from typing import Any, Dict +from typing import Any, Callable, Dict from django.utils.translation import ugettext as _ @@ -45,13 +45,13 @@ class StateHandler: def __init__(self, user_profile: UserProfile) -> None: self.user_profile = user_profile - self.marshal = lambda obj: json.dumps(obj) - self.demarshal = lambda obj: json.loads(obj) + self.marshal: Callable[[object], str] = lambda obj: json.dumps(obj) + self.demarshal: Callable[[str], object] = lambda obj: json.loads(obj) - def get(self, key: str) -> str: + def get(self, key: str) -> object: return self.demarshal(get_bot_storage(self.user_profile, key)) - def put(self, key: str, value: str) -> None: + def put(self, key: str, value: object) -> None: set_bot_storage(self.user_profile, [(key, self.marshal(value))]) def remove(self, key: str) -> None: diff --git a/zerver/lib/bot_storage.py b/zerver/lib/bot_storage.py index 1b43644f0b..db6d1d7474 100644 --- a/zerver/lib/bot_storage.py +++ b/zerver/lib/bot_storage.py @@ -32,10 +32,8 @@ def set_bot_storage(bot_profile: UserProfile, entries: List[Tuple[str, str]]) -> storage_size_limit = settings.USER_STATE_SIZE_LIMIT storage_size_difference = 0 for key, value in entries: - if not isinstance(key, str): - raise StateError(f"Key type is {type(key)}, but should be str.") - if not isinstance(value, str): - raise StateError(f"Value type is {type(value)}, but should be str.") + assert isinstance(key, str), "Key type should be str." + assert isinstance(value, str), "Value type should be str." storage_size_difference += (len(key) + len(value)) - get_bot_storage_size(bot_profile, key) new_storage_size = get_bot_storage_size(bot_profile) + storage_size_difference if new_storage_size > storage_size_limit: diff --git a/zerver/tests/test_service_bot_system.py b/zerver/tests/test_service_bot_system.py index eba9854448..1a2aec0a23 100644 --- a/zerver/tests/test_service_bot_system.py +++ b/zerver/tests/test_service_bot_system.py @@ -10,6 +10,7 @@ from zerver.lib.bot_config import ConfigError, load_bot_config_template, set_bot from zerver.lib.bot_lib import EmbeddedBotEmptyRecipientsList, EmbeddedBotHandler, StateHandler from zerver.lib.bot_storage import StateError from zerver.lib.test_classes import ZulipTestCase +from zerver.lib.validator import check_string from zerver.models import Recipient, UserProfile, get_realm BOT_TYPE_TO_QUEUE_NAME = { @@ -198,19 +199,9 @@ class TestServiceBotStateHandler(ZulipTestCase): def test_marshaling(self) -> None: storage = StateHandler(self.bot_profile) serializable_obj = {'foo': 'bar', 'baz': [42, 'cux']} - storage.put('some key', serializable_obj) # type: ignore[arg-type] # Ignore for testing. + storage.put('some key', serializable_obj) self.assertEqual(storage.get('some key'), serializable_obj) - def test_invalid_calls(self) -> None: - storage = StateHandler(self.bot_profile) - storage.marshal = lambda obj: obj - storage.demarshal = lambda obj: obj - serializable_obj = {'foo': 'bar', 'baz': [42, 'cux']} - with self.assertRaisesMessage(StateError, "Value type is , but should be str."): - storage.put('some key', serializable_obj) # type: ignore[arg-type] # We intend to test an invalid type. - with self.assertRaisesMessage(StateError, "Key type is , but should be str."): - storage.put(serializable_obj, 'some value') # type: ignore[arg-type] # We intend to test an invalid type. - # Reduce maximal storage size for faster test string construction. @override_settings(USER_STATE_SIZE_LIMIT=100) def test_storage_limit(self) -> None: @@ -218,7 +209,7 @@ class TestServiceBotStateHandler(ZulipTestCase): # Disable marshaling for storing a string whose size is # equivalent to the size of the stored object. - storage.marshal = lambda obj: obj + storage.marshal = lambda obj: check_string("obj", obj) storage.demarshal = lambda obj: obj key = 'capacity-filling entry' @@ -283,10 +274,10 @@ class TestServiceBotStateHandler(ZulipTestCase): self.assertEqual(result.json()['storage'], updated_dict) # Assert errors on invalid requests. - params = { - 'keys': ["This is a list, but should be a serialized string."], # type: ignore[dict-item] # Ignore 'incompatible type "str": "List[str]"; expected "str": "str"' for testing + invalid_params = { + 'keys': ["This is a list, but should be a serialized string."], } - result = self.client_get('/json/bot_storage', params) + result = self.client_get('/json/bot_storage', invalid_params) self.assert_json_error(result, 'Argument "keys" is not valid JSON.') params = {