diff --git a/zerver/lib/addressee.py b/zerver/lib/addressee.py index cbe4d22175..58761baeee 100644 --- a/zerver/lib/addressee.py +++ b/zerver/lib/addressee.py @@ -11,8 +11,7 @@ from zerver.models import ( get_user_including_cross_realm, ) -def user_profiles_from_unvalidated_emails(emails, realm): - # type: (Iterable[Text], Realm) -> List[UserProfile] +def user_profiles_from_unvalidated_emails(emails: Iterable[Text], realm: Realm) -> List[UserProfile]: user_profiles = [] # type: List[UserProfile] for email in emails: try: @@ -22,8 +21,7 @@ def user_profiles_from_unvalidated_emails(emails, realm): user_profiles.append(user_profile) return user_profiles -def get_user_profiles(emails, realm): - # type: (Iterable[Text], Realm) -> List[UserProfile] +def get_user_profiles(emails: Iterable[Text], realm: Realm) -> List[UserProfile]: try: return user_profiles_from_unvalidated_emails(emails, realm) except ValidationError as e: @@ -42,44 +40,43 @@ class Addressee: # in memory. # # This should be treated as an immutable class. - def __init__(self, msg_type, user_profiles=None, stream_name=None, topic=None): - # type: (str, Optional[Sequence[UserProfile]], Optional[Text], Text) -> None + def __init__(self, msg_type: str, + user_profiles: Optional[Sequence[UserProfile]]=None, + stream_name: Optional[Text]=None, + topic: Text=None) -> None: assert(msg_type in ['stream', 'private']) self._msg_type = msg_type self._user_profiles = user_profiles self._stream_name = stream_name self._topic = topic - def msg_type(self): - # type: () -> str + def msg_type(self) -> str: return self._msg_type - def is_stream(self): - # type: () -> bool + def is_stream(self) -> bool: return self._msg_type == 'stream' - def is_private(self): - # type: () -> bool + def is_private(self) -> bool: return self._msg_type == 'private' - def user_profiles(self): - # type: () -> List[UserProfile] + def user_profiles(self) -> List[UserProfile]: assert(self.is_private()) return self._user_profiles # type: ignore # assertion protects us - def stream_name(self): - # type: () -> Text + def stream_name(self) -> Text: assert(self.is_stream()) return self._stream_name - def topic(self): - # type: () -> Text + def topic(self) -> Text: assert(self.is_stream()) return self._topic @staticmethod - def legacy_build(sender, message_type_name, message_to, topic_name, realm=None): - # type: (UserProfile, Text, Sequence[Text], Text, Optional[Realm]) -> Addressee + def legacy_build(sender: UserProfile, + message_type_name: Text, + message_to: Sequence[Text], + topic_name: Text, + realm: Optional[Realm]=None) -> 'Addressee': # For legacy reason message_to used to be either a list of # emails or a list of streams. We haven't fixed all of our @@ -111,8 +108,7 @@ class Addressee: raise JsonableError(_("Invalid message type")) @staticmethod - def for_stream(stream_name, topic): - # type: (Text, Text) -> Addressee + def for_stream(stream_name: Text, topic: Text) -> 'Addressee': return Addressee( msg_type='stream', stream_name=stream_name, @@ -120,8 +116,7 @@ class Addressee: ) @staticmethod - def for_private(emails, realm): - # type: (Sequence[Text], Realm) -> Addressee + def for_private(emails: Sequence[Text], realm: Realm) -> 'Addressee': user_profiles = get_user_profiles(emails, realm) return Addressee( msg_type='private', @@ -129,8 +124,7 @@ class Addressee: ) @staticmethod - def for_user_profile(user_profile): - # type: (UserProfile) -> Addressee + def for_user_profile(user_profile: UserProfile) -> 'Addressee': user_profiles = [user_profile] return Addressee( msg_type='private', diff --git a/zerver/lib/attachments.py b/zerver/lib/attachments.py index e48b1fd993..5650d8651d 100644 --- a/zerver/lib/attachments.py +++ b/zerver/lib/attachments.py @@ -6,13 +6,12 @@ from zerver.lib.request import JsonableError from zerver.lib.upload import delete_message_image from zerver.models import Attachment, UserProfile -def user_attachments(user_profile): - # type: (UserProfile) -> List[Dict[str, Any]] +def user_attachments(user_profile: UserProfile) -> List[Dict[str, Any]]: attachments = Attachment.objects.filter(owner=user_profile).prefetch_related('messages') return [a.to_dict() for a in attachments] -def access_attachment_by_id(user_profile, attachment_id, needs_owner=False): - # type: (UserProfile, int, bool) -> Attachment +def access_attachment_by_id(user_profile: UserProfile, attachment_id: int, + needs_owner: bool=False) -> Attachment: query = Attachment.objects.filter(id=attachment_id) if needs_owner: query = query.filter(owner=user_profile) @@ -22,8 +21,7 @@ def access_attachment_by_id(user_profile, attachment_id, needs_owner=False): raise JsonableError(_("Invalid attachment")) return attachment -def remove_attachment(user_profile, attachment): - # type: (UserProfile, Attachment) -> None +def remove_attachment(user_profile: UserProfile, attachment: Attachment) -> None: try: delete_message_image(attachment.path_id) except Exception: diff --git a/zerver/lib/bot_lib.py b/zerver/lib/bot_lib.py index ca60177de9..202a957f2e 100644 --- a/zerver/lib/bot_lib.py +++ b/zerver/lib/bot_lib.py @@ -24,8 +24,7 @@ our_dir = os.path.dirname(os.path.abspath(__file__)) from zulip_bots.lib import RateLimit -def get_bot_handler(service_name): - # type: (str) -> Any +def get_bot_handler(service_name: str) -> Any: # Check that this service is present in EMBEDDED_BOTS, add exception handling. is_present_in_registry = any(service_name == embedded_bot_service.name for @@ -40,31 +39,25 @@ def get_bot_handler(service_name): class StateHandler: state_size_limit = 10000000 # type: int # TODO: Store this in the server configuration model. - def __init__(self, user_profile): - # type: (UserProfile) -> None + def __init__(self, user_profile: UserProfile) -> None: self.user_profile = user_profile self.marshal = lambda obj: json.dumps(obj) self.demarshal = lambda obj: json.loads(obj) - def get(self, key): - # type: (Text) -> Text + def get(self, key: Text) -> Text: return self.demarshal(get_bot_state(self.user_profile, key)) - def put(self, key, value): - # type: (Text, Text) -> None + def put(self, key: Text, value: Text) -> None: set_bot_state(self.user_profile, key, self.marshal(value)) - def remove(self, key): - # type: (Text) -> None + def remove(self, key: Text) -> None: remove_bot_state(self.user_profile, key) - def contains(self, key): - # type: (Text) -> bool + def contains(self, key: Text) -> bool: return is_key_in_bot_state(self.user_profile, key) class EmbeddedBotHandler: - def __init__(self, user_profile): - # type: (UserProfile) -> None + def __init__(self, user_profile: UserProfile) -> None: # Only expose a subset of our UserProfile's functionality self.user_profile = user_profile self._rate_limit = RateLimit(20, 5) @@ -72,8 +65,7 @@ class EmbeddedBotHandler: self.email = user_profile.email self.storage = StateHandler(user_profile) - def send_message(self, message): - # type: (Dict[str, Any]) -> None + def send_message(self, message: Dict[str, Any]) -> None: if self._rate_limit.is_legal(): recipients = message['to'] if message['type'] == 'stream' else ','.join(message['to']) internal_send_message(realm=self.user_profile.realm, sender_email=self.user_profile.email, @@ -82,8 +74,7 @@ class EmbeddedBotHandler: else: self._rate_limit.show_error_and_exit() - def send_reply(self, message, response): - # type: (Dict[str, Any], str) -> None + def send_reply(self, message: Dict[str, Any], response: str) -> None: if message['type'] == 'private': self.send_message(dict( type='private', @@ -100,6 +91,5 @@ class EmbeddedBotHandler: sender_email=message['sender_email'], )) - def get_config_info(self): - # type: () -> Dict[Text, Text] + def get_config_info(self) -> Dict[Text, Text]: return get_bot_config(self.user_profile) diff --git a/zerver/lib/bugdown/testing_mocks.py b/zerver/lib/bugdown/testing_mocks.py index 03b59a2a90..aaf618d3ee 100644 --- a/zerver/lib/bugdown/testing_mocks.py +++ b/zerver/lib/bugdown/testing_mocks.py @@ -221,8 +221,7 @@ EMOJI_TWEET = """{ ] }""" -def twitter(tweet_id): - # type: (Text) -> Optional[Dict[Text, Any]] +def twitter(tweet_id: Text) -> Optional[Dict[Text, Any]]: if tweet_id in ["112652479837110273", "287977969287315456", "287977969287315457"]: return ujson.loads(NORMAL_TWEET) elif tweet_id == "287977969287315458": diff --git a/zerver/lib/cache_helpers.py b/zerver/lib/cache_helpers.py index 2ac4374468..9581f2e4c6 100644 --- a/zerver/lib/cache_helpers.py +++ b/zerver/lib/cache_helpers.py @@ -22,8 +22,7 @@ from django.db.models import Q MESSAGE_CACHE_SIZE = 75000 -def message_fetch_objects(): - # type: () -> List[Any] +def message_fetch_objects() -> List[Any]: try: max_id = Message.objects.only('id').order_by("-id")[0].id except IndexError: @@ -31,8 +30,8 @@ def message_fetch_objects(): return Message.objects.select_related().filter(~Q(sender__email='tabbott/extra@mit.edu'), id__gt=max_id - MESSAGE_CACHE_SIZE) -def message_cache_items(items_for_remote_cache, message): - # type: (Dict[Text, Tuple[bytes]], Message) -> None +def message_cache_items(items_for_remote_cache: Dict[Text, Tuple[bytes]], + message: Message) -> None: ''' Note: this code is untested, and the caller has been commented out for a while. @@ -41,32 +40,32 @@ def message_cache_items(items_for_remote_cache, message): value = MessageDict.to_dict_uncached(message) items_for_remote_cache[key] = (value,) -def user_cache_items(items_for_remote_cache, user_profile): - # type: (Dict[Text, Tuple[UserProfile]], UserProfile) -> None +def user_cache_items(items_for_remote_cache: Dict[Text, Tuple[UserProfile]], + user_profile: UserProfile) -> None: items_for_remote_cache[user_profile_by_email_cache_key(user_profile.email)] = (user_profile,) items_for_remote_cache[user_profile_by_id_cache_key(user_profile.id)] = (user_profile,) items_for_remote_cache[user_profile_by_api_key_cache_key(user_profile.api_key)] = (user_profile,) items_for_remote_cache[user_profile_cache_key(user_profile.email, user_profile.realm)] = (user_profile,) -def stream_cache_items(items_for_remote_cache, stream): - # type: (Dict[Text, Tuple[Stream]], Stream) -> None +def stream_cache_items(items_for_remote_cache: Dict[Text, Tuple[Stream]], + stream: Stream) -> None: items_for_remote_cache[get_stream_cache_key(stream.name, stream.realm_id)] = (stream,) -def client_cache_items(items_for_remote_cache, client): - # type: (Dict[Text, Tuple[Client]], Client) -> None +def client_cache_items(items_for_remote_cache: Dict[Text, Tuple[Client]], + client: Client) -> None: items_for_remote_cache[get_client_cache_key(client.name)] = (client,) -def huddle_cache_items(items_for_remote_cache, huddle): - # type: (Dict[Text, Tuple[Huddle]], Huddle) -> None +def huddle_cache_items(items_for_remote_cache: Dict[Text, Tuple[Huddle]], + huddle: Huddle) -> None: items_for_remote_cache[huddle_hash_cache_key(huddle.huddle_hash)] = (huddle,) -def recipient_cache_items(items_for_remote_cache, recipient): - # type: (Dict[Text, Tuple[Recipient]], Recipient) -> None +def recipient_cache_items(items_for_remote_cache: Dict[Text, Tuple[Recipient]], + recipient: Recipient) -> None: items_for_remote_cache[get_recipient_cache_key(recipient.type, recipient.type_id)] = (recipient,) session_engine = import_module(settings.SESSION_ENGINE) -def session_cache_items(items_for_remote_cache, session): - # type: (Dict[Text, Text], Session) -> None +def session_cache_items(items_for_remote_cache: Dict[Text, Text], + session: Session) -> None: store = session_engine.SessionStore(session_key=session.session_key) # type: ignore # import_module items_for_remote_cache[store.cache_key] = store.decode(session.session_data) @@ -89,8 +88,7 @@ cache_fillers = { 'session': (lambda: Session.objects.all(), session_cache_items, 3600*24*7, 10000), } # type: Dict[str, Tuple[Callable[[], List[Any]], Callable[[Dict[Text, Any], Any], None], int, int]] -def fill_remote_cache(cache): - # type: (str) -> None +def fill_remote_cache(cache: str) -> None: remote_cache_time_start = get_remote_cache_time() remote_cache_requests_start = get_remote_cache_requests() items_for_remote_cache = {} # type: Dict[Text, Any] diff --git a/zerver/lib/create_user.py b/zerver/lib/create_user.py index 60bf98a8ac..4e676d9515 100644 --- a/zerver/lib/create_user.py +++ b/zerver/lib/create_user.py @@ -9,8 +9,7 @@ import string from typing import Optional, Text -def random_api_key(): - # type: () -> Text +def random_api_key() -> Text: choices = string.ascii_letters + string.digits altchars = ''.join([choices[ord(os.urandom(1)) % 62] for _ in range(2)]).encode("utf-8") return base64.b64encode(os.urandom(24), altchars=altchars).decode("utf-8") diff --git a/zerver/lib/debug.py b/zerver/lib/debug.py index 8ca4b6e71f..c2c9234e4a 100644 --- a/zerver/lib/debug.py +++ b/zerver/lib/debug.py @@ -12,8 +12,7 @@ from typing import Optional # (that link also points to code for an interactive remote debugger # setup, which we might want if we move Tornado to run in a daemon # rather than via screen). -def interactive_debug(sig, frame): - # type: (int, FrameType) -> None +def interactive_debug(sig: int, frame: FrameType) -> None: """Interrupt running process, and provide a python prompt for interactive debugging.""" d = {'_frame': frame} # Allow access to frame object. @@ -27,7 +26,6 @@ def interactive_debug(sig, frame): # SIGUSR1 => Just print the stack # SIGUSR2 => Print stack + open interactive debugging shell -def interactive_debug_listen(): - # type: () -> None +def interactive_debug_listen() -> None: signal.signal(signal.SIGUSR1, lambda sig, stack: traceback.print_stack(stack)) signal.signal(signal.SIGUSR2, interactive_debug) diff --git a/zerver/lib/digest.py b/zerver/lib/digest.py index 467f1547e2..99b7fc74b3 100644 --- a/zerver/lib/digest.py +++ b/zerver/lib/digest.py @@ -31,8 +31,7 @@ DIGEST_CUTOFF = 5 # 4. Interesting stream traffic, as determined by the longest and most # diversely comment upon topics. -def inactive_since(user_profile, cutoff): - # type: (UserProfile, datetime.datetime) -> bool +def inactive_since(user_profile: UserProfile, cutoff: datetime.datetime) -> bool: # Hasn't used the app in the last DIGEST_CUTOFF (5) days. most_recent_visit = [row.last_visit for row in UserActivity.objects.filter( @@ -45,8 +44,7 @@ def inactive_since(user_profile, cutoff): last_visit = max(most_recent_visit) return last_visit < cutoff -def should_process_digest(realm_str): - # type: (str) -> bool +def should_process_digest(realm_str: str) -> bool: if realm_str in settings.SYSTEM_ONLY_REALMS: # Don't try to send emails to system-only realms return False @@ -54,15 +52,13 @@ def should_process_digest(realm_str): # Changes to this should also be reflected in # zerver/worker/queue_processors.py:DigestWorker.consume() -def queue_digest_recipient(user_profile, cutoff): - # type: (UserProfile, datetime.datetime) -> None +def queue_digest_recipient(user_profile: UserProfile, cutoff: datetime.datetime) -> None: # Convert cutoff to epoch seconds for transit. event = {"user_profile_id": user_profile.id, "cutoff": cutoff.strftime('%s')} queue_json_publish("digest_emails", event, lambda event: None, call_consume_in_tests=True) -def enqueue_emails(cutoff): - # type: (datetime.datetime) -> None +def enqueue_emails(cutoff: datetime.datetime) -> None: # To be really conservative while we don't have user timezones or # special-casing for companies with non-standard workweeks, only # try to send mail on Tuesdays. @@ -82,8 +78,7 @@ def enqueue_emails(cutoff): logger.info("%s is inactive, queuing for potential digest" % ( user_profile.email,)) -def gather_hot_conversations(user_profile, stream_messages): - # type: (UserProfile, QuerySet) -> List[Dict[str, Any]] +def gather_hot_conversations(user_profile: UserProfile, stream_messages: QuerySet) -> List[Dict[str, Any]]: # Gather stream conversations of 2 types: # 1. long conversations # 2. conversations where many different people participated @@ -146,8 +141,7 @@ def gather_hot_conversations(user_profile, stream_messages): hot_conversation_render_payloads.append(teaser_data) return hot_conversation_render_payloads -def gather_new_users(user_profile, threshold): - # type: (UserProfile, datetime.datetime) -> Tuple[int, List[Text]] +def gather_new_users(user_profile: UserProfile, threshold: datetime.datetime) -> Tuple[int, List[Text]]: # Gather information on users in the realm who have recently # joined. if user_profile.realm.is_zephyr_mirror_realm: @@ -160,8 +154,8 @@ def gather_new_users(user_profile, threshold): return len(user_names), user_names -def gather_new_streams(user_profile, threshold): - # type: (UserProfile, datetime.datetime) -> Tuple[int, Dict[str, List[Text]]] +def gather_new_streams(user_profile: UserProfile, + threshold: datetime.datetime) -> Tuple[int, Dict[str, List[Text]]]: if user_profile.realm.is_zephyr_mirror_realm: new_streams = [] # type: List[Stream] else: @@ -181,8 +175,7 @@ def gather_new_streams(user_profile, threshold): return len(new_streams), {"html": streams_html, "plain": streams_plain} -def enough_traffic(unread_pms, hot_conversations, new_streams, new_users): - # type: (Text, Text, int, int) -> bool +def enough_traffic(unread_pms: Text, hot_conversations: Text, new_streams: int, new_users: int) -> bool: if unread_pms or hot_conversations: # If you have any unread traffic, good enough. return True @@ -192,8 +185,7 @@ def enough_traffic(unread_pms, hot_conversations, new_streams, new_users): return True return False -def handle_digest_email(user_profile_id, cutoff): - # type: (int, float) -> None +def handle_digest_email(user_profile_id: int, cutoff: float) -> None: user_profile = get_user_profile_by_id(user_profile_id) # We are disabling digest emails for soft deactivated users for the time. diff --git a/zerver/lib/domains.py b/zerver/lib/domains.py index ae7c7641b2..0c6a0a6531 100644 --- a/zerver/lib/domains.py +++ b/zerver/lib/domains.py @@ -4,8 +4,7 @@ from django.utils.translation import ugettext as _ import re from typing import Text -def validate_domain(domain): - # type: (Text) -> None +def validate_domain(domain: Text) -> None: if domain is None or len(domain) == 0: raise ValidationError(_("Domain can't be empty.")) if '.' not in domain: diff --git a/zerver/lib/emoji.py b/zerver/lib/emoji.py index fc129f16d4..4101e27ed3 100644 --- a/zerver/lib/emoji.py +++ b/zerver/lib/emoji.py @@ -20,8 +20,7 @@ with open(NAME_TO_CODEPOINT_PATH) as fp: with open(CODEPOINT_TO_NAME_PATH) as fp: codepoint_to_name = ujson.load(fp) -def emoji_name_to_emoji_code(realm, emoji_name): - # type: (Realm, Text) -> Tuple[Text, Text] +def emoji_name_to_emoji_code(realm: Realm, emoji_name: Text) -> Tuple[Text, Text]: realm_emojis = realm.get_emoji() if emoji_name in realm_emojis and not realm_emojis[emoji_name]['deactivated']: return emoji_name, Reaction.REALM_EMOJI @@ -31,8 +30,7 @@ def emoji_name_to_emoji_code(realm, emoji_name): return name_to_codepoint[emoji_name], Reaction.UNICODE_EMOJI raise JsonableError(_("Emoji '%s' does not exist" % (emoji_name,))) -def check_valid_emoji(realm, emoji_name): - # type: (Realm, Text) -> None +def check_valid_emoji(realm: Realm, emoji_name: Text) -> None: emoji_name_to_emoji_code(realm, emoji_name) def check_emoji_request(realm: Realm, emoji_name: str, emoji_code: str, @@ -61,8 +59,7 @@ def check_emoji_request(realm: Realm, emoji_name: str, emoji_code: str, # The above are the only valid emoji types raise JsonableError(_("Invalid emoji type.")) -def check_emoji_admin(user_profile, emoji_name=None): - # type: (UserProfile, Optional[Text]) -> None +def check_emoji_admin(user_profile: UserProfile, emoji_name: Optional[Text]=None) -> None: """Raises an exception if the user cannot administer the target realm emoji name in their organization.""" @@ -84,18 +81,15 @@ def check_emoji_admin(user_profile, emoji_name=None): if not user_profile.is_realm_admin and not current_user_is_author: raise JsonableError(_("Must be a realm administrator or emoji author")) -def check_valid_emoji_name(emoji_name): - # type: (Text) -> None +def check_valid_emoji_name(emoji_name: Text) -> None: if re.match('^[0-9a-z.\-_]+(? Text +def get_emoji_url(emoji_file_name: Text, realm_id: int) -> Text: return upload_backend.get_emoji_url(emoji_file_name, realm_id) -def get_emoji_file_name(emoji_file_name, emoji_name): - # type: (Text, Text) -> Text +def get_emoji_file_name(emoji_file_name: Text, emoji_name: Text) -> Text: _, image_ext = os.path.splitext(emoji_file_name) return ''.join((emoji_name, image_ext)) diff --git a/zerver/lib/error_notify.py b/zerver/lib/error_notify.py index 188c860894..960d2adf6c 100644 --- a/zerver/lib/error_notify.py +++ b/zerver/lib/error_notify.py @@ -16,15 +16,13 @@ from zerver.lib.actions import internal_send_message from zerver.lib.response import json_success, json_error from version import ZULIP_VERSION -def format_subject(subject): - # type: (str) -> str +def format_subject(subject: str) -> str: """ Escape CR and LF characters. """ return subject.replace('\n', '\\n').replace('\r', '\\r') -def user_info_str(report): - # type: (Dict[str, Any]) -> str +def user_info_str(report: Dict[str, Any]) -> str: if report['user_full_name'] and report['user_email']: user_info = "%(user_full_name)s (%(user_email)s)" % (report) else: @@ -59,15 +57,13 @@ def deployment_repr() -> str: return deployment -def notify_browser_error(report): - # type: (Dict[str, Any]) -> None +def notify_browser_error(report: Dict[str, Any]) -> None: report = defaultdict(lambda: None, report) if settings.ERROR_BOT: zulip_browser_error(report) email_browser_error(report) -def email_browser_error(report): - # type: (Dict[str, Any]) -> None +def email_browser_error(report: Dict[str, Any]) -> None: subject = "Browser error for %s" % (user_info_str(report)) body = ("User: %(user_full_name)s <%(user_email)s> on %(deployment)s\n\n" @@ -89,8 +85,7 @@ def email_browser_error(report): mail_admins(subject, body) -def zulip_browser_error(report): - # type: (Dict[str, Any]) -> None +def zulip_browser_error(report: Dict[str, Any]) -> None: subject = "JS error: %s" % (report['user_email'],) user_info = user_info_str(report) @@ -103,15 +98,13 @@ def zulip_browser_error(report): internal_send_message(realm, settings.ERROR_BOT, "stream", "errors", format_subject(subject), body) -def notify_server_error(report): - # type: (Dict[str, Any]) -> None +def notify_server_error(report: Dict[str, Any]) -> None: report = defaultdict(lambda: None, report) email_server_error(report) if settings.ERROR_BOT: zulip_server_error(report) -def zulip_server_error(report): - # type: (Dict[str, Any]) -> None +def zulip_server_error(report: Dict[str, Any]) -> None: subject = '%(node)s: %(message)s' % (report) stack_trace = report['stack_trace'] or "No stack trace available" @@ -133,8 +126,7 @@ def zulip_server_error(report): "Error generated by %s\n\n~~~~ pytb\n%s\n\n~~~~\n%s\n%s" % (user_info, stack_trace, deployment, request_repr)) -def email_server_error(report): - # type: (Dict[str, Any]) -> None +def email_server_error(report: Dict[str, Any]) -> None: subject = '%(node)s: %(message)s' % (report) user_info = user_info_str(report) @@ -153,8 +145,7 @@ def email_server_error(report): mail_admins(format_subject(subject), message, fail_silently=True) -def do_report_error(deployment_name, type, report): - # type: (Text, Text, Dict[str, Any]) -> HttpResponse +def do_report_error(deployment_name: Text, type: Text, report: Dict[str, Any]) -> HttpResponse: report['deployment'] = deployment_name if type == 'browser': notify_browser_error(report) diff --git a/zerver/lib/events.py b/zerver/lib/events.py index ead78b37c6..4eeedf20b0 100644 --- a/zerver/lib/events.py +++ b/zerver/lib/events.py @@ -47,12 +47,10 @@ from zproject.backends import email_auth_enabled, password_auth_enabled from version import ZULIP_VERSION -def get_raw_user_data(realm_id, client_gravatar): - # type: (int, bool) -> Dict[int, Dict[str, Text]] +def get_raw_user_data(realm_id: int, client_gravatar: bool) -> Dict[int, Dict[str, Text]]: user_dicts = get_realm_user_dicts(realm_id) - def user_data(row): - # type: (Dict[str, Any]) -> Dict[str, Any] + def user_data(row: Dict[str, Any]) -> Dict[str, Any]: avatar_url = get_avatar_field( user_id=row['id'], realm_id= realm_id, @@ -81,8 +79,7 @@ def get_raw_user_data(realm_id, client_gravatar): for row in user_dicts } -def always_want(msg_type): - # type: (str) -> bool +def always_want(msg_type: str) -> bool: ''' This function is used as a helper in fetch_initial_state_data, when the user passes @@ -262,8 +259,8 @@ def fetch_initial_state_data(user_profile, event_types, queue_id, client_gravata return state -def remove_message_id_from_unread_mgs(state, message_id): - # type: (Dict[str, Dict[str, Any]], int) -> None +def remove_message_id_from_unread_mgs(state: Dict[str, Dict[str, Any]], + message_id: int) -> None: raw_unread = state['raw_unread_msgs'] for key in ['pm_dict', 'stream_dict', 'huddle_dict']: @@ -288,8 +285,11 @@ def apply_events(state, events, user_profile, client_gravatar, include_subscribe continue apply_event(state, event, user_profile, client_gravatar, include_subscribers) -def apply_event(state, event, user_profile, client_gravatar, include_subscribers): - # type: (Dict[str, Any], Dict[str, Any], UserProfile, bool, bool) -> None +def apply_event(state: Dict[str, Any], + event: Dict[str, Any], + user_profile: UserProfile, + client_gravatar: bool, + include_subscribers: bool) -> None: if event['type'] == "message": state['max_message_id'] = max(state['max_message_id'], event['message']['id']) if 'raw_unread_msgs' in state: @@ -440,8 +440,7 @@ def apply_event(state, event, user_profile, client_gravatar, include_subscribers event['subscriptions'][i] = copy.deepcopy(event['subscriptions'][i]) del event['subscriptions'][i]['subscribers'] - def name(sub): - # type: (Dict[str, Any]) -> Text + def name(sub: Dict[str, Any]) -> Text: return sub['name'].lower() if event['op'] == "add": diff --git a/zerver/lib/exceptions.py b/zerver/lib/exceptions.py index aadcd10e72..8d93cb1229 100644 --- a/zerver/lib/exceptions.py +++ b/zerver/lib/exceptions.py @@ -6,24 +6,20 @@ from django.core.exceptions import PermissionDenied class AbstractEnum(Enum): '''An enumeration whose members are used strictly for their names.''' - def __new__(cls): - # type: (Type[AbstractEnum]) -> AbstractEnum + def __new__(cls: Type['AbstractEnum']) -> 'AbstractEnum': obj = object.__new__(cls) obj._value_ = len(cls.__members__) + 1 return obj # Override all the `Enum` methods that use `_value_`. - def __repr__(self): - # type: () -> str + def __repr__(self) -> str: return str(self) - def value(self): - # type: () -> None + def value(self) -> None: assert False - def __reduce_ex__(self, proto): - # type: (int) -> None + def __reduce_ex__(self, proto: int) -> None: assert False class ErrorCode(AbstractEnum): @@ -69,13 +65,11 @@ class JsonableError(Exception): code = ErrorCode.NO_SUCH_WIDGET data_fields = ['widget_name'] - def __init__(self, widget_name): - # type: (str) -> None + def __init__(self, widget_name: str) -> None: self.widget_name = widget_name # type: str @staticmethod - def msg_format(): - # type: () -> str + def msg_format() -> str: return _("No such widget: {widget_name}") raise NoSuchWidgetError(widget_name) @@ -96,8 +90,7 @@ class JsonableError(Exception): # like 403 or 404. http_status_code = 400 # type: int - def __init__(self, msg, code=None): - # type: (Text, Optional[ErrorCode]) -> None + def __init__(self, msg: Text, code: Optional[ErrorCode]=None) -> None: if code is not None: self.code = code @@ -105,8 +98,7 @@ class JsonableError(Exception): self._msg = msg # type: Text @staticmethod - def msg_format(): - # type: () -> Text + def msg_format() -> Text: '''Override in subclasses. Gets the items in `data_fields` as format args. This should return (a translation of) a string literal. @@ -124,29 +116,24 @@ class JsonableError(Exception): # @property - def msg(self): - # type: () -> Text + def msg(self) -> Text: format_data = dict(((f, getattr(self, f)) for f in self.data_fields), _msg=getattr(self, '_msg', None)) return self.msg_format().format(**format_data) @property - def data(self): - # type: () -> Dict[str, Any] + def data(self) -> Dict[str, Any]: return dict(((f, getattr(self, f)) for f in self.data_fields), code=self.code.name) - def to_json(self): - # type: () -> Dict[str, Any] + def to_json(self) -> Dict[str, Any]: d = {'result': 'error', 'msg': self.msg} d.update(self.data) return d - def __str__(self): - # type: () -> str + def __str__(self) -> str: return self.msg class RateLimited(PermissionDenied): - def __init__(self, msg=""): - # type: (str) -> None + def __init__(self, msg: str="") -> None: super().__init__(msg) diff --git a/zerver/lib/export.py b/zerver/lib/export.py index 491e09a474..47223a98f5 100644 --- a/zerver/lib/export.py +++ b/zerver/lib/export.py @@ -125,8 +125,7 @@ DATE_FIELDS = { 'zerver_userprofile': ['date_joined', 'last_login', 'last_reminder'], } # type: Dict[TableName, List[Field]] -def sanity_check_output(data): - # type: (TableData) -> None +def sanity_check_output(data: TableData) -> None: tables = set(ALL_ZERVER_TABLES) tables -= set(NON_EXPORTED_TABLES) tables -= set(IMPLICIT_TABLES) @@ -137,13 +136,11 @@ def sanity_check_output(data): if table not in data: logging.warning('??? NO DATA EXPORTED FOR TABLE %s!!!' % (table,)) -def write_data_to_file(output_file, data): - # type: (Path, Any) -> None +def write_data_to_file(output_file: Path, data: Any) -> None: with open(output_file, "w") as f: f.write(ujson.dumps(data, indent=4)) -def make_raw(query, exclude=None): - # type: (Any, List[Field]) -> List[Record] +def make_raw(query: Any, exclude: List[Field]=None) -> List[Record]: ''' Takes a Django query and returns a JSONable list of dictionaries corresponding to the database rows. @@ -165,8 +162,7 @@ def make_raw(query, exclude=None): return rows -def floatify_datetime_fields(data, table): - # type: (TableData, TableName) -> None +def floatify_datetime_fields(data: TableData, table: TableName) -> None: for item in data[table]: for field in DATE_FIELDS[table]: orig_dt = item[field] @@ -261,8 +257,8 @@ class Config: self.virtual_parent.table)) -def export_from_config(response, config, seed_object=None, context=None): - # type: (TableData, Config, Any, Context) -> None +def export_from_config(response: TableData, config: Config, seed_object: Any=None, + context: Context=None) -> None: table = config.table parent = config.parent model = config.model @@ -372,8 +368,7 @@ def export_from_config(response, config, seed_object=None, context=None): context=context, ) -def get_realm_config(): - # type: () -> Config +def get_realm_config() -> Config: # This is common, public information about the realm that we can share # with all realm users. @@ -536,8 +531,7 @@ def get_realm_config(): return realm_config -def sanity_check_stream_data(response, config, context): - # type: (TableData, Config, Context) -> None +def sanity_check_stream_data(response: TableData, config: Config, context: Context) -> None: if context['exportable_user_ids'] is not None: # If we restrict which user ids are exportable, @@ -559,8 +553,7 @@ def sanity_check_stream_data(response, config, context): Please investigate! ''') -def fetch_user_profile(response, config, context): - # type: (TableData, Config, Context) -> None +def fetch_user_profile(response: TableData, config: Config, context: Context) -> None: realm = context['realm'] exportable_user_ids = context['exportable_user_ids'] @@ -589,8 +582,7 @@ def fetch_user_profile(response, config, context): response['zerver_userprofile'] = normal_rows response['zerver_userprofile_mirrordummy'] = dummy_rows -def fetch_user_profile_cross_realm(response, config, context): - # type: (TableData, Config, Context) -> None +def fetch_user_profile_cross_realm(response: TableData, config: Config, context: Context) -> None: realm = context['realm'] if realm.string_id == "zulip": @@ -602,8 +594,7 @@ def fetch_user_profile_cross_realm(response, config, context): get_system_bot(settings.WELCOME_BOT), ]] -def fetch_attachment_data(response, realm_id, message_ids): - # type: (TableData, int, Set[int]) -> None +def fetch_attachment_data(response: TableData, realm_id: int, message_ids: Set[int]) -> None: filter_args = {'realm_id': realm_id} query = Attachment.objects.filter(**filter_args) response['zerver_attachment'] = make_raw(list(query)) @@ -630,8 +621,7 @@ def fetch_attachment_data(response, realm_id, message_ids): row for row in response['zerver_attachment'] if row['messages']] -def fetch_huddle_objects(response, config, context): - # type: (TableData, Config, Context) -> None +def fetch_huddle_objects(response: TableData, config: Config, context: Context) -> None: realm = context['realm'] assert config.parent is not None @@ -667,8 +657,10 @@ def fetch_huddle_objects(response, config, context): response['_huddle_subscription'] = huddle_subscription_dicts response['zerver_huddle'] = make_raw(Huddle.objects.filter(id__in=huddle_ids)) -def fetch_usermessages(realm, message_ids, user_profile_ids, message_filename): - # type: (Realm, Set[int], Set[int], Path) -> List[Record] +def fetch_usermessages(realm: Realm, + message_ids: Set[int], + user_profile_ids: Set[int], + message_filename: Path) -> List[Record]: # UserMessage export security rule: You can export UserMessages # for the messages you exported for the users in your realm. user_message_query = UserMessage.objects.filter(user_profile__realm=realm, @@ -684,8 +676,7 @@ def fetch_usermessages(realm, message_ids, user_profile_ids, message_filename): logging.info("Fetched UserMessages for %s" % (message_filename,)) return user_message_chunk -def export_usermessages_batch(input_path, output_path): - # type: (Path, Path) -> None +def export_usermessages_batch(input_path: Path, output_path: Path) -> None: """As part of the system for doing parallel exports, this runs on one batch of Message objects and adds the corresponding UserMessage objects. (This is called by the export_usermessage_batch @@ -701,18 +692,18 @@ def export_usermessages_batch(input_path, output_path): write_message_export(output_path, output) os.unlink(input_path) -def write_message_export(message_filename, output): - # type: (Path, MessageOutput) -> None +def write_message_export(message_filename: Path, output: MessageOutput) -> None: write_data_to_file(output_file=message_filename, data=output) logging.info("Dumped to %s" % (message_filename,)) -def export_partial_message_files(realm, response, chunk_size=1000, output_dir=None): - # type: (Realm, TableData, int, Path) -> Set[int] +def export_partial_message_files(realm: Realm, + response: TableData, + chunk_size: int=1000, + output_dir: Path=None) -> Set[int]: if output_dir is None: output_dir = tempfile.mkdtemp(prefix="zulip-export") - def get_ids(records): - # type: (List[Record]) -> Set[int] + def get_ids(records: List[Record]) -> Set[int]: return set(x['id'] for x in records) # Basic security rule: You can export everything either... @@ -824,8 +815,7 @@ def write_message_partial_for_query(realm, message_query, dump_file_id, return dump_file_id -def export_uploads_and_avatars(realm, output_dir): - # type: (Realm, Path) -> None +def export_uploads_and_avatars(realm: Realm, output_dir: Path) -> None: uploads_output_dir = os.path.join(output_dir, 'uploads') avatars_output_dir = os.path.join(output_dir, 'avatars') @@ -851,8 +841,8 @@ def export_uploads_and_avatars(realm, output_dir): settings.S3_AUTH_UPLOADS_BUCKET, output_dir=uploads_output_dir) -def export_files_from_s3(realm, bucket_name, output_dir, processing_avatars=False): - # type: (Realm, str, Path, bool) -> None +def export_files_from_s3(realm: Realm, bucket_name: str, output_dir: Path, + processing_avatars: bool=False) -> None: conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY) bucket = conn.get_bucket(bucket_name, validate=True) records = [] @@ -932,8 +922,7 @@ def export_files_from_s3(realm, bucket_name, output_dir, processing_avatars=Fals with open(os.path.join(output_dir, "records.json"), "w") as records_file: ujson.dump(records, records_file, indent=4) -def export_uploads_from_local(realm, local_dir, output_dir): - # type: (Realm, Path, Path) -> None +def export_uploads_from_local(realm: Realm, local_dir: Path, output_dir: Path) -> None: count = 0 records = [] @@ -960,8 +949,7 @@ def export_uploads_from_local(realm, local_dir, output_dir): with open(os.path.join(output_dir, "records.json"), "w") as records_file: ujson.dump(records, records_file, indent=4) -def export_avatars_from_local(realm, local_dir, output_dir): - # type: (Realm, Path, Path) -> None +def export_avatars_from_local(realm: Realm, local_dir: Path, output_dir: Path) -> None: count = 0 records = [] @@ -1005,8 +993,7 @@ def export_avatars_from_local(realm, local_dir, output_dir): with open(os.path.join(output_dir, "records.json"), "w") as records_file: ujson.dump(records, records_file, indent=4) -def do_write_stats_file_for_realm_export(output_dir): - # type: (Path) -> None +def do_write_stats_file_for_realm_export(output_dir: Path) -> None: stats_file = os.path.join(output_dir, 'stats.txt') realm_file = os.path.join(output_dir, 'realm.json') attachment_file = os.path.join(output_dir, 'attachment.json') @@ -1033,8 +1020,8 @@ def do_write_stats_file_for_realm_export(output_dir): f.write('%5d records\n' % len(data)) f.write('\n') -def do_export_realm(realm, output_dir, threads, exportable_user_ids=None): - # type: (Realm, Path, int, Set[int]) -> None +def do_export_realm(realm: Realm, output_dir: Path, threads: int, + exportable_user_ids: Set[int]=None) -> None: response = {} # type: TableData # We need at least one thread running to export @@ -1084,16 +1071,14 @@ def do_export_realm(realm, output_dir, threads, exportable_user_ids=None): logging.info("Finished exporting %s" % (realm.string_id)) create_soft_link(source=output_dir, in_progress=False) -def export_attachment_table(realm, output_dir, message_ids): - # type: (Realm, Path, Set[int]) -> None +def export_attachment_table(realm: Realm, output_dir: Path, message_ids: Set[int]) -> None: response = {} # type: TableData fetch_attachment_data(response=response, realm_id=realm.id, message_ids=message_ids) output_file = os.path.join(output_dir, "attachment.json") logging.info('Writing attachment table data to %s' % (output_file,)) write_data_to_file(output_file=output_file, data=response) -def create_soft_link(source, in_progress=True): - # type: (Path, bool) -> None +def create_soft_link(source: Path, in_progress: bool=True) -> None: is_done = not in_progress in_progress_link = '/tmp/zulip-export-in-progress' done_link = '/tmp/zulip-export-most-recent' @@ -1109,12 +1094,10 @@ def create_soft_link(source, in_progress=True): logging.info('See %s for output files' % (new_target,)) -def launch_user_message_subprocesses(threads, output_dir): - # type: (int, Path) -> None +def launch_user_message_subprocesses(threads: int, output_dir: Path) -> None: logging.info('Launching %d PARALLEL subprocesses to export UserMessage rows' % (threads,)) - def run_job(shard): - # type: (str) -> int + def run_job(shard: str) -> int: subprocess.call(["./manage.py", 'export_usermessage_batch', '--path', str(output_dir), '--thread', shard]) return 0 @@ -1124,8 +1107,7 @@ def launch_user_message_subprocesses(threads, output_dir): threads=threads): print("Shard %s finished, status %s" % (job, status)) -def do_export_user(user_profile, output_dir): - # type: (UserProfile, Path) -> None +def do_export_user(user_profile: UserProfile, output_dir: Path) -> None: response = {} # type: TableData export_single_user(user_profile, response) @@ -1134,8 +1116,7 @@ def do_export_user(user_profile, output_dir): logging.info("Exporting messages") export_messages_single_user(user_profile, output_dir) -def export_single_user(user_profile, response): - # type: (UserProfile, TableData) -> None +def export_single_user(user_profile: UserProfile, response: TableData) -> None: config = get_single_user_config() export_from_config( @@ -1144,8 +1125,7 @@ def export_single_user(user_profile, response): seed_object=user_profile, ) -def get_single_user_config(): - # type: () -> Config +def get_single_user_config() -> Config: # zerver_userprofile user_profile_config = Config( @@ -1182,8 +1162,7 @@ def get_single_user_config(): return user_profile_config -def export_messages_single_user(user_profile, output_dir, chunk_size=1000): - # type: (UserProfile, Path, int) -> None +def export_messages_single_user(user_profile: UserProfile, output_dir: Path, chunk_size: int=1000) -> None: user_message_query = UserMessage.objects.filter(user_profile=user_profile).order_by("id") min_id = -1 dump_file_id = 1 @@ -1232,8 +1211,7 @@ id_maps = { 'user_profile': {}, } # type: Dict[str, Dict[int, int]] -def update_id_map(table, old_id, new_id): - # type: (TableName, int, int) -> None +def update_id_map(table: TableName, old_id: int, new_id: int) -> None: if table not in id_maps: raise Exception(''' Table %s is not initialized in id_maps, which could @@ -1242,15 +1220,13 @@ def update_id_map(table, old_id, new_id): ''' % (table,)) id_maps[table][old_id] = new_id -def fix_datetime_fields(data, table): - # type: (TableData, TableName) -> None +def fix_datetime_fields(data: TableData, table: TableName) -> None: for item in data[table]: for field_name in DATE_FIELDS[table]: if item[field_name] is not None: item[field_name] = datetime.datetime.fromtimestamp(item[field_name], tz=timezone_utc) -def convert_to_id_fields(data, table, field_name): - # type: (TableData, TableName, Field) -> None +def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -> None: ''' When Django gives us dict objects via model_to_dict, the foreign key fields are `foo`, but we want `foo_id` for the bulk insert. @@ -1262,8 +1238,11 @@ def convert_to_id_fields(data, table, field_name): item[field_name + "_id"] = item[field_name] del item[field_name] -def re_map_foreign_keys(data, table, field_name, related_table, verbose=False): - # type: (TableData, TableName, Field, TableName, bool) -> None +def re_map_foreign_keys(data: TableData, + table: TableName, + field_name: Field, + related_table: TableName, + verbose: bool=False) -> None: ''' We occasionally need to assign new ids to rows during the import/export process, to accommodate things like existing rows @@ -1288,14 +1267,12 @@ def re_map_foreign_keys(data, table, field_name, related_table, verbose=False): item[field_name + "_id"] = new_id del item[field_name] -def fix_bitfield_keys(data, table, field_name): - # type: (TableData, TableName, Field) -> None +def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> None: for item in data[table]: item[field_name] = item[field_name + '_mask'] del item[field_name + '_mask'] -def fix_realm_authentication_bitfield(data, table, field_name): - # type: (TableData, TableName, Field) -> None +def fix_realm_authentication_bitfield(data: TableData, table: TableName, field_name: Field) -> None: """Used to fixup the authentication_methods bitfield to be a string""" for item in data[table]: values_as_bitstring = ''.join(['1' if field[1] else '0' for field in @@ -1303,8 +1280,7 @@ def fix_realm_authentication_bitfield(data, table, field_name): values_as_int = int(values_as_bitstring, 2) item[field_name] = values_as_int -def bulk_import_model(data, model, table, dump_file_id=None): - # type: (TableData, Any, TableName, str) -> None +def bulk_import_model(data: TableData, model: Any, table: TableName, dump_file_id: str=None) -> None: # TODO, deprecate dump_file_id model.objects.bulk_create(model(**item) for item in data[table]) if dump_file_id is None: @@ -1316,8 +1292,7 @@ def bulk_import_model(data, model, table, dump_file_id=None): # correctly import multiple realms into the same server, we need to # check if a Client object already exists, and so we need to support # remap all Client IDs to the values in the new DB. -def bulk_import_client(data, model, table): - # type: (TableData, Any, TableName) -> None +def bulk_import_client(data: TableData, model: Any, table: TableName) -> None: for item in data[table]: try: client = Client.objects.get(name=item['name']) @@ -1325,8 +1300,7 @@ def bulk_import_client(data, model, table): client = Client.objects.create(name=item['name']) update_id_map(table='client', old_id=item['id'], new_id=client.id) -def import_uploads_local(import_dir, processing_avatars=False): - # type: (Path, bool) -> None +def import_uploads_local(import_dir: Path, processing_avatars: bool=False) -> None: records_filename = os.path.join(import_dir, "records.json") with open(records_filename) as records_file: records = ujson.loads(records_file.read()) @@ -1349,8 +1323,7 @@ def import_uploads_local(import_dir, processing_avatars=False): subprocess.check_call(["mkdir", "-p", os.path.dirname(file_path)]) shutil.copy(orig_file_path, file_path) -def import_uploads_s3(bucket_name, import_dir, processing_avatars=False): - # type: (str, Path, bool) -> None +def import_uploads_s3(bucket_name: str, import_dir: Path, processing_avatars: bool=False) -> None: conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY) bucket = conn.get_bucket(bucket_name, validate=True) @@ -1385,8 +1358,7 @@ def import_uploads_s3(bucket_name, import_dir, processing_avatars=False): key.set_contents_from_filename(os.path.join(import_dir, record['path']), headers=headers) -def import_uploads(import_dir, processing_avatars=False): - # type: (Path, bool) -> None +def import_uploads(import_dir: Path, processing_avatars: bool=False) -> None: if processing_avatars: logging.info("Importing avatars") else: @@ -1418,8 +1390,7 @@ def import_uploads(import_dir, processing_avatars=False): # Because the Python object => JSON conversion process is not fully # faithful, we have to use a set of fixers (e.g. on DateTime objects # and Foreign Keys) to do the import correctly. -def do_import_realm(import_dir): - # type: (Path) -> None +def do_import_realm(import_dir: Path) -> None: logging.info("Importing realm dump %s" % (import_dir,)) if not os.path.exists(import_dir): raise Exception("Missing import directory!") @@ -1527,8 +1498,7 @@ def do_import_realm(import_dir): import_attachments(data) -def import_message_data(import_dir): - # type: (Path) -> None +def import_message_data(import_dir: Path) -> None: dump_file_id = 1 while True: message_filename = os.path.join(import_dir, "messages-%06d.json" % (dump_file_id,)) @@ -1555,8 +1525,7 @@ def import_message_data(import_dir): dump_file_id += 1 -def import_attachments(data): - # type: (TableData) -> None +def import_attachments(data: TableData) -> None: # Clean up the data in zerver_attachment that is not # relevant to our many-to-many import. diff --git a/zerver/lib/i18n.py b/zerver/lib/i18n.py index fab1c3f108..f3452d8140 100644 --- a/zerver/lib/i18n.py +++ b/zerver/lib/i18n.py @@ -12,8 +12,7 @@ from typing import Any, List, Dict, Optional, Text import os import ujson -def with_language(string, language): - # type: (Text, Text) -> Text +def with_language(string: Text, language: Text) -> Text: """ This is an expensive function. If you are using it in a loop, it will make your code slow. @@ -25,15 +24,13 @@ def with_language(string, language): return result @lru_cache() -def get_language_list(): - # type: () -> List[Dict[str, Any]] +def get_language_list() -> List[Dict[str, Any]]: path = os.path.join(settings.STATIC_ROOT, 'locale', 'language_name_map.json') with open(path, 'r') as reader: languages = ujson.load(reader) return languages['name_map'] -def get_language_list_for_templates(default_language): - # type: (Text) -> List[Dict[str, Dict[str, str]]] +def get_language_list_for_templates(default_language: Text) -> List[Dict[str, Dict[str, str]]]: language_list = [l for l in get_language_list() if 'percent_translated' not in l or l['percent_translated'] >= 5.] @@ -70,15 +67,13 @@ def get_language_list_for_templates(default_language): return formatted_list -def get_language_name(code): - # type: (str) -> Optional[Text] +def get_language_name(code: str) -> Optional[Text]: for lang in get_language_list(): if code in (lang['code'], lang['locale']): return lang['name'] return None -def get_available_language_codes(): - # type: () -> List[Text] +def get_available_language_codes() -> List[Text]: language_list = get_language_list() codes = [language['code'] for language in language_list] return codes diff --git a/zerver/lib/logging_util.py b/zerver/lib/logging_util.py index 5d03ed3c34..a8d8d7fa4e 100644 --- a/zerver/lib/logging_util.py +++ b/zerver/lib/logging_util.py @@ -17,8 +17,7 @@ from logging import Logger class _RateLimitFilter: last_error = datetime.min.replace(tzinfo=timezone_utc) - def filter(self, record): - # type: (logging.LogRecord) -> bool + def filter(self, record: logging.LogRecord) -> bool: from django.conf import settings from django.core.cache import cache @@ -58,23 +57,19 @@ class EmailLimiter(_RateLimitFilter): pass class ReturnTrue(logging.Filter): - def filter(self, record): - # type: (logging.LogRecord) -> bool + def filter(self, record: logging.LogRecord) -> bool: return True class ReturnEnabled(logging.Filter): - def filter(self, record): - # type: (logging.LogRecord) -> bool + def filter(self, record: logging.LogRecord) -> bool: return settings.LOGGING_NOT_DISABLED class RequireReallyDeployed(logging.Filter): - def filter(self, record): - # type: (logging.LogRecord) -> bool + def filter(self, record: logging.LogRecord) -> bool: from django.conf import settings return settings.PRODUCTION -def skip_200_and_304(record): - # type: (logging.LogRecord) -> bool +def skip_200_and_304(record: logging.LogRecord) -> bool: # Apparently, `status_code` is added by Django and is not an actual # attribute of LogRecord; as a result, mypy throws an error if we # access the `status_code` attribute directly. @@ -91,8 +86,7 @@ IGNORABLE_404_URLS = [ re.compile(r'^/wp-login.php$'), ] -def skip_boring_404s(record): - # type: (logging.LogRecord) -> bool +def skip_boring_404s(record: logging.LogRecord) -> bool: """Prevents Django's 'Not Found' warnings from being logged for common 404 errors that don't reflect a problem in Zulip. The overall result is to keep the Zulip error logs cleaner than they would @@ -116,8 +110,7 @@ def skip_boring_404s(record): return False return True -def skip_site_packages_logs(record): - # type: (logging.LogRecord) -> bool +def skip_site_packages_logs(record: logging.LogRecord) -> bool: # This skips the log records that are generated from libraries # installed in site packages. # Workaround for https://code.djangoproject.com/ticket/26886 @@ -125,8 +118,7 @@ def skip_site_packages_logs(record): return False return True -def find_log_caller_module(record): - # type: (logging.LogRecord) -> Optional[str] +def find_log_caller_module(record: logging.LogRecord) -> Optional[str]: '''Find the module name corresponding to where this record was logged.''' # Repeat a search similar to that in logging.Logger.findCaller. # The logging call should still be on the stack somewhere; search until @@ -144,8 +136,7 @@ logger_nicknames = { 'zulip.requests': 'zr', # Super common. } -def find_log_origin(record): - # type: (logging.LogRecord) -> str +def find_log_origin(record: logging.LogRecord) -> str: logger_name = logger_nicknames.get(record.name, record.name) if settings.LOGGING_SHOW_MODULE: @@ -166,8 +157,7 @@ log_level_abbrevs = { 'CRITICAL': 'CRIT', } -def abbrev_log_levelname(levelname): - # type: (str) -> str +def abbrev_log_levelname(levelname: str) -> str: # It's unlikely someone will set a custom log level with a custom name, # but it's an option, so we shouldn't crash if someone does. return log_level_abbrevs.get(levelname, levelname[:4]) @@ -176,20 +166,17 @@ class ZulipFormatter(logging.Formatter): # Used in the base implementation. Default uses `,`. default_msec_format = '%s.%03d' - def __init__(self): - # type: () -> None + def __init__(self) -> None: super().__init__(fmt=self._compute_fmt()) - def _compute_fmt(self): - # type: () -> str + def _compute_fmt(self) -> str: pieces = ['%(asctime)s', '%(zulip_level_abbrev)-4s'] if settings.LOGGING_SHOW_PID: pieces.append('pid:%(process)d') pieces.extend(['[%(zulip_origin)s]', '%(message)s']) return ' '.join(pieces) - def format(self, record): - # type: (logging.LogRecord) -> str + def format(self, record: logging.LogRecord) -> str: if not getattr(record, 'zulip_decorated', False): # The `setattr` calls put this logic explicitly outside the bounds of the # type system; otherwise mypy would complain LogRecord lacks these attributes. @@ -198,8 +185,10 @@ class ZulipFormatter(logging.Formatter): setattr(record, 'zulip_decorated', True) return super().format(record) -def create_logger(name, log_file, log_level, log_format="%(asctime)s %(levelname)-8s %(message)s"): - # type: (str, str, str, str) -> Logger +def create_logger(name: str, + log_file: str, + log_level: str, + log_format: str="%(asctime)s%(levelname)-8s%(message)s") -> Logger: """Creates a named logger for use in logging content to a certain file. A few notes: diff --git a/zerver/lib/mention.py b/zerver/lib/mention.py index f581ab01c6..89fe06b765 100644 --- a/zerver/lib/mention.py +++ b/zerver/lib/mention.py @@ -10,12 +10,10 @@ user_group_mentions = r'(? bool +def user_mention_matches_wildcard(mention: Text) -> bool: return mention in wildcards -def extract_name(s): - # type: (Text) -> Optional[Text] +def extract_name(s: Text) -> Optional[Text]: if s.startswith("**") and s.endswith("**"): name = s[2:-2] if name in wildcards: @@ -25,18 +23,15 @@ def extract_name(s): # We don't care about @all or @everyone return None -def possible_mentions(content): - # type: (Text) -> Set[Text] +def possible_mentions(content: Text) -> Set[Text]: matches = re.findall(find_mentions, content) names_with_none = (extract_name(match) for match in matches) names = {name for name in names_with_none if name} return names -def extract_user_group(matched_text): - # type: (Text) -> Text +def extract_user_group(matched_text: Text) -> Text: return matched_text[1:-1] -def possible_user_group_mentions(content): - # type: (Text) -> Set[Text] +def possible_user_group_mentions(content: Text) -> Set[Text]: matches = re.findall(user_group_mentions, content) return {extract_user_group(match) for match in matches} diff --git a/zerver/lib/message.py b/zerver/lib/message.py index 1657c9ff9f..3295404943 100644 --- a/zerver/lib/message.py +++ b/zerver/lib/message.py @@ -97,9 +97,8 @@ def messages_for_ids(message_ids: List[int], return message_list - -def sew_messages_and_reactions(messages, reactions): - # type: (List[Dict[str, Any]], List[Dict[str, Any]]) -> List[Dict[str, Any]] +def sew_messages_and_reactions(messages: List[Dict[str, Any]], + reactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]: """Given a iterable of messages and reactions stitch reactions into messages. """ @@ -117,23 +116,19 @@ def sew_messages_and_reactions(messages, reactions): return list(converted_messages.values()) -def extract_message_dict(message_bytes): - # type: (bytes) -> Dict[str, Any] +def extract_message_dict(message_bytes: bytes) -> Dict[str, Any]: return ujson.loads(zlib.decompress(message_bytes).decode("utf-8")) -def stringify_message_dict(message_dict): - # type: (Dict[str, Any]) -> bytes +def stringify_message_dict(message_dict: Dict[str, Any]) -> bytes: return zlib.compress(ujson.dumps(message_dict).encode()) @cache_with_key(to_dict_cache_key, timeout=3600*24) -def message_to_dict_json(message): - # type: (Message) -> bytes +def message_to_dict_json(message: Message) -> bytes: return MessageDict.to_dict_uncached(message) class MessageDict: @staticmethod - def wide_dict(message): - # type: (Message) -> Dict[str, Any] + def wide_dict(message: Message) -> Dict[str, Any]: ''' The next two lines get the cachable field related to our message object, with the side effect of @@ -154,8 +149,7 @@ class MessageDict: return obj @staticmethod - def post_process_dicts(objs, apply_markdown, client_gravatar): - # type: (List[Dict[str, Any]], bool, bool) -> None + def post_process_dicts(objs: List[Dict[str, Any]], apply_markdown: bool, client_gravatar: bool) -> None: MessageDict.bulk_hydrate_sender_info(objs) for obj in objs: @@ -163,10 +157,10 @@ class MessageDict: MessageDict.finalize_payload(obj, apply_markdown, client_gravatar) @staticmethod - def finalize_payload(obj, apply_markdown, client_gravatar): - # type: (Dict[str, Any], bool, bool) -> None + def finalize_payload(obj: Dict[str, Any], + apply_markdown: bool, + client_gravatar: bool) -> None: MessageDict.set_sender_avatar(obj, client_gravatar) - if apply_markdown: obj['content_type'] = 'text/html' obj['content'] = obj['rendered_content'] @@ -184,14 +178,12 @@ class MessageDict: del obj['sender_is_mirror_dummy'] @staticmethod - def to_dict_uncached(message): - # type: (Message) -> bytes + def to_dict_uncached(message: Message) -> bytes: dct = MessageDict.to_dict_uncached_helper(message) return stringify_message_dict(dct) @staticmethod - def to_dict_uncached_helper(message): - # type: (Message) -> Dict[str, Any] + def to_dict_uncached_helper(message: Message) -> Dict[str, Any]: return MessageDict.build_message_dict( message = message, message_id = message.id, @@ -212,8 +204,7 @@ class MessageDict: ) @staticmethod - def get_raw_db_rows(needed_ids): - # type: (List[int]) -> List[Dict[str, Any]] + def get_raw_db_rows(needed_ids: List[int]) -> List[Dict[str, Any]]: # This is a special purpose function optimized for # callers like get_messages_backend(). fields = [ @@ -242,8 +233,7 @@ class MessageDict: return sew_messages_and_reactions(messages, reactions) @staticmethod - def build_dict_from_raw_db_row(row): - # type: (Dict[str, Any]) -> Dict[str, Any] + def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]: ''' row is a row from a .values() call, and it needs to have all the relevant fields populated @@ -352,8 +342,7 @@ class MessageDict: return obj @staticmethod - def bulk_hydrate_sender_info(objs): - # type: (List[Dict[str, Any]]) -> None + def bulk_hydrate_sender_info(objs: List[Dict[str, Any]]) -> None: sender_ids = list({ obj['sender_id'] @@ -393,8 +382,7 @@ class MessageDict: obj['sender_is_mirror_dummy'] = user_row['is_mirror_dummy'] @staticmethod - def hydrate_recipient_info(obj): - # type: (Dict[str, Any]) -> None + def hydrate_recipient_info(obj: Dict[str, Any]) -> None: ''' This method hyrdrates recipient info with things like full names and emails of senders. Eventually @@ -437,8 +425,7 @@ class MessageDict: obj['stream_id'] = recipient_type_id @staticmethod - def set_sender_avatar(obj, client_gravatar): - # type: (Dict[str, Any], bool) -> None + def set_sender_avatar(obj: Dict[str, Any], client_gravatar: bool) -> None: sender_id = obj['sender_id'] sender_realm_id = obj['sender_realm_id'] sender_email = obj['sender_email'] @@ -457,8 +444,7 @@ class MessageDict: class ReactionDict: @staticmethod - def build_dict_from_raw_db_row(row): - # type: (Dict[str, Any]) -> Dict[str, Any] + def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]: return {'emoji_name': row['emoji_name'], 'emoji_code': row['emoji_code'], 'reaction_type': row['reaction_type'], @@ -467,8 +453,7 @@ class ReactionDict: 'full_name': row['user_profile__full_name']}} -def access_message(user_profile, message_id): - # type: (UserProfile, int) -> Tuple[Message, UserMessage] +def access_message(user_profile: UserProfile, message_id: int) -> Tuple[Message, UserMessage]: """You can access a message by ID in our APIs that either: (1) You received or have previously accessed via starring (aka have a UserMessage row for). @@ -506,9 +491,13 @@ def access_message(user_profile, message_id): # stream in your realm, so return the message, user_message pair return (message, user_message) -def render_markdown(message, content, realm=None, realm_alert_words=None, user_ids=None, - mention_data=None, email_gateway=False): - # type: (Message, Text, Optional[Realm], Optional[RealmAlertWords], Optional[Set[int]], Optional[bugdown.MentionData], Optional[bool]) -> Text +def render_markdown(message: Message, + content: Text, + realm: Optional[Realm]=None, + realm_alert_words: Optional[RealmAlertWords]=None, + user_ids: Optional[Set[int]]=None, + mention_data: Optional[bugdown.MentionData]=None, + email_gateway: Optional[bool]=False) -> Text: """Return HTML for given markdown. Bugdown may add properties to the message object such as `mentions_user_ids`, `mentions_user_group_ids`, and `mentions_wildcard`. These are only on this Django object and are not @@ -565,8 +554,7 @@ def render_markdown(message, content, realm=None, realm_alert_words=None, user_i return rendered_content -def huddle_users(recipient_id): - # type: (int) -> str +def huddle_users(recipient_id: int) -> str: display_recipient = get_display_recipient_by_id(recipient_id, Recipient.HUDDLE, None) # type: Union[Text, List[Dict[str, Any]]] @@ -578,8 +566,9 @@ def huddle_users(recipient_id): user_ids = sorted(user_ids) return ','.join(str(uid) for uid in user_ids) -def aggregate_message_dict(input_dict, lookup_fields, collect_senders): - # type: (Dict[int, Dict[str, Any]], List[str], bool) -> List[Dict[str, Any]] +def aggregate_message_dict(input_dict: Dict[int, Dict[str, Any]], + lookup_fields: List[str], + collect_senders: bool) -> List[Dict[str, Any]]: lookup_dict = dict() # type: Dict[Tuple[Any, ...], Dict[str, Any]] ''' @@ -639,8 +628,7 @@ def aggregate_message_dict(input_dict, lookup_fields, collect_senders): return [lookup_dict[k] for k in sorted_keys] -def get_inactive_recipient_ids(user_profile): - # type: (UserProfile) -> List[int] +def get_inactive_recipient_ids(user_profile: UserProfile) -> List[int]: rows = get_stream_subscriptions_for_user(user_profile).filter( active=False, ).values( @@ -651,8 +639,7 @@ def get_inactive_recipient_ids(user_profile): for row in rows] return inactive_recipient_ids -def get_muted_stream_ids(user_profile): - # type: (UserProfile) -> List[int] +def get_muted_stream_ids(user_profile: UserProfile) -> List[int]: rows = get_stream_subscriptions_for_user(user_profile).filter( active=True, in_home_view=False, @@ -664,8 +651,7 @@ def get_muted_stream_ids(user_profile): for row in rows] return muted_stream_ids -def get_raw_unread_data(user_profile): - # type: (UserProfile) -> RawUnreadMessagesResult +def get_raw_unread_data(user_profile: UserProfile) -> RawUnreadMessagesResult: excluded_recipient_ids = get_inactive_recipient_ids(user_profile) @@ -694,8 +680,7 @@ def get_raw_unread_data(user_profile): topic_mute_checker = build_topic_mute_checker(user_profile) - def is_row_muted(stream_id, recipient_id, topic): - # type: (int, int, Text) -> bool + def is_row_muted(stream_id: int, recipient_id: int, topic: Text) -> bool: if stream_id in muted_stream_ids: return True @@ -706,8 +691,7 @@ def get_raw_unread_data(user_profile): huddle_cache = {} # type: Dict[int, str] - def get_huddle_users(recipient_id): - # type: (int) -> str + def get_huddle_users(recipient_id: int) -> str: if recipient_id in huddle_cache: return huddle_cache[recipient_id] @@ -762,8 +746,7 @@ def get_raw_unread_data(user_profile): mentions=mentions, ) -def aggregate_unread_data(raw_data): - # type: (RawUnreadMessagesResult) -> UnreadMessagesResult +def aggregate_unread_data(raw_data: RawUnreadMessagesResult) -> UnreadMessagesResult: pm_dict = raw_data['pm_dict'] stream_dict = raw_data['stream_dict'] @@ -807,8 +790,10 @@ def aggregate_unread_data(raw_data): return result -def apply_unread_message_event(user_profile, state, message, flags): - # type: (UserProfile, Dict[str, Any], Dict[str, Any], List[str]) -> None +def apply_unread_message_event(user_profile: UserProfile, + state: Dict[str, Any], + message: Dict[str, Any], + flags: List[str]) -> None: message_id = message['id'] if message['type'] == 'stream': message_type = 'stream' diff --git a/zerver/lib/migrate.py b/zerver/lib/migrate.py index c61e70500f..6145667770 100644 --- a/zerver/lib/migrate.py +++ b/zerver/lib/migrate.py @@ -3,8 +3,7 @@ from django.db.models.query import QuerySet import re import time -def timed_ddl(db, stmt): - # type: (Any, str) -> None +def timed_ddl(db: Any, stmt: str) -> None: print() print(time.asctime()) print(stmt) @@ -13,14 +12,17 @@ def timed_ddl(db, stmt): delay = time.time() - t print('Took %.2fs' % (delay,)) -def validate(sql_thingy): - # type: (str) -> None +def validate(sql_thingy: str) -> None: # Do basic validation that table/col name is safe. if not re.match('^[a-z][a-z\d_]+$', sql_thingy): raise Exception('Invalid SQL object: %s' % (sql_thingy,)) -def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1): - # type: (Any, str, List[str], List[str], int, float) -> None +def do_batch_update(db: Any, + table: str, + cols: List[str], + vals: List[str], + batch_size: int=10000, + sleep: float=0.1) -> None: validate(table) for col in cols: validate(col) @@ -46,8 +48,7 @@ def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1): min_id = upper time.sleep(sleep) -def add_bool_columns(db, table, cols): - # type: (Any, str, List[str]) -> None +def add_bool_columns(db: Any, table: str, cols: List[str]) -> None: validate(table) for col in cols: validate(col) @@ -72,8 +73,8 @@ def add_bool_columns(db, table, cols): ', '.join(['ALTER %s SET NOT NULL' % (col,) for col in cols])) timed_ddl(db, stmt) -def create_index_if_not_exist(index_name, table_name, column_string, where_clause): - # type: (Text, Text, Text, Text) -> Text +def create_index_if_not_exist(index_name: Text, table_name: Text, column_string: Text, + where_clause: Text) -> Text: # # FUTURE TODO: When we no longer need to support postgres 9.3 for Trusty, # we can use "IF NOT EXISTS", which is part of postgres 9.5 @@ -95,8 +96,11 @@ def create_index_if_not_exist(index_name, table_name, column_string, where_claus ''' % (index_name, index_name, table_name, column_string, where_clause) return stmt -def act_on_message_ranges(db, orm, tasks, batch_size=5000, sleep=0.5): - # type: (Any, Dict[str, Any], List[Tuple[Callable[[QuerySet], QuerySet], Callable[[QuerySet], None]]], int , float) -> None +def act_on_message_ranges(db: Any, + orm: Dict[str, Any], + tasks: List[Tuple[Callable[[QuerySet], QuerySet], Callable[[QuerySet], None]]], + batch_size: int=5000, + sleep: float=0.5) -> None: # tasks should be an array of (filterer, action) tuples # where filterer is a function that returns a filtered QuerySet # and action is a function that acts on a QuerySet diff --git a/zerver/lib/narrow.py b/zerver/lib/narrow.py index c60438c654..be17bdef45 100644 --- a/zerver/lib/narrow.py +++ b/zerver/lib/narrow.py @@ -4,21 +4,18 @@ from django.utils.translation import ugettext as _ from typing import Any, Callable, Iterable, Mapping, Sequence, Text -def check_supported_events_narrow_filter(narrow): - # type: (Iterable[Sequence[Text]]) -> None +def check_supported_events_narrow_filter(narrow: Iterable[Sequence[Text]]) -> None: for element in narrow: operator = element[0] if operator not in ["stream", "topic", "sender", "is"]: raise JsonableError(_("Operator %s not supported.") % (operator,)) -def build_narrow_filter(narrow): - # type: (Iterable[Sequence[Text]]) -> Callable[[Mapping[str, Any]], bool] +def build_narrow_filter(narrow: Iterable[Sequence[Text]]) -> Callable[[Mapping[str, Any]], bool]: """Changes to this function should come with corresponding changes to BuildNarrowFilterTest.""" check_supported_events_narrow_filter(narrow) - def narrow_filter(event): - # type: (Mapping[str, Any]) -> bool + def narrow_filter(event: Mapping[str, Any]) -> bool: message = event["message"] flags = event["flags"] for element in narrow: diff --git a/zerver/lib/outgoing_webhook.py b/zerver/lib/outgoing_webhook.py index 4994db3ffd..55e6370b27 100644 --- a/zerver/lib/outgoing_webhook.py +++ b/zerver/lib/outgoing_webhook.py @@ -21,8 +21,7 @@ from zerver.decorator import JsonableError class OutgoingWebhookServiceInterface: - def __init__(self, base_url, token, user_profile, service_name): - # type: (Text, Text, UserProfile, Text) -> None + def __init__(self, base_url: Text, token: Text, user_profile: UserProfile, service_name: Text) -> None: self.base_url = base_url # type: Text self.token = token # type: Text self.user_profile = user_profile # type: Text @@ -37,20 +36,17 @@ class OutgoingWebhookServiceInterface: # - base_url # - relative_url_path # - request_kwargs - def process_event(self, event): - # type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any] + def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]: raise NotImplementedError() # Given a successful outgoing webhook REST operation, returns the message # to sent back to the user (or None if no message should be sent). - def process_success(self, response, event): - # type: (Response, Dict[Text, Any]) -> Optional[str] + def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]: raise NotImplementedError() class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface): - def process_event(self, event): - # type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any] + def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]: rest_operation = {'method': 'POST', 'relative_url_path': '', 'base_url': self.base_url, @@ -60,8 +56,7 @@ class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface): "token": self.token} return rest_operation, json.dumps(request_data) - def process_success(self, response, event): - # type: (Response, Dict[Text, Any]) -> Optional[str] + def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]: response_json = json.loads(response.text) if "response_not_required" in response_json and response_json['response_not_required']: @@ -73,8 +68,7 @@ class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface): class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface): - def process_event(self, event): - # type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any] + def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]: rest_operation = {'method': 'POST', 'relative_url_path': '', 'base_url': self.base_url, @@ -99,8 +93,7 @@ class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface): return rest_operation, request_data - def process_success(self, response, event): - # type: (Response, Dict[Text, Any]) -> Optional[str] + def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]: response_json = json.loads(response.text) if "text" in response_json: return response_json["text"] @@ -112,15 +105,13 @@ AVAILABLE_OUTGOING_WEBHOOK_INTERFACES = { SLACK_INTERFACE: SlackOutgoingWebhookService, } # type: Dict[Text, Any] -def get_service_interface_class(interface): - # type: (Text) -> Any +def get_service_interface_class(interface: Text) -> Any: if interface is None or interface not in AVAILABLE_OUTGOING_WEBHOOK_INTERFACES: return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[GENERIC_INTERFACE] else: return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[interface] -def get_outgoing_webhook_service_handler(service): - # type: (Service) -> Any +def get_outgoing_webhook_service_handler(service: Service) -> Any: service_interface_class = get_service_interface_class(service.interface_name()) service_interface = service_interface_class(base_url=service.base_url, @@ -129,8 +120,7 @@ def get_outgoing_webhook_service_handler(service): service_name=service.name) return service_interface -def send_response_message(bot_id, message, response_message_content): - # type: (str, Dict[str, Any], Text) -> None +def send_response_message(bot_id: str, message: Dict[str, Any], response_message_content: Text) -> None: recipient_type_name = message['type'] bot_user = get_user_profile_by_id(bot_id) realm = bot_user.realm @@ -146,18 +136,15 @@ def send_response_message(bot_id, message, response_message_content): else: raise JsonableError(_("Invalid message type")) -def succeed_with_message(event, success_message): - # type: (Dict[str, Any], Text) -> None +def succeed_with_message(event: Dict[str, Any], success_message: Text) -> None: success_message = "Success! " + success_message send_response_message(event['user_profile_id'], event['message'], success_message) -def fail_with_message(event, failure_message): - # type: (Dict[str, Any], Text) -> None +def fail_with_message(event: Dict[str, Any], failure_message: Text) -> None: failure_message = "Failure! " + failure_message send_response_message(event['user_profile_id'], event['message'], failure_message) -def get_message_url(event, request_data): - # type: (Dict[str, Any], Dict[str, Any]) -> Text +def get_message_url(event: Dict[str, Any], request_data: Dict[str, Any]) -> Text: bot_user = get_user_profile_by_id(event['user_profile_id']) message = event['message'] if message['type'] == 'stream': @@ -175,8 +162,11 @@ def get_message_url(event, request_data): 'id': str(message['id'])}) return message_url -def notify_bot_owner(event, request_data, status_code=None, response_content=None, exception=None): - # type: (Dict[str, Any], Dict[str, Any], Optional[int], Optional[AnyStr], Optional[Exception]) -> None +def notify_bot_owner(event: Dict[str, Any], + request_data: Dict[str, Any], + status_code: Optional[int]=None, + response_content: Optional[AnyStr]=None, + exception: Optional[Exception]=None) -> None: message_url = get_message_url(event, request_data) bot_id = event['user_profile_id'] bot_owner = get_user_profile_by_id(bot_id).bot_owner @@ -194,10 +184,11 @@ def notify_bot_owner(event, request_data, status_code=None, response_content=Non type(exception).__name__, str(exception)) send_response_message(bot_id, message_info, notification_message) -def request_retry(event, request_data, failure_message, exception=None): - # type: (Dict[str, Any], Dict[str, Any], Text, Optional[Exception]) -> None - def failure_processor(event): - # type: (Dict[str, Any]) -> None +def request_retry(event: Dict[str, Any], + request_data: Dict[str, Any], + failure_message: Text, + exception: Optional[Exception]=None) -> None: + def failure_processor(event: Dict[str, Any]) -> None: """ The name of the argument is 'event' on purpose. This argument will hide the 'event' argument of the request_retry function. Keeping the same name @@ -211,8 +202,11 @@ def request_retry(event, request_data, failure_message, exception=None): retry_event('outgoing_webhooks', event, failure_processor) -def do_rest_call(rest_operation, request_data, event, service_handler, timeout=None): - # type: (Dict[str, Any], Optional[Dict[str, Any]], Dict[str, Any], Any, Any) -> None +def do_rest_call(rest_operation: Dict[str, Any], + request_data: Optional[Dict[str, Any]], + event: Dict[str, Any], + service_handler: Any, + timeout: Any=None) -> None: rest_operation_validator = check_dict([ ('method', check_string), ('relative_url_path', check_string), diff --git a/zerver/lib/rate_limiter.py b/zerver/lib/rate_limiter.py index 073dc7afee..7fa764ae4b 100644 --- a/zerver/lib/rate_limiter.py +++ b/zerver/lib/rate_limiter.py @@ -21,32 +21,26 @@ rules = settings.RATE_LIMITING_RULES # type: List[Tuple[int, int]] KEY_PREFIX = '' class RateLimitedObject: - def get_keys(self): - # type: () -> List[Text] + def get_keys(self) -> List[Text]: key_fragment = self.key_fragment() return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype) for keytype in ['list', 'zset', 'block']] - def key_fragment(self): - # type: () -> Text + def key_fragment(self) -> Text: raise NotImplementedError() - def rules(self): - # type: () -> List[Tuple[int, int]] + def rules(self) -> List[Tuple[int, int]]: raise NotImplementedError() class RateLimitedUser(RateLimitedObject): - def __init__(self, user, domain='all'): - # type: (UserProfile, Text) -> None + def __init__(self, user: UserProfile, domain: Text='all') -> None: self.user = user self.domain = domain - def key_fragment(self): - # type: () -> Text + def key_fragment(self) -> Text: return "{}:{}:{}".format(type(self.user), self.user.id, self.domain) - def rules(self): - # type: () -> List[Tuple[int, int]] + def rules(self) -> List[Tuple[int, int]]: if self.user.rate_limits != "": result = [] # type: List[Tuple[int, int]] for limit in self.user.rate_limits.split(','): @@ -55,36 +49,30 @@ class RateLimitedUser(RateLimitedObject): return result return rules -def bounce_redis_key_prefix_for_testing(test_name): - # type: (Text) -> None +def bounce_redis_key_prefix_for_testing(test_name: Text) -> None: global KEY_PREFIX KEY_PREFIX = test_name + ':' + Text(os.getpid()) + ':' -def max_api_calls(entity): - # type: (RateLimitedObject) -> int +def max_api_calls(entity: RateLimitedObject) -> int: "Returns the API rate limit for the highest limit" return entity.rules()[-1][1] -def max_api_window(entity): - # type: (RateLimitedObject) -> int +def max_api_window(entity: RateLimitedObject) -> int: "Returns the API time window for the highest limit" return entity.rules()[-1][0] -def add_ratelimit_rule(range_seconds, num_requests): - # type: (int , int) -> None +def add_ratelimit_rule(range_seconds: int, num_requests: int) -> None: "Add a rate-limiting rule to the ratelimiter" global rules rules.append((range_seconds, num_requests)) rules.sort(key=lambda x: x[0]) -def remove_ratelimit_rule(range_seconds, num_requests): - # type: (int , int) -> None +def remove_ratelimit_rule(range_seconds: int, num_requests: int) -> None: global rules rules = [x for x in rules if x[0] != range_seconds and x[1] != num_requests] -def block_access(entity, seconds): - # type: (RateLimitedObject, int) -> None +def block_access(entity: RateLimitedObject, seconds: int) -> None: "Manually blocks an entity for the desired number of seconds" _, _, blocking_key = entity.get_keys() with client.pipeline() as pipe: @@ -92,13 +80,11 @@ def block_access(entity, seconds): pipe.expire(blocking_key, seconds) pipe.execute() -def unblock_access(entity): - # type: (RateLimitedObject) -> None +def unblock_access(entity: RateLimitedObject) -> None: _, _, blocking_key = entity.get_keys() client.delete(blocking_key) -def clear_history(entity): - # type: (RateLimitedObject) -> None +def clear_history(entity: RateLimitedObject) -> None: ''' This is only used by test code now, where it's very helpful in allowing us to run tests quickly, by giving a user a clean slate. @@ -106,8 +92,7 @@ def clear_history(entity): for key in entity.get_keys(): client.delete(key) -def _get_api_calls_left(entity, range_seconds, max_calls): - # type: (RateLimitedObject, int, int) -> Tuple[int, float] +def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]: list_key, set_key, _ = entity.get_keys() # Count the number of values in our sorted set # that are between now and the cutoff @@ -134,16 +119,14 @@ def _get_api_calls_left(entity, range_seconds, max_calls): return calls_left, time_reset -def api_calls_left(entity): - # type: (RateLimitedObject) -> Tuple[int, float] +def api_calls_left(entity: RateLimitedObject) -> Tuple[int, float]: """Returns how many API calls in this range this client has, as well as when the rate-limit will be reset to 0""" max_window = max_api_window(entity) max_calls = max_api_calls(entity) return _get_api_calls_left(entity, max_window, max_calls) -def is_ratelimited(entity): - # type: (RateLimitedObject) -> Tuple[bool, float] +def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]: "Returns a tuple of (rate_limited, time_till_free)" list_key, set_key, blocking_key = entity.get_keys() @@ -192,8 +175,7 @@ def is_ratelimited(entity): # No api calls recorded yet return False, 0.0 -def incr_ratelimit(entity): - # type: (RateLimitedObject) -> None +def incr_ratelimit(entity: RateLimitedObject) -> None: """Increases the rate-limit for the specified entity""" list_key, set_key, _ = entity.get_keys() now = time.time() diff --git a/zerver/lib/rest.py b/zerver/lib/rest.py index cb7bfde86f..05f7293770 100644 --- a/zerver/lib/rest.py +++ b/zerver/lib/rest.py @@ -15,8 +15,7 @@ METHODS = ('GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'PATCH') FLAGS = ('override_api_url_scheme') @csrf_exempt -def rest_dispatch(request, **kwargs): - # type: (HttpRequest, **Any) -> HttpResponse +def rest_dispatch(request: HttpRequest, **kwargs: Any) -> HttpResponse: """Dispatch to a REST API endpoint. Unauthenticated endpoints should not use this, as authentication is verified diff --git a/zerver/lib/send_email.py b/zerver/lib/send_email.py index b09adc0870..7034872f1b 100644 --- a/zerver/lib/send_email.py +++ b/zerver/lib/send_email.py @@ -93,8 +93,7 @@ def send_email(template_prefix, to_user_id=None, to_email=None, from_name=None, logger.error("Error sending %s email to %s" % (template, mail.to)) raise EmailNotDeliveredException -def send_email_from_dict(email_dict): - # type: (Mapping[str, Any]) -> None +def send_email_from_dict(email_dict: Mapping[str, Any]) -> None: send_email(**dict(email_dict)) def send_future_email(template_prefix, to_user_id=None, to_email=None, from_name=None, diff --git a/zerver/lib/soft_deactivation.py b/zerver/lib/soft_deactivation.py index 7e766b07b2..0839e376a0 100644 --- a/zerver/lib/soft_deactivation.py +++ b/zerver/lib/soft_deactivation.py @@ -17,8 +17,7 @@ def filter_by_subscription_history( # type: (UserProfile, DefaultDict[int, List[Message]], DefaultDict[int, List[RealmAuditLog]]) -> List[UserMessage] user_messages_to_insert = [] # type: List[UserMessage] - def store_user_message_to_insert(message): - # type: (Message) -> None + def store_user_message_to_insert(message: Message) -> None: message = UserMessage(user_profile=user_profile, message_id=message['id'], flags=0) user_messages_to_insert.append(message) @@ -60,8 +59,7 @@ def filter_by_subscription_history( store_user_message_to_insert(stream_message) return user_messages_to_insert -def add_missing_messages(user_profile): - # type: (UserProfile) -> None +def add_missing_messages(user_profile: UserProfile) -> None: """This function takes a soft-deactivated user, and computes and adds to the database any UserMessage rows that were not created while the user was soft-deactivated. The end result is that from the @@ -156,8 +154,7 @@ def add_missing_messages(user_profile): if len(user_messages_to_insert) > 0: UserMessage.objects.bulk_create(user_messages_to_insert) -def do_soft_deactivate_user(user_profile): - # type: (UserProfile) -> None +def do_soft_deactivate_user(user_profile: UserProfile) -> None: user_profile.last_active_message_id = UserMessage.objects.filter( user_profile=user_profile).order_by( '-message__id')[0].message_id @@ -168,8 +165,7 @@ def do_soft_deactivate_user(user_profile): logger.info('Soft Deactivated user %s (%s)' % (user_profile.id, user_profile.email)) -def do_soft_deactivate_users(users): - # type: (List[UserProfile]) -> List[UserProfile] +def do_soft_deactivate_users(users: List[UserProfile]) -> List[UserProfile]: users_soft_deactivated = [] with transaction.atomic(): realm_logs = [] @@ -187,8 +183,7 @@ def do_soft_deactivate_users(users): RealmAuditLog.objects.bulk_create(realm_logs) return users_soft_deactivated -def maybe_catch_up_soft_deactivated_user(user_profile): - # type: (UserProfile) -> Union[UserProfile, None] +def maybe_catch_up_soft_deactivated_user(user_profile: UserProfile) -> Union[UserProfile, None]: if user_profile.long_term_idle: add_missing_messages(user_profile) user_profile.long_term_idle = False @@ -204,8 +199,7 @@ def maybe_catch_up_soft_deactivated_user(user_profile): return user_profile return None -def get_users_for_soft_deactivation(inactive_for_days, filter_kwargs): - # type: (int, **Any) -> List[UserProfile] +def get_users_for_soft_deactivation(inactive_for_days: int, filter_kwargs: Any) -> List[UserProfile]: users_activity = list(UserActivity.objects.filter( user_profile__is_active=True, user_profile__is_bot=False, @@ -221,8 +215,7 @@ def get_users_for_soft_deactivation(inactive_for_days, filter_kwargs): id__in=user_ids_to_deactivate)) return users_to_deactivate -def do_soft_activate_users(users): - # type: (List[UserProfile]) -> List[UserProfile] +def do_soft_activate_users(users: List[UserProfile]) -> List[UserProfile]: users_soft_activated = [] for user_profile in users: user_activated = maybe_catch_up_soft_deactivated_user(user_profile) diff --git a/zerver/lib/streams.py b/zerver/lib/streams.py index 7d4c72c751..d023046c54 100644 --- a/zerver/lib/streams.py +++ b/zerver/lib/streams.py @@ -10,8 +10,7 @@ from zerver.models import UserProfile, Stream, Subscription, \ Realm, Recipient, bulk_get_recipients, get_stream_recipient, get_stream, \ bulk_get_streams, get_realm_stream, DefaultStreamGroup -def access_stream_for_delete(user_profile, stream_id): - # type: (UserProfile, int) -> Stream +def access_stream_for_delete(user_profile: UserProfile, stream_id: int) -> Stream: # We should only ever use this for realm admins, who are allowed # to delete all streams on their realm, even private streams to @@ -30,8 +29,8 @@ def access_stream_for_delete(user_profile, stream_id): return stream -def access_stream_common(user_profile, stream, error): - # type: (UserProfile, Stream, Text) -> Tuple[Recipient, Subscription] +def access_stream_common(user_profile: UserProfile, stream: Stream, + error: Text) -> Tuple[Recipient, Subscription]: """Common function for backend code where the target use attempts to access the target stream, returning all the data fetched along the way. If that user does not have permission to access that stream, @@ -63,8 +62,7 @@ def access_stream_common(user_profile, stream, error): # an error. raise JsonableError(error) -def access_stream_by_id(user_profile, stream_id): - # type: (UserProfile, int) -> Tuple[Stream, Recipient, Subscription] +def access_stream_by_id(user_profile: UserProfile, stream_id: int) -> Tuple[Stream, Recipient, Subscription]: error = _("Invalid stream id") try: stream = Stream.objects.get(id=stream_id) @@ -74,8 +72,7 @@ def access_stream_by_id(user_profile, stream_id): (recipient, sub) = access_stream_common(user_profile, stream, error) return (stream, recipient, sub) -def check_stream_name_available(realm, name): - # type: (Realm, Text) -> None +def check_stream_name_available(realm: Realm, name: Text) -> None: check_stream_name(name) try: get_stream(name, realm) @@ -83,8 +80,8 @@ def check_stream_name_available(realm, name): except Stream.DoesNotExist: pass -def access_stream_by_name(user_profile, stream_name): - # type: (UserProfile, Text) -> Tuple[Stream, Recipient, Subscription] +def access_stream_by_name(user_profile: UserProfile, + stream_name: Text) -> Tuple[Stream, Recipient, Subscription]: error = _("Invalid stream name '%s'" % (stream_name,)) try: stream = get_realm_stream(stream_name, user_profile.realm_id) @@ -94,8 +91,7 @@ def access_stream_by_name(user_profile, stream_name): (recipient, sub) = access_stream_common(user_profile, stream, error) return (stream, recipient, sub) -def access_stream_for_unmute_topic(user_profile, stream_name, error): - # type: (UserProfile, Text, Text) -> Stream +def access_stream_for_unmute_topic(user_profile: UserProfile, stream_name: Text, error: Text) -> Stream: """ It may seem a little silly to have this helper function for unmuting topics, but it gets around a linter warning, and it helps to be able @@ -115,8 +111,7 @@ def access_stream_for_unmute_topic(user_profile, stream_name, error): raise JsonableError(error) return stream -def is_public_stream_by_name(stream_name, realm): - # type: (Text, Realm) -> bool +def is_public_stream_by_name(stream_name: Text, realm: Realm) -> bool: """Determine whether a stream is public, so that our caller can decide whether we can get historical messages for a narrowing search. @@ -136,8 +131,8 @@ def is_public_stream_by_name(stream_name, realm): return False return stream.is_public() -def filter_stream_authorization(user_profile, streams): - # type: (UserProfile, Iterable[Stream]) -> Tuple[List[Stream], List[Stream]] +def filter_stream_authorization(user_profile: UserProfile, + streams: Iterable[Stream]) -> Tuple[List[Stream], List[Stream]]: streams_subscribed = set() # type: Set[int] recipients_map = bulk_get_recipients(Recipient.STREAM, [stream.id for stream in streams]) subs = Subscription.objects.filter(user_profile=user_profile, @@ -161,8 +156,9 @@ def filter_stream_authorization(user_profile, streams): stream.id not in set(stream.id for stream in unauthorized_streams)] return authorized_streams, unauthorized_streams -def list_to_streams(streams_raw, user_profile, autocreate=False): - # type: (Iterable[Mapping[str, Any]], UserProfile, bool) -> Tuple[List[Stream], List[Stream]] +def list_to_streams(streams_raw: Iterable[Mapping[str, Any]], + user_profile: UserProfile, + autocreate: bool=False) -> Tuple[List[Stream], List[Stream]]: """Converts list of dicts to a list of Streams, validating input in the process For each stream name, we validate it to ensure it meets our diff --git a/zerver/lib/topic_mutes.py b/zerver/lib/topic_mutes.py index 1b05317491..c8878c5d21 100644 --- a/zerver/lib/topic_mutes.py +++ b/zerver/lib/topic_mutes.py @@ -15,8 +15,7 @@ from sqlalchemy.sql import ( Selectable ) -def get_topic_mutes(user_profile): - # type: (UserProfile) -> List[List[Text]] +def get_topic_mutes(user_profile: UserProfile) -> List[List[Text]]: rows = MutedTopic.objects.filter( user_profile=user_profile, ).values( @@ -28,8 +27,7 @@ def get_topic_mutes(user_profile): for row in rows ] -def set_topic_mutes(user_profile, muted_topics): - # type: (UserProfile, List[List[Text]]) -> None +def set_topic_mutes(user_profile: UserProfile, muted_topics: List[List[Text]]) -> None: ''' This is only used in tests. @@ -50,8 +48,7 @@ def set_topic_mutes(user_profile, muted_topics): topic_name=topic_name, ) -def add_topic_mute(user_profile, stream_id, recipient_id, topic_name): - # type: (UserProfile, int, int, str) -> None +def add_topic_mute(user_profile: UserProfile, stream_id: int, recipient_id: int, topic_name: str) -> None: MutedTopic.objects.create( user_profile=user_profile, stream_id=stream_id, @@ -59,8 +56,7 @@ def add_topic_mute(user_profile, stream_id, recipient_id, topic_name): topic_name=topic_name, ) -def remove_topic_mute(user_profile, stream_id, topic_name): - # type: (UserProfile, int, str) -> None +def remove_topic_mute(user_profile: UserProfile, stream_id: int, topic_name: str) -> None: row = MutedTopic.objects.get( user_profile=user_profile, stream_id=stream_id, @@ -68,8 +64,7 @@ def remove_topic_mute(user_profile, stream_id, topic_name): ) row.delete() -def topic_is_muted(user_profile, stream_id, topic_name): - # type: (UserProfile, int, Text) -> bool +def topic_is_muted(user_profile: UserProfile, stream_id: int, topic_name: Text) -> bool: is_muted = MutedTopic.objects.filter( user_profile=user_profile, stream_id=stream_id, @@ -77,8 +72,9 @@ def topic_is_muted(user_profile, stream_id, topic_name): ).exists() return is_muted -def exclude_topic_mutes(conditions, user_profile, stream_id): - # type: (List[Selectable], UserProfile, Optional[int]) -> List[Selectable] +def exclude_topic_mutes(conditions: List[Selectable], + user_profile: UserProfile, + stream_id: Optional[int]) -> List[Selectable]: query = MutedTopic.objects.filter( user_profile=user_profile, ) @@ -97,8 +93,7 @@ def exclude_topic_mutes(conditions, user_profile, stream_id): if not rows: return conditions - def mute_cond(row): - # type: (Dict[str, Any]) -> Selectable + def mute_cond(row: Dict[str, Any]) -> Selectable: recipient_id = row['recipient_id'] topic_name = row['topic_name'] stream_cond = column("recipient_id") == recipient_id @@ -108,8 +103,7 @@ def exclude_topic_mutes(conditions, user_profile, stream_id): condition = not_(or_(*list(map(mute_cond, rows)))) return conditions + [condition] -def build_topic_mute_checker(user_profile): - # type: (UserProfile) -> Callable[[int, Text], bool] +def build_topic_mute_checker(user_profile: UserProfile) -> Callable[[int, Text], bool]: rows = MutedTopic.objects.filter( user_profile=user_profile, ).values( @@ -124,8 +118,7 @@ def build_topic_mute_checker(user_profile): topic_name = row['topic_name'] tups.add((recipient_id, topic_name.lower())) - def is_muted(recipient_id, topic): - # type: (int, Text) -> bool + def is_muted(recipient_id: int, topic: Text) -> bool: return (recipient_id, topic.lower()) in tups return is_muted diff --git a/zerver/lib/url_preview/oembed/__init__.py b/zerver/lib/url_preview/oembed/__init__.py index bc4bc09344..c9c5f186d1 100644 --- a/zerver/lib/url_preview/oembed/__init__.py +++ b/zerver/lib/url_preview/oembed/__init__.py @@ -2,8 +2,9 @@ from typing import Optional, Text, Dict, Any from pyoembed import oEmbed, PyOembedException -def get_oembed_data(url, maxwidth=640, maxheight=480): - # type: (Text, Optional[int], Optional[int]) -> Optional[Dict[Any, Any]] +def get_oembed_data(url: Text, + maxwidth: Optional[int]=640, + maxheight: Optional[int]=480) -> Optional[Dict[Any, Any]]: try: data = oEmbed(url, maxwidth=maxwidth, maxheight=maxheight) except PyOembedException: diff --git a/zerver/lib/url_preview/parsers/base.py b/zerver/lib/url_preview/parsers/base.py index 6b09119bd0..1888802952 100644 --- a/zerver/lib/url_preview/parsers/base.py +++ b/zerver/lib/url_preview/parsers/base.py @@ -3,10 +3,8 @@ from bs4 import BeautifulSoup class BaseParser: - def __init__(self, html_source): - # type: (Text) -> None + def __init__(self, html_source: Text) -> None: self._soup = BeautifulSoup(html_source, "lxml") - def extract_data(self): - # type: () -> Any + def extract_data(self) -> Any: raise NotImplementedError() diff --git a/zerver/lib/url_preview/parsers/generic.py b/zerver/lib/url_preview/parsers/generic.py index 89bd7bb39e..d49a0c6651 100644 --- a/zerver/lib/url_preview/parsers/generic.py +++ b/zerver/lib/url_preview/parsers/generic.py @@ -3,15 +3,13 @@ from zerver.lib.url_preview.parsers.base import BaseParser class GenericParser(BaseParser): - def extract_data(self): - # type: () -> Dict[str, Optional[Text]] + def extract_data(self) -> Dict[str, Optional[Text]]: return { 'title': self._get_title(), 'description': self._get_description(), 'image': self._get_image()} - def _get_title(self): - # type: () -> Optional[Text] + def _get_title(self) -> Optional[Text]: soup = self._soup if (soup.title and soup.title.text != ''): return soup.title.text @@ -19,8 +17,7 @@ class GenericParser(BaseParser): return soup.h1.text return None - def _get_description(self): - # type: () -> Optional[Text] + def _get_description(self) -> Optional[Text]: soup = self._soup meta_description = soup.find('meta', attrs={'name': 'description'}) if (meta_description and meta_description['content'] != ''): @@ -35,8 +32,7 @@ class GenericParser(BaseParser): return first_p.string return None - def _get_image(self): - # type: () -> Optional[Text] + def _get_image(self) -> Optional[Text]: """ Finding a first image after the h1 header. Presumably it will be the main image. diff --git a/zerver/lib/url_preview/parsers/open_graph.py b/zerver/lib/url_preview/parsers/open_graph.py index f2728d0f9c..37ac51abce 100644 --- a/zerver/lib/url_preview/parsers/open_graph.py +++ b/zerver/lib/url_preview/parsers/open_graph.py @@ -4,8 +4,7 @@ from .base import BaseParser class OpenGraphParser(BaseParser): - def extract_data(self): - # type: () -> Dict[str, Text] + def extract_data(self) -> Dict[str, Text]: meta = self._soup.findAll('meta') content = {} for tag in meta: diff --git a/zerver/lib/url_preview/preview.py b/zerver/lib/url_preview/preview.py index 72ba3d1f74..184d3becb2 100644 --- a/zerver/lib/url_preview/preview.py +++ b/zerver/lib/url_preview/preview.py @@ -20,19 +20,18 @@ link_regex = re.compile( r'(?:/?|[/?]\S+)$', re.IGNORECASE) -def is_link(url): - # type: (Text) -> Match[Text] +def is_link(url: Text) -> Match[Text]: return link_regex.match(smart_text(url)) -def cache_key_func(url): - # type: (Text) -> Text +def cache_key_func(url: Text) -> Text: return url @cache_with_key(cache_key_func, cache_name=CACHE_NAME, with_statsd_key="urlpreview_data") -def get_link_embed_data(url, maxwidth=640, maxheight=480): - # type: (Text, Optional[int], Optional[int]) -> Optional[Dict[Any, Any]] +def get_link_embed_data(url: Text, + maxwidth: Optional[int]=640, + maxheight: Optional[int]=480) -> Optional[Dict[Any, Any]]: if not is_link(url): return None # Fetch information from URL. @@ -60,6 +59,5 @@ def get_link_embed_data(url, maxwidth=640, maxheight=480): @get_cache_with_key(cache_key_func, cache_name=CACHE_NAME) -def link_embed_data_from_cache(url, maxwidth=640, maxheight=480): - # type: (Text, Optional[int], Optional[int]) -> Any +def link_embed_data_from_cache(url: Text, maxwidth: Optional[int]=640, maxheight: Optional[int]=480) -> Any: return diff --git a/zerver/lib/users.py b/zerver/lib/users.py index 72a4391ba0..2fbd750a0a 100644 --- a/zerver/lib/users.py +++ b/zerver/lib/users.py @@ -8,8 +8,7 @@ from zerver.lib.request import JsonableError from zerver.models import UserProfile, Service, Realm, \ get_user_profile_by_id, user_profile_by_email_cache_key -def check_full_name(full_name_raw): - # type: (Text) -> Text +def check_full_name(full_name_raw: Text) -> Text: full_name = full_name_raw.strip() if len(full_name) > UserProfile.MAX_NAME_LENGTH: raise JsonableError(_("Name too long!")) @@ -19,20 +18,17 @@ def check_full_name(full_name_raw): raise JsonableError(_("Invalid characters in name!")) return full_name -def check_short_name(short_name_raw): - # type: (Text) -> Text +def check_short_name(short_name_raw: Text) -> Text: short_name = short_name_raw.strip() if len(short_name) == 0: raise JsonableError(_("Bad name or username")) return short_name -def check_valid_bot_type(bot_type): - # type: (int) -> None +def check_valid_bot_type(bot_type: int) -> None: if bot_type not in UserProfile.ALLOWED_BOT_TYPES: raise JsonableError(_('Invalid bot type')) -def check_valid_interface_type(interface_type): - # type: (int) -> None +def check_valid_interface_type(interface_type: int) -> None: if interface_type not in Service.ALLOWED_INTERFACE_TYPES: raise JsonableError(_('Invalid interface type')) diff --git a/zerver/lib/utils.py b/zerver/lib/utils.py index 8a91238cba..4bab5390cb 100644 --- a/zerver/lib/utils.py +++ b/zerver/lib/utils.py @@ -15,8 +15,7 @@ from django.conf import settings T = TypeVar('T') -def statsd_key(val, clean_periods=False): - # type: (Any, bool) -> str +def statsd_key(val: Any, clean_periods: bool=False) -> str: if not isinstance(val, str): val = str(val) @@ -35,8 +34,7 @@ class StatsDWrapper: # Backported support for gauge deltas # as our statsd server supports them but supporting # pystatsd is not released yet - def _our_gauge(self, stat, value, rate=1, delta=False): - # type: (str, float, float, bool) -> None + def _our_gauge(self, stat: str, value: float, rate: float=1, delta: bool=False) -> None: """Set a gauge value.""" from django_statsd.clients import statsd if delta: @@ -45,8 +43,7 @@ class StatsDWrapper: value_str = '%g|g' % (value,) statsd._send(stat, value_str, rate) - def __getattr__(self, name): - # type: (str) -> Any + def __getattr__(self, name: str) -> Any: # Hand off to statsd if we have it enabled # otherwise do nothing if name in ['timer', 'timing', 'incr', 'decr', 'gauge']: @@ -64,8 +61,11 @@ class StatsDWrapper: statsd = StatsDWrapper() # Runs the callback with slices of all_list of a given batch_size -def run_in_batches(all_list, batch_size, callback, sleep_time = 0, logger = None): - # type: (Sequence[T], int, Callable[[Sequence[T]], None], int, Optional[Callable[[str], None]]) -> None +def run_in_batches(all_list: Sequence[T], + batch_size: int, + callback: Callable[[Sequence[T]], None], + sleep_time: int=0, + logger: Optional[Callable[[str], None]]=None) -> None: if len(all_list) == 0: return @@ -85,8 +85,8 @@ def run_in_batches(all_list, batch_size, callback, sleep_time = 0, logger = None if i != limit - 1: sleep(sleep_time) -def make_safe_digest(string, hash_func=hashlib.sha1): - # type: (Text, Callable[[bytes], Any]) -> Text +def make_safe_digest(string: Text, + hash_func: Callable[[bytes], Any]=hashlib.sha1) -> Text: """ return a hex digest of `string`. """ @@ -95,8 +95,7 @@ def make_safe_digest(string, hash_func=hashlib.sha1): return hash_func(string.encode('utf-8')).hexdigest() -def log_statsd_event(name): - # type: (str) -> None +def log_statsd_event(name: str) -> None: """ Sends a single event to statsd with the desired name and the current timestamp @@ -110,12 +109,13 @@ def log_statsd_event(name): event_name = "events.%s" % (name,) statsd.incr(event_name) -def generate_random_token(length): - # type: (int) -> str +def generate_random_token(length: int) -> str: return str(base64.b16encode(os.urandom(length // 2)).decode('utf-8').lower()) -def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=None): - # type: (List[Any], Set[int], int, int) -> Iterable[Any] +def query_chunker(queries: List[Any], + id_collector: Set[int]=None, + chunk_size: int=1000, + db_chunk_size: int=None) -> Iterable[Any]: ''' This merges one or more Django ascending-id queries into a generator that returns chunks of chunk_size row objects @@ -142,8 +142,7 @@ def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=Non else: id_collector = set() - def chunkify(q, i): - # type: (Any, int) -> Iterable[Tuple[int, int, Any]] + def chunkify(q: Any, i: int) -> Iterable[Tuple[int, int, Any]]: q = q.order_by('id') min_id = -1 while True: @@ -171,8 +170,7 @@ def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=Non yield [row for row_id, i, row in tup_chunk] -def split_by(array, group_size, filler): - # type: (List[Any], int, Any) -> List[List[Any]] +def split_by(array: List[Any], group_size: int, filler: Any) -> List[List[Any]]: """ Group elements into list of size `group_size` and fill empty cells with `filler`. Recipe from https://docs.python.org/3/library/itertools.html @@ -180,8 +178,7 @@ def split_by(array, group_size, filler): args = [iter(array)] * group_size return list(map(list, zip_longest(*args, fillvalue=filler))) -def is_remote_server(identifier): - # type: (Text) -> bool +def is_remote_server(identifier: Text) -> bool: """ This function can be used to identify the source of API auth request. We can have two types of sources, Remote Zulip Servers diff --git a/zerver/lib/webhooks/git.py b/zerver/lib/webhooks/git.py index b8856812da..c72d9eba7b 100644 --- a/zerver/lib/webhooks/git.py +++ b/zerver/lib/webhooks/git.py @@ -98,8 +98,7 @@ def get_push_commits_event_message(user_name, compare_url, branch_name, commits_data=get_commits_content(commits_data, is_truncated), ).rstrip() -def get_force_push_commits_event_message(user_name, url, branch_name, head): - # type: (Text, Text, Text, Text) -> Text +def get_force_push_commits_event_message(user_name: Text, url: Text, branch_name: Text, head: Text) -> Text: return FORCE_PUSH_COMMITS_MESSAGE_TEMPLATE.format( user_name=user_name, url=url, @@ -107,16 +106,14 @@ def get_force_push_commits_event_message(user_name, url, branch_name, head): head=head ) -def get_create_branch_event_message(user_name, url, branch_name): - # type: (Text, Text, Text) -> Text +def get_create_branch_event_message(user_name: Text, url: Text, branch_name: Text) -> Text: return CREATE_BRANCH_MESSAGE_TEMPLATE.format( user_name=user_name, url=url, branch_name=branch_name, ) -def get_remove_branch_event_message(user_name, branch_name): - # type: (Text, Text) -> Text +def get_remove_branch_event_message(user_name: Text, branch_name: Text) -> Text: return REMOVE_BRANCH_MESSAGE_TEMPLATE.format( user_name=user_name, branch_name=branch_name, @@ -147,15 +144,18 @@ def get_pull_request_event_message( main_message += '\n' + CONTENT_MESSAGE_TEMPLATE.format(message=message) return main_message.rstrip() -def get_setup_webhook_message(integration, user_name=None): - # type: (Text, Optional[Text]) -> Text +def get_setup_webhook_message(integration: Text, user_name: Optional[Text]=None) -> Text: content = SETUP_MESSAGE_TEMPLATE.format(integration=integration) if user_name: content += SETUP_MESSAGE_USER_PART.format(user_name=user_name) return content -def get_issue_event_message(user_name, action, url, number=None, message=None, assignee=None): - # type: (Text, Text, Text, Optional[int], Optional[Text], Optional[Text]) -> Text +def get_issue_event_message(user_name: Text, + action: Text, + url: Text, + number: Optional[int]=None, + message: Optional[Text]=None, + assignee: Optional[Text]=None) -> Text: return get_pull_request_event_message( user_name, action, @@ -166,8 +166,10 @@ def get_issue_event_message(user_name, action, url, number=None, message=None, a type='Issue' ) -def get_push_tag_event_message(user_name, tag_name, tag_url=None, action='pushed'): - # type: (Text, Text, Optional[Text], Optional[Text]) -> Text +def get_push_tag_event_message(user_name: Text, + tag_name: Text, + tag_url: Optional[Text]=None, + action: Optional[Text]='pushed') -> Text: if tag_url: tag_part = TAG_WITH_URL_TEMPLATE.format(tag_name=tag_name, tag_url=tag_url) else: @@ -178,8 +180,11 @@ def get_push_tag_event_message(user_name, tag_name, tag_url=None, action='pushed tag=tag_part ) -def get_commits_comment_action_message(user_name, action, commit_url, sha, message=None): - # type: (Text, Text, Text, Text, Optional[Text]) -> Text +def get_commits_comment_action_message(user_name: Text, + action: Text, + commit_url: Text, + sha: Text, + message: Optional[Text]=None) -> Text: content = COMMITS_COMMENT_MESSAGE_TEMPLATE.format( user_name=user_name, action=action, @@ -192,8 +197,7 @@ def get_commits_comment_action_message(user_name, action, commit_url, sha, messa ) return content -def get_commits_content(commits_data, is_truncated=False): - # type: (List[Dict[str, Any]], Optional[bool]) -> Text +def get_commits_content(commits_data: List[Dict[str, Any]], is_truncated: Optional[bool]=False) -> Text: commits_content = '' for commit in commits_data[:COMMITS_LIMIT]: commits_content += COMMIT_ROW_TEMPLATE.format( @@ -212,12 +216,10 @@ def get_commits_content(commits_data, is_truncated=False): ).replace(' ', ' ') return commits_content.rstrip() -def get_short_sha(sha): - # type: (Text) -> Text +def get_short_sha(sha: Text) -> Text: return sha[:7] -def get_all_committers(commits_data): - # type: (List[Dict[str, Any]]) -> List[Tuple[str, int]] +def get_all_committers(commits_data: List[Dict[str, Any]]) -> List[Tuple[str, int]]: committers = defaultdict(int) # type: Dict[str, int] for commit in commits_data: