zerver/lib: Use python 3 syntax for typing.

Extracted from a larger commit by tabbott because these changes will
not create significant merge conflicts.
This commit is contained in:
rht 2017-11-05 11:15:10 +01:00 committed by Tim Abbott
parent 561ba33f69
commit 3f4bf2d22f
35 changed files with 388 additions and 573 deletions

View File

@ -11,8 +11,7 @@ from zerver.models import (
get_user_including_cross_realm, get_user_including_cross_realm,
) )
def user_profiles_from_unvalidated_emails(emails, realm): def user_profiles_from_unvalidated_emails(emails: Iterable[Text], realm: Realm) -> List[UserProfile]:
# type: (Iterable[Text], Realm) -> List[UserProfile]
user_profiles = [] # type: List[UserProfile] user_profiles = [] # type: List[UserProfile]
for email in emails: for email in emails:
try: try:
@ -22,8 +21,7 @@ def user_profiles_from_unvalidated_emails(emails, realm):
user_profiles.append(user_profile) user_profiles.append(user_profile)
return user_profiles return user_profiles
def get_user_profiles(emails, realm): def get_user_profiles(emails: Iterable[Text], realm: Realm) -> List[UserProfile]:
# type: (Iterable[Text], Realm) -> List[UserProfile]
try: try:
return user_profiles_from_unvalidated_emails(emails, realm) return user_profiles_from_unvalidated_emails(emails, realm)
except ValidationError as e: except ValidationError as e:
@ -42,44 +40,43 @@ class Addressee:
# in memory. # in memory.
# #
# This should be treated as an immutable class. # This should be treated as an immutable class.
def __init__(self, msg_type, user_profiles=None, stream_name=None, topic=None): def __init__(self, msg_type: str,
# type: (str, Optional[Sequence[UserProfile]], Optional[Text], Text) -> None user_profiles: Optional[Sequence[UserProfile]]=None,
stream_name: Optional[Text]=None,
topic: Text=None) -> None:
assert(msg_type in ['stream', 'private']) assert(msg_type in ['stream', 'private'])
self._msg_type = msg_type self._msg_type = msg_type
self._user_profiles = user_profiles self._user_profiles = user_profiles
self._stream_name = stream_name self._stream_name = stream_name
self._topic = topic self._topic = topic
def msg_type(self): def msg_type(self) -> str:
# type: () -> str
return self._msg_type return self._msg_type
def is_stream(self): def is_stream(self) -> bool:
# type: () -> bool
return self._msg_type == 'stream' return self._msg_type == 'stream'
def is_private(self): def is_private(self) -> bool:
# type: () -> bool
return self._msg_type == 'private' return self._msg_type == 'private'
def user_profiles(self): def user_profiles(self) -> List[UserProfile]:
# type: () -> List[UserProfile]
assert(self.is_private()) assert(self.is_private())
return self._user_profiles # type: ignore # assertion protects us return self._user_profiles # type: ignore # assertion protects us
def stream_name(self): def stream_name(self) -> Text:
# type: () -> Text
assert(self.is_stream()) assert(self.is_stream())
return self._stream_name return self._stream_name
def topic(self): def topic(self) -> Text:
# type: () -> Text
assert(self.is_stream()) assert(self.is_stream())
return self._topic return self._topic
@staticmethod @staticmethod
def legacy_build(sender, message_type_name, message_to, topic_name, realm=None): def legacy_build(sender: UserProfile,
# type: (UserProfile, Text, Sequence[Text], Text, Optional[Realm]) -> Addressee message_type_name: Text,
message_to: Sequence[Text],
topic_name: Text,
realm: Optional[Realm]=None) -> 'Addressee':
# For legacy reason message_to used to be either a list of # For legacy reason message_to used to be either a list of
# emails or a list of streams. We haven't fixed all of our # emails or a list of streams. We haven't fixed all of our
@ -111,8 +108,7 @@ class Addressee:
raise JsonableError(_("Invalid message type")) raise JsonableError(_("Invalid message type"))
@staticmethod @staticmethod
def for_stream(stream_name, topic): def for_stream(stream_name: Text, topic: Text) -> 'Addressee':
# type: (Text, Text) -> Addressee
return Addressee( return Addressee(
msg_type='stream', msg_type='stream',
stream_name=stream_name, stream_name=stream_name,
@ -120,8 +116,7 @@ class Addressee:
) )
@staticmethod @staticmethod
def for_private(emails, realm): def for_private(emails: Sequence[Text], realm: Realm) -> 'Addressee':
# type: (Sequence[Text], Realm) -> Addressee
user_profiles = get_user_profiles(emails, realm) user_profiles = get_user_profiles(emails, realm)
return Addressee( return Addressee(
msg_type='private', msg_type='private',
@ -129,8 +124,7 @@ class Addressee:
) )
@staticmethod @staticmethod
def for_user_profile(user_profile): def for_user_profile(user_profile: UserProfile) -> 'Addressee':
# type: (UserProfile) -> Addressee
user_profiles = [user_profile] user_profiles = [user_profile]
return Addressee( return Addressee(
msg_type='private', msg_type='private',

View File

@ -6,13 +6,12 @@ from zerver.lib.request import JsonableError
from zerver.lib.upload import delete_message_image from zerver.lib.upload import delete_message_image
from zerver.models import Attachment, UserProfile from zerver.models import Attachment, UserProfile
def user_attachments(user_profile): def user_attachments(user_profile: UserProfile) -> List[Dict[str, Any]]:
# type: (UserProfile) -> List[Dict[str, Any]]
attachments = Attachment.objects.filter(owner=user_profile).prefetch_related('messages') attachments = Attachment.objects.filter(owner=user_profile).prefetch_related('messages')
return [a.to_dict() for a in attachments] return [a.to_dict() for a in attachments]
def access_attachment_by_id(user_profile, attachment_id, needs_owner=False): def access_attachment_by_id(user_profile: UserProfile, attachment_id: int,
# type: (UserProfile, int, bool) -> Attachment needs_owner: bool=False) -> Attachment:
query = Attachment.objects.filter(id=attachment_id) query = Attachment.objects.filter(id=attachment_id)
if needs_owner: if needs_owner:
query = query.filter(owner=user_profile) query = query.filter(owner=user_profile)
@ -22,8 +21,7 @@ def access_attachment_by_id(user_profile, attachment_id, needs_owner=False):
raise JsonableError(_("Invalid attachment")) raise JsonableError(_("Invalid attachment"))
return attachment return attachment
def remove_attachment(user_profile, attachment): def remove_attachment(user_profile: UserProfile, attachment: Attachment) -> None:
# type: (UserProfile, Attachment) -> None
try: try:
delete_message_image(attachment.path_id) delete_message_image(attachment.path_id)
except Exception: except Exception:

View File

@ -24,8 +24,7 @@ our_dir = os.path.dirname(os.path.abspath(__file__))
from zulip_bots.lib import RateLimit from zulip_bots.lib import RateLimit
def get_bot_handler(service_name): def get_bot_handler(service_name: str) -> Any:
# type: (str) -> Any
# Check that this service is present in EMBEDDED_BOTS, add exception handling. # Check that this service is present in EMBEDDED_BOTS, add exception handling.
is_present_in_registry = any(service_name == embedded_bot_service.name for is_present_in_registry = any(service_name == embedded_bot_service.name for
@ -40,31 +39,25 @@ def get_bot_handler(service_name):
class StateHandler: class StateHandler:
state_size_limit = 10000000 # type: int # TODO: Store this in the server configuration model. state_size_limit = 10000000 # type: int # TODO: Store this in the server configuration model.
def __init__(self, user_profile): def __init__(self, user_profile: UserProfile) -> None:
# type: (UserProfile) -> None
self.user_profile = user_profile self.user_profile = user_profile
self.marshal = lambda obj: json.dumps(obj) self.marshal = lambda obj: json.dumps(obj)
self.demarshal = lambda obj: json.loads(obj) self.demarshal = lambda obj: json.loads(obj)
def get(self, key): def get(self, key: Text) -> Text:
# type: (Text) -> Text
return self.demarshal(get_bot_state(self.user_profile, key)) return self.demarshal(get_bot_state(self.user_profile, key))
def put(self, key, value): def put(self, key: Text, value: Text) -> None:
# type: (Text, Text) -> None
set_bot_state(self.user_profile, key, self.marshal(value)) set_bot_state(self.user_profile, key, self.marshal(value))
def remove(self, key): def remove(self, key: Text) -> None:
# type: (Text) -> None
remove_bot_state(self.user_profile, key) remove_bot_state(self.user_profile, key)
def contains(self, key): def contains(self, key: Text) -> bool:
# type: (Text) -> bool
return is_key_in_bot_state(self.user_profile, key) return is_key_in_bot_state(self.user_profile, key)
class EmbeddedBotHandler: class EmbeddedBotHandler:
def __init__(self, user_profile): def __init__(self, user_profile: UserProfile) -> None:
# type: (UserProfile) -> None
# Only expose a subset of our UserProfile's functionality # Only expose a subset of our UserProfile's functionality
self.user_profile = user_profile self.user_profile = user_profile
self._rate_limit = RateLimit(20, 5) self._rate_limit = RateLimit(20, 5)
@ -72,8 +65,7 @@ class EmbeddedBotHandler:
self.email = user_profile.email self.email = user_profile.email
self.storage = StateHandler(user_profile) self.storage = StateHandler(user_profile)
def send_message(self, message): def send_message(self, message: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
if self._rate_limit.is_legal(): if self._rate_limit.is_legal():
recipients = message['to'] if message['type'] == 'stream' else ','.join(message['to']) recipients = message['to'] if message['type'] == 'stream' else ','.join(message['to'])
internal_send_message(realm=self.user_profile.realm, sender_email=self.user_profile.email, internal_send_message(realm=self.user_profile.realm, sender_email=self.user_profile.email,
@ -82,8 +74,7 @@ class EmbeddedBotHandler:
else: else:
self._rate_limit.show_error_and_exit() self._rate_limit.show_error_and_exit()
def send_reply(self, message, response): def send_reply(self, message: Dict[str, Any], response: str) -> None:
# type: (Dict[str, Any], str) -> None
if message['type'] == 'private': if message['type'] == 'private':
self.send_message(dict( self.send_message(dict(
type='private', type='private',
@ -100,6 +91,5 @@ class EmbeddedBotHandler:
sender_email=message['sender_email'], sender_email=message['sender_email'],
)) ))
def get_config_info(self): def get_config_info(self) -> Dict[Text, Text]:
# type: () -> Dict[Text, Text]
return get_bot_config(self.user_profile) return get_bot_config(self.user_profile)

View File

@ -221,8 +221,7 @@ EMOJI_TWEET = """{
] ]
}""" }"""
def twitter(tweet_id): def twitter(tweet_id: Text) -> Optional[Dict[Text, Any]]:
# type: (Text) -> Optional[Dict[Text, Any]]
if tweet_id in ["112652479837110273", "287977969287315456", "287977969287315457"]: if tweet_id in ["112652479837110273", "287977969287315456", "287977969287315457"]:
return ujson.loads(NORMAL_TWEET) return ujson.loads(NORMAL_TWEET)
elif tweet_id == "287977969287315458": elif tweet_id == "287977969287315458":

View File

@ -22,8 +22,7 @@ from django.db.models import Q
MESSAGE_CACHE_SIZE = 75000 MESSAGE_CACHE_SIZE = 75000
def message_fetch_objects(): def message_fetch_objects() -> List[Any]:
# type: () -> List[Any]
try: try:
max_id = Message.objects.only('id').order_by("-id")[0].id max_id = Message.objects.only('id').order_by("-id")[0].id
except IndexError: except IndexError:
@ -31,8 +30,8 @@ def message_fetch_objects():
return Message.objects.select_related().filter(~Q(sender__email='tabbott/extra@mit.edu'), return Message.objects.select_related().filter(~Q(sender__email='tabbott/extra@mit.edu'),
id__gt=max_id - MESSAGE_CACHE_SIZE) id__gt=max_id - MESSAGE_CACHE_SIZE)
def message_cache_items(items_for_remote_cache, message): def message_cache_items(items_for_remote_cache: Dict[Text, Tuple[bytes]],
# type: (Dict[Text, Tuple[bytes]], Message) -> None message: Message) -> None:
''' '''
Note: this code is untested, and the caller has been Note: this code is untested, and the caller has been
commented out for a while. commented out for a while.
@ -41,32 +40,32 @@ def message_cache_items(items_for_remote_cache, message):
value = MessageDict.to_dict_uncached(message) value = MessageDict.to_dict_uncached(message)
items_for_remote_cache[key] = (value,) items_for_remote_cache[key] = (value,)
def user_cache_items(items_for_remote_cache, user_profile): def user_cache_items(items_for_remote_cache: Dict[Text, Tuple[UserProfile]],
# type: (Dict[Text, Tuple[UserProfile]], UserProfile) -> None user_profile: UserProfile) -> None:
items_for_remote_cache[user_profile_by_email_cache_key(user_profile.email)] = (user_profile,) items_for_remote_cache[user_profile_by_email_cache_key(user_profile.email)] = (user_profile,)
items_for_remote_cache[user_profile_by_id_cache_key(user_profile.id)] = (user_profile,) items_for_remote_cache[user_profile_by_id_cache_key(user_profile.id)] = (user_profile,)
items_for_remote_cache[user_profile_by_api_key_cache_key(user_profile.api_key)] = (user_profile,) items_for_remote_cache[user_profile_by_api_key_cache_key(user_profile.api_key)] = (user_profile,)
items_for_remote_cache[user_profile_cache_key(user_profile.email, user_profile.realm)] = (user_profile,) items_for_remote_cache[user_profile_cache_key(user_profile.email, user_profile.realm)] = (user_profile,)
def stream_cache_items(items_for_remote_cache, stream): def stream_cache_items(items_for_remote_cache: Dict[Text, Tuple[Stream]],
# type: (Dict[Text, Tuple[Stream]], Stream) -> None stream: Stream) -> None:
items_for_remote_cache[get_stream_cache_key(stream.name, stream.realm_id)] = (stream,) items_for_remote_cache[get_stream_cache_key(stream.name, stream.realm_id)] = (stream,)
def client_cache_items(items_for_remote_cache, client): def client_cache_items(items_for_remote_cache: Dict[Text, Tuple[Client]],
# type: (Dict[Text, Tuple[Client]], Client) -> None client: Client) -> None:
items_for_remote_cache[get_client_cache_key(client.name)] = (client,) items_for_remote_cache[get_client_cache_key(client.name)] = (client,)
def huddle_cache_items(items_for_remote_cache, huddle): def huddle_cache_items(items_for_remote_cache: Dict[Text, Tuple[Huddle]],
# type: (Dict[Text, Tuple[Huddle]], Huddle) -> None huddle: Huddle) -> None:
items_for_remote_cache[huddle_hash_cache_key(huddle.huddle_hash)] = (huddle,) items_for_remote_cache[huddle_hash_cache_key(huddle.huddle_hash)] = (huddle,)
def recipient_cache_items(items_for_remote_cache, recipient): def recipient_cache_items(items_for_remote_cache: Dict[Text, Tuple[Recipient]],
# type: (Dict[Text, Tuple[Recipient]], Recipient) -> None recipient: Recipient) -> None:
items_for_remote_cache[get_recipient_cache_key(recipient.type, recipient.type_id)] = (recipient,) items_for_remote_cache[get_recipient_cache_key(recipient.type, recipient.type_id)] = (recipient,)
session_engine = import_module(settings.SESSION_ENGINE) session_engine = import_module(settings.SESSION_ENGINE)
def session_cache_items(items_for_remote_cache, session): def session_cache_items(items_for_remote_cache: Dict[Text, Text],
# type: (Dict[Text, Text], Session) -> None session: Session) -> None:
store = session_engine.SessionStore(session_key=session.session_key) # type: ignore # import_module store = session_engine.SessionStore(session_key=session.session_key) # type: ignore # import_module
items_for_remote_cache[store.cache_key] = store.decode(session.session_data) items_for_remote_cache[store.cache_key] = store.decode(session.session_data)
@ -89,8 +88,7 @@ cache_fillers = {
'session': (lambda: Session.objects.all(), session_cache_items, 3600*24*7, 10000), 'session': (lambda: Session.objects.all(), session_cache_items, 3600*24*7, 10000),
} # type: Dict[str, Tuple[Callable[[], List[Any]], Callable[[Dict[Text, Any], Any], None], int, int]] } # type: Dict[str, Tuple[Callable[[], List[Any]], Callable[[Dict[Text, Any], Any], None], int, int]]
def fill_remote_cache(cache): def fill_remote_cache(cache: str) -> None:
# type: (str) -> None
remote_cache_time_start = get_remote_cache_time() remote_cache_time_start = get_remote_cache_time()
remote_cache_requests_start = get_remote_cache_requests() remote_cache_requests_start = get_remote_cache_requests()
items_for_remote_cache = {} # type: Dict[Text, Any] items_for_remote_cache = {} # type: Dict[Text, Any]

View File

@ -9,8 +9,7 @@ import string
from typing import Optional, Text from typing import Optional, Text
def random_api_key(): def random_api_key() -> Text:
# type: () -> Text
choices = string.ascii_letters + string.digits choices = string.ascii_letters + string.digits
altchars = ''.join([choices[ord(os.urandom(1)) % 62] for _ in range(2)]).encode("utf-8") altchars = ''.join([choices[ord(os.urandom(1)) % 62] for _ in range(2)]).encode("utf-8")
return base64.b64encode(os.urandom(24), altchars=altchars).decode("utf-8") return base64.b64encode(os.urandom(24), altchars=altchars).decode("utf-8")

View File

@ -12,8 +12,7 @@ from typing import Optional
# (that link also points to code for an interactive remote debugger # (that link also points to code for an interactive remote debugger
# setup, which we might want if we move Tornado to run in a daemon # setup, which we might want if we move Tornado to run in a daemon
# rather than via screen). # rather than via screen).
def interactive_debug(sig, frame): def interactive_debug(sig: int, frame: FrameType) -> None:
# type: (int, FrameType) -> None
"""Interrupt running process, and provide a python prompt for """Interrupt running process, and provide a python prompt for
interactive debugging.""" interactive debugging."""
d = {'_frame': frame} # Allow access to frame object. d = {'_frame': frame} # Allow access to frame object.
@ -27,7 +26,6 @@ def interactive_debug(sig, frame):
# SIGUSR1 => Just print the stack # SIGUSR1 => Just print the stack
# SIGUSR2 => Print stack + open interactive debugging shell # SIGUSR2 => Print stack + open interactive debugging shell
def interactive_debug_listen(): def interactive_debug_listen() -> None:
# type: () -> None
signal.signal(signal.SIGUSR1, lambda sig, stack: traceback.print_stack(stack)) signal.signal(signal.SIGUSR1, lambda sig, stack: traceback.print_stack(stack))
signal.signal(signal.SIGUSR2, interactive_debug) signal.signal(signal.SIGUSR2, interactive_debug)

View File

@ -31,8 +31,7 @@ DIGEST_CUTOFF = 5
# 4. Interesting stream traffic, as determined by the longest and most # 4. Interesting stream traffic, as determined by the longest and most
# diversely comment upon topics. # diversely comment upon topics.
def inactive_since(user_profile, cutoff): def inactive_since(user_profile: UserProfile, cutoff: datetime.datetime) -> bool:
# type: (UserProfile, datetime.datetime) -> bool
# Hasn't used the app in the last DIGEST_CUTOFF (5) days. # Hasn't used the app in the last DIGEST_CUTOFF (5) days.
most_recent_visit = [row.last_visit for row in most_recent_visit = [row.last_visit for row in
UserActivity.objects.filter( UserActivity.objects.filter(
@ -45,8 +44,7 @@ def inactive_since(user_profile, cutoff):
last_visit = max(most_recent_visit) last_visit = max(most_recent_visit)
return last_visit < cutoff return last_visit < cutoff
def should_process_digest(realm_str): def should_process_digest(realm_str: str) -> bool:
# type: (str) -> bool
if realm_str in settings.SYSTEM_ONLY_REALMS: if realm_str in settings.SYSTEM_ONLY_REALMS:
# Don't try to send emails to system-only realms # Don't try to send emails to system-only realms
return False return False
@ -54,15 +52,13 @@ def should_process_digest(realm_str):
# Changes to this should also be reflected in # Changes to this should also be reflected in
# zerver/worker/queue_processors.py:DigestWorker.consume() # zerver/worker/queue_processors.py:DigestWorker.consume()
def queue_digest_recipient(user_profile, cutoff): def queue_digest_recipient(user_profile: UserProfile, cutoff: datetime.datetime) -> None:
# type: (UserProfile, datetime.datetime) -> None
# Convert cutoff to epoch seconds for transit. # Convert cutoff to epoch seconds for transit.
event = {"user_profile_id": user_profile.id, event = {"user_profile_id": user_profile.id,
"cutoff": cutoff.strftime('%s')} "cutoff": cutoff.strftime('%s')}
queue_json_publish("digest_emails", event, lambda event: None, call_consume_in_tests=True) queue_json_publish("digest_emails", event, lambda event: None, call_consume_in_tests=True)
def enqueue_emails(cutoff): def enqueue_emails(cutoff: datetime.datetime) -> None:
# type: (datetime.datetime) -> None
# To be really conservative while we don't have user timezones or # To be really conservative while we don't have user timezones or
# special-casing for companies with non-standard workweeks, only # special-casing for companies with non-standard workweeks, only
# try to send mail on Tuesdays. # try to send mail on Tuesdays.
@ -82,8 +78,7 @@ def enqueue_emails(cutoff):
logger.info("%s is inactive, queuing for potential digest" % ( logger.info("%s is inactive, queuing for potential digest" % (
user_profile.email,)) user_profile.email,))
def gather_hot_conversations(user_profile, stream_messages): def gather_hot_conversations(user_profile: UserProfile, stream_messages: QuerySet) -> List[Dict[str, Any]]:
# type: (UserProfile, QuerySet) -> List[Dict[str, Any]]
# Gather stream conversations of 2 types: # Gather stream conversations of 2 types:
# 1. long conversations # 1. long conversations
# 2. conversations where many different people participated # 2. conversations where many different people participated
@ -146,8 +141,7 @@ def gather_hot_conversations(user_profile, stream_messages):
hot_conversation_render_payloads.append(teaser_data) hot_conversation_render_payloads.append(teaser_data)
return hot_conversation_render_payloads return hot_conversation_render_payloads
def gather_new_users(user_profile, threshold): def gather_new_users(user_profile: UserProfile, threshold: datetime.datetime) -> Tuple[int, List[Text]]:
# type: (UserProfile, datetime.datetime) -> Tuple[int, List[Text]]
# Gather information on users in the realm who have recently # Gather information on users in the realm who have recently
# joined. # joined.
if user_profile.realm.is_zephyr_mirror_realm: if user_profile.realm.is_zephyr_mirror_realm:
@ -160,8 +154,8 @@ def gather_new_users(user_profile, threshold):
return len(user_names), user_names return len(user_names), user_names
def gather_new_streams(user_profile, threshold): def gather_new_streams(user_profile: UserProfile,
# type: (UserProfile, datetime.datetime) -> Tuple[int, Dict[str, List[Text]]] threshold: datetime.datetime) -> Tuple[int, Dict[str, List[Text]]]:
if user_profile.realm.is_zephyr_mirror_realm: if user_profile.realm.is_zephyr_mirror_realm:
new_streams = [] # type: List[Stream] new_streams = [] # type: List[Stream]
else: else:
@ -181,8 +175,7 @@ def gather_new_streams(user_profile, threshold):
return len(new_streams), {"html": streams_html, "plain": streams_plain} return len(new_streams), {"html": streams_html, "plain": streams_plain}
def enough_traffic(unread_pms, hot_conversations, new_streams, new_users): def enough_traffic(unread_pms: Text, hot_conversations: Text, new_streams: int, new_users: int) -> bool:
# type: (Text, Text, int, int) -> bool
if unread_pms or hot_conversations: if unread_pms or hot_conversations:
# If you have any unread traffic, good enough. # If you have any unread traffic, good enough.
return True return True
@ -192,8 +185,7 @@ def enough_traffic(unread_pms, hot_conversations, new_streams, new_users):
return True return True
return False return False
def handle_digest_email(user_profile_id, cutoff): def handle_digest_email(user_profile_id: int, cutoff: float) -> None:
# type: (int, float) -> None
user_profile = get_user_profile_by_id(user_profile_id) user_profile = get_user_profile_by_id(user_profile_id)
# We are disabling digest emails for soft deactivated users for the time. # We are disabling digest emails for soft deactivated users for the time.

View File

@ -4,8 +4,7 @@ from django.utils.translation import ugettext as _
import re import re
from typing import Text from typing import Text
def validate_domain(domain): def validate_domain(domain: Text) -> None:
# type: (Text) -> None
if domain is None or len(domain) == 0: if domain is None or len(domain) == 0:
raise ValidationError(_("Domain can't be empty.")) raise ValidationError(_("Domain can't be empty."))
if '.' not in domain: if '.' not in domain:

View File

@ -20,8 +20,7 @@ with open(NAME_TO_CODEPOINT_PATH) as fp:
with open(CODEPOINT_TO_NAME_PATH) as fp: with open(CODEPOINT_TO_NAME_PATH) as fp:
codepoint_to_name = ujson.load(fp) codepoint_to_name = ujson.load(fp)
def emoji_name_to_emoji_code(realm, emoji_name): def emoji_name_to_emoji_code(realm: Realm, emoji_name: Text) -> Tuple[Text, Text]:
# type: (Realm, Text) -> Tuple[Text, Text]
realm_emojis = realm.get_emoji() realm_emojis = realm.get_emoji()
if emoji_name in realm_emojis and not realm_emojis[emoji_name]['deactivated']: if emoji_name in realm_emojis and not realm_emojis[emoji_name]['deactivated']:
return emoji_name, Reaction.REALM_EMOJI return emoji_name, Reaction.REALM_EMOJI
@ -31,8 +30,7 @@ def emoji_name_to_emoji_code(realm, emoji_name):
return name_to_codepoint[emoji_name], Reaction.UNICODE_EMOJI return name_to_codepoint[emoji_name], Reaction.UNICODE_EMOJI
raise JsonableError(_("Emoji '%s' does not exist" % (emoji_name,))) raise JsonableError(_("Emoji '%s' does not exist" % (emoji_name,)))
def check_valid_emoji(realm, emoji_name): def check_valid_emoji(realm: Realm, emoji_name: Text) -> None:
# type: (Realm, Text) -> None
emoji_name_to_emoji_code(realm, emoji_name) emoji_name_to_emoji_code(realm, emoji_name)
def check_emoji_request(realm: Realm, emoji_name: str, emoji_code: str, def check_emoji_request(realm: Realm, emoji_name: str, emoji_code: str,
@ -61,8 +59,7 @@ def check_emoji_request(realm: Realm, emoji_name: str, emoji_code: str,
# The above are the only valid emoji types # The above are the only valid emoji types
raise JsonableError(_("Invalid emoji type.")) raise JsonableError(_("Invalid emoji type."))
def check_emoji_admin(user_profile, emoji_name=None): def check_emoji_admin(user_profile: UserProfile, emoji_name: Optional[Text]=None) -> None:
# type: (UserProfile, Optional[Text]) -> None
"""Raises an exception if the user cannot administer the target realm """Raises an exception if the user cannot administer the target realm
emoji name in their organization.""" emoji name in their organization."""
@ -84,18 +81,15 @@ def check_emoji_admin(user_profile, emoji_name=None):
if not user_profile.is_realm_admin and not current_user_is_author: if not user_profile.is_realm_admin and not current_user_is_author:
raise JsonableError(_("Must be a realm administrator or emoji author")) raise JsonableError(_("Must be a realm administrator or emoji author"))
def check_valid_emoji_name(emoji_name): def check_valid_emoji_name(emoji_name: Text) -> None:
# type: (Text) -> None
if re.match('^[0-9a-z.\-_]+(?<![.\-_])$', emoji_name): if re.match('^[0-9a-z.\-_]+(?<![.\-_])$', emoji_name):
return return
raise JsonableError(_("Invalid characters in emoji name")) raise JsonableError(_("Invalid characters in emoji name"))
def get_emoji_url(emoji_file_name, realm_id): def get_emoji_url(emoji_file_name: Text, realm_id: int) -> Text:
# type: (Text, int) -> Text
return upload_backend.get_emoji_url(emoji_file_name, realm_id) return upload_backend.get_emoji_url(emoji_file_name, realm_id)
def get_emoji_file_name(emoji_file_name, emoji_name): def get_emoji_file_name(emoji_file_name: Text, emoji_name: Text) -> Text:
# type: (Text, Text) -> Text
_, image_ext = os.path.splitext(emoji_file_name) _, image_ext = os.path.splitext(emoji_file_name)
return ''.join((emoji_name, image_ext)) return ''.join((emoji_name, image_ext))

View File

@ -16,15 +16,13 @@ from zerver.lib.actions import internal_send_message
from zerver.lib.response import json_success, json_error from zerver.lib.response import json_success, json_error
from version import ZULIP_VERSION from version import ZULIP_VERSION
def format_subject(subject): def format_subject(subject: str) -> str:
# type: (str) -> str
""" """
Escape CR and LF characters. Escape CR and LF characters.
""" """
return subject.replace('\n', '\\n').replace('\r', '\\r') return subject.replace('\n', '\\n').replace('\r', '\\r')
def user_info_str(report): def user_info_str(report: Dict[str, Any]) -> str:
# type: (Dict[str, Any]) -> str
if report['user_full_name'] and report['user_email']: if report['user_full_name'] and report['user_email']:
user_info = "%(user_full_name)s (%(user_email)s)" % (report) user_info = "%(user_full_name)s (%(user_email)s)" % (report)
else: else:
@ -59,15 +57,13 @@ def deployment_repr() -> str:
return deployment return deployment
def notify_browser_error(report): def notify_browser_error(report: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
report = defaultdict(lambda: None, report) report = defaultdict(lambda: None, report)
if settings.ERROR_BOT: if settings.ERROR_BOT:
zulip_browser_error(report) zulip_browser_error(report)
email_browser_error(report) email_browser_error(report)
def email_browser_error(report): def email_browser_error(report: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
subject = "Browser error for %s" % (user_info_str(report)) subject = "Browser error for %s" % (user_info_str(report))
body = ("User: %(user_full_name)s <%(user_email)s> on %(deployment)s\n\n" body = ("User: %(user_full_name)s <%(user_email)s> on %(deployment)s\n\n"
@ -89,8 +85,7 @@ def email_browser_error(report):
mail_admins(subject, body) mail_admins(subject, body)
def zulip_browser_error(report): def zulip_browser_error(report: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
subject = "JS error: %s" % (report['user_email'],) subject = "JS error: %s" % (report['user_email'],)
user_info = user_info_str(report) user_info = user_info_str(report)
@ -103,15 +98,13 @@ def zulip_browser_error(report):
internal_send_message(realm, settings.ERROR_BOT, internal_send_message(realm, settings.ERROR_BOT,
"stream", "errors", format_subject(subject), body) "stream", "errors", format_subject(subject), body)
def notify_server_error(report): def notify_server_error(report: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
report = defaultdict(lambda: None, report) report = defaultdict(lambda: None, report)
email_server_error(report) email_server_error(report)
if settings.ERROR_BOT: if settings.ERROR_BOT:
zulip_server_error(report) zulip_server_error(report)
def zulip_server_error(report): def zulip_server_error(report: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
subject = '%(node)s: %(message)s' % (report) subject = '%(node)s: %(message)s' % (report)
stack_trace = report['stack_trace'] or "No stack trace available" stack_trace = report['stack_trace'] or "No stack trace available"
@ -133,8 +126,7 @@ def zulip_server_error(report):
"Error generated by %s\n\n~~~~ pytb\n%s\n\n~~~~\n%s\n%s" "Error generated by %s\n\n~~~~ pytb\n%s\n\n~~~~\n%s\n%s"
% (user_info, stack_trace, deployment, request_repr)) % (user_info, stack_trace, deployment, request_repr))
def email_server_error(report): def email_server_error(report: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
subject = '%(node)s: %(message)s' % (report) subject = '%(node)s: %(message)s' % (report)
user_info = user_info_str(report) user_info = user_info_str(report)
@ -153,8 +145,7 @@ def email_server_error(report):
mail_admins(format_subject(subject), message, fail_silently=True) mail_admins(format_subject(subject), message, fail_silently=True)
def do_report_error(deployment_name, type, report): def do_report_error(deployment_name: Text, type: Text, report: Dict[str, Any]) -> HttpResponse:
# type: (Text, Text, Dict[str, Any]) -> HttpResponse
report['deployment'] = deployment_name report['deployment'] = deployment_name
if type == 'browser': if type == 'browser':
notify_browser_error(report) notify_browser_error(report)

View File

@ -47,12 +47,10 @@ from zproject.backends import email_auth_enabled, password_auth_enabled
from version import ZULIP_VERSION from version import ZULIP_VERSION
def get_raw_user_data(realm_id, client_gravatar): def get_raw_user_data(realm_id: int, client_gravatar: bool) -> Dict[int, Dict[str, Text]]:
# type: (int, bool) -> Dict[int, Dict[str, Text]]
user_dicts = get_realm_user_dicts(realm_id) user_dicts = get_realm_user_dicts(realm_id)
def user_data(row): def user_data(row: Dict[str, Any]) -> Dict[str, Any]:
# type: (Dict[str, Any]) -> Dict[str, Any]
avatar_url = get_avatar_field( avatar_url = get_avatar_field(
user_id=row['id'], user_id=row['id'],
realm_id= realm_id, realm_id= realm_id,
@ -81,8 +79,7 @@ def get_raw_user_data(realm_id, client_gravatar):
for row in user_dicts for row in user_dicts
} }
def always_want(msg_type): def always_want(msg_type: str) -> bool:
# type: (str) -> bool
''' '''
This function is used as a helper in This function is used as a helper in
fetch_initial_state_data, when the user passes fetch_initial_state_data, when the user passes
@ -262,8 +259,8 @@ def fetch_initial_state_data(user_profile, event_types, queue_id, client_gravata
return state return state
def remove_message_id_from_unread_mgs(state, message_id): def remove_message_id_from_unread_mgs(state: Dict[str, Dict[str, Any]],
# type: (Dict[str, Dict[str, Any]], int) -> None message_id: int) -> None:
raw_unread = state['raw_unread_msgs'] raw_unread = state['raw_unread_msgs']
for key in ['pm_dict', 'stream_dict', 'huddle_dict']: for key in ['pm_dict', 'stream_dict', 'huddle_dict']:
@ -288,8 +285,11 @@ def apply_events(state, events, user_profile, client_gravatar, include_subscribe
continue continue
apply_event(state, event, user_profile, client_gravatar, include_subscribers) apply_event(state, event, user_profile, client_gravatar, include_subscribers)
def apply_event(state, event, user_profile, client_gravatar, include_subscribers): def apply_event(state: Dict[str, Any],
# type: (Dict[str, Any], Dict[str, Any], UserProfile, bool, bool) -> None event: Dict[str, Any],
user_profile: UserProfile,
client_gravatar: bool,
include_subscribers: bool) -> None:
if event['type'] == "message": if event['type'] == "message":
state['max_message_id'] = max(state['max_message_id'], event['message']['id']) state['max_message_id'] = max(state['max_message_id'], event['message']['id'])
if 'raw_unread_msgs' in state: if 'raw_unread_msgs' in state:
@ -440,8 +440,7 @@ def apply_event(state, event, user_profile, client_gravatar, include_subscribers
event['subscriptions'][i] = copy.deepcopy(event['subscriptions'][i]) event['subscriptions'][i] = copy.deepcopy(event['subscriptions'][i])
del event['subscriptions'][i]['subscribers'] del event['subscriptions'][i]['subscribers']
def name(sub): def name(sub: Dict[str, Any]) -> Text:
# type: (Dict[str, Any]) -> Text
return sub['name'].lower() return sub['name'].lower()
if event['op'] == "add": if event['op'] == "add":

View File

@ -6,24 +6,20 @@ from django.core.exceptions import PermissionDenied
class AbstractEnum(Enum): class AbstractEnum(Enum):
'''An enumeration whose members are used strictly for their names.''' '''An enumeration whose members are used strictly for their names.'''
def __new__(cls): def __new__(cls: Type['AbstractEnum']) -> 'AbstractEnum':
# type: (Type[AbstractEnum]) -> AbstractEnum
obj = object.__new__(cls) obj = object.__new__(cls)
obj._value_ = len(cls.__members__) + 1 obj._value_ = len(cls.__members__) + 1
return obj return obj
# Override all the `Enum` methods that use `_value_`. # Override all the `Enum` methods that use `_value_`.
def __repr__(self): def __repr__(self) -> str:
# type: () -> str
return str(self) return str(self)
def value(self): def value(self) -> None:
# type: () -> None
assert False assert False
def __reduce_ex__(self, proto): def __reduce_ex__(self, proto: int) -> None:
# type: (int) -> None
assert False assert False
class ErrorCode(AbstractEnum): class ErrorCode(AbstractEnum):
@ -69,13 +65,11 @@ class JsonableError(Exception):
code = ErrorCode.NO_SUCH_WIDGET code = ErrorCode.NO_SUCH_WIDGET
data_fields = ['widget_name'] data_fields = ['widget_name']
def __init__(self, widget_name): def __init__(self, widget_name: str) -> None:
# type: (str) -> None
self.widget_name = widget_name # type: str self.widget_name = widget_name # type: str
@staticmethod @staticmethod
def msg_format(): def msg_format() -> str:
# type: () -> str
return _("No such widget: {widget_name}") return _("No such widget: {widget_name}")
raise NoSuchWidgetError(widget_name) raise NoSuchWidgetError(widget_name)
@ -96,8 +90,7 @@ class JsonableError(Exception):
# like 403 or 404. # like 403 or 404.
http_status_code = 400 # type: int http_status_code = 400 # type: int
def __init__(self, msg, code=None): def __init__(self, msg: Text, code: Optional[ErrorCode]=None) -> None:
# type: (Text, Optional[ErrorCode]) -> None
if code is not None: if code is not None:
self.code = code self.code = code
@ -105,8 +98,7 @@ class JsonableError(Exception):
self._msg = msg # type: Text self._msg = msg # type: Text
@staticmethod @staticmethod
def msg_format(): def msg_format() -> Text:
# type: () -> Text
'''Override in subclasses. Gets the items in `data_fields` as format args. '''Override in subclasses. Gets the items in `data_fields` as format args.
This should return (a translation of) a string literal. This should return (a translation of) a string literal.
@ -124,29 +116,24 @@ class JsonableError(Exception):
# #
@property @property
def msg(self): def msg(self) -> Text:
# type: () -> Text
format_data = dict(((f, getattr(self, f)) for f in self.data_fields), format_data = dict(((f, getattr(self, f)) for f in self.data_fields),
_msg=getattr(self, '_msg', None)) _msg=getattr(self, '_msg', None))
return self.msg_format().format(**format_data) return self.msg_format().format(**format_data)
@property @property
def data(self): def data(self) -> Dict[str, Any]:
# type: () -> Dict[str, Any]
return dict(((f, getattr(self, f)) for f in self.data_fields), return dict(((f, getattr(self, f)) for f in self.data_fields),
code=self.code.name) code=self.code.name)
def to_json(self): def to_json(self) -> Dict[str, Any]:
# type: () -> Dict[str, Any]
d = {'result': 'error', 'msg': self.msg} d = {'result': 'error', 'msg': self.msg}
d.update(self.data) d.update(self.data)
return d return d
def __str__(self): def __str__(self) -> str:
# type: () -> str
return self.msg return self.msg
class RateLimited(PermissionDenied): class RateLimited(PermissionDenied):
def __init__(self, msg=""): def __init__(self, msg: str="") -> None:
# type: (str) -> None
super().__init__(msg) super().__init__(msg)

View File

@ -125,8 +125,7 @@ DATE_FIELDS = {
'zerver_userprofile': ['date_joined', 'last_login', 'last_reminder'], 'zerver_userprofile': ['date_joined', 'last_login', 'last_reminder'],
} # type: Dict[TableName, List[Field]] } # type: Dict[TableName, List[Field]]
def sanity_check_output(data): def sanity_check_output(data: TableData) -> None:
# type: (TableData) -> None
tables = set(ALL_ZERVER_TABLES) tables = set(ALL_ZERVER_TABLES)
tables -= set(NON_EXPORTED_TABLES) tables -= set(NON_EXPORTED_TABLES)
tables -= set(IMPLICIT_TABLES) tables -= set(IMPLICIT_TABLES)
@ -137,13 +136,11 @@ def sanity_check_output(data):
if table not in data: if table not in data:
logging.warning('??? NO DATA EXPORTED FOR TABLE %s!!!' % (table,)) logging.warning('??? NO DATA EXPORTED FOR TABLE %s!!!' % (table,))
def write_data_to_file(output_file, data): def write_data_to_file(output_file: Path, data: Any) -> None:
# type: (Path, Any) -> None
with open(output_file, "w") as f: with open(output_file, "w") as f:
f.write(ujson.dumps(data, indent=4)) f.write(ujson.dumps(data, indent=4))
def make_raw(query, exclude=None): def make_raw(query: Any, exclude: List[Field]=None) -> List[Record]:
# type: (Any, List[Field]) -> List[Record]
''' '''
Takes a Django query and returns a JSONable list Takes a Django query and returns a JSONable list
of dictionaries corresponding to the database rows. of dictionaries corresponding to the database rows.
@ -165,8 +162,7 @@ def make_raw(query, exclude=None):
return rows return rows
def floatify_datetime_fields(data, table): def floatify_datetime_fields(data: TableData, table: TableName) -> None:
# type: (TableData, TableName) -> None
for item in data[table]: for item in data[table]:
for field in DATE_FIELDS[table]: for field in DATE_FIELDS[table]:
orig_dt = item[field] orig_dt = item[field]
@ -261,8 +257,8 @@ class Config:
self.virtual_parent.table)) self.virtual_parent.table))
def export_from_config(response, config, seed_object=None, context=None): def export_from_config(response: TableData, config: Config, seed_object: Any=None,
# type: (TableData, Config, Any, Context) -> None context: Context=None) -> None:
table = config.table table = config.table
parent = config.parent parent = config.parent
model = config.model model = config.model
@ -372,8 +368,7 @@ def export_from_config(response, config, seed_object=None, context=None):
context=context, context=context,
) )
def get_realm_config(): def get_realm_config() -> Config:
# type: () -> Config
# This is common, public information about the realm that we can share # This is common, public information about the realm that we can share
# with all realm users. # with all realm users.
@ -536,8 +531,7 @@ def get_realm_config():
return realm_config return realm_config
def sanity_check_stream_data(response, config, context): def sanity_check_stream_data(response: TableData, config: Config, context: Context) -> None:
# type: (TableData, Config, Context) -> None
if context['exportable_user_ids'] is not None: if context['exportable_user_ids'] is not None:
# If we restrict which user ids are exportable, # If we restrict which user ids are exportable,
@ -559,8 +553,7 @@ def sanity_check_stream_data(response, config, context):
Please investigate! Please investigate!
''') ''')
def fetch_user_profile(response, config, context): def fetch_user_profile(response: TableData, config: Config, context: Context) -> None:
# type: (TableData, Config, Context) -> None
realm = context['realm'] realm = context['realm']
exportable_user_ids = context['exportable_user_ids'] exportable_user_ids = context['exportable_user_ids']
@ -589,8 +582,7 @@ def fetch_user_profile(response, config, context):
response['zerver_userprofile'] = normal_rows response['zerver_userprofile'] = normal_rows
response['zerver_userprofile_mirrordummy'] = dummy_rows response['zerver_userprofile_mirrordummy'] = dummy_rows
def fetch_user_profile_cross_realm(response, config, context): def fetch_user_profile_cross_realm(response: TableData, config: Config, context: Context) -> None:
# type: (TableData, Config, Context) -> None
realm = context['realm'] realm = context['realm']
if realm.string_id == "zulip": if realm.string_id == "zulip":
@ -602,8 +594,7 @@ def fetch_user_profile_cross_realm(response, config, context):
get_system_bot(settings.WELCOME_BOT), get_system_bot(settings.WELCOME_BOT),
]] ]]
def fetch_attachment_data(response, realm_id, message_ids): def fetch_attachment_data(response: TableData, realm_id: int, message_ids: Set[int]) -> None:
# type: (TableData, int, Set[int]) -> None
filter_args = {'realm_id': realm_id} filter_args = {'realm_id': realm_id}
query = Attachment.objects.filter(**filter_args) query = Attachment.objects.filter(**filter_args)
response['zerver_attachment'] = make_raw(list(query)) response['zerver_attachment'] = make_raw(list(query))
@ -630,8 +621,7 @@ def fetch_attachment_data(response, realm_id, message_ids):
row for row in response['zerver_attachment'] row for row in response['zerver_attachment']
if row['messages']] if row['messages']]
def fetch_huddle_objects(response, config, context): def fetch_huddle_objects(response: TableData, config: Config, context: Context) -> None:
# type: (TableData, Config, Context) -> None
realm = context['realm'] realm = context['realm']
assert config.parent is not None assert config.parent is not None
@ -667,8 +657,10 @@ def fetch_huddle_objects(response, config, context):
response['_huddle_subscription'] = huddle_subscription_dicts response['_huddle_subscription'] = huddle_subscription_dicts
response['zerver_huddle'] = make_raw(Huddle.objects.filter(id__in=huddle_ids)) response['zerver_huddle'] = make_raw(Huddle.objects.filter(id__in=huddle_ids))
def fetch_usermessages(realm, message_ids, user_profile_ids, message_filename): def fetch_usermessages(realm: Realm,
# type: (Realm, Set[int], Set[int], Path) -> List[Record] message_ids: Set[int],
user_profile_ids: Set[int],
message_filename: Path) -> List[Record]:
# UserMessage export security rule: You can export UserMessages # UserMessage export security rule: You can export UserMessages
# for the messages you exported for the users in your realm. # for the messages you exported for the users in your realm.
user_message_query = UserMessage.objects.filter(user_profile__realm=realm, user_message_query = UserMessage.objects.filter(user_profile__realm=realm,
@ -684,8 +676,7 @@ def fetch_usermessages(realm, message_ids, user_profile_ids, message_filename):
logging.info("Fetched UserMessages for %s" % (message_filename,)) logging.info("Fetched UserMessages for %s" % (message_filename,))
return user_message_chunk return user_message_chunk
def export_usermessages_batch(input_path, output_path): def export_usermessages_batch(input_path: Path, output_path: Path) -> None:
# type: (Path, Path) -> None
"""As part of the system for doing parallel exports, this runs on one """As part of the system for doing parallel exports, this runs on one
batch of Message objects and adds the corresponding UserMessage batch of Message objects and adds the corresponding UserMessage
objects. (This is called by the export_usermessage_batch objects. (This is called by the export_usermessage_batch
@ -701,18 +692,18 @@ def export_usermessages_batch(input_path, output_path):
write_message_export(output_path, output) write_message_export(output_path, output)
os.unlink(input_path) os.unlink(input_path)
def write_message_export(message_filename, output): def write_message_export(message_filename: Path, output: MessageOutput) -> None:
# type: (Path, MessageOutput) -> None
write_data_to_file(output_file=message_filename, data=output) write_data_to_file(output_file=message_filename, data=output)
logging.info("Dumped to %s" % (message_filename,)) logging.info("Dumped to %s" % (message_filename,))
def export_partial_message_files(realm, response, chunk_size=1000, output_dir=None): def export_partial_message_files(realm: Realm,
# type: (Realm, TableData, int, Path) -> Set[int] response: TableData,
chunk_size: int=1000,
output_dir: Path=None) -> Set[int]:
if output_dir is None: if output_dir is None:
output_dir = tempfile.mkdtemp(prefix="zulip-export") output_dir = tempfile.mkdtemp(prefix="zulip-export")
def get_ids(records): def get_ids(records: List[Record]) -> Set[int]:
# type: (List[Record]) -> Set[int]
return set(x['id'] for x in records) return set(x['id'] for x in records)
# Basic security rule: You can export everything either... # Basic security rule: You can export everything either...
@ -824,8 +815,7 @@ def write_message_partial_for_query(realm, message_query, dump_file_id,
return dump_file_id return dump_file_id
def export_uploads_and_avatars(realm, output_dir): def export_uploads_and_avatars(realm: Realm, output_dir: Path) -> None:
# type: (Realm, Path) -> None
uploads_output_dir = os.path.join(output_dir, 'uploads') uploads_output_dir = os.path.join(output_dir, 'uploads')
avatars_output_dir = os.path.join(output_dir, 'avatars') avatars_output_dir = os.path.join(output_dir, 'avatars')
@ -851,8 +841,8 @@ def export_uploads_and_avatars(realm, output_dir):
settings.S3_AUTH_UPLOADS_BUCKET, settings.S3_AUTH_UPLOADS_BUCKET,
output_dir=uploads_output_dir) output_dir=uploads_output_dir)
def export_files_from_s3(realm, bucket_name, output_dir, processing_avatars=False): def export_files_from_s3(realm: Realm, bucket_name: str, output_dir: Path,
# type: (Realm, str, Path, bool) -> None processing_avatars: bool=False) -> None:
conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY) conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY)
bucket = conn.get_bucket(bucket_name, validate=True) bucket = conn.get_bucket(bucket_name, validate=True)
records = [] records = []
@ -932,8 +922,7 @@ def export_files_from_s3(realm, bucket_name, output_dir, processing_avatars=Fals
with open(os.path.join(output_dir, "records.json"), "w") as records_file: with open(os.path.join(output_dir, "records.json"), "w") as records_file:
ujson.dump(records, records_file, indent=4) ujson.dump(records, records_file, indent=4)
def export_uploads_from_local(realm, local_dir, output_dir): def export_uploads_from_local(realm: Realm, local_dir: Path, output_dir: Path) -> None:
# type: (Realm, Path, Path) -> None
count = 0 count = 0
records = [] records = []
@ -960,8 +949,7 @@ def export_uploads_from_local(realm, local_dir, output_dir):
with open(os.path.join(output_dir, "records.json"), "w") as records_file: with open(os.path.join(output_dir, "records.json"), "w") as records_file:
ujson.dump(records, records_file, indent=4) ujson.dump(records, records_file, indent=4)
def export_avatars_from_local(realm, local_dir, output_dir): def export_avatars_from_local(realm: Realm, local_dir: Path, output_dir: Path) -> None:
# type: (Realm, Path, Path) -> None
count = 0 count = 0
records = [] records = []
@ -1005,8 +993,7 @@ def export_avatars_from_local(realm, local_dir, output_dir):
with open(os.path.join(output_dir, "records.json"), "w") as records_file: with open(os.path.join(output_dir, "records.json"), "w") as records_file:
ujson.dump(records, records_file, indent=4) ujson.dump(records, records_file, indent=4)
def do_write_stats_file_for_realm_export(output_dir): def do_write_stats_file_for_realm_export(output_dir: Path) -> None:
# type: (Path) -> None
stats_file = os.path.join(output_dir, 'stats.txt') stats_file = os.path.join(output_dir, 'stats.txt')
realm_file = os.path.join(output_dir, 'realm.json') realm_file = os.path.join(output_dir, 'realm.json')
attachment_file = os.path.join(output_dir, 'attachment.json') attachment_file = os.path.join(output_dir, 'attachment.json')
@ -1033,8 +1020,8 @@ def do_write_stats_file_for_realm_export(output_dir):
f.write('%5d records\n' % len(data)) f.write('%5d records\n' % len(data))
f.write('\n') f.write('\n')
def do_export_realm(realm, output_dir, threads, exportable_user_ids=None): def do_export_realm(realm: Realm, output_dir: Path, threads: int,
# type: (Realm, Path, int, Set[int]) -> None exportable_user_ids: Set[int]=None) -> None:
response = {} # type: TableData response = {} # type: TableData
# We need at least one thread running to export # We need at least one thread running to export
@ -1084,16 +1071,14 @@ def do_export_realm(realm, output_dir, threads, exportable_user_ids=None):
logging.info("Finished exporting %s" % (realm.string_id)) logging.info("Finished exporting %s" % (realm.string_id))
create_soft_link(source=output_dir, in_progress=False) create_soft_link(source=output_dir, in_progress=False)
def export_attachment_table(realm, output_dir, message_ids): def export_attachment_table(realm: Realm, output_dir: Path, message_ids: Set[int]) -> None:
# type: (Realm, Path, Set[int]) -> None
response = {} # type: TableData response = {} # type: TableData
fetch_attachment_data(response=response, realm_id=realm.id, message_ids=message_ids) fetch_attachment_data(response=response, realm_id=realm.id, message_ids=message_ids)
output_file = os.path.join(output_dir, "attachment.json") output_file = os.path.join(output_dir, "attachment.json")
logging.info('Writing attachment table data to %s' % (output_file,)) logging.info('Writing attachment table data to %s' % (output_file,))
write_data_to_file(output_file=output_file, data=response) write_data_to_file(output_file=output_file, data=response)
def create_soft_link(source, in_progress=True): def create_soft_link(source: Path, in_progress: bool=True) -> None:
# type: (Path, bool) -> None
is_done = not in_progress is_done = not in_progress
in_progress_link = '/tmp/zulip-export-in-progress' in_progress_link = '/tmp/zulip-export-in-progress'
done_link = '/tmp/zulip-export-most-recent' done_link = '/tmp/zulip-export-most-recent'
@ -1109,12 +1094,10 @@ def create_soft_link(source, in_progress=True):
logging.info('See %s for output files' % (new_target,)) logging.info('See %s for output files' % (new_target,))
def launch_user_message_subprocesses(threads, output_dir): def launch_user_message_subprocesses(threads: int, output_dir: Path) -> None:
# type: (int, Path) -> None
logging.info('Launching %d PARALLEL subprocesses to export UserMessage rows' % (threads,)) logging.info('Launching %d PARALLEL subprocesses to export UserMessage rows' % (threads,))
def run_job(shard): def run_job(shard: str) -> int:
# type: (str) -> int
subprocess.call(["./manage.py", 'export_usermessage_batch', '--path', subprocess.call(["./manage.py", 'export_usermessage_batch', '--path',
str(output_dir), '--thread', shard]) str(output_dir), '--thread', shard])
return 0 return 0
@ -1124,8 +1107,7 @@ def launch_user_message_subprocesses(threads, output_dir):
threads=threads): threads=threads):
print("Shard %s finished, status %s" % (job, status)) print("Shard %s finished, status %s" % (job, status))
def do_export_user(user_profile, output_dir): def do_export_user(user_profile: UserProfile, output_dir: Path) -> None:
# type: (UserProfile, Path) -> None
response = {} # type: TableData response = {} # type: TableData
export_single_user(user_profile, response) export_single_user(user_profile, response)
@ -1134,8 +1116,7 @@ def do_export_user(user_profile, output_dir):
logging.info("Exporting messages") logging.info("Exporting messages")
export_messages_single_user(user_profile, output_dir) export_messages_single_user(user_profile, output_dir)
def export_single_user(user_profile, response): def export_single_user(user_profile: UserProfile, response: TableData) -> None:
# type: (UserProfile, TableData) -> None
config = get_single_user_config() config = get_single_user_config()
export_from_config( export_from_config(
@ -1144,8 +1125,7 @@ def export_single_user(user_profile, response):
seed_object=user_profile, seed_object=user_profile,
) )
def get_single_user_config(): def get_single_user_config() -> Config:
# type: () -> Config
# zerver_userprofile # zerver_userprofile
user_profile_config = Config( user_profile_config = Config(
@ -1182,8 +1162,7 @@ def get_single_user_config():
return user_profile_config return user_profile_config
def export_messages_single_user(user_profile, output_dir, chunk_size=1000): def export_messages_single_user(user_profile: UserProfile, output_dir: Path, chunk_size: int=1000) -> None:
# type: (UserProfile, Path, int) -> None
user_message_query = UserMessage.objects.filter(user_profile=user_profile).order_by("id") user_message_query = UserMessage.objects.filter(user_profile=user_profile).order_by("id")
min_id = -1 min_id = -1
dump_file_id = 1 dump_file_id = 1
@ -1232,8 +1211,7 @@ id_maps = {
'user_profile': {}, 'user_profile': {},
} # type: Dict[str, Dict[int, int]] } # type: Dict[str, Dict[int, int]]
def update_id_map(table, old_id, new_id): def update_id_map(table: TableName, old_id: int, new_id: int) -> None:
# type: (TableName, int, int) -> None
if table not in id_maps: if table not in id_maps:
raise Exception(''' raise Exception('''
Table %s is not initialized in id_maps, which could Table %s is not initialized in id_maps, which could
@ -1242,15 +1220,13 @@ def update_id_map(table, old_id, new_id):
''' % (table,)) ''' % (table,))
id_maps[table][old_id] = new_id id_maps[table][old_id] = new_id
def fix_datetime_fields(data, table): def fix_datetime_fields(data: TableData, table: TableName) -> None:
# type: (TableData, TableName) -> None
for item in data[table]: for item in data[table]:
for field_name in DATE_FIELDS[table]: for field_name in DATE_FIELDS[table]:
if item[field_name] is not None: if item[field_name] is not None:
item[field_name] = datetime.datetime.fromtimestamp(item[field_name], tz=timezone_utc) item[field_name] = datetime.datetime.fromtimestamp(item[field_name], tz=timezone_utc)
def convert_to_id_fields(data, table, field_name): def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -> None:
# type: (TableData, TableName, Field) -> None
''' '''
When Django gives us dict objects via model_to_dict, the foreign When Django gives us dict objects via model_to_dict, the foreign
key fields are `foo`, but we want `foo_id` for the bulk insert. key fields are `foo`, but we want `foo_id` for the bulk insert.
@ -1262,8 +1238,11 @@ def convert_to_id_fields(data, table, field_name):
item[field_name + "_id"] = item[field_name] item[field_name + "_id"] = item[field_name]
del item[field_name] del item[field_name]
def re_map_foreign_keys(data, table, field_name, related_table, verbose=False): def re_map_foreign_keys(data: TableData,
# type: (TableData, TableName, Field, TableName, bool) -> None table: TableName,
field_name: Field,
related_table: TableName,
verbose: bool=False) -> None:
''' '''
We occasionally need to assign new ids to rows during the We occasionally need to assign new ids to rows during the
import/export process, to accommodate things like existing rows import/export process, to accommodate things like existing rows
@ -1288,14 +1267,12 @@ def re_map_foreign_keys(data, table, field_name, related_table, verbose=False):
item[field_name + "_id"] = new_id item[field_name + "_id"] = new_id
del item[field_name] del item[field_name]
def fix_bitfield_keys(data, table, field_name): def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> None:
# type: (TableData, TableName, Field) -> None
for item in data[table]: for item in data[table]:
item[field_name] = item[field_name + '_mask'] item[field_name] = item[field_name + '_mask']
del item[field_name + '_mask'] del item[field_name + '_mask']
def fix_realm_authentication_bitfield(data, table, field_name): def fix_realm_authentication_bitfield(data: TableData, table: TableName, field_name: Field) -> None:
# type: (TableData, TableName, Field) -> None
"""Used to fixup the authentication_methods bitfield to be a string""" """Used to fixup the authentication_methods bitfield to be a string"""
for item in data[table]: for item in data[table]:
values_as_bitstring = ''.join(['1' if field[1] else '0' for field in values_as_bitstring = ''.join(['1' if field[1] else '0' for field in
@ -1303,8 +1280,7 @@ def fix_realm_authentication_bitfield(data, table, field_name):
values_as_int = int(values_as_bitstring, 2) values_as_int = int(values_as_bitstring, 2)
item[field_name] = values_as_int item[field_name] = values_as_int
def bulk_import_model(data, model, table, dump_file_id=None): def bulk_import_model(data: TableData, model: Any, table: TableName, dump_file_id: str=None) -> None:
# type: (TableData, Any, TableName, str) -> None
# TODO, deprecate dump_file_id # TODO, deprecate dump_file_id
model.objects.bulk_create(model(**item) for item in data[table]) model.objects.bulk_create(model(**item) for item in data[table])
if dump_file_id is None: if dump_file_id is None:
@ -1316,8 +1292,7 @@ def bulk_import_model(data, model, table, dump_file_id=None):
# correctly import multiple realms into the same server, we need to # correctly import multiple realms into the same server, we need to
# check if a Client object already exists, and so we need to support # check if a Client object already exists, and so we need to support
# remap all Client IDs to the values in the new DB. # remap all Client IDs to the values in the new DB.
def bulk_import_client(data, model, table): def bulk_import_client(data: TableData, model: Any, table: TableName) -> None:
# type: (TableData, Any, TableName) -> None
for item in data[table]: for item in data[table]:
try: try:
client = Client.objects.get(name=item['name']) client = Client.objects.get(name=item['name'])
@ -1325,8 +1300,7 @@ def bulk_import_client(data, model, table):
client = Client.objects.create(name=item['name']) client = Client.objects.create(name=item['name'])
update_id_map(table='client', old_id=item['id'], new_id=client.id) update_id_map(table='client', old_id=item['id'], new_id=client.id)
def import_uploads_local(import_dir, processing_avatars=False): def import_uploads_local(import_dir: Path, processing_avatars: bool=False) -> None:
# type: (Path, bool) -> None
records_filename = os.path.join(import_dir, "records.json") records_filename = os.path.join(import_dir, "records.json")
with open(records_filename) as records_file: with open(records_filename) as records_file:
records = ujson.loads(records_file.read()) records = ujson.loads(records_file.read())
@ -1349,8 +1323,7 @@ def import_uploads_local(import_dir, processing_avatars=False):
subprocess.check_call(["mkdir", "-p", os.path.dirname(file_path)]) subprocess.check_call(["mkdir", "-p", os.path.dirname(file_path)])
shutil.copy(orig_file_path, file_path) shutil.copy(orig_file_path, file_path)
def import_uploads_s3(bucket_name, import_dir, processing_avatars=False): def import_uploads_s3(bucket_name: str, import_dir: Path, processing_avatars: bool=False) -> None:
# type: (str, Path, bool) -> None
conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY) conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY)
bucket = conn.get_bucket(bucket_name, validate=True) bucket = conn.get_bucket(bucket_name, validate=True)
@ -1385,8 +1358,7 @@ def import_uploads_s3(bucket_name, import_dir, processing_avatars=False):
key.set_contents_from_filename(os.path.join(import_dir, record['path']), headers=headers) key.set_contents_from_filename(os.path.join(import_dir, record['path']), headers=headers)
def import_uploads(import_dir, processing_avatars=False): def import_uploads(import_dir: Path, processing_avatars: bool=False) -> None:
# type: (Path, bool) -> None
if processing_avatars: if processing_avatars:
logging.info("Importing avatars") logging.info("Importing avatars")
else: else:
@ -1418,8 +1390,7 @@ def import_uploads(import_dir, processing_avatars=False):
# Because the Python object => JSON conversion process is not fully # Because the Python object => JSON conversion process is not fully
# faithful, we have to use a set of fixers (e.g. on DateTime objects # faithful, we have to use a set of fixers (e.g. on DateTime objects
# and Foreign Keys) to do the import correctly. # and Foreign Keys) to do the import correctly.
def do_import_realm(import_dir): def do_import_realm(import_dir: Path) -> None:
# type: (Path) -> None
logging.info("Importing realm dump %s" % (import_dir,)) logging.info("Importing realm dump %s" % (import_dir,))
if not os.path.exists(import_dir): if not os.path.exists(import_dir):
raise Exception("Missing import directory!") raise Exception("Missing import directory!")
@ -1527,8 +1498,7 @@ def do_import_realm(import_dir):
import_attachments(data) import_attachments(data)
def import_message_data(import_dir): def import_message_data(import_dir: Path) -> None:
# type: (Path) -> None
dump_file_id = 1 dump_file_id = 1
while True: while True:
message_filename = os.path.join(import_dir, "messages-%06d.json" % (dump_file_id,)) message_filename = os.path.join(import_dir, "messages-%06d.json" % (dump_file_id,))
@ -1555,8 +1525,7 @@ def import_message_data(import_dir):
dump_file_id += 1 dump_file_id += 1
def import_attachments(data): def import_attachments(data: TableData) -> None:
# type: (TableData) -> None
# Clean up the data in zerver_attachment that is not # Clean up the data in zerver_attachment that is not
# relevant to our many-to-many import. # relevant to our many-to-many import.

View File

@ -12,8 +12,7 @@ from typing import Any, List, Dict, Optional, Text
import os import os
import ujson import ujson
def with_language(string, language): def with_language(string: Text, language: Text) -> Text:
# type: (Text, Text) -> Text
""" """
This is an expensive function. If you are using it in a loop, it will This is an expensive function. If you are using it in a loop, it will
make your code slow. make your code slow.
@ -25,15 +24,13 @@ def with_language(string, language):
return result return result
@lru_cache() @lru_cache()
def get_language_list(): def get_language_list() -> List[Dict[str, Any]]:
# type: () -> List[Dict[str, Any]]
path = os.path.join(settings.STATIC_ROOT, 'locale', 'language_name_map.json') path = os.path.join(settings.STATIC_ROOT, 'locale', 'language_name_map.json')
with open(path, 'r') as reader: with open(path, 'r') as reader:
languages = ujson.load(reader) languages = ujson.load(reader)
return languages['name_map'] return languages['name_map']
def get_language_list_for_templates(default_language): def get_language_list_for_templates(default_language: Text) -> List[Dict[str, Dict[str, str]]]:
# type: (Text) -> List[Dict[str, Dict[str, str]]]
language_list = [l for l in get_language_list() language_list = [l for l in get_language_list()
if 'percent_translated' not in l or if 'percent_translated' not in l or
l['percent_translated'] >= 5.] l['percent_translated'] >= 5.]
@ -70,15 +67,13 @@ def get_language_list_for_templates(default_language):
return formatted_list return formatted_list
def get_language_name(code): def get_language_name(code: str) -> Optional[Text]:
# type: (str) -> Optional[Text]
for lang in get_language_list(): for lang in get_language_list():
if code in (lang['code'], lang['locale']): if code in (lang['code'], lang['locale']):
return lang['name'] return lang['name']
return None return None
def get_available_language_codes(): def get_available_language_codes() -> List[Text]:
# type: () -> List[Text]
language_list = get_language_list() language_list = get_language_list()
codes = [language['code'] for language in language_list] codes = [language['code'] for language in language_list]
return codes return codes

View File

@ -17,8 +17,7 @@ from logging import Logger
class _RateLimitFilter: class _RateLimitFilter:
last_error = datetime.min.replace(tzinfo=timezone_utc) last_error = datetime.min.replace(tzinfo=timezone_utc)
def filter(self, record): def filter(self, record: logging.LogRecord) -> bool:
# type: (logging.LogRecord) -> bool
from django.conf import settings from django.conf import settings
from django.core.cache import cache from django.core.cache import cache
@ -58,23 +57,19 @@ class EmailLimiter(_RateLimitFilter):
pass pass
class ReturnTrue(logging.Filter): class ReturnTrue(logging.Filter):
def filter(self, record): def filter(self, record: logging.LogRecord) -> bool:
# type: (logging.LogRecord) -> bool
return True return True
class ReturnEnabled(logging.Filter): class ReturnEnabled(logging.Filter):
def filter(self, record): def filter(self, record: logging.LogRecord) -> bool:
# type: (logging.LogRecord) -> bool
return settings.LOGGING_NOT_DISABLED return settings.LOGGING_NOT_DISABLED
class RequireReallyDeployed(logging.Filter): class RequireReallyDeployed(logging.Filter):
def filter(self, record): def filter(self, record: logging.LogRecord) -> bool:
# type: (logging.LogRecord) -> bool
from django.conf import settings from django.conf import settings
return settings.PRODUCTION return settings.PRODUCTION
def skip_200_and_304(record): def skip_200_and_304(record: logging.LogRecord) -> bool:
# type: (logging.LogRecord) -> bool
# Apparently, `status_code` is added by Django and is not an actual # Apparently, `status_code` is added by Django and is not an actual
# attribute of LogRecord; as a result, mypy throws an error if we # attribute of LogRecord; as a result, mypy throws an error if we
# access the `status_code` attribute directly. # access the `status_code` attribute directly.
@ -91,8 +86,7 @@ IGNORABLE_404_URLS = [
re.compile(r'^/wp-login.php$'), re.compile(r'^/wp-login.php$'),
] ]
def skip_boring_404s(record): def skip_boring_404s(record: logging.LogRecord) -> bool:
# type: (logging.LogRecord) -> bool
"""Prevents Django's 'Not Found' warnings from being logged for common """Prevents Django's 'Not Found' warnings from being logged for common
404 errors that don't reflect a problem in Zulip. The overall 404 errors that don't reflect a problem in Zulip. The overall
result is to keep the Zulip error logs cleaner than they would result is to keep the Zulip error logs cleaner than they would
@ -116,8 +110,7 @@ def skip_boring_404s(record):
return False return False
return True return True
def skip_site_packages_logs(record): def skip_site_packages_logs(record: logging.LogRecord) -> bool:
# type: (logging.LogRecord) -> bool
# This skips the log records that are generated from libraries # This skips the log records that are generated from libraries
# installed in site packages. # installed in site packages.
# Workaround for https://code.djangoproject.com/ticket/26886 # Workaround for https://code.djangoproject.com/ticket/26886
@ -125,8 +118,7 @@ def skip_site_packages_logs(record):
return False return False
return True return True
def find_log_caller_module(record): def find_log_caller_module(record: logging.LogRecord) -> Optional[str]:
# type: (logging.LogRecord) -> Optional[str]
'''Find the module name corresponding to where this record was logged.''' '''Find the module name corresponding to where this record was logged.'''
# Repeat a search similar to that in logging.Logger.findCaller. # Repeat a search similar to that in logging.Logger.findCaller.
# The logging call should still be on the stack somewhere; search until # The logging call should still be on the stack somewhere; search until
@ -144,8 +136,7 @@ logger_nicknames = {
'zulip.requests': 'zr', # Super common. 'zulip.requests': 'zr', # Super common.
} }
def find_log_origin(record): def find_log_origin(record: logging.LogRecord) -> str:
# type: (logging.LogRecord) -> str
logger_name = logger_nicknames.get(record.name, record.name) logger_name = logger_nicknames.get(record.name, record.name)
if settings.LOGGING_SHOW_MODULE: if settings.LOGGING_SHOW_MODULE:
@ -166,8 +157,7 @@ log_level_abbrevs = {
'CRITICAL': 'CRIT', 'CRITICAL': 'CRIT',
} }
def abbrev_log_levelname(levelname): def abbrev_log_levelname(levelname: str) -> str:
# type: (str) -> str
# It's unlikely someone will set a custom log level with a custom name, # It's unlikely someone will set a custom log level with a custom name,
# but it's an option, so we shouldn't crash if someone does. # but it's an option, so we shouldn't crash if someone does.
return log_level_abbrevs.get(levelname, levelname[:4]) return log_level_abbrevs.get(levelname, levelname[:4])
@ -176,20 +166,17 @@ class ZulipFormatter(logging.Formatter):
# Used in the base implementation. Default uses `,`. # Used in the base implementation. Default uses `,`.
default_msec_format = '%s.%03d' default_msec_format = '%s.%03d'
def __init__(self): def __init__(self) -> None:
# type: () -> None
super().__init__(fmt=self._compute_fmt()) super().__init__(fmt=self._compute_fmt())
def _compute_fmt(self): def _compute_fmt(self) -> str:
# type: () -> str
pieces = ['%(asctime)s', '%(zulip_level_abbrev)-4s'] pieces = ['%(asctime)s', '%(zulip_level_abbrev)-4s']
if settings.LOGGING_SHOW_PID: if settings.LOGGING_SHOW_PID:
pieces.append('pid:%(process)d') pieces.append('pid:%(process)d')
pieces.extend(['[%(zulip_origin)s]', '%(message)s']) pieces.extend(['[%(zulip_origin)s]', '%(message)s'])
return ' '.join(pieces) return ' '.join(pieces)
def format(self, record): def format(self, record: logging.LogRecord) -> str:
# type: (logging.LogRecord) -> str
if not getattr(record, 'zulip_decorated', False): if not getattr(record, 'zulip_decorated', False):
# The `setattr` calls put this logic explicitly outside the bounds of the # The `setattr` calls put this logic explicitly outside the bounds of the
# type system; otherwise mypy would complain LogRecord lacks these attributes. # type system; otherwise mypy would complain LogRecord lacks these attributes.
@ -198,8 +185,10 @@ class ZulipFormatter(logging.Formatter):
setattr(record, 'zulip_decorated', True) setattr(record, 'zulip_decorated', True)
return super().format(record) return super().format(record)
def create_logger(name, log_file, log_level, log_format="%(asctime)s %(levelname)-8s %(message)s"): def create_logger(name: str,
# type: (str, str, str, str) -> Logger log_file: str,
log_level: str,
log_format: str="%(asctime)s%(levelname)-8s%(message)s") -> Logger:
"""Creates a named logger for use in logging content to a certain """Creates a named logger for use in logging content to a certain
file. A few notes: file. A few notes:

View File

@ -10,12 +10,10 @@ user_group_mentions = r'(?<![^\s\'\"\(,:<])@(\*[^\*]+\*)'
wildcards = ['all', 'everyone'] wildcards = ['all', 'everyone']
def user_mention_matches_wildcard(mention): def user_mention_matches_wildcard(mention: Text) -> bool:
# type: (Text) -> bool
return mention in wildcards return mention in wildcards
def extract_name(s): def extract_name(s: Text) -> Optional[Text]:
# type: (Text) -> Optional[Text]
if s.startswith("**") and s.endswith("**"): if s.startswith("**") and s.endswith("**"):
name = s[2:-2] name = s[2:-2]
if name in wildcards: if name in wildcards:
@ -25,18 +23,15 @@ def extract_name(s):
# We don't care about @all or @everyone # We don't care about @all or @everyone
return None return None
def possible_mentions(content): def possible_mentions(content: Text) -> Set[Text]:
# type: (Text) -> Set[Text]
matches = re.findall(find_mentions, content) matches = re.findall(find_mentions, content)
names_with_none = (extract_name(match) for match in matches) names_with_none = (extract_name(match) for match in matches)
names = {name for name in names_with_none if name} names = {name for name in names_with_none if name}
return names return names
def extract_user_group(matched_text): def extract_user_group(matched_text: Text) -> Text:
# type: (Text) -> Text
return matched_text[1:-1] return matched_text[1:-1]
def possible_user_group_mentions(content): def possible_user_group_mentions(content: Text) -> Set[Text]:
# type: (Text) -> Set[Text]
matches = re.findall(user_group_mentions, content) matches = re.findall(user_group_mentions, content)
return {extract_user_group(match) for match in matches} return {extract_user_group(match) for match in matches}

View File

@ -97,9 +97,8 @@ def messages_for_ids(message_ids: List[int],
return message_list return message_list
def sew_messages_and_reactions(messages: List[Dict[str, Any]],
def sew_messages_and_reactions(messages, reactions): reactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
# type: (List[Dict[str, Any]], List[Dict[str, Any]]) -> List[Dict[str, Any]]
"""Given a iterable of messages and reactions stitch reactions """Given a iterable of messages and reactions stitch reactions
into messages. into messages.
""" """
@ -117,23 +116,19 @@ def sew_messages_and_reactions(messages, reactions):
return list(converted_messages.values()) return list(converted_messages.values())
def extract_message_dict(message_bytes): def extract_message_dict(message_bytes: bytes) -> Dict[str, Any]:
# type: (bytes) -> Dict[str, Any]
return ujson.loads(zlib.decompress(message_bytes).decode("utf-8")) return ujson.loads(zlib.decompress(message_bytes).decode("utf-8"))
def stringify_message_dict(message_dict): def stringify_message_dict(message_dict: Dict[str, Any]) -> bytes:
# type: (Dict[str, Any]) -> bytes
return zlib.compress(ujson.dumps(message_dict).encode()) return zlib.compress(ujson.dumps(message_dict).encode())
@cache_with_key(to_dict_cache_key, timeout=3600*24) @cache_with_key(to_dict_cache_key, timeout=3600*24)
def message_to_dict_json(message): def message_to_dict_json(message: Message) -> bytes:
# type: (Message) -> bytes
return MessageDict.to_dict_uncached(message) return MessageDict.to_dict_uncached(message)
class MessageDict: class MessageDict:
@staticmethod @staticmethod
def wide_dict(message): def wide_dict(message: Message) -> Dict[str, Any]:
# type: (Message) -> Dict[str, Any]
''' '''
The next two lines get the cachable field related The next two lines get the cachable field related
to our message object, with the side effect of to our message object, with the side effect of
@ -154,8 +149,7 @@ class MessageDict:
return obj return obj
@staticmethod @staticmethod
def post_process_dicts(objs, apply_markdown, client_gravatar): def post_process_dicts(objs: List[Dict[str, Any]], apply_markdown: bool, client_gravatar: bool) -> None:
# type: (List[Dict[str, Any]], bool, bool) -> None
MessageDict.bulk_hydrate_sender_info(objs) MessageDict.bulk_hydrate_sender_info(objs)
for obj in objs: for obj in objs:
@ -163,10 +157,10 @@ class MessageDict:
MessageDict.finalize_payload(obj, apply_markdown, client_gravatar) MessageDict.finalize_payload(obj, apply_markdown, client_gravatar)
@staticmethod @staticmethod
def finalize_payload(obj, apply_markdown, client_gravatar): def finalize_payload(obj: Dict[str, Any],
# type: (Dict[str, Any], bool, bool) -> None apply_markdown: bool,
client_gravatar: bool) -> None:
MessageDict.set_sender_avatar(obj, client_gravatar) MessageDict.set_sender_avatar(obj, client_gravatar)
if apply_markdown: if apply_markdown:
obj['content_type'] = 'text/html' obj['content_type'] = 'text/html'
obj['content'] = obj['rendered_content'] obj['content'] = obj['rendered_content']
@ -184,14 +178,12 @@ class MessageDict:
del obj['sender_is_mirror_dummy'] del obj['sender_is_mirror_dummy']
@staticmethod @staticmethod
def to_dict_uncached(message): def to_dict_uncached(message: Message) -> bytes:
# type: (Message) -> bytes
dct = MessageDict.to_dict_uncached_helper(message) dct = MessageDict.to_dict_uncached_helper(message)
return stringify_message_dict(dct) return stringify_message_dict(dct)
@staticmethod @staticmethod
def to_dict_uncached_helper(message): def to_dict_uncached_helper(message: Message) -> Dict[str, Any]:
# type: (Message) -> Dict[str, Any]
return MessageDict.build_message_dict( return MessageDict.build_message_dict(
message = message, message = message,
message_id = message.id, message_id = message.id,
@ -212,8 +204,7 @@ class MessageDict:
) )
@staticmethod @staticmethod
def get_raw_db_rows(needed_ids): def get_raw_db_rows(needed_ids: List[int]) -> List[Dict[str, Any]]:
# type: (List[int]) -> List[Dict[str, Any]]
# This is a special purpose function optimized for # This is a special purpose function optimized for
# callers like get_messages_backend(). # callers like get_messages_backend().
fields = [ fields = [
@ -242,8 +233,7 @@ class MessageDict:
return sew_messages_and_reactions(messages, reactions) return sew_messages_and_reactions(messages, reactions)
@staticmethod @staticmethod
def build_dict_from_raw_db_row(row): def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]:
# type: (Dict[str, Any]) -> Dict[str, Any]
''' '''
row is a row from a .values() call, and it needs to have row is a row from a .values() call, and it needs to have
all the relevant fields populated all the relevant fields populated
@ -352,8 +342,7 @@ class MessageDict:
return obj return obj
@staticmethod @staticmethod
def bulk_hydrate_sender_info(objs): def bulk_hydrate_sender_info(objs: List[Dict[str, Any]]) -> None:
# type: (List[Dict[str, Any]]) -> None
sender_ids = list({ sender_ids = list({
obj['sender_id'] obj['sender_id']
@ -393,8 +382,7 @@ class MessageDict:
obj['sender_is_mirror_dummy'] = user_row['is_mirror_dummy'] obj['sender_is_mirror_dummy'] = user_row['is_mirror_dummy']
@staticmethod @staticmethod
def hydrate_recipient_info(obj): def hydrate_recipient_info(obj: Dict[str, Any]) -> None:
# type: (Dict[str, Any]) -> None
''' '''
This method hyrdrates recipient info with things This method hyrdrates recipient info with things
like full names and emails of senders. Eventually like full names and emails of senders. Eventually
@ -437,8 +425,7 @@ class MessageDict:
obj['stream_id'] = recipient_type_id obj['stream_id'] = recipient_type_id
@staticmethod @staticmethod
def set_sender_avatar(obj, client_gravatar): def set_sender_avatar(obj: Dict[str, Any], client_gravatar: bool) -> None:
# type: (Dict[str, Any], bool) -> None
sender_id = obj['sender_id'] sender_id = obj['sender_id']
sender_realm_id = obj['sender_realm_id'] sender_realm_id = obj['sender_realm_id']
sender_email = obj['sender_email'] sender_email = obj['sender_email']
@ -457,8 +444,7 @@ class MessageDict:
class ReactionDict: class ReactionDict:
@staticmethod @staticmethod
def build_dict_from_raw_db_row(row): def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]:
# type: (Dict[str, Any]) -> Dict[str, Any]
return {'emoji_name': row['emoji_name'], return {'emoji_name': row['emoji_name'],
'emoji_code': row['emoji_code'], 'emoji_code': row['emoji_code'],
'reaction_type': row['reaction_type'], 'reaction_type': row['reaction_type'],
@ -467,8 +453,7 @@ class ReactionDict:
'full_name': row['user_profile__full_name']}} 'full_name': row['user_profile__full_name']}}
def access_message(user_profile, message_id): def access_message(user_profile: UserProfile, message_id: int) -> Tuple[Message, UserMessage]:
# type: (UserProfile, int) -> Tuple[Message, UserMessage]
"""You can access a message by ID in our APIs that either: """You can access a message by ID in our APIs that either:
(1) You received or have previously accessed via starring (1) You received or have previously accessed via starring
(aka have a UserMessage row for). (aka have a UserMessage row for).
@ -506,9 +491,13 @@ def access_message(user_profile, message_id):
# stream in your realm, so return the message, user_message pair # stream in your realm, so return the message, user_message pair
return (message, user_message) return (message, user_message)
def render_markdown(message, content, realm=None, realm_alert_words=None, user_ids=None, def render_markdown(message: Message,
mention_data=None, email_gateway=False): content: Text,
# type: (Message, Text, Optional[Realm], Optional[RealmAlertWords], Optional[Set[int]], Optional[bugdown.MentionData], Optional[bool]) -> Text realm: Optional[Realm]=None,
realm_alert_words: Optional[RealmAlertWords]=None,
user_ids: Optional[Set[int]]=None,
mention_data: Optional[bugdown.MentionData]=None,
email_gateway: Optional[bool]=False) -> Text:
"""Return HTML for given markdown. Bugdown may add properties to the """Return HTML for given markdown. Bugdown may add properties to the
message object such as `mentions_user_ids`, `mentions_user_group_ids`, and message object such as `mentions_user_ids`, `mentions_user_group_ids`, and
`mentions_wildcard`. These are only on this Django object and are not `mentions_wildcard`. These are only on this Django object and are not
@ -565,8 +554,7 @@ def render_markdown(message, content, realm=None, realm_alert_words=None, user_i
return rendered_content return rendered_content
def huddle_users(recipient_id): def huddle_users(recipient_id: int) -> str:
# type: (int) -> str
display_recipient = get_display_recipient_by_id(recipient_id, display_recipient = get_display_recipient_by_id(recipient_id,
Recipient.HUDDLE, Recipient.HUDDLE,
None) # type: Union[Text, List[Dict[str, Any]]] None) # type: Union[Text, List[Dict[str, Any]]]
@ -578,8 +566,9 @@ def huddle_users(recipient_id):
user_ids = sorted(user_ids) user_ids = sorted(user_ids)
return ','.join(str(uid) for uid in user_ids) return ','.join(str(uid) for uid in user_ids)
def aggregate_message_dict(input_dict, lookup_fields, collect_senders): def aggregate_message_dict(input_dict: Dict[int, Dict[str, Any]],
# type: (Dict[int, Dict[str, Any]], List[str], bool) -> List[Dict[str, Any]] lookup_fields: List[str],
collect_senders: bool) -> List[Dict[str, Any]]:
lookup_dict = dict() # type: Dict[Tuple[Any, ...], Dict[str, Any]] lookup_dict = dict() # type: Dict[Tuple[Any, ...], Dict[str, Any]]
''' '''
@ -639,8 +628,7 @@ def aggregate_message_dict(input_dict, lookup_fields, collect_senders):
return [lookup_dict[k] for k in sorted_keys] return [lookup_dict[k] for k in sorted_keys]
def get_inactive_recipient_ids(user_profile): def get_inactive_recipient_ids(user_profile: UserProfile) -> List[int]:
# type: (UserProfile) -> List[int]
rows = get_stream_subscriptions_for_user(user_profile).filter( rows = get_stream_subscriptions_for_user(user_profile).filter(
active=False, active=False,
).values( ).values(
@ -651,8 +639,7 @@ def get_inactive_recipient_ids(user_profile):
for row in rows] for row in rows]
return inactive_recipient_ids return inactive_recipient_ids
def get_muted_stream_ids(user_profile): def get_muted_stream_ids(user_profile: UserProfile) -> List[int]:
# type: (UserProfile) -> List[int]
rows = get_stream_subscriptions_for_user(user_profile).filter( rows = get_stream_subscriptions_for_user(user_profile).filter(
active=True, active=True,
in_home_view=False, in_home_view=False,
@ -664,8 +651,7 @@ def get_muted_stream_ids(user_profile):
for row in rows] for row in rows]
return muted_stream_ids return muted_stream_ids
def get_raw_unread_data(user_profile): def get_raw_unread_data(user_profile: UserProfile) -> RawUnreadMessagesResult:
# type: (UserProfile) -> RawUnreadMessagesResult
excluded_recipient_ids = get_inactive_recipient_ids(user_profile) excluded_recipient_ids = get_inactive_recipient_ids(user_profile)
@ -694,8 +680,7 @@ def get_raw_unread_data(user_profile):
topic_mute_checker = build_topic_mute_checker(user_profile) topic_mute_checker = build_topic_mute_checker(user_profile)
def is_row_muted(stream_id, recipient_id, topic): def is_row_muted(stream_id: int, recipient_id: int, topic: Text) -> bool:
# type: (int, int, Text) -> bool
if stream_id in muted_stream_ids: if stream_id in muted_stream_ids:
return True return True
@ -706,8 +691,7 @@ def get_raw_unread_data(user_profile):
huddle_cache = {} # type: Dict[int, str] huddle_cache = {} # type: Dict[int, str]
def get_huddle_users(recipient_id): def get_huddle_users(recipient_id: int) -> str:
# type: (int) -> str
if recipient_id in huddle_cache: if recipient_id in huddle_cache:
return huddle_cache[recipient_id] return huddle_cache[recipient_id]
@ -762,8 +746,7 @@ def get_raw_unread_data(user_profile):
mentions=mentions, mentions=mentions,
) )
def aggregate_unread_data(raw_data): def aggregate_unread_data(raw_data: RawUnreadMessagesResult) -> UnreadMessagesResult:
# type: (RawUnreadMessagesResult) -> UnreadMessagesResult
pm_dict = raw_data['pm_dict'] pm_dict = raw_data['pm_dict']
stream_dict = raw_data['stream_dict'] stream_dict = raw_data['stream_dict']
@ -807,8 +790,10 @@ def aggregate_unread_data(raw_data):
return result return result
def apply_unread_message_event(user_profile, state, message, flags): def apply_unread_message_event(user_profile: UserProfile,
# type: (UserProfile, Dict[str, Any], Dict[str, Any], List[str]) -> None state: Dict[str, Any],
message: Dict[str, Any],
flags: List[str]) -> None:
message_id = message['id'] message_id = message['id']
if message['type'] == 'stream': if message['type'] == 'stream':
message_type = 'stream' message_type = 'stream'

View File

@ -3,8 +3,7 @@ from django.db.models.query import QuerySet
import re import re
import time import time
def timed_ddl(db, stmt): def timed_ddl(db: Any, stmt: str) -> None:
# type: (Any, str) -> None
print() print()
print(time.asctime()) print(time.asctime())
print(stmt) print(stmt)
@ -13,14 +12,17 @@ def timed_ddl(db, stmt):
delay = time.time() - t delay = time.time() - t
print('Took %.2fs' % (delay,)) print('Took %.2fs' % (delay,))
def validate(sql_thingy): def validate(sql_thingy: str) -> None:
# type: (str) -> None
# Do basic validation that table/col name is safe. # Do basic validation that table/col name is safe.
if not re.match('^[a-z][a-z\d_]+$', sql_thingy): if not re.match('^[a-z][a-z\d_]+$', sql_thingy):
raise Exception('Invalid SQL object: %s' % (sql_thingy,)) raise Exception('Invalid SQL object: %s' % (sql_thingy,))
def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1): def do_batch_update(db: Any,
# type: (Any, str, List[str], List[str], int, float) -> None table: str,
cols: List[str],
vals: List[str],
batch_size: int=10000,
sleep: float=0.1) -> None:
validate(table) validate(table)
for col in cols: for col in cols:
validate(col) validate(col)
@ -46,8 +48,7 @@ def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1):
min_id = upper min_id = upper
time.sleep(sleep) time.sleep(sleep)
def add_bool_columns(db, table, cols): def add_bool_columns(db: Any, table: str, cols: List[str]) -> None:
# type: (Any, str, List[str]) -> None
validate(table) validate(table)
for col in cols: for col in cols:
validate(col) validate(col)
@ -72,8 +73,8 @@ def add_bool_columns(db, table, cols):
', '.join(['ALTER %s SET NOT NULL' % (col,) for col in cols])) ', '.join(['ALTER %s SET NOT NULL' % (col,) for col in cols]))
timed_ddl(db, stmt) timed_ddl(db, stmt)
def create_index_if_not_exist(index_name, table_name, column_string, where_clause): def create_index_if_not_exist(index_name: Text, table_name: Text, column_string: Text,
# type: (Text, Text, Text, Text) -> Text where_clause: Text) -> Text:
# #
# FUTURE TODO: When we no longer need to support postgres 9.3 for Trusty, # FUTURE TODO: When we no longer need to support postgres 9.3 for Trusty,
# we can use "IF NOT EXISTS", which is part of postgres 9.5 # we can use "IF NOT EXISTS", which is part of postgres 9.5
@ -95,8 +96,11 @@ def create_index_if_not_exist(index_name, table_name, column_string, where_claus
''' % (index_name, index_name, table_name, column_string, where_clause) ''' % (index_name, index_name, table_name, column_string, where_clause)
return stmt return stmt
def act_on_message_ranges(db, orm, tasks, batch_size=5000, sleep=0.5): def act_on_message_ranges(db: Any,
# type: (Any, Dict[str, Any], List[Tuple[Callable[[QuerySet], QuerySet], Callable[[QuerySet], None]]], int , float) -> None orm: Dict[str, Any],
tasks: List[Tuple[Callable[[QuerySet], QuerySet], Callable[[QuerySet], None]]],
batch_size: int=5000,
sleep: float=0.5) -> None:
# tasks should be an array of (filterer, action) tuples # tasks should be an array of (filterer, action) tuples
# where filterer is a function that returns a filtered QuerySet # where filterer is a function that returns a filtered QuerySet
# and action is a function that acts on a QuerySet # and action is a function that acts on a QuerySet

View File

@ -4,21 +4,18 @@ from django.utils.translation import ugettext as _
from typing import Any, Callable, Iterable, Mapping, Sequence, Text from typing import Any, Callable, Iterable, Mapping, Sequence, Text
def check_supported_events_narrow_filter(narrow): def check_supported_events_narrow_filter(narrow: Iterable[Sequence[Text]]) -> None:
# type: (Iterable[Sequence[Text]]) -> None
for element in narrow: for element in narrow:
operator = element[0] operator = element[0]
if operator not in ["stream", "topic", "sender", "is"]: if operator not in ["stream", "topic", "sender", "is"]:
raise JsonableError(_("Operator %s not supported.") % (operator,)) raise JsonableError(_("Operator %s not supported.") % (operator,))
def build_narrow_filter(narrow): def build_narrow_filter(narrow: Iterable[Sequence[Text]]) -> Callable[[Mapping[str, Any]], bool]:
# type: (Iterable[Sequence[Text]]) -> Callable[[Mapping[str, Any]], bool]
"""Changes to this function should come with corresponding changes to """Changes to this function should come with corresponding changes to
BuildNarrowFilterTest.""" BuildNarrowFilterTest."""
check_supported_events_narrow_filter(narrow) check_supported_events_narrow_filter(narrow)
def narrow_filter(event): def narrow_filter(event: Mapping[str, Any]) -> bool:
# type: (Mapping[str, Any]) -> bool
message = event["message"] message = event["message"]
flags = event["flags"] flags = event["flags"]
for element in narrow: for element in narrow:

View File

@ -21,8 +21,7 @@ from zerver.decorator import JsonableError
class OutgoingWebhookServiceInterface: class OutgoingWebhookServiceInterface:
def __init__(self, base_url, token, user_profile, service_name): def __init__(self, base_url: Text, token: Text, user_profile: UserProfile, service_name: Text) -> None:
# type: (Text, Text, UserProfile, Text) -> None
self.base_url = base_url # type: Text self.base_url = base_url # type: Text
self.token = token # type: Text self.token = token # type: Text
self.user_profile = user_profile # type: Text self.user_profile = user_profile # type: Text
@ -37,20 +36,17 @@ class OutgoingWebhookServiceInterface:
# - base_url # - base_url
# - relative_url_path # - relative_url_path
# - request_kwargs # - request_kwargs
def process_event(self, event): def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]:
# type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]
raise NotImplementedError() raise NotImplementedError()
# Given a successful outgoing webhook REST operation, returns the message # Given a successful outgoing webhook REST operation, returns the message
# to sent back to the user (or None if no message should be sent). # to sent back to the user (or None if no message should be sent).
def process_success(self, response, event): def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]:
# type: (Response, Dict[Text, Any]) -> Optional[str]
raise NotImplementedError() raise NotImplementedError()
class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface): class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface):
def process_event(self, event): def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]:
# type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]
rest_operation = {'method': 'POST', rest_operation = {'method': 'POST',
'relative_url_path': '', 'relative_url_path': '',
'base_url': self.base_url, 'base_url': self.base_url,
@ -60,8 +56,7 @@ class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface):
"token": self.token} "token": self.token}
return rest_operation, json.dumps(request_data) return rest_operation, json.dumps(request_data)
def process_success(self, response, event): def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]:
# type: (Response, Dict[Text, Any]) -> Optional[str]
response_json = json.loads(response.text) response_json = json.loads(response.text)
if "response_not_required" in response_json and response_json['response_not_required']: if "response_not_required" in response_json and response_json['response_not_required']:
@ -73,8 +68,7 @@ class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface):
class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface): class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface):
def process_event(self, event): def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]:
# type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]
rest_operation = {'method': 'POST', rest_operation = {'method': 'POST',
'relative_url_path': '', 'relative_url_path': '',
'base_url': self.base_url, 'base_url': self.base_url,
@ -99,8 +93,7 @@ class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface):
return rest_operation, request_data return rest_operation, request_data
def process_success(self, response, event): def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]:
# type: (Response, Dict[Text, Any]) -> Optional[str]
response_json = json.loads(response.text) response_json = json.loads(response.text)
if "text" in response_json: if "text" in response_json:
return response_json["text"] return response_json["text"]
@ -112,15 +105,13 @@ AVAILABLE_OUTGOING_WEBHOOK_INTERFACES = {
SLACK_INTERFACE: SlackOutgoingWebhookService, SLACK_INTERFACE: SlackOutgoingWebhookService,
} # type: Dict[Text, Any] } # type: Dict[Text, Any]
def get_service_interface_class(interface): def get_service_interface_class(interface: Text) -> Any:
# type: (Text) -> Any
if interface is None or interface not in AVAILABLE_OUTGOING_WEBHOOK_INTERFACES: if interface is None or interface not in AVAILABLE_OUTGOING_WEBHOOK_INTERFACES:
return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[GENERIC_INTERFACE] return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[GENERIC_INTERFACE]
else: else:
return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[interface] return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[interface]
def get_outgoing_webhook_service_handler(service): def get_outgoing_webhook_service_handler(service: Service) -> Any:
# type: (Service) -> Any
service_interface_class = get_service_interface_class(service.interface_name()) service_interface_class = get_service_interface_class(service.interface_name())
service_interface = service_interface_class(base_url=service.base_url, service_interface = service_interface_class(base_url=service.base_url,
@ -129,8 +120,7 @@ def get_outgoing_webhook_service_handler(service):
service_name=service.name) service_name=service.name)
return service_interface return service_interface
def send_response_message(bot_id, message, response_message_content): def send_response_message(bot_id: str, message: Dict[str, Any], response_message_content: Text) -> None:
# type: (str, Dict[str, Any], Text) -> None
recipient_type_name = message['type'] recipient_type_name = message['type']
bot_user = get_user_profile_by_id(bot_id) bot_user = get_user_profile_by_id(bot_id)
realm = bot_user.realm realm = bot_user.realm
@ -146,18 +136,15 @@ def send_response_message(bot_id, message, response_message_content):
else: else:
raise JsonableError(_("Invalid message type")) raise JsonableError(_("Invalid message type"))
def succeed_with_message(event, success_message): def succeed_with_message(event: Dict[str, Any], success_message: Text) -> None:
# type: (Dict[str, Any], Text) -> None
success_message = "Success! " + success_message success_message = "Success! " + success_message
send_response_message(event['user_profile_id'], event['message'], success_message) send_response_message(event['user_profile_id'], event['message'], success_message)
def fail_with_message(event, failure_message): def fail_with_message(event: Dict[str, Any], failure_message: Text) -> None:
# type: (Dict[str, Any], Text) -> None
failure_message = "Failure! " + failure_message failure_message = "Failure! " + failure_message
send_response_message(event['user_profile_id'], event['message'], failure_message) send_response_message(event['user_profile_id'], event['message'], failure_message)
def get_message_url(event, request_data): def get_message_url(event: Dict[str, Any], request_data: Dict[str, Any]) -> Text:
# type: (Dict[str, Any], Dict[str, Any]) -> Text
bot_user = get_user_profile_by_id(event['user_profile_id']) bot_user = get_user_profile_by_id(event['user_profile_id'])
message = event['message'] message = event['message']
if message['type'] == 'stream': if message['type'] == 'stream':
@ -175,8 +162,11 @@ def get_message_url(event, request_data):
'id': str(message['id'])}) 'id': str(message['id'])})
return message_url return message_url
def notify_bot_owner(event, request_data, status_code=None, response_content=None, exception=None): def notify_bot_owner(event: Dict[str, Any],
# type: (Dict[str, Any], Dict[str, Any], Optional[int], Optional[AnyStr], Optional[Exception]) -> None request_data: Dict[str, Any],
status_code: Optional[int]=None,
response_content: Optional[AnyStr]=None,
exception: Optional[Exception]=None) -> None:
message_url = get_message_url(event, request_data) message_url = get_message_url(event, request_data)
bot_id = event['user_profile_id'] bot_id = event['user_profile_id']
bot_owner = get_user_profile_by_id(bot_id).bot_owner bot_owner = get_user_profile_by_id(bot_id).bot_owner
@ -194,10 +184,11 @@ def notify_bot_owner(event, request_data, status_code=None, response_content=Non
type(exception).__name__, str(exception)) type(exception).__name__, str(exception))
send_response_message(bot_id, message_info, notification_message) send_response_message(bot_id, message_info, notification_message)
def request_retry(event, request_data, failure_message, exception=None): def request_retry(event: Dict[str, Any],
# type: (Dict[str, Any], Dict[str, Any], Text, Optional[Exception]) -> None request_data: Dict[str, Any],
def failure_processor(event): failure_message: Text,
# type: (Dict[str, Any]) -> None exception: Optional[Exception]=None) -> None:
def failure_processor(event: Dict[str, Any]) -> None:
""" """
The name of the argument is 'event' on purpose. This argument will hide The name of the argument is 'event' on purpose. This argument will hide
the 'event' argument of the request_retry function. Keeping the same name the 'event' argument of the request_retry function. Keeping the same name
@ -211,8 +202,11 @@ def request_retry(event, request_data, failure_message, exception=None):
retry_event('outgoing_webhooks', event, failure_processor) retry_event('outgoing_webhooks', event, failure_processor)
def do_rest_call(rest_operation, request_data, event, service_handler, timeout=None): def do_rest_call(rest_operation: Dict[str, Any],
# type: (Dict[str, Any], Optional[Dict[str, Any]], Dict[str, Any], Any, Any) -> None request_data: Optional[Dict[str, Any]],
event: Dict[str, Any],
service_handler: Any,
timeout: Any=None) -> None:
rest_operation_validator = check_dict([ rest_operation_validator = check_dict([
('method', check_string), ('method', check_string),
('relative_url_path', check_string), ('relative_url_path', check_string),

View File

@ -21,32 +21,26 @@ rules = settings.RATE_LIMITING_RULES # type: List[Tuple[int, int]]
KEY_PREFIX = '' KEY_PREFIX = ''
class RateLimitedObject: class RateLimitedObject:
def get_keys(self): def get_keys(self) -> List[Text]:
# type: () -> List[Text]
key_fragment = self.key_fragment() key_fragment = self.key_fragment()
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype) return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype)
for keytype in ['list', 'zset', 'block']] for keytype in ['list', 'zset', 'block']]
def key_fragment(self): def key_fragment(self) -> Text:
# type: () -> Text
raise NotImplementedError() raise NotImplementedError()
def rules(self): def rules(self) -> List[Tuple[int, int]]:
# type: () -> List[Tuple[int, int]]
raise NotImplementedError() raise NotImplementedError()
class RateLimitedUser(RateLimitedObject): class RateLimitedUser(RateLimitedObject):
def __init__(self, user, domain='all'): def __init__(self, user: UserProfile, domain: Text='all') -> None:
# type: (UserProfile, Text) -> None
self.user = user self.user = user
self.domain = domain self.domain = domain
def key_fragment(self): def key_fragment(self) -> Text:
# type: () -> Text
return "{}:{}:{}".format(type(self.user), self.user.id, self.domain) return "{}:{}:{}".format(type(self.user), self.user.id, self.domain)
def rules(self): def rules(self) -> List[Tuple[int, int]]:
# type: () -> List[Tuple[int, int]]
if self.user.rate_limits != "": if self.user.rate_limits != "":
result = [] # type: List[Tuple[int, int]] result = [] # type: List[Tuple[int, int]]
for limit in self.user.rate_limits.split(','): for limit in self.user.rate_limits.split(','):
@ -55,36 +49,30 @@ class RateLimitedUser(RateLimitedObject):
return result return result
return rules return rules
def bounce_redis_key_prefix_for_testing(test_name): def bounce_redis_key_prefix_for_testing(test_name: Text) -> None:
# type: (Text) -> None
global KEY_PREFIX global KEY_PREFIX
KEY_PREFIX = test_name + ':' + Text(os.getpid()) + ':' KEY_PREFIX = test_name + ':' + Text(os.getpid()) + ':'
def max_api_calls(entity): def max_api_calls(entity: RateLimitedObject) -> int:
# type: (RateLimitedObject) -> int
"Returns the API rate limit for the highest limit" "Returns the API rate limit for the highest limit"
return entity.rules()[-1][1] return entity.rules()[-1][1]
def max_api_window(entity): def max_api_window(entity: RateLimitedObject) -> int:
# type: (RateLimitedObject) -> int
"Returns the API time window for the highest limit" "Returns the API time window for the highest limit"
return entity.rules()[-1][0] return entity.rules()[-1][0]
def add_ratelimit_rule(range_seconds, num_requests): def add_ratelimit_rule(range_seconds: int, num_requests: int) -> None:
# type: (int , int) -> None
"Add a rate-limiting rule to the ratelimiter" "Add a rate-limiting rule to the ratelimiter"
global rules global rules
rules.append((range_seconds, num_requests)) rules.append((range_seconds, num_requests))
rules.sort(key=lambda x: x[0]) rules.sort(key=lambda x: x[0])
def remove_ratelimit_rule(range_seconds, num_requests): def remove_ratelimit_rule(range_seconds: int, num_requests: int) -> None:
# type: (int , int) -> None
global rules global rules
rules = [x for x in rules if x[0] != range_seconds and x[1] != num_requests] rules = [x for x in rules if x[0] != range_seconds and x[1] != num_requests]
def block_access(entity, seconds): def block_access(entity: RateLimitedObject, seconds: int) -> None:
# type: (RateLimitedObject, int) -> None
"Manually blocks an entity for the desired number of seconds" "Manually blocks an entity for the desired number of seconds"
_, _, blocking_key = entity.get_keys() _, _, blocking_key = entity.get_keys()
with client.pipeline() as pipe: with client.pipeline() as pipe:
@ -92,13 +80,11 @@ def block_access(entity, seconds):
pipe.expire(blocking_key, seconds) pipe.expire(blocking_key, seconds)
pipe.execute() pipe.execute()
def unblock_access(entity): def unblock_access(entity: RateLimitedObject) -> None:
# type: (RateLimitedObject) -> None
_, _, blocking_key = entity.get_keys() _, _, blocking_key = entity.get_keys()
client.delete(blocking_key) client.delete(blocking_key)
def clear_history(entity): def clear_history(entity: RateLimitedObject) -> None:
# type: (RateLimitedObject) -> None
''' '''
This is only used by test code now, where it's very helpful in This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate. allowing us to run tests quickly, by giving a user a clean slate.
@ -106,8 +92,7 @@ def clear_history(entity):
for key in entity.get_keys(): for key in entity.get_keys():
client.delete(key) client.delete(key)
def _get_api_calls_left(entity, range_seconds, max_calls): def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]:
# type: (RateLimitedObject, int, int) -> Tuple[int, float]
list_key, set_key, _ = entity.get_keys() list_key, set_key, _ = entity.get_keys()
# Count the number of values in our sorted set # Count the number of values in our sorted set
# that are between now and the cutoff # that are between now and the cutoff
@ -134,16 +119,14 @@ def _get_api_calls_left(entity, range_seconds, max_calls):
return calls_left, time_reset return calls_left, time_reset
def api_calls_left(entity): def api_calls_left(entity: RateLimitedObject) -> Tuple[int, float]:
# type: (RateLimitedObject) -> Tuple[int, float]
"""Returns how many API calls in this range this client has, as well as when """Returns how many API calls in this range this client has, as well as when
the rate-limit will be reset to 0""" the rate-limit will be reset to 0"""
max_window = max_api_window(entity) max_window = max_api_window(entity)
max_calls = max_api_calls(entity) max_calls = max_api_calls(entity)
return _get_api_calls_left(entity, max_window, max_calls) return _get_api_calls_left(entity, max_window, max_calls)
def is_ratelimited(entity): def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]:
# type: (RateLimitedObject) -> Tuple[bool, float]
"Returns a tuple of (rate_limited, time_till_free)" "Returns a tuple of (rate_limited, time_till_free)"
list_key, set_key, blocking_key = entity.get_keys() list_key, set_key, blocking_key = entity.get_keys()
@ -192,8 +175,7 @@ def is_ratelimited(entity):
# No api calls recorded yet # No api calls recorded yet
return False, 0.0 return False, 0.0
def incr_ratelimit(entity): def incr_ratelimit(entity: RateLimitedObject) -> None:
# type: (RateLimitedObject) -> None
"""Increases the rate-limit for the specified entity""" """Increases the rate-limit for the specified entity"""
list_key, set_key, _ = entity.get_keys() list_key, set_key, _ = entity.get_keys()
now = time.time() now = time.time()

View File

@ -15,8 +15,7 @@ METHODS = ('GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'PATCH')
FLAGS = ('override_api_url_scheme') FLAGS = ('override_api_url_scheme')
@csrf_exempt @csrf_exempt
def rest_dispatch(request, **kwargs): def rest_dispatch(request: HttpRequest, **kwargs: Any) -> HttpResponse:
# type: (HttpRequest, **Any) -> HttpResponse
"""Dispatch to a REST API endpoint. """Dispatch to a REST API endpoint.
Unauthenticated endpoints should not use this, as authentication is verified Unauthenticated endpoints should not use this, as authentication is verified

View File

@ -93,8 +93,7 @@ def send_email(template_prefix, to_user_id=None, to_email=None, from_name=None,
logger.error("Error sending %s email to %s" % (template, mail.to)) logger.error("Error sending %s email to %s" % (template, mail.to))
raise EmailNotDeliveredException raise EmailNotDeliveredException
def send_email_from_dict(email_dict): def send_email_from_dict(email_dict: Mapping[str, Any]) -> None:
# type: (Mapping[str, Any]) -> None
send_email(**dict(email_dict)) send_email(**dict(email_dict))
def send_future_email(template_prefix, to_user_id=None, to_email=None, from_name=None, def send_future_email(template_prefix, to_user_id=None, to_email=None, from_name=None,

View File

@ -17,8 +17,7 @@ def filter_by_subscription_history(
# type: (UserProfile, DefaultDict[int, List[Message]], DefaultDict[int, List[RealmAuditLog]]) -> List[UserMessage] # type: (UserProfile, DefaultDict[int, List[Message]], DefaultDict[int, List[RealmAuditLog]]) -> List[UserMessage]
user_messages_to_insert = [] # type: List[UserMessage] user_messages_to_insert = [] # type: List[UserMessage]
def store_user_message_to_insert(message): def store_user_message_to_insert(message: Message) -> None:
# type: (Message) -> None
message = UserMessage(user_profile=user_profile, message = UserMessage(user_profile=user_profile,
message_id=message['id'], flags=0) message_id=message['id'], flags=0)
user_messages_to_insert.append(message) user_messages_to_insert.append(message)
@ -60,8 +59,7 @@ def filter_by_subscription_history(
store_user_message_to_insert(stream_message) store_user_message_to_insert(stream_message)
return user_messages_to_insert return user_messages_to_insert
def add_missing_messages(user_profile): def add_missing_messages(user_profile: UserProfile) -> None:
# type: (UserProfile) -> None
"""This function takes a soft-deactivated user, and computes and adds """This function takes a soft-deactivated user, and computes and adds
to the database any UserMessage rows that were not created while to the database any UserMessage rows that were not created while
the user was soft-deactivated. The end result is that from the the user was soft-deactivated. The end result is that from the
@ -156,8 +154,7 @@ def add_missing_messages(user_profile):
if len(user_messages_to_insert) > 0: if len(user_messages_to_insert) > 0:
UserMessage.objects.bulk_create(user_messages_to_insert) UserMessage.objects.bulk_create(user_messages_to_insert)
def do_soft_deactivate_user(user_profile): def do_soft_deactivate_user(user_profile: UserProfile) -> None:
# type: (UserProfile) -> None
user_profile.last_active_message_id = UserMessage.objects.filter( user_profile.last_active_message_id = UserMessage.objects.filter(
user_profile=user_profile).order_by( user_profile=user_profile).order_by(
'-message__id')[0].message_id '-message__id')[0].message_id
@ -168,8 +165,7 @@ def do_soft_deactivate_user(user_profile):
logger.info('Soft Deactivated user %s (%s)' % logger.info('Soft Deactivated user %s (%s)' %
(user_profile.id, user_profile.email)) (user_profile.id, user_profile.email))
def do_soft_deactivate_users(users): def do_soft_deactivate_users(users: List[UserProfile]) -> List[UserProfile]:
# type: (List[UserProfile]) -> List[UserProfile]
users_soft_deactivated = [] users_soft_deactivated = []
with transaction.atomic(): with transaction.atomic():
realm_logs = [] realm_logs = []
@ -187,8 +183,7 @@ def do_soft_deactivate_users(users):
RealmAuditLog.objects.bulk_create(realm_logs) RealmAuditLog.objects.bulk_create(realm_logs)
return users_soft_deactivated return users_soft_deactivated
def maybe_catch_up_soft_deactivated_user(user_profile): def maybe_catch_up_soft_deactivated_user(user_profile: UserProfile) -> Union[UserProfile, None]:
# type: (UserProfile) -> Union[UserProfile, None]
if user_profile.long_term_idle: if user_profile.long_term_idle:
add_missing_messages(user_profile) add_missing_messages(user_profile)
user_profile.long_term_idle = False user_profile.long_term_idle = False
@ -204,8 +199,7 @@ def maybe_catch_up_soft_deactivated_user(user_profile):
return user_profile return user_profile
return None return None
def get_users_for_soft_deactivation(inactive_for_days, filter_kwargs): def get_users_for_soft_deactivation(inactive_for_days: int, filter_kwargs: Any) -> List[UserProfile]:
# type: (int, **Any) -> List[UserProfile]
users_activity = list(UserActivity.objects.filter( users_activity = list(UserActivity.objects.filter(
user_profile__is_active=True, user_profile__is_active=True,
user_profile__is_bot=False, user_profile__is_bot=False,
@ -221,8 +215,7 @@ def get_users_for_soft_deactivation(inactive_for_days, filter_kwargs):
id__in=user_ids_to_deactivate)) id__in=user_ids_to_deactivate))
return users_to_deactivate return users_to_deactivate
def do_soft_activate_users(users): def do_soft_activate_users(users: List[UserProfile]) -> List[UserProfile]:
# type: (List[UserProfile]) -> List[UserProfile]
users_soft_activated = [] users_soft_activated = []
for user_profile in users: for user_profile in users:
user_activated = maybe_catch_up_soft_deactivated_user(user_profile) user_activated = maybe_catch_up_soft_deactivated_user(user_profile)

View File

@ -10,8 +10,7 @@ from zerver.models import UserProfile, Stream, Subscription, \
Realm, Recipient, bulk_get_recipients, get_stream_recipient, get_stream, \ Realm, Recipient, bulk_get_recipients, get_stream_recipient, get_stream, \
bulk_get_streams, get_realm_stream, DefaultStreamGroup bulk_get_streams, get_realm_stream, DefaultStreamGroup
def access_stream_for_delete(user_profile, stream_id): def access_stream_for_delete(user_profile: UserProfile, stream_id: int) -> Stream:
# type: (UserProfile, int) -> Stream
# We should only ever use this for realm admins, who are allowed # We should only ever use this for realm admins, who are allowed
# to delete all streams on their realm, even private streams to # to delete all streams on their realm, even private streams to
@ -30,8 +29,8 @@ def access_stream_for_delete(user_profile, stream_id):
return stream return stream
def access_stream_common(user_profile, stream, error): def access_stream_common(user_profile: UserProfile, stream: Stream,
# type: (UserProfile, Stream, Text) -> Tuple[Recipient, Subscription] error: Text) -> Tuple[Recipient, Subscription]:
"""Common function for backend code where the target use attempts to """Common function for backend code where the target use attempts to
access the target stream, returning all the data fetched along the access the target stream, returning all the data fetched along the
way. If that user does not have permission to access that stream, way. If that user does not have permission to access that stream,
@ -63,8 +62,7 @@ def access_stream_common(user_profile, stream, error):
# an error. # an error.
raise JsonableError(error) raise JsonableError(error)
def access_stream_by_id(user_profile, stream_id): def access_stream_by_id(user_profile: UserProfile, stream_id: int) -> Tuple[Stream, Recipient, Subscription]:
# type: (UserProfile, int) -> Tuple[Stream, Recipient, Subscription]
error = _("Invalid stream id") error = _("Invalid stream id")
try: try:
stream = Stream.objects.get(id=stream_id) stream = Stream.objects.get(id=stream_id)
@ -74,8 +72,7 @@ def access_stream_by_id(user_profile, stream_id):
(recipient, sub) = access_stream_common(user_profile, stream, error) (recipient, sub) = access_stream_common(user_profile, stream, error)
return (stream, recipient, sub) return (stream, recipient, sub)
def check_stream_name_available(realm, name): def check_stream_name_available(realm: Realm, name: Text) -> None:
# type: (Realm, Text) -> None
check_stream_name(name) check_stream_name(name)
try: try:
get_stream(name, realm) get_stream(name, realm)
@ -83,8 +80,8 @@ def check_stream_name_available(realm, name):
except Stream.DoesNotExist: except Stream.DoesNotExist:
pass pass
def access_stream_by_name(user_profile, stream_name): def access_stream_by_name(user_profile: UserProfile,
# type: (UserProfile, Text) -> Tuple[Stream, Recipient, Subscription] stream_name: Text) -> Tuple[Stream, Recipient, Subscription]:
error = _("Invalid stream name '%s'" % (stream_name,)) error = _("Invalid stream name '%s'" % (stream_name,))
try: try:
stream = get_realm_stream(stream_name, user_profile.realm_id) stream = get_realm_stream(stream_name, user_profile.realm_id)
@ -94,8 +91,7 @@ def access_stream_by_name(user_profile, stream_name):
(recipient, sub) = access_stream_common(user_profile, stream, error) (recipient, sub) = access_stream_common(user_profile, stream, error)
return (stream, recipient, sub) return (stream, recipient, sub)
def access_stream_for_unmute_topic(user_profile, stream_name, error): def access_stream_for_unmute_topic(user_profile: UserProfile, stream_name: Text, error: Text) -> Stream:
# type: (UserProfile, Text, Text) -> Stream
""" """
It may seem a little silly to have this helper function for unmuting It may seem a little silly to have this helper function for unmuting
topics, but it gets around a linter warning, and it helps to be able topics, but it gets around a linter warning, and it helps to be able
@ -115,8 +111,7 @@ def access_stream_for_unmute_topic(user_profile, stream_name, error):
raise JsonableError(error) raise JsonableError(error)
return stream return stream
def is_public_stream_by_name(stream_name, realm): def is_public_stream_by_name(stream_name: Text, realm: Realm) -> bool:
# type: (Text, Realm) -> bool
"""Determine whether a stream is public, so that """Determine whether a stream is public, so that
our caller can decide whether we can get our caller can decide whether we can get
historical messages for a narrowing search. historical messages for a narrowing search.
@ -136,8 +131,8 @@ def is_public_stream_by_name(stream_name, realm):
return False return False
return stream.is_public() return stream.is_public()
def filter_stream_authorization(user_profile, streams): def filter_stream_authorization(user_profile: UserProfile,
# type: (UserProfile, Iterable[Stream]) -> Tuple[List[Stream], List[Stream]] streams: Iterable[Stream]) -> Tuple[List[Stream], List[Stream]]:
streams_subscribed = set() # type: Set[int] streams_subscribed = set() # type: Set[int]
recipients_map = bulk_get_recipients(Recipient.STREAM, [stream.id for stream in streams]) recipients_map = bulk_get_recipients(Recipient.STREAM, [stream.id for stream in streams])
subs = Subscription.objects.filter(user_profile=user_profile, subs = Subscription.objects.filter(user_profile=user_profile,
@ -161,8 +156,9 @@ def filter_stream_authorization(user_profile, streams):
stream.id not in set(stream.id for stream in unauthorized_streams)] stream.id not in set(stream.id for stream in unauthorized_streams)]
return authorized_streams, unauthorized_streams return authorized_streams, unauthorized_streams
def list_to_streams(streams_raw, user_profile, autocreate=False): def list_to_streams(streams_raw: Iterable[Mapping[str, Any]],
# type: (Iterable[Mapping[str, Any]], UserProfile, bool) -> Tuple[List[Stream], List[Stream]] user_profile: UserProfile,
autocreate: bool=False) -> Tuple[List[Stream], List[Stream]]:
"""Converts list of dicts to a list of Streams, validating input in the process """Converts list of dicts to a list of Streams, validating input in the process
For each stream name, we validate it to ensure it meets our For each stream name, we validate it to ensure it meets our

View File

@ -15,8 +15,7 @@ from sqlalchemy.sql import (
Selectable Selectable
) )
def get_topic_mutes(user_profile): def get_topic_mutes(user_profile: UserProfile) -> List[List[Text]]:
# type: (UserProfile) -> List[List[Text]]
rows = MutedTopic.objects.filter( rows = MutedTopic.objects.filter(
user_profile=user_profile, user_profile=user_profile,
).values( ).values(
@ -28,8 +27,7 @@ def get_topic_mutes(user_profile):
for row in rows for row in rows
] ]
def set_topic_mutes(user_profile, muted_topics): def set_topic_mutes(user_profile: UserProfile, muted_topics: List[List[Text]]) -> None:
# type: (UserProfile, List[List[Text]]) -> None
''' '''
This is only used in tests. This is only used in tests.
@ -50,8 +48,7 @@ def set_topic_mutes(user_profile, muted_topics):
topic_name=topic_name, topic_name=topic_name,
) )
def add_topic_mute(user_profile, stream_id, recipient_id, topic_name): def add_topic_mute(user_profile: UserProfile, stream_id: int, recipient_id: int, topic_name: str) -> None:
# type: (UserProfile, int, int, str) -> None
MutedTopic.objects.create( MutedTopic.objects.create(
user_profile=user_profile, user_profile=user_profile,
stream_id=stream_id, stream_id=stream_id,
@ -59,8 +56,7 @@ def add_topic_mute(user_profile, stream_id, recipient_id, topic_name):
topic_name=topic_name, topic_name=topic_name,
) )
def remove_topic_mute(user_profile, stream_id, topic_name): def remove_topic_mute(user_profile: UserProfile, stream_id: int, topic_name: str) -> None:
# type: (UserProfile, int, str) -> None
row = MutedTopic.objects.get( row = MutedTopic.objects.get(
user_profile=user_profile, user_profile=user_profile,
stream_id=stream_id, stream_id=stream_id,
@ -68,8 +64,7 @@ def remove_topic_mute(user_profile, stream_id, topic_name):
) )
row.delete() row.delete()
def topic_is_muted(user_profile, stream_id, topic_name): def topic_is_muted(user_profile: UserProfile, stream_id: int, topic_name: Text) -> bool:
# type: (UserProfile, int, Text) -> bool
is_muted = MutedTopic.objects.filter( is_muted = MutedTopic.objects.filter(
user_profile=user_profile, user_profile=user_profile,
stream_id=stream_id, stream_id=stream_id,
@ -77,8 +72,9 @@ def topic_is_muted(user_profile, stream_id, topic_name):
).exists() ).exists()
return is_muted return is_muted
def exclude_topic_mutes(conditions, user_profile, stream_id): def exclude_topic_mutes(conditions: List[Selectable],
# type: (List[Selectable], UserProfile, Optional[int]) -> List[Selectable] user_profile: UserProfile,
stream_id: Optional[int]) -> List[Selectable]:
query = MutedTopic.objects.filter( query = MutedTopic.objects.filter(
user_profile=user_profile, user_profile=user_profile,
) )
@ -97,8 +93,7 @@ def exclude_topic_mutes(conditions, user_profile, stream_id):
if not rows: if not rows:
return conditions return conditions
def mute_cond(row): def mute_cond(row: Dict[str, Any]) -> Selectable:
# type: (Dict[str, Any]) -> Selectable
recipient_id = row['recipient_id'] recipient_id = row['recipient_id']
topic_name = row['topic_name'] topic_name = row['topic_name']
stream_cond = column("recipient_id") == recipient_id stream_cond = column("recipient_id") == recipient_id
@ -108,8 +103,7 @@ def exclude_topic_mutes(conditions, user_profile, stream_id):
condition = not_(or_(*list(map(mute_cond, rows)))) condition = not_(or_(*list(map(mute_cond, rows))))
return conditions + [condition] return conditions + [condition]
def build_topic_mute_checker(user_profile): def build_topic_mute_checker(user_profile: UserProfile) -> Callable[[int, Text], bool]:
# type: (UserProfile) -> Callable[[int, Text], bool]
rows = MutedTopic.objects.filter( rows = MutedTopic.objects.filter(
user_profile=user_profile, user_profile=user_profile,
).values( ).values(
@ -124,8 +118,7 @@ def build_topic_mute_checker(user_profile):
topic_name = row['topic_name'] topic_name = row['topic_name']
tups.add((recipient_id, topic_name.lower())) tups.add((recipient_id, topic_name.lower()))
def is_muted(recipient_id, topic): def is_muted(recipient_id: int, topic: Text) -> bool:
# type: (int, Text) -> bool
return (recipient_id, topic.lower()) in tups return (recipient_id, topic.lower()) in tups
return is_muted return is_muted

View File

@ -2,8 +2,9 @@ from typing import Optional, Text, Dict, Any
from pyoembed import oEmbed, PyOembedException from pyoembed import oEmbed, PyOembedException
def get_oembed_data(url, maxwidth=640, maxheight=480): def get_oembed_data(url: Text,
# type: (Text, Optional[int], Optional[int]) -> Optional[Dict[Any, Any]] maxwidth: Optional[int]=640,
maxheight: Optional[int]=480) -> Optional[Dict[Any, Any]]:
try: try:
data = oEmbed(url, maxwidth=maxwidth, maxheight=maxheight) data = oEmbed(url, maxwidth=maxwidth, maxheight=maxheight)
except PyOembedException: except PyOembedException:

View File

@ -3,10 +3,8 @@ from bs4 import BeautifulSoup
class BaseParser: class BaseParser:
def __init__(self, html_source): def __init__(self, html_source: Text) -> None:
# type: (Text) -> None
self._soup = BeautifulSoup(html_source, "lxml") self._soup = BeautifulSoup(html_source, "lxml")
def extract_data(self): def extract_data(self) -> Any:
# type: () -> Any
raise NotImplementedError() raise NotImplementedError()

View File

@ -3,15 +3,13 @@ from zerver.lib.url_preview.parsers.base import BaseParser
class GenericParser(BaseParser): class GenericParser(BaseParser):
def extract_data(self): def extract_data(self) -> Dict[str, Optional[Text]]:
# type: () -> Dict[str, Optional[Text]]
return { return {
'title': self._get_title(), 'title': self._get_title(),
'description': self._get_description(), 'description': self._get_description(),
'image': self._get_image()} 'image': self._get_image()}
def _get_title(self): def _get_title(self) -> Optional[Text]:
# type: () -> Optional[Text]
soup = self._soup soup = self._soup
if (soup.title and soup.title.text != ''): if (soup.title and soup.title.text != ''):
return soup.title.text return soup.title.text
@ -19,8 +17,7 @@ class GenericParser(BaseParser):
return soup.h1.text return soup.h1.text
return None return None
def _get_description(self): def _get_description(self) -> Optional[Text]:
# type: () -> Optional[Text]
soup = self._soup soup = self._soup
meta_description = soup.find('meta', attrs={'name': 'description'}) meta_description = soup.find('meta', attrs={'name': 'description'})
if (meta_description and meta_description['content'] != ''): if (meta_description and meta_description['content'] != ''):
@ -35,8 +32,7 @@ class GenericParser(BaseParser):
return first_p.string return first_p.string
return None return None
def _get_image(self): def _get_image(self) -> Optional[Text]:
# type: () -> Optional[Text]
""" """
Finding a first image after the h1 header. Finding a first image after the h1 header.
Presumably it will be the main image. Presumably it will be the main image.

View File

@ -4,8 +4,7 @@ from .base import BaseParser
class OpenGraphParser(BaseParser): class OpenGraphParser(BaseParser):
def extract_data(self): def extract_data(self) -> Dict[str, Text]:
# type: () -> Dict[str, Text]
meta = self._soup.findAll('meta') meta = self._soup.findAll('meta')
content = {} content = {}
for tag in meta: for tag in meta:

View File

@ -20,19 +20,18 @@ link_regex = re.compile(
r'(?:/?|[/?]\S+)$', re.IGNORECASE) r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def is_link(url): def is_link(url: Text) -> Match[Text]:
# type: (Text) -> Match[Text]
return link_regex.match(smart_text(url)) return link_regex.match(smart_text(url))
def cache_key_func(url): def cache_key_func(url: Text) -> Text:
# type: (Text) -> Text
return url return url
@cache_with_key(cache_key_func, cache_name=CACHE_NAME, with_statsd_key="urlpreview_data") @cache_with_key(cache_key_func, cache_name=CACHE_NAME, with_statsd_key="urlpreview_data")
def get_link_embed_data(url, maxwidth=640, maxheight=480): def get_link_embed_data(url: Text,
# type: (Text, Optional[int], Optional[int]) -> Optional[Dict[Any, Any]] maxwidth: Optional[int]=640,
maxheight: Optional[int]=480) -> Optional[Dict[Any, Any]]:
if not is_link(url): if not is_link(url):
return None return None
# Fetch information from URL. # Fetch information from URL.
@ -60,6 +59,5 @@ def get_link_embed_data(url, maxwidth=640, maxheight=480):
@get_cache_with_key(cache_key_func, cache_name=CACHE_NAME) @get_cache_with_key(cache_key_func, cache_name=CACHE_NAME)
def link_embed_data_from_cache(url, maxwidth=640, maxheight=480): def link_embed_data_from_cache(url: Text, maxwidth: Optional[int]=640, maxheight: Optional[int]=480) -> Any:
# type: (Text, Optional[int], Optional[int]) -> Any
return return

View File

@ -8,8 +8,7 @@ from zerver.lib.request import JsonableError
from zerver.models import UserProfile, Service, Realm, \ from zerver.models import UserProfile, Service, Realm, \
get_user_profile_by_id, user_profile_by_email_cache_key get_user_profile_by_id, user_profile_by_email_cache_key
def check_full_name(full_name_raw): def check_full_name(full_name_raw: Text) -> Text:
# type: (Text) -> Text
full_name = full_name_raw.strip() full_name = full_name_raw.strip()
if len(full_name) > UserProfile.MAX_NAME_LENGTH: if len(full_name) > UserProfile.MAX_NAME_LENGTH:
raise JsonableError(_("Name too long!")) raise JsonableError(_("Name too long!"))
@ -19,20 +18,17 @@ def check_full_name(full_name_raw):
raise JsonableError(_("Invalid characters in name!")) raise JsonableError(_("Invalid characters in name!"))
return full_name return full_name
def check_short_name(short_name_raw): def check_short_name(short_name_raw: Text) -> Text:
# type: (Text) -> Text
short_name = short_name_raw.strip() short_name = short_name_raw.strip()
if len(short_name) == 0: if len(short_name) == 0:
raise JsonableError(_("Bad name or username")) raise JsonableError(_("Bad name or username"))
return short_name return short_name
def check_valid_bot_type(bot_type): def check_valid_bot_type(bot_type: int) -> None:
# type: (int) -> None
if bot_type not in UserProfile.ALLOWED_BOT_TYPES: if bot_type not in UserProfile.ALLOWED_BOT_TYPES:
raise JsonableError(_('Invalid bot type')) raise JsonableError(_('Invalid bot type'))
def check_valid_interface_type(interface_type): def check_valid_interface_type(interface_type: int) -> None:
# type: (int) -> None
if interface_type not in Service.ALLOWED_INTERFACE_TYPES: if interface_type not in Service.ALLOWED_INTERFACE_TYPES:
raise JsonableError(_('Invalid interface type')) raise JsonableError(_('Invalid interface type'))

View File

@ -15,8 +15,7 @@ from django.conf import settings
T = TypeVar('T') T = TypeVar('T')
def statsd_key(val, clean_periods=False): def statsd_key(val: Any, clean_periods: bool=False) -> str:
# type: (Any, bool) -> str
if not isinstance(val, str): if not isinstance(val, str):
val = str(val) val = str(val)
@ -35,8 +34,7 @@ class StatsDWrapper:
# Backported support for gauge deltas # Backported support for gauge deltas
# as our statsd server supports them but supporting # as our statsd server supports them but supporting
# pystatsd is not released yet # pystatsd is not released yet
def _our_gauge(self, stat, value, rate=1, delta=False): def _our_gauge(self, stat: str, value: float, rate: float=1, delta: bool=False) -> None:
# type: (str, float, float, bool) -> None
"""Set a gauge value.""" """Set a gauge value."""
from django_statsd.clients import statsd from django_statsd.clients import statsd
if delta: if delta:
@ -45,8 +43,7 @@ class StatsDWrapper:
value_str = '%g|g' % (value,) value_str = '%g|g' % (value,)
statsd._send(stat, value_str, rate) statsd._send(stat, value_str, rate)
def __getattr__(self, name): def __getattr__(self, name: str) -> Any:
# type: (str) -> Any
# Hand off to statsd if we have it enabled # Hand off to statsd if we have it enabled
# otherwise do nothing # otherwise do nothing
if name in ['timer', 'timing', 'incr', 'decr', 'gauge']: if name in ['timer', 'timing', 'incr', 'decr', 'gauge']:
@ -64,8 +61,11 @@ class StatsDWrapper:
statsd = StatsDWrapper() statsd = StatsDWrapper()
# Runs the callback with slices of all_list of a given batch_size # Runs the callback with slices of all_list of a given batch_size
def run_in_batches(all_list, batch_size, callback, sleep_time = 0, logger = None): def run_in_batches(all_list: Sequence[T],
# type: (Sequence[T], int, Callable[[Sequence[T]], None], int, Optional[Callable[[str], None]]) -> None batch_size: int,
callback: Callable[[Sequence[T]], None],
sleep_time: int=0,
logger: Optional[Callable[[str], None]]=None) -> None:
if len(all_list) == 0: if len(all_list) == 0:
return return
@ -85,8 +85,8 @@ def run_in_batches(all_list, batch_size, callback, sleep_time = 0, logger = None
if i != limit - 1: if i != limit - 1:
sleep(sleep_time) sleep(sleep_time)
def make_safe_digest(string, hash_func=hashlib.sha1): def make_safe_digest(string: Text,
# type: (Text, Callable[[bytes], Any]) -> Text hash_func: Callable[[bytes], Any]=hashlib.sha1) -> Text:
""" """
return a hex digest of `string`. return a hex digest of `string`.
""" """
@ -95,8 +95,7 @@ def make_safe_digest(string, hash_func=hashlib.sha1):
return hash_func(string.encode('utf-8')).hexdigest() return hash_func(string.encode('utf-8')).hexdigest()
def log_statsd_event(name): def log_statsd_event(name: str) -> None:
# type: (str) -> None
""" """
Sends a single event to statsd with the desired name and the current timestamp Sends a single event to statsd with the desired name and the current timestamp
@ -110,12 +109,13 @@ def log_statsd_event(name):
event_name = "events.%s" % (name,) event_name = "events.%s" % (name,)
statsd.incr(event_name) statsd.incr(event_name)
def generate_random_token(length): def generate_random_token(length: int) -> str:
# type: (int) -> str
return str(base64.b16encode(os.urandom(length // 2)).decode('utf-8').lower()) return str(base64.b16encode(os.urandom(length // 2)).decode('utf-8').lower())
def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=None): def query_chunker(queries: List[Any],
# type: (List[Any], Set[int], int, int) -> Iterable[Any] id_collector: Set[int]=None,
chunk_size: int=1000,
db_chunk_size: int=None) -> Iterable[Any]:
''' '''
This merges one or more Django ascending-id queries into This merges one or more Django ascending-id queries into
a generator that returns chunks of chunk_size row objects a generator that returns chunks of chunk_size row objects
@ -142,8 +142,7 @@ def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=Non
else: else:
id_collector = set() id_collector = set()
def chunkify(q, i): def chunkify(q: Any, i: int) -> Iterable[Tuple[int, int, Any]]:
# type: (Any, int) -> Iterable[Tuple[int, int, Any]]
q = q.order_by('id') q = q.order_by('id')
min_id = -1 min_id = -1
while True: while True:
@ -171,8 +170,7 @@ def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=Non
yield [row for row_id, i, row in tup_chunk] yield [row for row_id, i, row in tup_chunk]
def split_by(array, group_size, filler): def split_by(array: List[Any], group_size: int, filler: Any) -> List[List[Any]]:
# type: (List[Any], int, Any) -> List[List[Any]]
""" """
Group elements into list of size `group_size` and fill empty cells with Group elements into list of size `group_size` and fill empty cells with
`filler`. Recipe from https://docs.python.org/3/library/itertools.html `filler`. Recipe from https://docs.python.org/3/library/itertools.html
@ -180,8 +178,7 @@ def split_by(array, group_size, filler):
args = [iter(array)] * group_size args = [iter(array)] * group_size
return list(map(list, zip_longest(*args, fillvalue=filler))) return list(map(list, zip_longest(*args, fillvalue=filler)))
def is_remote_server(identifier): def is_remote_server(identifier: Text) -> bool:
# type: (Text) -> bool
""" """
This function can be used to identify the source of API auth This function can be used to identify the source of API auth
request. We can have two types of sources, Remote Zulip Servers request. We can have two types of sources, Remote Zulip Servers

View File

@ -98,8 +98,7 @@ def get_push_commits_event_message(user_name, compare_url, branch_name,
commits_data=get_commits_content(commits_data, is_truncated), commits_data=get_commits_content(commits_data, is_truncated),
).rstrip() ).rstrip()
def get_force_push_commits_event_message(user_name, url, branch_name, head): def get_force_push_commits_event_message(user_name: Text, url: Text, branch_name: Text, head: Text) -> Text:
# type: (Text, Text, Text, Text) -> Text
return FORCE_PUSH_COMMITS_MESSAGE_TEMPLATE.format( return FORCE_PUSH_COMMITS_MESSAGE_TEMPLATE.format(
user_name=user_name, user_name=user_name,
url=url, url=url,
@ -107,16 +106,14 @@ def get_force_push_commits_event_message(user_name, url, branch_name, head):
head=head head=head
) )
def get_create_branch_event_message(user_name, url, branch_name): def get_create_branch_event_message(user_name: Text, url: Text, branch_name: Text) -> Text:
# type: (Text, Text, Text) -> Text
return CREATE_BRANCH_MESSAGE_TEMPLATE.format( return CREATE_BRANCH_MESSAGE_TEMPLATE.format(
user_name=user_name, user_name=user_name,
url=url, url=url,
branch_name=branch_name, branch_name=branch_name,
) )
def get_remove_branch_event_message(user_name, branch_name): def get_remove_branch_event_message(user_name: Text, branch_name: Text) -> Text:
# type: (Text, Text) -> Text
return REMOVE_BRANCH_MESSAGE_TEMPLATE.format( return REMOVE_BRANCH_MESSAGE_TEMPLATE.format(
user_name=user_name, user_name=user_name,
branch_name=branch_name, branch_name=branch_name,
@ -147,15 +144,18 @@ def get_pull_request_event_message(
main_message += '\n' + CONTENT_MESSAGE_TEMPLATE.format(message=message) main_message += '\n' + CONTENT_MESSAGE_TEMPLATE.format(message=message)
return main_message.rstrip() return main_message.rstrip()
def get_setup_webhook_message(integration, user_name=None): def get_setup_webhook_message(integration: Text, user_name: Optional[Text]=None) -> Text:
# type: (Text, Optional[Text]) -> Text
content = SETUP_MESSAGE_TEMPLATE.format(integration=integration) content = SETUP_MESSAGE_TEMPLATE.format(integration=integration)
if user_name: if user_name:
content += SETUP_MESSAGE_USER_PART.format(user_name=user_name) content += SETUP_MESSAGE_USER_PART.format(user_name=user_name)
return content return content
def get_issue_event_message(user_name, action, url, number=None, message=None, assignee=None): def get_issue_event_message(user_name: Text,
# type: (Text, Text, Text, Optional[int], Optional[Text], Optional[Text]) -> Text action: Text,
url: Text,
number: Optional[int]=None,
message: Optional[Text]=None,
assignee: Optional[Text]=None) -> Text:
return get_pull_request_event_message( return get_pull_request_event_message(
user_name, user_name,
action, action,
@ -166,8 +166,10 @@ def get_issue_event_message(user_name, action, url, number=None, message=None, a
type='Issue' type='Issue'
) )
def get_push_tag_event_message(user_name, tag_name, tag_url=None, action='pushed'): def get_push_tag_event_message(user_name: Text,
# type: (Text, Text, Optional[Text], Optional[Text]) -> Text tag_name: Text,
tag_url: Optional[Text]=None,
action: Optional[Text]='pushed') -> Text:
if tag_url: if tag_url:
tag_part = TAG_WITH_URL_TEMPLATE.format(tag_name=tag_name, tag_url=tag_url) tag_part = TAG_WITH_URL_TEMPLATE.format(tag_name=tag_name, tag_url=tag_url)
else: else:
@ -178,8 +180,11 @@ def get_push_tag_event_message(user_name, tag_name, tag_url=None, action='pushed
tag=tag_part tag=tag_part
) )
def get_commits_comment_action_message(user_name, action, commit_url, sha, message=None): def get_commits_comment_action_message(user_name: Text,
# type: (Text, Text, Text, Text, Optional[Text]) -> Text action: Text,
commit_url: Text,
sha: Text,
message: Optional[Text]=None) -> Text:
content = COMMITS_COMMENT_MESSAGE_TEMPLATE.format( content = COMMITS_COMMENT_MESSAGE_TEMPLATE.format(
user_name=user_name, user_name=user_name,
action=action, action=action,
@ -192,8 +197,7 @@ def get_commits_comment_action_message(user_name, action, commit_url, sha, messa
) )
return content return content
def get_commits_content(commits_data, is_truncated=False): def get_commits_content(commits_data: List[Dict[str, Any]], is_truncated: Optional[bool]=False) -> Text:
# type: (List[Dict[str, Any]], Optional[bool]) -> Text
commits_content = '' commits_content = ''
for commit in commits_data[:COMMITS_LIMIT]: for commit in commits_data[:COMMITS_LIMIT]:
commits_content += COMMIT_ROW_TEMPLATE.format( commits_content += COMMIT_ROW_TEMPLATE.format(
@ -212,12 +216,10 @@ def get_commits_content(commits_data, is_truncated=False):
).replace(' ', ' ') ).replace(' ', ' ')
return commits_content.rstrip() return commits_content.rstrip()
def get_short_sha(sha): def get_short_sha(sha: Text) -> Text:
# type: (Text) -> Text
return sha[:7] return sha[:7]
def get_all_committers(commits_data): def get_all_committers(commits_data: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
# type: (List[Dict[str, Any]]) -> List[Tuple[str, int]]
committers = defaultdict(int) # type: Dict[str, int] committers = defaultdict(int) # type: Dict[str, int]
for commit in commits_data: for commit in commits_data: