zerver/lib: Use python 3 syntax for typing.

Extracted from a larger commit by tabbott because these changes will
not create significant merge conflicts.
This commit is contained in:
rht 2017-11-05 11:15:10 +01:00 committed by Tim Abbott
parent 561ba33f69
commit 3f4bf2d22f
35 changed files with 388 additions and 573 deletions

View File

@ -11,8 +11,7 @@ from zerver.models import (
get_user_including_cross_realm,
)
def user_profiles_from_unvalidated_emails(emails, realm):
# type: (Iterable[Text], Realm) -> List[UserProfile]
def user_profiles_from_unvalidated_emails(emails: Iterable[Text], realm: Realm) -> List[UserProfile]:
user_profiles = [] # type: List[UserProfile]
for email in emails:
try:
@ -22,8 +21,7 @@ def user_profiles_from_unvalidated_emails(emails, realm):
user_profiles.append(user_profile)
return user_profiles
def get_user_profiles(emails, realm):
# type: (Iterable[Text], Realm) -> List[UserProfile]
def get_user_profiles(emails: Iterable[Text], realm: Realm) -> List[UserProfile]:
try:
return user_profiles_from_unvalidated_emails(emails, realm)
except ValidationError as e:
@ -42,44 +40,43 @@ class Addressee:
# in memory.
#
# This should be treated as an immutable class.
def __init__(self, msg_type, user_profiles=None, stream_name=None, topic=None):
# type: (str, Optional[Sequence[UserProfile]], Optional[Text], Text) -> None
def __init__(self, msg_type: str,
user_profiles: Optional[Sequence[UserProfile]]=None,
stream_name: Optional[Text]=None,
topic: Text=None) -> None:
assert(msg_type in ['stream', 'private'])
self._msg_type = msg_type
self._user_profiles = user_profiles
self._stream_name = stream_name
self._topic = topic
def msg_type(self):
# type: () -> str
def msg_type(self) -> str:
return self._msg_type
def is_stream(self):
# type: () -> bool
def is_stream(self) -> bool:
return self._msg_type == 'stream'
def is_private(self):
# type: () -> bool
def is_private(self) -> bool:
return self._msg_type == 'private'
def user_profiles(self):
# type: () -> List[UserProfile]
def user_profiles(self) -> List[UserProfile]:
assert(self.is_private())
return self._user_profiles # type: ignore # assertion protects us
def stream_name(self):
# type: () -> Text
def stream_name(self) -> Text:
assert(self.is_stream())
return self._stream_name
def topic(self):
# type: () -> Text
def topic(self) -> Text:
assert(self.is_stream())
return self._topic
@staticmethod
def legacy_build(sender, message_type_name, message_to, topic_name, realm=None):
# type: (UserProfile, Text, Sequence[Text], Text, Optional[Realm]) -> Addressee
def legacy_build(sender: UserProfile,
message_type_name: Text,
message_to: Sequence[Text],
topic_name: Text,
realm: Optional[Realm]=None) -> 'Addressee':
# For legacy reason message_to used to be either a list of
# emails or a list of streams. We haven't fixed all of our
@ -111,8 +108,7 @@ class Addressee:
raise JsonableError(_("Invalid message type"))
@staticmethod
def for_stream(stream_name, topic):
# type: (Text, Text) -> Addressee
def for_stream(stream_name: Text, topic: Text) -> 'Addressee':
return Addressee(
msg_type='stream',
stream_name=stream_name,
@ -120,8 +116,7 @@ class Addressee:
)
@staticmethod
def for_private(emails, realm):
# type: (Sequence[Text], Realm) -> Addressee
def for_private(emails: Sequence[Text], realm: Realm) -> 'Addressee':
user_profiles = get_user_profiles(emails, realm)
return Addressee(
msg_type='private',
@ -129,8 +124,7 @@ class Addressee:
)
@staticmethod
def for_user_profile(user_profile):
# type: (UserProfile) -> Addressee
def for_user_profile(user_profile: UserProfile) -> 'Addressee':
user_profiles = [user_profile]
return Addressee(
msg_type='private',

View File

@ -6,13 +6,12 @@ from zerver.lib.request import JsonableError
from zerver.lib.upload import delete_message_image
from zerver.models import Attachment, UserProfile
def user_attachments(user_profile):
# type: (UserProfile) -> List[Dict[str, Any]]
def user_attachments(user_profile: UserProfile) -> List[Dict[str, Any]]:
attachments = Attachment.objects.filter(owner=user_profile).prefetch_related('messages')
return [a.to_dict() for a in attachments]
def access_attachment_by_id(user_profile, attachment_id, needs_owner=False):
# type: (UserProfile, int, bool) -> Attachment
def access_attachment_by_id(user_profile: UserProfile, attachment_id: int,
needs_owner: bool=False) -> Attachment:
query = Attachment.objects.filter(id=attachment_id)
if needs_owner:
query = query.filter(owner=user_profile)
@ -22,8 +21,7 @@ def access_attachment_by_id(user_profile, attachment_id, needs_owner=False):
raise JsonableError(_("Invalid attachment"))
return attachment
def remove_attachment(user_profile, attachment):
# type: (UserProfile, Attachment) -> None
def remove_attachment(user_profile: UserProfile, attachment: Attachment) -> None:
try:
delete_message_image(attachment.path_id)
except Exception:

View File

@ -24,8 +24,7 @@ our_dir = os.path.dirname(os.path.abspath(__file__))
from zulip_bots.lib import RateLimit
def get_bot_handler(service_name):
# type: (str) -> Any
def get_bot_handler(service_name: str) -> Any:
# Check that this service is present in EMBEDDED_BOTS, add exception handling.
is_present_in_registry = any(service_name == embedded_bot_service.name for
@ -40,31 +39,25 @@ def get_bot_handler(service_name):
class StateHandler:
state_size_limit = 10000000 # type: int # TODO: Store this in the server configuration model.
def __init__(self, user_profile):
# type: (UserProfile) -> None
def __init__(self, user_profile: UserProfile) -> None:
self.user_profile = user_profile
self.marshal = lambda obj: json.dumps(obj)
self.demarshal = lambda obj: json.loads(obj)
def get(self, key):
# type: (Text) -> Text
def get(self, key: Text) -> Text:
return self.demarshal(get_bot_state(self.user_profile, key))
def put(self, key, value):
# type: (Text, Text) -> None
def put(self, key: Text, value: Text) -> None:
set_bot_state(self.user_profile, key, self.marshal(value))
def remove(self, key):
# type: (Text) -> None
def remove(self, key: Text) -> None:
remove_bot_state(self.user_profile, key)
def contains(self, key):
# type: (Text) -> bool
def contains(self, key: Text) -> bool:
return is_key_in_bot_state(self.user_profile, key)
class EmbeddedBotHandler:
def __init__(self, user_profile):
# type: (UserProfile) -> None
def __init__(self, user_profile: UserProfile) -> None:
# Only expose a subset of our UserProfile's functionality
self.user_profile = user_profile
self._rate_limit = RateLimit(20, 5)
@ -72,8 +65,7 @@ class EmbeddedBotHandler:
self.email = user_profile.email
self.storage = StateHandler(user_profile)
def send_message(self, message):
# type: (Dict[str, Any]) -> None
def send_message(self, message: Dict[str, Any]) -> None:
if self._rate_limit.is_legal():
recipients = message['to'] if message['type'] == 'stream' else ','.join(message['to'])
internal_send_message(realm=self.user_profile.realm, sender_email=self.user_profile.email,
@ -82,8 +74,7 @@ class EmbeddedBotHandler:
else:
self._rate_limit.show_error_and_exit()
def send_reply(self, message, response):
# type: (Dict[str, Any], str) -> None
def send_reply(self, message: Dict[str, Any], response: str) -> None:
if message['type'] == 'private':
self.send_message(dict(
type='private',
@ -100,6 +91,5 @@ class EmbeddedBotHandler:
sender_email=message['sender_email'],
))
def get_config_info(self):
# type: () -> Dict[Text, Text]
def get_config_info(self) -> Dict[Text, Text]:
return get_bot_config(self.user_profile)

View File

@ -221,8 +221,7 @@ EMOJI_TWEET = """{
]
}"""
def twitter(tweet_id):
# type: (Text) -> Optional[Dict[Text, Any]]
def twitter(tweet_id: Text) -> Optional[Dict[Text, Any]]:
if tweet_id in ["112652479837110273", "287977969287315456", "287977969287315457"]:
return ujson.loads(NORMAL_TWEET)
elif tweet_id == "287977969287315458":

View File

@ -22,8 +22,7 @@ from django.db.models import Q
MESSAGE_CACHE_SIZE = 75000
def message_fetch_objects():
# type: () -> List[Any]
def message_fetch_objects() -> List[Any]:
try:
max_id = Message.objects.only('id').order_by("-id")[0].id
except IndexError:
@ -31,8 +30,8 @@ def message_fetch_objects():
return Message.objects.select_related().filter(~Q(sender__email='tabbott/extra@mit.edu'),
id__gt=max_id - MESSAGE_CACHE_SIZE)
def message_cache_items(items_for_remote_cache, message):
# type: (Dict[Text, Tuple[bytes]], Message) -> None
def message_cache_items(items_for_remote_cache: Dict[Text, Tuple[bytes]],
message: Message) -> None:
'''
Note: this code is untested, and the caller has been
commented out for a while.
@ -41,32 +40,32 @@ def message_cache_items(items_for_remote_cache, message):
value = MessageDict.to_dict_uncached(message)
items_for_remote_cache[key] = (value,)
def user_cache_items(items_for_remote_cache, user_profile):
# type: (Dict[Text, Tuple[UserProfile]], UserProfile) -> None
def user_cache_items(items_for_remote_cache: Dict[Text, Tuple[UserProfile]],
user_profile: UserProfile) -> None:
items_for_remote_cache[user_profile_by_email_cache_key(user_profile.email)] = (user_profile,)
items_for_remote_cache[user_profile_by_id_cache_key(user_profile.id)] = (user_profile,)
items_for_remote_cache[user_profile_by_api_key_cache_key(user_profile.api_key)] = (user_profile,)
items_for_remote_cache[user_profile_cache_key(user_profile.email, user_profile.realm)] = (user_profile,)
def stream_cache_items(items_for_remote_cache, stream):
# type: (Dict[Text, Tuple[Stream]], Stream) -> None
def stream_cache_items(items_for_remote_cache: Dict[Text, Tuple[Stream]],
stream: Stream) -> None:
items_for_remote_cache[get_stream_cache_key(stream.name, stream.realm_id)] = (stream,)
def client_cache_items(items_for_remote_cache, client):
# type: (Dict[Text, Tuple[Client]], Client) -> None
def client_cache_items(items_for_remote_cache: Dict[Text, Tuple[Client]],
client: Client) -> None:
items_for_remote_cache[get_client_cache_key(client.name)] = (client,)
def huddle_cache_items(items_for_remote_cache, huddle):
# type: (Dict[Text, Tuple[Huddle]], Huddle) -> None
def huddle_cache_items(items_for_remote_cache: Dict[Text, Tuple[Huddle]],
huddle: Huddle) -> None:
items_for_remote_cache[huddle_hash_cache_key(huddle.huddle_hash)] = (huddle,)
def recipient_cache_items(items_for_remote_cache, recipient):
# type: (Dict[Text, Tuple[Recipient]], Recipient) -> None
def recipient_cache_items(items_for_remote_cache: Dict[Text, Tuple[Recipient]],
recipient: Recipient) -> None:
items_for_remote_cache[get_recipient_cache_key(recipient.type, recipient.type_id)] = (recipient,)
session_engine = import_module(settings.SESSION_ENGINE)
def session_cache_items(items_for_remote_cache, session):
# type: (Dict[Text, Text], Session) -> None
def session_cache_items(items_for_remote_cache: Dict[Text, Text],
session: Session) -> None:
store = session_engine.SessionStore(session_key=session.session_key) # type: ignore # import_module
items_for_remote_cache[store.cache_key] = store.decode(session.session_data)
@ -89,8 +88,7 @@ cache_fillers = {
'session': (lambda: Session.objects.all(), session_cache_items, 3600*24*7, 10000),
} # type: Dict[str, Tuple[Callable[[], List[Any]], Callable[[Dict[Text, Any], Any], None], int, int]]
def fill_remote_cache(cache):
# type: (str) -> None
def fill_remote_cache(cache: str) -> None:
remote_cache_time_start = get_remote_cache_time()
remote_cache_requests_start = get_remote_cache_requests()
items_for_remote_cache = {} # type: Dict[Text, Any]

View File

@ -9,8 +9,7 @@ import string
from typing import Optional, Text
def random_api_key():
# type: () -> Text
def random_api_key() -> Text:
choices = string.ascii_letters + string.digits
altchars = ''.join([choices[ord(os.urandom(1)) % 62] for _ in range(2)]).encode("utf-8")
return base64.b64encode(os.urandom(24), altchars=altchars).decode("utf-8")

View File

@ -12,8 +12,7 @@ from typing import Optional
# (that link also points to code for an interactive remote debugger
# setup, which we might want if we move Tornado to run in a daemon
# rather than via screen).
def interactive_debug(sig, frame):
# type: (int, FrameType) -> None
def interactive_debug(sig: int, frame: FrameType) -> None:
"""Interrupt running process, and provide a python prompt for
interactive debugging."""
d = {'_frame': frame} # Allow access to frame object.
@ -27,7 +26,6 @@ def interactive_debug(sig, frame):
# SIGUSR1 => Just print the stack
# SIGUSR2 => Print stack + open interactive debugging shell
def interactive_debug_listen():
# type: () -> None
def interactive_debug_listen() -> None:
signal.signal(signal.SIGUSR1, lambda sig, stack: traceback.print_stack(stack))
signal.signal(signal.SIGUSR2, interactive_debug)

View File

@ -31,8 +31,7 @@ DIGEST_CUTOFF = 5
# 4. Interesting stream traffic, as determined by the longest and most
# diversely comment upon topics.
def inactive_since(user_profile, cutoff):
# type: (UserProfile, datetime.datetime) -> bool
def inactive_since(user_profile: UserProfile, cutoff: datetime.datetime) -> bool:
# Hasn't used the app in the last DIGEST_CUTOFF (5) days.
most_recent_visit = [row.last_visit for row in
UserActivity.objects.filter(
@ -45,8 +44,7 @@ def inactive_since(user_profile, cutoff):
last_visit = max(most_recent_visit)
return last_visit < cutoff
def should_process_digest(realm_str):
# type: (str) -> bool
def should_process_digest(realm_str: str) -> bool:
if realm_str in settings.SYSTEM_ONLY_REALMS:
# Don't try to send emails to system-only realms
return False
@ -54,15 +52,13 @@ def should_process_digest(realm_str):
# Changes to this should also be reflected in
# zerver/worker/queue_processors.py:DigestWorker.consume()
def queue_digest_recipient(user_profile, cutoff):
# type: (UserProfile, datetime.datetime) -> None
def queue_digest_recipient(user_profile: UserProfile, cutoff: datetime.datetime) -> None:
# Convert cutoff to epoch seconds for transit.
event = {"user_profile_id": user_profile.id,
"cutoff": cutoff.strftime('%s')}
queue_json_publish("digest_emails", event, lambda event: None, call_consume_in_tests=True)
def enqueue_emails(cutoff):
# type: (datetime.datetime) -> None
def enqueue_emails(cutoff: datetime.datetime) -> None:
# To be really conservative while we don't have user timezones or
# special-casing for companies with non-standard workweeks, only
# try to send mail on Tuesdays.
@ -82,8 +78,7 @@ def enqueue_emails(cutoff):
logger.info("%s is inactive, queuing for potential digest" % (
user_profile.email,))
def gather_hot_conversations(user_profile, stream_messages):
# type: (UserProfile, QuerySet) -> List[Dict[str, Any]]
def gather_hot_conversations(user_profile: UserProfile, stream_messages: QuerySet) -> List[Dict[str, Any]]:
# Gather stream conversations of 2 types:
# 1. long conversations
# 2. conversations where many different people participated
@ -146,8 +141,7 @@ def gather_hot_conversations(user_profile, stream_messages):
hot_conversation_render_payloads.append(teaser_data)
return hot_conversation_render_payloads
def gather_new_users(user_profile, threshold):
# type: (UserProfile, datetime.datetime) -> Tuple[int, List[Text]]
def gather_new_users(user_profile: UserProfile, threshold: datetime.datetime) -> Tuple[int, List[Text]]:
# Gather information on users in the realm who have recently
# joined.
if user_profile.realm.is_zephyr_mirror_realm:
@ -160,8 +154,8 @@ def gather_new_users(user_profile, threshold):
return len(user_names), user_names
def gather_new_streams(user_profile, threshold):
# type: (UserProfile, datetime.datetime) -> Tuple[int, Dict[str, List[Text]]]
def gather_new_streams(user_profile: UserProfile,
threshold: datetime.datetime) -> Tuple[int, Dict[str, List[Text]]]:
if user_profile.realm.is_zephyr_mirror_realm:
new_streams = [] # type: List[Stream]
else:
@ -181,8 +175,7 @@ def gather_new_streams(user_profile, threshold):
return len(new_streams), {"html": streams_html, "plain": streams_plain}
def enough_traffic(unread_pms, hot_conversations, new_streams, new_users):
# type: (Text, Text, int, int) -> bool
def enough_traffic(unread_pms: Text, hot_conversations: Text, new_streams: int, new_users: int) -> bool:
if unread_pms or hot_conversations:
# If you have any unread traffic, good enough.
return True
@ -192,8 +185,7 @@ def enough_traffic(unread_pms, hot_conversations, new_streams, new_users):
return True
return False
def handle_digest_email(user_profile_id, cutoff):
# type: (int, float) -> None
def handle_digest_email(user_profile_id: int, cutoff: float) -> None:
user_profile = get_user_profile_by_id(user_profile_id)
# We are disabling digest emails for soft deactivated users for the time.

View File

@ -4,8 +4,7 @@ from django.utils.translation import ugettext as _
import re
from typing import Text
def validate_domain(domain):
# type: (Text) -> None
def validate_domain(domain: Text) -> None:
if domain is None or len(domain) == 0:
raise ValidationError(_("Domain can't be empty."))
if '.' not in domain:

View File

@ -20,8 +20,7 @@ with open(NAME_TO_CODEPOINT_PATH) as fp:
with open(CODEPOINT_TO_NAME_PATH) as fp:
codepoint_to_name = ujson.load(fp)
def emoji_name_to_emoji_code(realm, emoji_name):
# type: (Realm, Text) -> Tuple[Text, Text]
def emoji_name_to_emoji_code(realm: Realm, emoji_name: Text) -> Tuple[Text, Text]:
realm_emojis = realm.get_emoji()
if emoji_name in realm_emojis and not realm_emojis[emoji_name]['deactivated']:
return emoji_name, Reaction.REALM_EMOJI
@ -31,8 +30,7 @@ def emoji_name_to_emoji_code(realm, emoji_name):
return name_to_codepoint[emoji_name], Reaction.UNICODE_EMOJI
raise JsonableError(_("Emoji '%s' does not exist" % (emoji_name,)))
def check_valid_emoji(realm, emoji_name):
# type: (Realm, Text) -> None
def check_valid_emoji(realm: Realm, emoji_name: Text) -> None:
emoji_name_to_emoji_code(realm, emoji_name)
def check_emoji_request(realm: Realm, emoji_name: str, emoji_code: str,
@ -61,8 +59,7 @@ def check_emoji_request(realm: Realm, emoji_name: str, emoji_code: str,
# The above are the only valid emoji types
raise JsonableError(_("Invalid emoji type."))
def check_emoji_admin(user_profile, emoji_name=None):
# type: (UserProfile, Optional[Text]) -> None
def check_emoji_admin(user_profile: UserProfile, emoji_name: Optional[Text]=None) -> None:
"""Raises an exception if the user cannot administer the target realm
emoji name in their organization."""
@ -84,18 +81,15 @@ def check_emoji_admin(user_profile, emoji_name=None):
if not user_profile.is_realm_admin and not current_user_is_author:
raise JsonableError(_("Must be a realm administrator or emoji author"))
def check_valid_emoji_name(emoji_name):
# type: (Text) -> None
def check_valid_emoji_name(emoji_name: Text) -> None:
if re.match('^[0-9a-z.\-_]+(?<![.\-_])$', emoji_name):
return
raise JsonableError(_("Invalid characters in emoji name"))
def get_emoji_url(emoji_file_name, realm_id):
# type: (Text, int) -> Text
def get_emoji_url(emoji_file_name: Text, realm_id: int) -> Text:
return upload_backend.get_emoji_url(emoji_file_name, realm_id)
def get_emoji_file_name(emoji_file_name, emoji_name):
# type: (Text, Text) -> Text
def get_emoji_file_name(emoji_file_name: Text, emoji_name: Text) -> Text:
_, image_ext = os.path.splitext(emoji_file_name)
return ''.join((emoji_name, image_ext))

View File

@ -16,15 +16,13 @@ from zerver.lib.actions import internal_send_message
from zerver.lib.response import json_success, json_error
from version import ZULIP_VERSION
def format_subject(subject):
# type: (str) -> str
def format_subject(subject: str) -> str:
"""
Escape CR and LF characters.
"""
return subject.replace('\n', '\\n').replace('\r', '\\r')
def user_info_str(report):
# type: (Dict[str, Any]) -> str
def user_info_str(report: Dict[str, Any]) -> str:
if report['user_full_name'] and report['user_email']:
user_info = "%(user_full_name)s (%(user_email)s)" % (report)
else:
@ -59,15 +57,13 @@ def deployment_repr() -> str:
return deployment
def notify_browser_error(report):
# type: (Dict[str, Any]) -> None
def notify_browser_error(report: Dict[str, Any]) -> None:
report = defaultdict(lambda: None, report)
if settings.ERROR_BOT:
zulip_browser_error(report)
email_browser_error(report)
def email_browser_error(report):
# type: (Dict[str, Any]) -> None
def email_browser_error(report: Dict[str, Any]) -> None:
subject = "Browser error for %s" % (user_info_str(report))
body = ("User: %(user_full_name)s <%(user_email)s> on %(deployment)s\n\n"
@ -89,8 +85,7 @@ def email_browser_error(report):
mail_admins(subject, body)
def zulip_browser_error(report):
# type: (Dict[str, Any]) -> None
def zulip_browser_error(report: Dict[str, Any]) -> None:
subject = "JS error: %s" % (report['user_email'],)
user_info = user_info_str(report)
@ -103,15 +98,13 @@ def zulip_browser_error(report):
internal_send_message(realm, settings.ERROR_BOT,
"stream", "errors", format_subject(subject), body)
def notify_server_error(report):
# type: (Dict[str, Any]) -> None
def notify_server_error(report: Dict[str, Any]) -> None:
report = defaultdict(lambda: None, report)
email_server_error(report)
if settings.ERROR_BOT:
zulip_server_error(report)
def zulip_server_error(report):
# type: (Dict[str, Any]) -> None
def zulip_server_error(report: Dict[str, Any]) -> None:
subject = '%(node)s: %(message)s' % (report)
stack_trace = report['stack_trace'] or "No stack trace available"
@ -133,8 +126,7 @@ def zulip_server_error(report):
"Error generated by %s\n\n~~~~ pytb\n%s\n\n~~~~\n%s\n%s"
% (user_info, stack_trace, deployment, request_repr))
def email_server_error(report):
# type: (Dict[str, Any]) -> None
def email_server_error(report: Dict[str, Any]) -> None:
subject = '%(node)s: %(message)s' % (report)
user_info = user_info_str(report)
@ -153,8 +145,7 @@ def email_server_error(report):
mail_admins(format_subject(subject), message, fail_silently=True)
def do_report_error(deployment_name, type, report):
# type: (Text, Text, Dict[str, Any]) -> HttpResponse
def do_report_error(deployment_name: Text, type: Text, report: Dict[str, Any]) -> HttpResponse:
report['deployment'] = deployment_name
if type == 'browser':
notify_browser_error(report)

View File

@ -47,12 +47,10 @@ from zproject.backends import email_auth_enabled, password_auth_enabled
from version import ZULIP_VERSION
def get_raw_user_data(realm_id, client_gravatar):
# type: (int, bool) -> Dict[int, Dict[str, Text]]
def get_raw_user_data(realm_id: int, client_gravatar: bool) -> Dict[int, Dict[str, Text]]:
user_dicts = get_realm_user_dicts(realm_id)
def user_data(row):
# type: (Dict[str, Any]) -> Dict[str, Any]
def user_data(row: Dict[str, Any]) -> Dict[str, Any]:
avatar_url = get_avatar_field(
user_id=row['id'],
realm_id= realm_id,
@ -81,8 +79,7 @@ def get_raw_user_data(realm_id, client_gravatar):
for row in user_dicts
}
def always_want(msg_type):
# type: (str) -> bool
def always_want(msg_type: str) -> bool:
'''
This function is used as a helper in
fetch_initial_state_data, when the user passes
@ -262,8 +259,8 @@ def fetch_initial_state_data(user_profile, event_types, queue_id, client_gravata
return state
def remove_message_id_from_unread_mgs(state, message_id):
# type: (Dict[str, Dict[str, Any]], int) -> None
def remove_message_id_from_unread_mgs(state: Dict[str, Dict[str, Any]],
message_id: int) -> None:
raw_unread = state['raw_unread_msgs']
for key in ['pm_dict', 'stream_dict', 'huddle_dict']:
@ -288,8 +285,11 @@ def apply_events(state, events, user_profile, client_gravatar, include_subscribe
continue
apply_event(state, event, user_profile, client_gravatar, include_subscribers)
def apply_event(state, event, user_profile, client_gravatar, include_subscribers):
# type: (Dict[str, Any], Dict[str, Any], UserProfile, bool, bool) -> None
def apply_event(state: Dict[str, Any],
event: Dict[str, Any],
user_profile: UserProfile,
client_gravatar: bool,
include_subscribers: bool) -> None:
if event['type'] == "message":
state['max_message_id'] = max(state['max_message_id'], event['message']['id'])
if 'raw_unread_msgs' in state:
@ -440,8 +440,7 @@ def apply_event(state, event, user_profile, client_gravatar, include_subscribers
event['subscriptions'][i] = copy.deepcopy(event['subscriptions'][i])
del event['subscriptions'][i]['subscribers']
def name(sub):
# type: (Dict[str, Any]) -> Text
def name(sub: Dict[str, Any]) -> Text:
return sub['name'].lower()
if event['op'] == "add":

View File

@ -6,24 +6,20 @@ from django.core.exceptions import PermissionDenied
class AbstractEnum(Enum):
'''An enumeration whose members are used strictly for their names.'''
def __new__(cls):
# type: (Type[AbstractEnum]) -> AbstractEnum
def __new__(cls: Type['AbstractEnum']) -> 'AbstractEnum':
obj = object.__new__(cls)
obj._value_ = len(cls.__members__) + 1
return obj
# Override all the `Enum` methods that use `_value_`.
def __repr__(self):
# type: () -> str
def __repr__(self) -> str:
return str(self)
def value(self):
# type: () -> None
def value(self) -> None:
assert False
def __reduce_ex__(self, proto):
# type: (int) -> None
def __reduce_ex__(self, proto: int) -> None:
assert False
class ErrorCode(AbstractEnum):
@ -69,13 +65,11 @@ class JsonableError(Exception):
code = ErrorCode.NO_SUCH_WIDGET
data_fields = ['widget_name']
def __init__(self, widget_name):
# type: (str) -> None
def __init__(self, widget_name: str) -> None:
self.widget_name = widget_name # type: str
@staticmethod
def msg_format():
# type: () -> str
def msg_format() -> str:
return _("No such widget: {widget_name}")
raise NoSuchWidgetError(widget_name)
@ -96,8 +90,7 @@ class JsonableError(Exception):
# like 403 or 404.
http_status_code = 400 # type: int
def __init__(self, msg, code=None):
# type: (Text, Optional[ErrorCode]) -> None
def __init__(self, msg: Text, code: Optional[ErrorCode]=None) -> None:
if code is not None:
self.code = code
@ -105,8 +98,7 @@ class JsonableError(Exception):
self._msg = msg # type: Text
@staticmethod
def msg_format():
# type: () -> Text
def msg_format() -> Text:
'''Override in subclasses. Gets the items in `data_fields` as format args.
This should return (a translation of) a string literal.
@ -124,29 +116,24 @@ class JsonableError(Exception):
#
@property
def msg(self):
# type: () -> Text
def msg(self) -> Text:
format_data = dict(((f, getattr(self, f)) for f in self.data_fields),
_msg=getattr(self, '_msg', None))
return self.msg_format().format(**format_data)
@property
def data(self):
# type: () -> Dict[str, Any]
def data(self) -> Dict[str, Any]:
return dict(((f, getattr(self, f)) for f in self.data_fields),
code=self.code.name)
def to_json(self):
# type: () -> Dict[str, Any]
def to_json(self) -> Dict[str, Any]:
d = {'result': 'error', 'msg': self.msg}
d.update(self.data)
return d
def __str__(self):
# type: () -> str
def __str__(self) -> str:
return self.msg
class RateLimited(PermissionDenied):
def __init__(self, msg=""):
# type: (str) -> None
def __init__(self, msg: str="") -> None:
super().__init__(msg)

View File

@ -125,8 +125,7 @@ DATE_FIELDS = {
'zerver_userprofile': ['date_joined', 'last_login', 'last_reminder'],
} # type: Dict[TableName, List[Field]]
def sanity_check_output(data):
# type: (TableData) -> None
def sanity_check_output(data: TableData) -> None:
tables = set(ALL_ZERVER_TABLES)
tables -= set(NON_EXPORTED_TABLES)
tables -= set(IMPLICIT_TABLES)
@ -137,13 +136,11 @@ def sanity_check_output(data):
if table not in data:
logging.warning('??? NO DATA EXPORTED FOR TABLE %s!!!' % (table,))
def write_data_to_file(output_file, data):
# type: (Path, Any) -> None
def write_data_to_file(output_file: Path, data: Any) -> None:
with open(output_file, "w") as f:
f.write(ujson.dumps(data, indent=4))
def make_raw(query, exclude=None):
# type: (Any, List[Field]) -> List[Record]
def make_raw(query: Any, exclude: List[Field]=None) -> List[Record]:
'''
Takes a Django query and returns a JSONable list
of dictionaries corresponding to the database rows.
@ -165,8 +162,7 @@ def make_raw(query, exclude=None):
return rows
def floatify_datetime_fields(data, table):
# type: (TableData, TableName) -> None
def floatify_datetime_fields(data: TableData, table: TableName) -> None:
for item in data[table]:
for field in DATE_FIELDS[table]:
orig_dt = item[field]
@ -261,8 +257,8 @@ class Config:
self.virtual_parent.table))
def export_from_config(response, config, seed_object=None, context=None):
# type: (TableData, Config, Any, Context) -> None
def export_from_config(response: TableData, config: Config, seed_object: Any=None,
context: Context=None) -> None:
table = config.table
parent = config.parent
model = config.model
@ -372,8 +368,7 @@ def export_from_config(response, config, seed_object=None, context=None):
context=context,
)
def get_realm_config():
# type: () -> Config
def get_realm_config() -> Config:
# This is common, public information about the realm that we can share
# with all realm users.
@ -536,8 +531,7 @@ def get_realm_config():
return realm_config
def sanity_check_stream_data(response, config, context):
# type: (TableData, Config, Context) -> None
def sanity_check_stream_data(response: TableData, config: Config, context: Context) -> None:
if context['exportable_user_ids'] is not None:
# If we restrict which user ids are exportable,
@ -559,8 +553,7 @@ def sanity_check_stream_data(response, config, context):
Please investigate!
''')
def fetch_user_profile(response, config, context):
# type: (TableData, Config, Context) -> None
def fetch_user_profile(response: TableData, config: Config, context: Context) -> None:
realm = context['realm']
exportable_user_ids = context['exportable_user_ids']
@ -589,8 +582,7 @@ def fetch_user_profile(response, config, context):
response['zerver_userprofile'] = normal_rows
response['zerver_userprofile_mirrordummy'] = dummy_rows
def fetch_user_profile_cross_realm(response, config, context):
# type: (TableData, Config, Context) -> None
def fetch_user_profile_cross_realm(response: TableData, config: Config, context: Context) -> None:
realm = context['realm']
if realm.string_id == "zulip":
@ -602,8 +594,7 @@ def fetch_user_profile_cross_realm(response, config, context):
get_system_bot(settings.WELCOME_BOT),
]]
def fetch_attachment_data(response, realm_id, message_ids):
# type: (TableData, int, Set[int]) -> None
def fetch_attachment_data(response: TableData, realm_id: int, message_ids: Set[int]) -> None:
filter_args = {'realm_id': realm_id}
query = Attachment.objects.filter(**filter_args)
response['zerver_attachment'] = make_raw(list(query))
@ -630,8 +621,7 @@ def fetch_attachment_data(response, realm_id, message_ids):
row for row in response['zerver_attachment']
if row['messages']]
def fetch_huddle_objects(response, config, context):
# type: (TableData, Config, Context) -> None
def fetch_huddle_objects(response: TableData, config: Config, context: Context) -> None:
realm = context['realm']
assert config.parent is not None
@ -667,8 +657,10 @@ def fetch_huddle_objects(response, config, context):
response['_huddle_subscription'] = huddle_subscription_dicts
response['zerver_huddle'] = make_raw(Huddle.objects.filter(id__in=huddle_ids))
def fetch_usermessages(realm, message_ids, user_profile_ids, message_filename):
# type: (Realm, Set[int], Set[int], Path) -> List[Record]
def fetch_usermessages(realm: Realm,
message_ids: Set[int],
user_profile_ids: Set[int],
message_filename: Path) -> List[Record]:
# UserMessage export security rule: You can export UserMessages
# for the messages you exported for the users in your realm.
user_message_query = UserMessage.objects.filter(user_profile__realm=realm,
@ -684,8 +676,7 @@ def fetch_usermessages(realm, message_ids, user_profile_ids, message_filename):
logging.info("Fetched UserMessages for %s" % (message_filename,))
return user_message_chunk
def export_usermessages_batch(input_path, output_path):
# type: (Path, Path) -> None
def export_usermessages_batch(input_path: Path, output_path: Path) -> None:
"""As part of the system for doing parallel exports, this runs on one
batch of Message objects and adds the corresponding UserMessage
objects. (This is called by the export_usermessage_batch
@ -701,18 +692,18 @@ def export_usermessages_batch(input_path, output_path):
write_message_export(output_path, output)
os.unlink(input_path)
def write_message_export(message_filename, output):
# type: (Path, MessageOutput) -> None
def write_message_export(message_filename: Path, output: MessageOutput) -> None:
write_data_to_file(output_file=message_filename, data=output)
logging.info("Dumped to %s" % (message_filename,))
def export_partial_message_files(realm, response, chunk_size=1000, output_dir=None):
# type: (Realm, TableData, int, Path) -> Set[int]
def export_partial_message_files(realm: Realm,
response: TableData,
chunk_size: int=1000,
output_dir: Path=None) -> Set[int]:
if output_dir is None:
output_dir = tempfile.mkdtemp(prefix="zulip-export")
def get_ids(records):
# type: (List[Record]) -> Set[int]
def get_ids(records: List[Record]) -> Set[int]:
return set(x['id'] for x in records)
# Basic security rule: You can export everything either...
@ -824,8 +815,7 @@ def write_message_partial_for_query(realm, message_query, dump_file_id,
return dump_file_id
def export_uploads_and_avatars(realm, output_dir):
# type: (Realm, Path) -> None
def export_uploads_and_avatars(realm: Realm, output_dir: Path) -> None:
uploads_output_dir = os.path.join(output_dir, 'uploads')
avatars_output_dir = os.path.join(output_dir, 'avatars')
@ -851,8 +841,8 @@ def export_uploads_and_avatars(realm, output_dir):
settings.S3_AUTH_UPLOADS_BUCKET,
output_dir=uploads_output_dir)
def export_files_from_s3(realm, bucket_name, output_dir, processing_avatars=False):
# type: (Realm, str, Path, bool) -> None
def export_files_from_s3(realm: Realm, bucket_name: str, output_dir: Path,
processing_avatars: bool=False) -> None:
conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY)
bucket = conn.get_bucket(bucket_name, validate=True)
records = []
@ -932,8 +922,7 @@ def export_files_from_s3(realm, bucket_name, output_dir, processing_avatars=Fals
with open(os.path.join(output_dir, "records.json"), "w") as records_file:
ujson.dump(records, records_file, indent=4)
def export_uploads_from_local(realm, local_dir, output_dir):
# type: (Realm, Path, Path) -> None
def export_uploads_from_local(realm: Realm, local_dir: Path, output_dir: Path) -> None:
count = 0
records = []
@ -960,8 +949,7 @@ def export_uploads_from_local(realm, local_dir, output_dir):
with open(os.path.join(output_dir, "records.json"), "w") as records_file:
ujson.dump(records, records_file, indent=4)
def export_avatars_from_local(realm, local_dir, output_dir):
# type: (Realm, Path, Path) -> None
def export_avatars_from_local(realm: Realm, local_dir: Path, output_dir: Path) -> None:
count = 0
records = []
@ -1005,8 +993,7 @@ def export_avatars_from_local(realm, local_dir, output_dir):
with open(os.path.join(output_dir, "records.json"), "w") as records_file:
ujson.dump(records, records_file, indent=4)
def do_write_stats_file_for_realm_export(output_dir):
# type: (Path) -> None
def do_write_stats_file_for_realm_export(output_dir: Path) -> None:
stats_file = os.path.join(output_dir, 'stats.txt')
realm_file = os.path.join(output_dir, 'realm.json')
attachment_file = os.path.join(output_dir, 'attachment.json')
@ -1033,8 +1020,8 @@ def do_write_stats_file_for_realm_export(output_dir):
f.write('%5d records\n' % len(data))
f.write('\n')
def do_export_realm(realm, output_dir, threads, exportable_user_ids=None):
# type: (Realm, Path, int, Set[int]) -> None
def do_export_realm(realm: Realm, output_dir: Path, threads: int,
exportable_user_ids: Set[int]=None) -> None:
response = {} # type: TableData
# We need at least one thread running to export
@ -1084,16 +1071,14 @@ def do_export_realm(realm, output_dir, threads, exportable_user_ids=None):
logging.info("Finished exporting %s" % (realm.string_id))
create_soft_link(source=output_dir, in_progress=False)
def export_attachment_table(realm, output_dir, message_ids):
# type: (Realm, Path, Set[int]) -> None
def export_attachment_table(realm: Realm, output_dir: Path, message_ids: Set[int]) -> None:
response = {} # type: TableData
fetch_attachment_data(response=response, realm_id=realm.id, message_ids=message_ids)
output_file = os.path.join(output_dir, "attachment.json")
logging.info('Writing attachment table data to %s' % (output_file,))
write_data_to_file(output_file=output_file, data=response)
def create_soft_link(source, in_progress=True):
# type: (Path, bool) -> None
def create_soft_link(source: Path, in_progress: bool=True) -> None:
is_done = not in_progress
in_progress_link = '/tmp/zulip-export-in-progress'
done_link = '/tmp/zulip-export-most-recent'
@ -1109,12 +1094,10 @@ def create_soft_link(source, in_progress=True):
logging.info('See %s for output files' % (new_target,))
def launch_user_message_subprocesses(threads, output_dir):
# type: (int, Path) -> None
def launch_user_message_subprocesses(threads: int, output_dir: Path) -> None:
logging.info('Launching %d PARALLEL subprocesses to export UserMessage rows' % (threads,))
def run_job(shard):
# type: (str) -> int
def run_job(shard: str) -> int:
subprocess.call(["./manage.py", 'export_usermessage_batch', '--path',
str(output_dir), '--thread', shard])
return 0
@ -1124,8 +1107,7 @@ def launch_user_message_subprocesses(threads, output_dir):
threads=threads):
print("Shard %s finished, status %s" % (job, status))
def do_export_user(user_profile, output_dir):
# type: (UserProfile, Path) -> None
def do_export_user(user_profile: UserProfile, output_dir: Path) -> None:
response = {} # type: TableData
export_single_user(user_profile, response)
@ -1134,8 +1116,7 @@ def do_export_user(user_profile, output_dir):
logging.info("Exporting messages")
export_messages_single_user(user_profile, output_dir)
def export_single_user(user_profile, response):
# type: (UserProfile, TableData) -> None
def export_single_user(user_profile: UserProfile, response: TableData) -> None:
config = get_single_user_config()
export_from_config(
@ -1144,8 +1125,7 @@ def export_single_user(user_profile, response):
seed_object=user_profile,
)
def get_single_user_config():
# type: () -> Config
def get_single_user_config() -> Config:
# zerver_userprofile
user_profile_config = Config(
@ -1182,8 +1162,7 @@ def get_single_user_config():
return user_profile_config
def export_messages_single_user(user_profile, output_dir, chunk_size=1000):
# type: (UserProfile, Path, int) -> None
def export_messages_single_user(user_profile: UserProfile, output_dir: Path, chunk_size: int=1000) -> None:
user_message_query = UserMessage.objects.filter(user_profile=user_profile).order_by("id")
min_id = -1
dump_file_id = 1
@ -1232,8 +1211,7 @@ id_maps = {
'user_profile': {},
} # type: Dict[str, Dict[int, int]]
def update_id_map(table, old_id, new_id):
# type: (TableName, int, int) -> None
def update_id_map(table: TableName, old_id: int, new_id: int) -> None:
if table not in id_maps:
raise Exception('''
Table %s is not initialized in id_maps, which could
@ -1242,15 +1220,13 @@ def update_id_map(table, old_id, new_id):
''' % (table,))
id_maps[table][old_id] = new_id
def fix_datetime_fields(data, table):
# type: (TableData, TableName) -> None
def fix_datetime_fields(data: TableData, table: TableName) -> None:
for item in data[table]:
for field_name in DATE_FIELDS[table]:
if item[field_name] is not None:
item[field_name] = datetime.datetime.fromtimestamp(item[field_name], tz=timezone_utc)
def convert_to_id_fields(data, table, field_name):
# type: (TableData, TableName, Field) -> None
def convert_to_id_fields(data: TableData, table: TableName, field_name: Field) -> None:
'''
When Django gives us dict objects via model_to_dict, the foreign
key fields are `foo`, but we want `foo_id` for the bulk insert.
@ -1262,8 +1238,11 @@ def convert_to_id_fields(data, table, field_name):
item[field_name + "_id"] = item[field_name]
del item[field_name]
def re_map_foreign_keys(data, table, field_name, related_table, verbose=False):
# type: (TableData, TableName, Field, TableName, bool) -> None
def re_map_foreign_keys(data: TableData,
table: TableName,
field_name: Field,
related_table: TableName,
verbose: bool=False) -> None:
'''
We occasionally need to assign new ids to rows during the
import/export process, to accommodate things like existing rows
@ -1288,14 +1267,12 @@ def re_map_foreign_keys(data, table, field_name, related_table, verbose=False):
item[field_name + "_id"] = new_id
del item[field_name]
def fix_bitfield_keys(data, table, field_name):
# type: (TableData, TableName, Field) -> None
def fix_bitfield_keys(data: TableData, table: TableName, field_name: Field) -> None:
for item in data[table]:
item[field_name] = item[field_name + '_mask']
del item[field_name + '_mask']
def fix_realm_authentication_bitfield(data, table, field_name):
# type: (TableData, TableName, Field) -> None
def fix_realm_authentication_bitfield(data: TableData, table: TableName, field_name: Field) -> None:
"""Used to fixup the authentication_methods bitfield to be a string"""
for item in data[table]:
values_as_bitstring = ''.join(['1' if field[1] else '0' for field in
@ -1303,8 +1280,7 @@ def fix_realm_authentication_bitfield(data, table, field_name):
values_as_int = int(values_as_bitstring, 2)
item[field_name] = values_as_int
def bulk_import_model(data, model, table, dump_file_id=None):
# type: (TableData, Any, TableName, str) -> None
def bulk_import_model(data: TableData, model: Any, table: TableName, dump_file_id: str=None) -> None:
# TODO, deprecate dump_file_id
model.objects.bulk_create(model(**item) for item in data[table])
if dump_file_id is None:
@ -1316,8 +1292,7 @@ def bulk_import_model(data, model, table, dump_file_id=None):
# correctly import multiple realms into the same server, we need to
# check if a Client object already exists, and so we need to support
# remap all Client IDs to the values in the new DB.
def bulk_import_client(data, model, table):
# type: (TableData, Any, TableName) -> None
def bulk_import_client(data: TableData, model: Any, table: TableName) -> None:
for item in data[table]:
try:
client = Client.objects.get(name=item['name'])
@ -1325,8 +1300,7 @@ def bulk_import_client(data, model, table):
client = Client.objects.create(name=item['name'])
update_id_map(table='client', old_id=item['id'], new_id=client.id)
def import_uploads_local(import_dir, processing_avatars=False):
# type: (Path, bool) -> None
def import_uploads_local(import_dir: Path, processing_avatars: bool=False) -> None:
records_filename = os.path.join(import_dir, "records.json")
with open(records_filename) as records_file:
records = ujson.loads(records_file.read())
@ -1349,8 +1323,7 @@ def import_uploads_local(import_dir, processing_avatars=False):
subprocess.check_call(["mkdir", "-p", os.path.dirname(file_path)])
shutil.copy(orig_file_path, file_path)
def import_uploads_s3(bucket_name, import_dir, processing_avatars=False):
# type: (str, Path, bool) -> None
def import_uploads_s3(bucket_name: str, import_dir: Path, processing_avatars: bool=False) -> None:
conn = S3Connection(settings.S3_KEY, settings.S3_SECRET_KEY)
bucket = conn.get_bucket(bucket_name, validate=True)
@ -1385,8 +1358,7 @@ def import_uploads_s3(bucket_name, import_dir, processing_avatars=False):
key.set_contents_from_filename(os.path.join(import_dir, record['path']), headers=headers)
def import_uploads(import_dir, processing_avatars=False):
# type: (Path, bool) -> None
def import_uploads(import_dir: Path, processing_avatars: bool=False) -> None:
if processing_avatars:
logging.info("Importing avatars")
else:
@ -1418,8 +1390,7 @@ def import_uploads(import_dir, processing_avatars=False):
# Because the Python object => JSON conversion process is not fully
# faithful, we have to use a set of fixers (e.g. on DateTime objects
# and Foreign Keys) to do the import correctly.
def do_import_realm(import_dir):
# type: (Path) -> None
def do_import_realm(import_dir: Path) -> None:
logging.info("Importing realm dump %s" % (import_dir,))
if not os.path.exists(import_dir):
raise Exception("Missing import directory!")
@ -1527,8 +1498,7 @@ def do_import_realm(import_dir):
import_attachments(data)
def import_message_data(import_dir):
# type: (Path) -> None
def import_message_data(import_dir: Path) -> None:
dump_file_id = 1
while True:
message_filename = os.path.join(import_dir, "messages-%06d.json" % (dump_file_id,))
@ -1555,8 +1525,7 @@ def import_message_data(import_dir):
dump_file_id += 1
def import_attachments(data):
# type: (TableData) -> None
def import_attachments(data: TableData) -> None:
# Clean up the data in zerver_attachment that is not
# relevant to our many-to-many import.

View File

@ -12,8 +12,7 @@ from typing import Any, List, Dict, Optional, Text
import os
import ujson
def with_language(string, language):
# type: (Text, Text) -> Text
def with_language(string: Text, language: Text) -> Text:
"""
This is an expensive function. If you are using it in a loop, it will
make your code slow.
@ -25,15 +24,13 @@ def with_language(string, language):
return result
@lru_cache()
def get_language_list():
# type: () -> List[Dict[str, Any]]
def get_language_list() -> List[Dict[str, Any]]:
path = os.path.join(settings.STATIC_ROOT, 'locale', 'language_name_map.json')
with open(path, 'r') as reader:
languages = ujson.load(reader)
return languages['name_map']
def get_language_list_for_templates(default_language):
# type: (Text) -> List[Dict[str, Dict[str, str]]]
def get_language_list_for_templates(default_language: Text) -> List[Dict[str, Dict[str, str]]]:
language_list = [l for l in get_language_list()
if 'percent_translated' not in l or
l['percent_translated'] >= 5.]
@ -70,15 +67,13 @@ def get_language_list_for_templates(default_language):
return formatted_list
def get_language_name(code):
# type: (str) -> Optional[Text]
def get_language_name(code: str) -> Optional[Text]:
for lang in get_language_list():
if code in (lang['code'], lang['locale']):
return lang['name']
return None
def get_available_language_codes():
# type: () -> List[Text]
def get_available_language_codes() -> List[Text]:
language_list = get_language_list()
codes = [language['code'] for language in language_list]
return codes

View File

@ -17,8 +17,7 @@ from logging import Logger
class _RateLimitFilter:
last_error = datetime.min.replace(tzinfo=timezone_utc)
def filter(self, record):
# type: (logging.LogRecord) -> bool
def filter(self, record: logging.LogRecord) -> bool:
from django.conf import settings
from django.core.cache import cache
@ -58,23 +57,19 @@ class EmailLimiter(_RateLimitFilter):
pass
class ReturnTrue(logging.Filter):
def filter(self, record):
# type: (logging.LogRecord) -> bool
def filter(self, record: logging.LogRecord) -> bool:
return True
class ReturnEnabled(logging.Filter):
def filter(self, record):
# type: (logging.LogRecord) -> bool
def filter(self, record: logging.LogRecord) -> bool:
return settings.LOGGING_NOT_DISABLED
class RequireReallyDeployed(logging.Filter):
def filter(self, record):
# type: (logging.LogRecord) -> bool
def filter(self, record: logging.LogRecord) -> bool:
from django.conf import settings
return settings.PRODUCTION
def skip_200_and_304(record):
# type: (logging.LogRecord) -> bool
def skip_200_and_304(record: logging.LogRecord) -> bool:
# Apparently, `status_code` is added by Django and is not an actual
# attribute of LogRecord; as a result, mypy throws an error if we
# access the `status_code` attribute directly.
@ -91,8 +86,7 @@ IGNORABLE_404_URLS = [
re.compile(r'^/wp-login.php$'),
]
def skip_boring_404s(record):
# type: (logging.LogRecord) -> bool
def skip_boring_404s(record: logging.LogRecord) -> bool:
"""Prevents Django's 'Not Found' warnings from being logged for common
404 errors that don't reflect a problem in Zulip. The overall
result is to keep the Zulip error logs cleaner than they would
@ -116,8 +110,7 @@ def skip_boring_404s(record):
return False
return True
def skip_site_packages_logs(record):
# type: (logging.LogRecord) -> bool
def skip_site_packages_logs(record: logging.LogRecord) -> bool:
# This skips the log records that are generated from libraries
# installed in site packages.
# Workaround for https://code.djangoproject.com/ticket/26886
@ -125,8 +118,7 @@ def skip_site_packages_logs(record):
return False
return True
def find_log_caller_module(record):
# type: (logging.LogRecord) -> Optional[str]
def find_log_caller_module(record: logging.LogRecord) -> Optional[str]:
'''Find the module name corresponding to where this record was logged.'''
# Repeat a search similar to that in logging.Logger.findCaller.
# The logging call should still be on the stack somewhere; search until
@ -144,8 +136,7 @@ logger_nicknames = {
'zulip.requests': 'zr', # Super common.
}
def find_log_origin(record):
# type: (logging.LogRecord) -> str
def find_log_origin(record: logging.LogRecord) -> str:
logger_name = logger_nicknames.get(record.name, record.name)
if settings.LOGGING_SHOW_MODULE:
@ -166,8 +157,7 @@ log_level_abbrevs = {
'CRITICAL': 'CRIT',
}
def abbrev_log_levelname(levelname):
# type: (str) -> str
def abbrev_log_levelname(levelname: str) -> str:
# It's unlikely someone will set a custom log level with a custom name,
# but it's an option, so we shouldn't crash if someone does.
return log_level_abbrevs.get(levelname, levelname[:4])
@ -176,20 +166,17 @@ class ZulipFormatter(logging.Formatter):
# Used in the base implementation. Default uses `,`.
default_msec_format = '%s.%03d'
def __init__(self):
# type: () -> None
def __init__(self) -> None:
super().__init__(fmt=self._compute_fmt())
def _compute_fmt(self):
# type: () -> str
def _compute_fmt(self) -> str:
pieces = ['%(asctime)s', '%(zulip_level_abbrev)-4s']
if settings.LOGGING_SHOW_PID:
pieces.append('pid:%(process)d')
pieces.extend(['[%(zulip_origin)s]', '%(message)s'])
return ' '.join(pieces)
def format(self, record):
# type: (logging.LogRecord) -> str
def format(self, record: logging.LogRecord) -> str:
if not getattr(record, 'zulip_decorated', False):
# The `setattr` calls put this logic explicitly outside the bounds of the
# type system; otherwise mypy would complain LogRecord lacks these attributes.
@ -198,8 +185,10 @@ class ZulipFormatter(logging.Formatter):
setattr(record, 'zulip_decorated', True)
return super().format(record)
def create_logger(name, log_file, log_level, log_format="%(asctime)s %(levelname)-8s %(message)s"):
# type: (str, str, str, str) -> Logger
def create_logger(name: str,
log_file: str,
log_level: str,
log_format: str="%(asctime)s%(levelname)-8s%(message)s") -> Logger:
"""Creates a named logger for use in logging content to a certain
file. A few notes:

View File

@ -10,12 +10,10 @@ user_group_mentions = r'(?<![^\s\'\"\(,:<])@(\*[^\*]+\*)'
wildcards = ['all', 'everyone']
def user_mention_matches_wildcard(mention):
# type: (Text) -> bool
def user_mention_matches_wildcard(mention: Text) -> bool:
return mention in wildcards
def extract_name(s):
# type: (Text) -> Optional[Text]
def extract_name(s: Text) -> Optional[Text]:
if s.startswith("**") and s.endswith("**"):
name = s[2:-2]
if name in wildcards:
@ -25,18 +23,15 @@ def extract_name(s):
# We don't care about @all or @everyone
return None
def possible_mentions(content):
# type: (Text) -> Set[Text]
def possible_mentions(content: Text) -> Set[Text]:
matches = re.findall(find_mentions, content)
names_with_none = (extract_name(match) for match in matches)
names = {name for name in names_with_none if name}
return names
def extract_user_group(matched_text):
# type: (Text) -> Text
def extract_user_group(matched_text: Text) -> Text:
return matched_text[1:-1]
def possible_user_group_mentions(content):
# type: (Text) -> Set[Text]
def possible_user_group_mentions(content: Text) -> Set[Text]:
matches = re.findall(user_group_mentions, content)
return {extract_user_group(match) for match in matches}

View File

@ -97,9 +97,8 @@ def messages_for_ids(message_ids: List[int],
return message_list
def sew_messages_and_reactions(messages, reactions):
# type: (List[Dict[str, Any]], List[Dict[str, Any]]) -> List[Dict[str, Any]]
def sew_messages_and_reactions(messages: List[Dict[str, Any]],
reactions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Given a iterable of messages and reactions stitch reactions
into messages.
"""
@ -117,23 +116,19 @@ def sew_messages_and_reactions(messages, reactions):
return list(converted_messages.values())
def extract_message_dict(message_bytes):
# type: (bytes) -> Dict[str, Any]
def extract_message_dict(message_bytes: bytes) -> Dict[str, Any]:
return ujson.loads(zlib.decompress(message_bytes).decode("utf-8"))
def stringify_message_dict(message_dict):
# type: (Dict[str, Any]) -> bytes
def stringify_message_dict(message_dict: Dict[str, Any]) -> bytes:
return zlib.compress(ujson.dumps(message_dict).encode())
@cache_with_key(to_dict_cache_key, timeout=3600*24)
def message_to_dict_json(message):
# type: (Message) -> bytes
def message_to_dict_json(message: Message) -> bytes:
return MessageDict.to_dict_uncached(message)
class MessageDict:
@staticmethod
def wide_dict(message):
# type: (Message) -> Dict[str, Any]
def wide_dict(message: Message) -> Dict[str, Any]:
'''
The next two lines get the cachable field related
to our message object, with the side effect of
@ -154,8 +149,7 @@ class MessageDict:
return obj
@staticmethod
def post_process_dicts(objs, apply_markdown, client_gravatar):
# type: (List[Dict[str, Any]], bool, bool) -> None
def post_process_dicts(objs: List[Dict[str, Any]], apply_markdown: bool, client_gravatar: bool) -> None:
MessageDict.bulk_hydrate_sender_info(objs)
for obj in objs:
@ -163,10 +157,10 @@ class MessageDict:
MessageDict.finalize_payload(obj, apply_markdown, client_gravatar)
@staticmethod
def finalize_payload(obj, apply_markdown, client_gravatar):
# type: (Dict[str, Any], bool, bool) -> None
def finalize_payload(obj: Dict[str, Any],
apply_markdown: bool,
client_gravatar: bool) -> None:
MessageDict.set_sender_avatar(obj, client_gravatar)
if apply_markdown:
obj['content_type'] = 'text/html'
obj['content'] = obj['rendered_content']
@ -184,14 +178,12 @@ class MessageDict:
del obj['sender_is_mirror_dummy']
@staticmethod
def to_dict_uncached(message):
# type: (Message) -> bytes
def to_dict_uncached(message: Message) -> bytes:
dct = MessageDict.to_dict_uncached_helper(message)
return stringify_message_dict(dct)
@staticmethod
def to_dict_uncached_helper(message):
# type: (Message) -> Dict[str, Any]
def to_dict_uncached_helper(message: Message) -> Dict[str, Any]:
return MessageDict.build_message_dict(
message = message,
message_id = message.id,
@ -212,8 +204,7 @@ class MessageDict:
)
@staticmethod
def get_raw_db_rows(needed_ids):
# type: (List[int]) -> List[Dict[str, Any]]
def get_raw_db_rows(needed_ids: List[int]) -> List[Dict[str, Any]]:
# This is a special purpose function optimized for
# callers like get_messages_backend().
fields = [
@ -242,8 +233,7 @@ class MessageDict:
return sew_messages_and_reactions(messages, reactions)
@staticmethod
def build_dict_from_raw_db_row(row):
# type: (Dict[str, Any]) -> Dict[str, Any]
def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]:
'''
row is a row from a .values() call, and it needs to have
all the relevant fields populated
@ -352,8 +342,7 @@ class MessageDict:
return obj
@staticmethod
def bulk_hydrate_sender_info(objs):
# type: (List[Dict[str, Any]]) -> None
def bulk_hydrate_sender_info(objs: List[Dict[str, Any]]) -> None:
sender_ids = list({
obj['sender_id']
@ -393,8 +382,7 @@ class MessageDict:
obj['sender_is_mirror_dummy'] = user_row['is_mirror_dummy']
@staticmethod
def hydrate_recipient_info(obj):
# type: (Dict[str, Any]) -> None
def hydrate_recipient_info(obj: Dict[str, Any]) -> None:
'''
This method hyrdrates recipient info with things
like full names and emails of senders. Eventually
@ -437,8 +425,7 @@ class MessageDict:
obj['stream_id'] = recipient_type_id
@staticmethod
def set_sender_avatar(obj, client_gravatar):
# type: (Dict[str, Any], bool) -> None
def set_sender_avatar(obj: Dict[str, Any], client_gravatar: bool) -> None:
sender_id = obj['sender_id']
sender_realm_id = obj['sender_realm_id']
sender_email = obj['sender_email']
@ -457,8 +444,7 @@ class MessageDict:
class ReactionDict:
@staticmethod
def build_dict_from_raw_db_row(row):
# type: (Dict[str, Any]) -> Dict[str, Any]
def build_dict_from_raw_db_row(row: Dict[str, Any]) -> Dict[str, Any]:
return {'emoji_name': row['emoji_name'],
'emoji_code': row['emoji_code'],
'reaction_type': row['reaction_type'],
@ -467,8 +453,7 @@ class ReactionDict:
'full_name': row['user_profile__full_name']}}
def access_message(user_profile, message_id):
# type: (UserProfile, int) -> Tuple[Message, UserMessage]
def access_message(user_profile: UserProfile, message_id: int) -> Tuple[Message, UserMessage]:
"""You can access a message by ID in our APIs that either:
(1) You received or have previously accessed via starring
(aka have a UserMessage row for).
@ -506,9 +491,13 @@ def access_message(user_profile, message_id):
# stream in your realm, so return the message, user_message pair
return (message, user_message)
def render_markdown(message, content, realm=None, realm_alert_words=None, user_ids=None,
mention_data=None, email_gateway=False):
# type: (Message, Text, Optional[Realm], Optional[RealmAlertWords], Optional[Set[int]], Optional[bugdown.MentionData], Optional[bool]) -> Text
def render_markdown(message: Message,
content: Text,
realm: Optional[Realm]=None,
realm_alert_words: Optional[RealmAlertWords]=None,
user_ids: Optional[Set[int]]=None,
mention_data: Optional[bugdown.MentionData]=None,
email_gateway: Optional[bool]=False) -> Text:
"""Return HTML for given markdown. Bugdown may add properties to the
message object such as `mentions_user_ids`, `mentions_user_group_ids`, and
`mentions_wildcard`. These are only on this Django object and are not
@ -565,8 +554,7 @@ def render_markdown(message, content, realm=None, realm_alert_words=None, user_i
return rendered_content
def huddle_users(recipient_id):
# type: (int) -> str
def huddle_users(recipient_id: int) -> str:
display_recipient = get_display_recipient_by_id(recipient_id,
Recipient.HUDDLE,
None) # type: Union[Text, List[Dict[str, Any]]]
@ -578,8 +566,9 @@ def huddle_users(recipient_id):
user_ids = sorted(user_ids)
return ','.join(str(uid) for uid in user_ids)
def aggregate_message_dict(input_dict, lookup_fields, collect_senders):
# type: (Dict[int, Dict[str, Any]], List[str], bool) -> List[Dict[str, Any]]
def aggregate_message_dict(input_dict: Dict[int, Dict[str, Any]],
lookup_fields: List[str],
collect_senders: bool) -> List[Dict[str, Any]]:
lookup_dict = dict() # type: Dict[Tuple[Any, ...], Dict[str, Any]]
'''
@ -639,8 +628,7 @@ def aggregate_message_dict(input_dict, lookup_fields, collect_senders):
return [lookup_dict[k] for k in sorted_keys]
def get_inactive_recipient_ids(user_profile):
# type: (UserProfile) -> List[int]
def get_inactive_recipient_ids(user_profile: UserProfile) -> List[int]:
rows = get_stream_subscriptions_for_user(user_profile).filter(
active=False,
).values(
@ -651,8 +639,7 @@ def get_inactive_recipient_ids(user_profile):
for row in rows]
return inactive_recipient_ids
def get_muted_stream_ids(user_profile):
# type: (UserProfile) -> List[int]
def get_muted_stream_ids(user_profile: UserProfile) -> List[int]:
rows = get_stream_subscriptions_for_user(user_profile).filter(
active=True,
in_home_view=False,
@ -664,8 +651,7 @@ def get_muted_stream_ids(user_profile):
for row in rows]
return muted_stream_ids
def get_raw_unread_data(user_profile):
# type: (UserProfile) -> RawUnreadMessagesResult
def get_raw_unread_data(user_profile: UserProfile) -> RawUnreadMessagesResult:
excluded_recipient_ids = get_inactive_recipient_ids(user_profile)
@ -694,8 +680,7 @@ def get_raw_unread_data(user_profile):
topic_mute_checker = build_topic_mute_checker(user_profile)
def is_row_muted(stream_id, recipient_id, topic):
# type: (int, int, Text) -> bool
def is_row_muted(stream_id: int, recipient_id: int, topic: Text) -> bool:
if stream_id in muted_stream_ids:
return True
@ -706,8 +691,7 @@ def get_raw_unread_data(user_profile):
huddle_cache = {} # type: Dict[int, str]
def get_huddle_users(recipient_id):
# type: (int) -> str
def get_huddle_users(recipient_id: int) -> str:
if recipient_id in huddle_cache:
return huddle_cache[recipient_id]
@ -762,8 +746,7 @@ def get_raw_unread_data(user_profile):
mentions=mentions,
)
def aggregate_unread_data(raw_data):
# type: (RawUnreadMessagesResult) -> UnreadMessagesResult
def aggregate_unread_data(raw_data: RawUnreadMessagesResult) -> UnreadMessagesResult:
pm_dict = raw_data['pm_dict']
stream_dict = raw_data['stream_dict']
@ -807,8 +790,10 @@ def aggregate_unread_data(raw_data):
return result
def apply_unread_message_event(user_profile, state, message, flags):
# type: (UserProfile, Dict[str, Any], Dict[str, Any], List[str]) -> None
def apply_unread_message_event(user_profile: UserProfile,
state: Dict[str, Any],
message: Dict[str, Any],
flags: List[str]) -> None:
message_id = message['id']
if message['type'] == 'stream':
message_type = 'stream'

View File

@ -3,8 +3,7 @@ from django.db.models.query import QuerySet
import re
import time
def timed_ddl(db, stmt):
# type: (Any, str) -> None
def timed_ddl(db: Any, stmt: str) -> None:
print()
print(time.asctime())
print(stmt)
@ -13,14 +12,17 @@ def timed_ddl(db, stmt):
delay = time.time() - t
print('Took %.2fs' % (delay,))
def validate(sql_thingy):
# type: (str) -> None
def validate(sql_thingy: str) -> None:
# Do basic validation that table/col name is safe.
if not re.match('^[a-z][a-z\d_]+$', sql_thingy):
raise Exception('Invalid SQL object: %s' % (sql_thingy,))
def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1):
# type: (Any, str, List[str], List[str], int, float) -> None
def do_batch_update(db: Any,
table: str,
cols: List[str],
vals: List[str],
batch_size: int=10000,
sleep: float=0.1) -> None:
validate(table)
for col in cols:
validate(col)
@ -46,8 +48,7 @@ def do_batch_update(db, table, cols, vals, batch_size=10000, sleep=0.1):
min_id = upper
time.sleep(sleep)
def add_bool_columns(db, table, cols):
# type: (Any, str, List[str]) -> None
def add_bool_columns(db: Any, table: str, cols: List[str]) -> None:
validate(table)
for col in cols:
validate(col)
@ -72,8 +73,8 @@ def add_bool_columns(db, table, cols):
', '.join(['ALTER %s SET NOT NULL' % (col,) for col in cols]))
timed_ddl(db, stmt)
def create_index_if_not_exist(index_name, table_name, column_string, where_clause):
# type: (Text, Text, Text, Text) -> Text
def create_index_if_not_exist(index_name: Text, table_name: Text, column_string: Text,
where_clause: Text) -> Text:
#
# FUTURE TODO: When we no longer need to support postgres 9.3 for Trusty,
# we can use "IF NOT EXISTS", which is part of postgres 9.5
@ -95,8 +96,11 @@ def create_index_if_not_exist(index_name, table_name, column_string, where_claus
''' % (index_name, index_name, table_name, column_string, where_clause)
return stmt
def act_on_message_ranges(db, orm, tasks, batch_size=5000, sleep=0.5):
# type: (Any, Dict[str, Any], List[Tuple[Callable[[QuerySet], QuerySet], Callable[[QuerySet], None]]], int , float) -> None
def act_on_message_ranges(db: Any,
orm: Dict[str, Any],
tasks: List[Tuple[Callable[[QuerySet], QuerySet], Callable[[QuerySet], None]]],
batch_size: int=5000,
sleep: float=0.5) -> None:
# tasks should be an array of (filterer, action) tuples
# where filterer is a function that returns a filtered QuerySet
# and action is a function that acts on a QuerySet

View File

@ -4,21 +4,18 @@ from django.utils.translation import ugettext as _
from typing import Any, Callable, Iterable, Mapping, Sequence, Text
def check_supported_events_narrow_filter(narrow):
# type: (Iterable[Sequence[Text]]) -> None
def check_supported_events_narrow_filter(narrow: Iterable[Sequence[Text]]) -> None:
for element in narrow:
operator = element[0]
if operator not in ["stream", "topic", "sender", "is"]:
raise JsonableError(_("Operator %s not supported.") % (operator,))
def build_narrow_filter(narrow):
# type: (Iterable[Sequence[Text]]) -> Callable[[Mapping[str, Any]], bool]
def build_narrow_filter(narrow: Iterable[Sequence[Text]]) -> Callable[[Mapping[str, Any]], bool]:
"""Changes to this function should come with corresponding changes to
BuildNarrowFilterTest."""
check_supported_events_narrow_filter(narrow)
def narrow_filter(event):
# type: (Mapping[str, Any]) -> bool
def narrow_filter(event: Mapping[str, Any]) -> bool:
message = event["message"]
flags = event["flags"]
for element in narrow:

View File

@ -21,8 +21,7 @@ from zerver.decorator import JsonableError
class OutgoingWebhookServiceInterface:
def __init__(self, base_url, token, user_profile, service_name):
# type: (Text, Text, UserProfile, Text) -> None
def __init__(self, base_url: Text, token: Text, user_profile: UserProfile, service_name: Text) -> None:
self.base_url = base_url # type: Text
self.token = token # type: Text
self.user_profile = user_profile # type: Text
@ -37,20 +36,17 @@ class OutgoingWebhookServiceInterface:
# - base_url
# - relative_url_path
# - request_kwargs
def process_event(self, event):
# type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]
def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]:
raise NotImplementedError()
# Given a successful outgoing webhook REST operation, returns the message
# to sent back to the user (or None if no message should be sent).
def process_success(self, response, event):
# type: (Response, Dict[Text, Any]) -> Optional[str]
def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]:
raise NotImplementedError()
class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface):
def process_event(self, event):
# type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]
def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]:
rest_operation = {'method': 'POST',
'relative_url_path': '',
'base_url': self.base_url,
@ -60,8 +56,7 @@ class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface):
"token": self.token}
return rest_operation, json.dumps(request_data)
def process_success(self, response, event):
# type: (Response, Dict[Text, Any]) -> Optional[str]
def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]:
response_json = json.loads(response.text)
if "response_not_required" in response_json and response_json['response_not_required']:
@ -73,8 +68,7 @@ class GenericOutgoingWebhookService(OutgoingWebhookServiceInterface):
class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface):
def process_event(self, event):
# type: (Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]
def process_event(self, event: Dict[Text, Any]) -> Tuple[Dict[str, Any], Any]:
rest_operation = {'method': 'POST',
'relative_url_path': '',
'base_url': self.base_url,
@ -99,8 +93,7 @@ class SlackOutgoingWebhookService(OutgoingWebhookServiceInterface):
return rest_operation, request_data
def process_success(self, response, event):
# type: (Response, Dict[Text, Any]) -> Optional[str]
def process_success(self, response: Response, event: Dict[Text, Any]) -> Optional[str]:
response_json = json.loads(response.text)
if "text" in response_json:
return response_json["text"]
@ -112,15 +105,13 @@ AVAILABLE_OUTGOING_WEBHOOK_INTERFACES = {
SLACK_INTERFACE: SlackOutgoingWebhookService,
} # type: Dict[Text, Any]
def get_service_interface_class(interface):
# type: (Text) -> Any
def get_service_interface_class(interface: Text) -> Any:
if interface is None or interface not in AVAILABLE_OUTGOING_WEBHOOK_INTERFACES:
return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[GENERIC_INTERFACE]
else:
return AVAILABLE_OUTGOING_WEBHOOK_INTERFACES[interface]
def get_outgoing_webhook_service_handler(service):
# type: (Service) -> Any
def get_outgoing_webhook_service_handler(service: Service) -> Any:
service_interface_class = get_service_interface_class(service.interface_name())
service_interface = service_interface_class(base_url=service.base_url,
@ -129,8 +120,7 @@ def get_outgoing_webhook_service_handler(service):
service_name=service.name)
return service_interface
def send_response_message(bot_id, message, response_message_content):
# type: (str, Dict[str, Any], Text) -> None
def send_response_message(bot_id: str, message: Dict[str, Any], response_message_content: Text) -> None:
recipient_type_name = message['type']
bot_user = get_user_profile_by_id(bot_id)
realm = bot_user.realm
@ -146,18 +136,15 @@ def send_response_message(bot_id, message, response_message_content):
else:
raise JsonableError(_("Invalid message type"))
def succeed_with_message(event, success_message):
# type: (Dict[str, Any], Text) -> None
def succeed_with_message(event: Dict[str, Any], success_message: Text) -> None:
success_message = "Success! " + success_message
send_response_message(event['user_profile_id'], event['message'], success_message)
def fail_with_message(event, failure_message):
# type: (Dict[str, Any], Text) -> None
def fail_with_message(event: Dict[str, Any], failure_message: Text) -> None:
failure_message = "Failure! " + failure_message
send_response_message(event['user_profile_id'], event['message'], failure_message)
def get_message_url(event, request_data):
# type: (Dict[str, Any], Dict[str, Any]) -> Text
def get_message_url(event: Dict[str, Any], request_data: Dict[str, Any]) -> Text:
bot_user = get_user_profile_by_id(event['user_profile_id'])
message = event['message']
if message['type'] == 'stream':
@ -175,8 +162,11 @@ def get_message_url(event, request_data):
'id': str(message['id'])})
return message_url
def notify_bot_owner(event, request_data, status_code=None, response_content=None, exception=None):
# type: (Dict[str, Any], Dict[str, Any], Optional[int], Optional[AnyStr], Optional[Exception]) -> None
def notify_bot_owner(event: Dict[str, Any],
request_data: Dict[str, Any],
status_code: Optional[int]=None,
response_content: Optional[AnyStr]=None,
exception: Optional[Exception]=None) -> None:
message_url = get_message_url(event, request_data)
bot_id = event['user_profile_id']
bot_owner = get_user_profile_by_id(bot_id).bot_owner
@ -194,10 +184,11 @@ def notify_bot_owner(event, request_data, status_code=None, response_content=Non
type(exception).__name__, str(exception))
send_response_message(bot_id, message_info, notification_message)
def request_retry(event, request_data, failure_message, exception=None):
# type: (Dict[str, Any], Dict[str, Any], Text, Optional[Exception]) -> None
def failure_processor(event):
# type: (Dict[str, Any]) -> None
def request_retry(event: Dict[str, Any],
request_data: Dict[str, Any],
failure_message: Text,
exception: Optional[Exception]=None) -> None:
def failure_processor(event: Dict[str, Any]) -> None:
"""
The name of the argument is 'event' on purpose. This argument will hide
the 'event' argument of the request_retry function. Keeping the same name
@ -211,8 +202,11 @@ def request_retry(event, request_data, failure_message, exception=None):
retry_event('outgoing_webhooks', event, failure_processor)
def do_rest_call(rest_operation, request_data, event, service_handler, timeout=None):
# type: (Dict[str, Any], Optional[Dict[str, Any]], Dict[str, Any], Any, Any) -> None
def do_rest_call(rest_operation: Dict[str, Any],
request_data: Optional[Dict[str, Any]],
event: Dict[str, Any],
service_handler: Any,
timeout: Any=None) -> None:
rest_operation_validator = check_dict([
('method', check_string),
('relative_url_path', check_string),

View File

@ -21,32 +21,26 @@ rules = settings.RATE_LIMITING_RULES # type: List[Tuple[int, int]]
KEY_PREFIX = ''
class RateLimitedObject:
def get_keys(self):
# type: () -> List[Text]
def get_keys(self) -> List[Text]:
key_fragment = self.key_fragment()
return ["{}ratelimit:{}:{}".format(KEY_PREFIX, key_fragment, keytype)
for keytype in ['list', 'zset', 'block']]
def key_fragment(self):
# type: () -> Text
def key_fragment(self) -> Text:
raise NotImplementedError()
def rules(self):
# type: () -> List[Tuple[int, int]]
def rules(self) -> List[Tuple[int, int]]:
raise NotImplementedError()
class RateLimitedUser(RateLimitedObject):
def __init__(self, user, domain='all'):
# type: (UserProfile, Text) -> None
def __init__(self, user: UserProfile, domain: Text='all') -> None:
self.user = user
self.domain = domain
def key_fragment(self):
# type: () -> Text
def key_fragment(self) -> Text:
return "{}:{}:{}".format(type(self.user), self.user.id, self.domain)
def rules(self):
# type: () -> List[Tuple[int, int]]
def rules(self) -> List[Tuple[int, int]]:
if self.user.rate_limits != "":
result = [] # type: List[Tuple[int, int]]
for limit in self.user.rate_limits.split(','):
@ -55,36 +49,30 @@ class RateLimitedUser(RateLimitedObject):
return result
return rules
def bounce_redis_key_prefix_for_testing(test_name):
# type: (Text) -> None
def bounce_redis_key_prefix_for_testing(test_name: Text) -> None:
global KEY_PREFIX
KEY_PREFIX = test_name + ':' + Text(os.getpid()) + ':'
def max_api_calls(entity):
# type: (RateLimitedObject) -> int
def max_api_calls(entity: RateLimitedObject) -> int:
"Returns the API rate limit for the highest limit"
return entity.rules()[-1][1]
def max_api_window(entity):
# type: (RateLimitedObject) -> int
def max_api_window(entity: RateLimitedObject) -> int:
"Returns the API time window for the highest limit"
return entity.rules()[-1][0]
def add_ratelimit_rule(range_seconds, num_requests):
# type: (int , int) -> None
def add_ratelimit_rule(range_seconds: int, num_requests: int) -> None:
"Add a rate-limiting rule to the ratelimiter"
global rules
rules.append((range_seconds, num_requests))
rules.sort(key=lambda x: x[0])
def remove_ratelimit_rule(range_seconds, num_requests):
# type: (int , int) -> None
def remove_ratelimit_rule(range_seconds: int, num_requests: int) -> None:
global rules
rules = [x for x in rules if x[0] != range_seconds and x[1] != num_requests]
def block_access(entity, seconds):
# type: (RateLimitedObject, int) -> None
def block_access(entity: RateLimitedObject, seconds: int) -> None:
"Manually blocks an entity for the desired number of seconds"
_, _, blocking_key = entity.get_keys()
with client.pipeline() as pipe:
@ -92,13 +80,11 @@ def block_access(entity, seconds):
pipe.expire(blocking_key, seconds)
pipe.execute()
def unblock_access(entity):
# type: (RateLimitedObject) -> None
def unblock_access(entity: RateLimitedObject) -> None:
_, _, blocking_key = entity.get_keys()
client.delete(blocking_key)
def clear_history(entity):
# type: (RateLimitedObject) -> None
def clear_history(entity: RateLimitedObject) -> None:
'''
This is only used by test code now, where it's very helpful in
allowing us to run tests quickly, by giving a user a clean slate.
@ -106,8 +92,7 @@ def clear_history(entity):
for key in entity.get_keys():
client.delete(key)
def _get_api_calls_left(entity, range_seconds, max_calls):
# type: (RateLimitedObject, int, int) -> Tuple[int, float]
def _get_api_calls_left(entity: RateLimitedObject, range_seconds: int, max_calls: int) -> Tuple[int, float]:
list_key, set_key, _ = entity.get_keys()
# Count the number of values in our sorted set
# that are between now and the cutoff
@ -134,16 +119,14 @@ def _get_api_calls_left(entity, range_seconds, max_calls):
return calls_left, time_reset
def api_calls_left(entity):
# type: (RateLimitedObject) -> Tuple[int, float]
def api_calls_left(entity: RateLimitedObject) -> Tuple[int, float]:
"""Returns how many API calls in this range this client has, as well as when
the rate-limit will be reset to 0"""
max_window = max_api_window(entity)
max_calls = max_api_calls(entity)
return _get_api_calls_left(entity, max_window, max_calls)
def is_ratelimited(entity):
# type: (RateLimitedObject) -> Tuple[bool, float]
def is_ratelimited(entity: RateLimitedObject) -> Tuple[bool, float]:
"Returns a tuple of (rate_limited, time_till_free)"
list_key, set_key, blocking_key = entity.get_keys()
@ -192,8 +175,7 @@ def is_ratelimited(entity):
# No api calls recorded yet
return False, 0.0
def incr_ratelimit(entity):
# type: (RateLimitedObject) -> None
def incr_ratelimit(entity: RateLimitedObject) -> None:
"""Increases the rate-limit for the specified entity"""
list_key, set_key, _ = entity.get_keys()
now = time.time()

View File

@ -15,8 +15,7 @@ METHODS = ('GET', 'HEAD', 'POST', 'PUT', 'DELETE', 'PATCH')
FLAGS = ('override_api_url_scheme')
@csrf_exempt
def rest_dispatch(request, **kwargs):
# type: (HttpRequest, **Any) -> HttpResponse
def rest_dispatch(request: HttpRequest, **kwargs: Any) -> HttpResponse:
"""Dispatch to a REST API endpoint.
Unauthenticated endpoints should not use this, as authentication is verified

View File

@ -93,8 +93,7 @@ def send_email(template_prefix, to_user_id=None, to_email=None, from_name=None,
logger.error("Error sending %s email to %s" % (template, mail.to))
raise EmailNotDeliveredException
def send_email_from_dict(email_dict):
# type: (Mapping[str, Any]) -> None
def send_email_from_dict(email_dict: Mapping[str, Any]) -> None:
send_email(**dict(email_dict))
def send_future_email(template_prefix, to_user_id=None, to_email=None, from_name=None,

View File

@ -17,8 +17,7 @@ def filter_by_subscription_history(
# type: (UserProfile, DefaultDict[int, List[Message]], DefaultDict[int, List[RealmAuditLog]]) -> List[UserMessage]
user_messages_to_insert = [] # type: List[UserMessage]
def store_user_message_to_insert(message):
# type: (Message) -> None
def store_user_message_to_insert(message: Message) -> None:
message = UserMessage(user_profile=user_profile,
message_id=message['id'], flags=0)
user_messages_to_insert.append(message)
@ -60,8 +59,7 @@ def filter_by_subscription_history(
store_user_message_to_insert(stream_message)
return user_messages_to_insert
def add_missing_messages(user_profile):
# type: (UserProfile) -> None
def add_missing_messages(user_profile: UserProfile) -> None:
"""This function takes a soft-deactivated user, and computes and adds
to the database any UserMessage rows that were not created while
the user was soft-deactivated. The end result is that from the
@ -156,8 +154,7 @@ def add_missing_messages(user_profile):
if len(user_messages_to_insert) > 0:
UserMessage.objects.bulk_create(user_messages_to_insert)
def do_soft_deactivate_user(user_profile):
# type: (UserProfile) -> None
def do_soft_deactivate_user(user_profile: UserProfile) -> None:
user_profile.last_active_message_id = UserMessage.objects.filter(
user_profile=user_profile).order_by(
'-message__id')[0].message_id
@ -168,8 +165,7 @@ def do_soft_deactivate_user(user_profile):
logger.info('Soft Deactivated user %s (%s)' %
(user_profile.id, user_profile.email))
def do_soft_deactivate_users(users):
# type: (List[UserProfile]) -> List[UserProfile]
def do_soft_deactivate_users(users: List[UserProfile]) -> List[UserProfile]:
users_soft_deactivated = []
with transaction.atomic():
realm_logs = []
@ -187,8 +183,7 @@ def do_soft_deactivate_users(users):
RealmAuditLog.objects.bulk_create(realm_logs)
return users_soft_deactivated
def maybe_catch_up_soft_deactivated_user(user_profile):
# type: (UserProfile) -> Union[UserProfile, None]
def maybe_catch_up_soft_deactivated_user(user_profile: UserProfile) -> Union[UserProfile, None]:
if user_profile.long_term_idle:
add_missing_messages(user_profile)
user_profile.long_term_idle = False
@ -204,8 +199,7 @@ def maybe_catch_up_soft_deactivated_user(user_profile):
return user_profile
return None
def get_users_for_soft_deactivation(inactive_for_days, filter_kwargs):
# type: (int, **Any) -> List[UserProfile]
def get_users_for_soft_deactivation(inactive_for_days: int, filter_kwargs: Any) -> List[UserProfile]:
users_activity = list(UserActivity.objects.filter(
user_profile__is_active=True,
user_profile__is_bot=False,
@ -221,8 +215,7 @@ def get_users_for_soft_deactivation(inactive_for_days, filter_kwargs):
id__in=user_ids_to_deactivate))
return users_to_deactivate
def do_soft_activate_users(users):
# type: (List[UserProfile]) -> List[UserProfile]
def do_soft_activate_users(users: List[UserProfile]) -> List[UserProfile]:
users_soft_activated = []
for user_profile in users:
user_activated = maybe_catch_up_soft_deactivated_user(user_profile)

View File

@ -10,8 +10,7 @@ from zerver.models import UserProfile, Stream, Subscription, \
Realm, Recipient, bulk_get_recipients, get_stream_recipient, get_stream, \
bulk_get_streams, get_realm_stream, DefaultStreamGroup
def access_stream_for_delete(user_profile, stream_id):
# type: (UserProfile, int) -> Stream
def access_stream_for_delete(user_profile: UserProfile, stream_id: int) -> Stream:
# We should only ever use this for realm admins, who are allowed
# to delete all streams on their realm, even private streams to
@ -30,8 +29,8 @@ def access_stream_for_delete(user_profile, stream_id):
return stream
def access_stream_common(user_profile, stream, error):
# type: (UserProfile, Stream, Text) -> Tuple[Recipient, Subscription]
def access_stream_common(user_profile: UserProfile, stream: Stream,
error: Text) -> Tuple[Recipient, Subscription]:
"""Common function for backend code where the target use attempts to
access the target stream, returning all the data fetched along the
way. If that user does not have permission to access that stream,
@ -63,8 +62,7 @@ def access_stream_common(user_profile, stream, error):
# an error.
raise JsonableError(error)
def access_stream_by_id(user_profile, stream_id):
# type: (UserProfile, int) -> Tuple[Stream, Recipient, Subscription]
def access_stream_by_id(user_profile: UserProfile, stream_id: int) -> Tuple[Stream, Recipient, Subscription]:
error = _("Invalid stream id")
try:
stream = Stream.objects.get(id=stream_id)
@ -74,8 +72,7 @@ def access_stream_by_id(user_profile, stream_id):
(recipient, sub) = access_stream_common(user_profile, stream, error)
return (stream, recipient, sub)
def check_stream_name_available(realm, name):
# type: (Realm, Text) -> None
def check_stream_name_available(realm: Realm, name: Text) -> None:
check_stream_name(name)
try:
get_stream(name, realm)
@ -83,8 +80,8 @@ def check_stream_name_available(realm, name):
except Stream.DoesNotExist:
pass
def access_stream_by_name(user_profile, stream_name):
# type: (UserProfile, Text) -> Tuple[Stream, Recipient, Subscription]
def access_stream_by_name(user_profile: UserProfile,
stream_name: Text) -> Tuple[Stream, Recipient, Subscription]:
error = _("Invalid stream name '%s'" % (stream_name,))
try:
stream = get_realm_stream(stream_name, user_profile.realm_id)
@ -94,8 +91,7 @@ def access_stream_by_name(user_profile, stream_name):
(recipient, sub) = access_stream_common(user_profile, stream, error)
return (stream, recipient, sub)
def access_stream_for_unmute_topic(user_profile, stream_name, error):
# type: (UserProfile, Text, Text) -> Stream
def access_stream_for_unmute_topic(user_profile: UserProfile, stream_name: Text, error: Text) -> Stream:
"""
It may seem a little silly to have this helper function for unmuting
topics, but it gets around a linter warning, and it helps to be able
@ -115,8 +111,7 @@ def access_stream_for_unmute_topic(user_profile, stream_name, error):
raise JsonableError(error)
return stream
def is_public_stream_by_name(stream_name, realm):
# type: (Text, Realm) -> bool
def is_public_stream_by_name(stream_name: Text, realm: Realm) -> bool:
"""Determine whether a stream is public, so that
our caller can decide whether we can get
historical messages for a narrowing search.
@ -136,8 +131,8 @@ def is_public_stream_by_name(stream_name, realm):
return False
return stream.is_public()
def filter_stream_authorization(user_profile, streams):
# type: (UserProfile, Iterable[Stream]) -> Tuple[List[Stream], List[Stream]]
def filter_stream_authorization(user_profile: UserProfile,
streams: Iterable[Stream]) -> Tuple[List[Stream], List[Stream]]:
streams_subscribed = set() # type: Set[int]
recipients_map = bulk_get_recipients(Recipient.STREAM, [stream.id for stream in streams])
subs = Subscription.objects.filter(user_profile=user_profile,
@ -161,8 +156,9 @@ def filter_stream_authorization(user_profile, streams):
stream.id not in set(stream.id for stream in unauthorized_streams)]
return authorized_streams, unauthorized_streams
def list_to_streams(streams_raw, user_profile, autocreate=False):
# type: (Iterable[Mapping[str, Any]], UserProfile, bool) -> Tuple[List[Stream], List[Stream]]
def list_to_streams(streams_raw: Iterable[Mapping[str, Any]],
user_profile: UserProfile,
autocreate: bool=False) -> Tuple[List[Stream], List[Stream]]:
"""Converts list of dicts to a list of Streams, validating input in the process
For each stream name, we validate it to ensure it meets our

View File

@ -15,8 +15,7 @@ from sqlalchemy.sql import (
Selectable
)
def get_topic_mutes(user_profile):
# type: (UserProfile) -> List[List[Text]]
def get_topic_mutes(user_profile: UserProfile) -> List[List[Text]]:
rows = MutedTopic.objects.filter(
user_profile=user_profile,
).values(
@ -28,8 +27,7 @@ def get_topic_mutes(user_profile):
for row in rows
]
def set_topic_mutes(user_profile, muted_topics):
# type: (UserProfile, List[List[Text]]) -> None
def set_topic_mutes(user_profile: UserProfile, muted_topics: List[List[Text]]) -> None:
'''
This is only used in tests.
@ -50,8 +48,7 @@ def set_topic_mutes(user_profile, muted_topics):
topic_name=topic_name,
)
def add_topic_mute(user_profile, stream_id, recipient_id, topic_name):
# type: (UserProfile, int, int, str) -> None
def add_topic_mute(user_profile: UserProfile, stream_id: int, recipient_id: int, topic_name: str) -> None:
MutedTopic.objects.create(
user_profile=user_profile,
stream_id=stream_id,
@ -59,8 +56,7 @@ def add_topic_mute(user_profile, stream_id, recipient_id, topic_name):
topic_name=topic_name,
)
def remove_topic_mute(user_profile, stream_id, topic_name):
# type: (UserProfile, int, str) -> None
def remove_topic_mute(user_profile: UserProfile, stream_id: int, topic_name: str) -> None:
row = MutedTopic.objects.get(
user_profile=user_profile,
stream_id=stream_id,
@ -68,8 +64,7 @@ def remove_topic_mute(user_profile, stream_id, topic_name):
)
row.delete()
def topic_is_muted(user_profile, stream_id, topic_name):
# type: (UserProfile, int, Text) -> bool
def topic_is_muted(user_profile: UserProfile, stream_id: int, topic_name: Text) -> bool:
is_muted = MutedTopic.objects.filter(
user_profile=user_profile,
stream_id=stream_id,
@ -77,8 +72,9 @@ def topic_is_muted(user_profile, stream_id, topic_name):
).exists()
return is_muted
def exclude_topic_mutes(conditions, user_profile, stream_id):
# type: (List[Selectable], UserProfile, Optional[int]) -> List[Selectable]
def exclude_topic_mutes(conditions: List[Selectable],
user_profile: UserProfile,
stream_id: Optional[int]) -> List[Selectable]:
query = MutedTopic.objects.filter(
user_profile=user_profile,
)
@ -97,8 +93,7 @@ def exclude_topic_mutes(conditions, user_profile, stream_id):
if not rows:
return conditions
def mute_cond(row):
# type: (Dict[str, Any]) -> Selectable
def mute_cond(row: Dict[str, Any]) -> Selectable:
recipient_id = row['recipient_id']
topic_name = row['topic_name']
stream_cond = column("recipient_id") == recipient_id
@ -108,8 +103,7 @@ def exclude_topic_mutes(conditions, user_profile, stream_id):
condition = not_(or_(*list(map(mute_cond, rows))))
return conditions + [condition]
def build_topic_mute_checker(user_profile):
# type: (UserProfile) -> Callable[[int, Text], bool]
def build_topic_mute_checker(user_profile: UserProfile) -> Callable[[int, Text], bool]:
rows = MutedTopic.objects.filter(
user_profile=user_profile,
).values(
@ -124,8 +118,7 @@ def build_topic_mute_checker(user_profile):
topic_name = row['topic_name']
tups.add((recipient_id, topic_name.lower()))
def is_muted(recipient_id, topic):
# type: (int, Text) -> bool
def is_muted(recipient_id: int, topic: Text) -> bool:
return (recipient_id, topic.lower()) in tups
return is_muted

View File

@ -2,8 +2,9 @@ from typing import Optional, Text, Dict, Any
from pyoembed import oEmbed, PyOembedException
def get_oembed_data(url, maxwidth=640, maxheight=480):
# type: (Text, Optional[int], Optional[int]) -> Optional[Dict[Any, Any]]
def get_oembed_data(url: Text,
maxwidth: Optional[int]=640,
maxheight: Optional[int]=480) -> Optional[Dict[Any, Any]]:
try:
data = oEmbed(url, maxwidth=maxwidth, maxheight=maxheight)
except PyOembedException:

View File

@ -3,10 +3,8 @@ from bs4 import BeautifulSoup
class BaseParser:
def __init__(self, html_source):
# type: (Text) -> None
def __init__(self, html_source: Text) -> None:
self._soup = BeautifulSoup(html_source, "lxml")
def extract_data(self):
# type: () -> Any
def extract_data(self) -> Any:
raise NotImplementedError()

View File

@ -3,15 +3,13 @@ from zerver.lib.url_preview.parsers.base import BaseParser
class GenericParser(BaseParser):
def extract_data(self):
# type: () -> Dict[str, Optional[Text]]
def extract_data(self) -> Dict[str, Optional[Text]]:
return {
'title': self._get_title(),
'description': self._get_description(),
'image': self._get_image()}
def _get_title(self):
# type: () -> Optional[Text]
def _get_title(self) -> Optional[Text]:
soup = self._soup
if (soup.title and soup.title.text != ''):
return soup.title.text
@ -19,8 +17,7 @@ class GenericParser(BaseParser):
return soup.h1.text
return None
def _get_description(self):
# type: () -> Optional[Text]
def _get_description(self) -> Optional[Text]:
soup = self._soup
meta_description = soup.find('meta', attrs={'name': 'description'})
if (meta_description and meta_description['content'] != ''):
@ -35,8 +32,7 @@ class GenericParser(BaseParser):
return first_p.string
return None
def _get_image(self):
# type: () -> Optional[Text]
def _get_image(self) -> Optional[Text]:
"""
Finding a first image after the h1 header.
Presumably it will be the main image.

View File

@ -4,8 +4,7 @@ from .base import BaseParser
class OpenGraphParser(BaseParser):
def extract_data(self):
# type: () -> Dict[str, Text]
def extract_data(self) -> Dict[str, Text]:
meta = self._soup.findAll('meta')
content = {}
for tag in meta:

View File

@ -20,19 +20,18 @@ link_regex = re.compile(
r'(?:/?|[/?]\S+)$', re.IGNORECASE)
def is_link(url):
# type: (Text) -> Match[Text]
def is_link(url: Text) -> Match[Text]:
return link_regex.match(smart_text(url))
def cache_key_func(url):
# type: (Text) -> Text
def cache_key_func(url: Text) -> Text:
return url
@cache_with_key(cache_key_func, cache_name=CACHE_NAME, with_statsd_key="urlpreview_data")
def get_link_embed_data(url, maxwidth=640, maxheight=480):
# type: (Text, Optional[int], Optional[int]) -> Optional[Dict[Any, Any]]
def get_link_embed_data(url: Text,
maxwidth: Optional[int]=640,
maxheight: Optional[int]=480) -> Optional[Dict[Any, Any]]:
if not is_link(url):
return None
# Fetch information from URL.
@ -60,6 +59,5 @@ def get_link_embed_data(url, maxwidth=640, maxheight=480):
@get_cache_with_key(cache_key_func, cache_name=CACHE_NAME)
def link_embed_data_from_cache(url, maxwidth=640, maxheight=480):
# type: (Text, Optional[int], Optional[int]) -> Any
def link_embed_data_from_cache(url: Text, maxwidth: Optional[int]=640, maxheight: Optional[int]=480) -> Any:
return

View File

@ -8,8 +8,7 @@ from zerver.lib.request import JsonableError
from zerver.models import UserProfile, Service, Realm, \
get_user_profile_by_id, user_profile_by_email_cache_key
def check_full_name(full_name_raw):
# type: (Text) -> Text
def check_full_name(full_name_raw: Text) -> Text:
full_name = full_name_raw.strip()
if len(full_name) > UserProfile.MAX_NAME_LENGTH:
raise JsonableError(_("Name too long!"))
@ -19,20 +18,17 @@ def check_full_name(full_name_raw):
raise JsonableError(_("Invalid characters in name!"))
return full_name
def check_short_name(short_name_raw):
# type: (Text) -> Text
def check_short_name(short_name_raw: Text) -> Text:
short_name = short_name_raw.strip()
if len(short_name) == 0:
raise JsonableError(_("Bad name or username"))
return short_name
def check_valid_bot_type(bot_type):
# type: (int) -> None
def check_valid_bot_type(bot_type: int) -> None:
if bot_type not in UserProfile.ALLOWED_BOT_TYPES:
raise JsonableError(_('Invalid bot type'))
def check_valid_interface_type(interface_type):
# type: (int) -> None
def check_valid_interface_type(interface_type: int) -> None:
if interface_type not in Service.ALLOWED_INTERFACE_TYPES:
raise JsonableError(_('Invalid interface type'))

View File

@ -15,8 +15,7 @@ from django.conf import settings
T = TypeVar('T')
def statsd_key(val, clean_periods=False):
# type: (Any, bool) -> str
def statsd_key(val: Any, clean_periods: bool=False) -> str:
if not isinstance(val, str):
val = str(val)
@ -35,8 +34,7 @@ class StatsDWrapper:
# Backported support for gauge deltas
# as our statsd server supports them but supporting
# pystatsd is not released yet
def _our_gauge(self, stat, value, rate=1, delta=False):
# type: (str, float, float, bool) -> None
def _our_gauge(self, stat: str, value: float, rate: float=1, delta: bool=False) -> None:
"""Set a gauge value."""
from django_statsd.clients import statsd
if delta:
@ -45,8 +43,7 @@ class StatsDWrapper:
value_str = '%g|g' % (value,)
statsd._send(stat, value_str, rate)
def __getattr__(self, name):
# type: (str) -> Any
def __getattr__(self, name: str) -> Any:
# Hand off to statsd if we have it enabled
# otherwise do nothing
if name in ['timer', 'timing', 'incr', 'decr', 'gauge']:
@ -64,8 +61,11 @@ class StatsDWrapper:
statsd = StatsDWrapper()
# Runs the callback with slices of all_list of a given batch_size
def run_in_batches(all_list, batch_size, callback, sleep_time = 0, logger = None):
# type: (Sequence[T], int, Callable[[Sequence[T]], None], int, Optional[Callable[[str], None]]) -> None
def run_in_batches(all_list: Sequence[T],
batch_size: int,
callback: Callable[[Sequence[T]], None],
sleep_time: int=0,
logger: Optional[Callable[[str], None]]=None) -> None:
if len(all_list) == 0:
return
@ -85,8 +85,8 @@ def run_in_batches(all_list, batch_size, callback, sleep_time = 0, logger = None
if i != limit - 1:
sleep(sleep_time)
def make_safe_digest(string, hash_func=hashlib.sha1):
# type: (Text, Callable[[bytes], Any]) -> Text
def make_safe_digest(string: Text,
hash_func: Callable[[bytes], Any]=hashlib.sha1) -> Text:
"""
return a hex digest of `string`.
"""
@ -95,8 +95,7 @@ def make_safe_digest(string, hash_func=hashlib.sha1):
return hash_func(string.encode('utf-8')).hexdigest()
def log_statsd_event(name):
# type: (str) -> None
def log_statsd_event(name: str) -> None:
"""
Sends a single event to statsd with the desired name and the current timestamp
@ -110,12 +109,13 @@ def log_statsd_event(name):
event_name = "events.%s" % (name,)
statsd.incr(event_name)
def generate_random_token(length):
# type: (int) -> str
def generate_random_token(length: int) -> str:
return str(base64.b16encode(os.urandom(length // 2)).decode('utf-8').lower())
def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=None):
# type: (List[Any], Set[int], int, int) -> Iterable[Any]
def query_chunker(queries: List[Any],
id_collector: Set[int]=None,
chunk_size: int=1000,
db_chunk_size: int=None) -> Iterable[Any]:
'''
This merges one or more Django ascending-id queries into
a generator that returns chunks of chunk_size row objects
@ -142,8 +142,7 @@ def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=Non
else:
id_collector = set()
def chunkify(q, i):
# type: (Any, int) -> Iterable[Tuple[int, int, Any]]
def chunkify(q: Any, i: int) -> Iterable[Tuple[int, int, Any]]:
q = q.order_by('id')
min_id = -1
while True:
@ -171,8 +170,7 @@ def query_chunker(queries, id_collector=None, chunk_size=1000, db_chunk_size=Non
yield [row for row_id, i, row in tup_chunk]
def split_by(array, group_size, filler):
# type: (List[Any], int, Any) -> List[List[Any]]
def split_by(array: List[Any], group_size: int, filler: Any) -> List[List[Any]]:
"""
Group elements into list of size `group_size` and fill empty cells with
`filler`. Recipe from https://docs.python.org/3/library/itertools.html
@ -180,8 +178,7 @@ def split_by(array, group_size, filler):
args = [iter(array)] * group_size
return list(map(list, zip_longest(*args, fillvalue=filler)))
def is_remote_server(identifier):
# type: (Text) -> bool
def is_remote_server(identifier: Text) -> bool:
"""
This function can be used to identify the source of API auth
request. We can have two types of sources, Remote Zulip Servers

View File

@ -98,8 +98,7 @@ def get_push_commits_event_message(user_name, compare_url, branch_name,
commits_data=get_commits_content(commits_data, is_truncated),
).rstrip()
def get_force_push_commits_event_message(user_name, url, branch_name, head):
# type: (Text, Text, Text, Text) -> Text
def get_force_push_commits_event_message(user_name: Text, url: Text, branch_name: Text, head: Text) -> Text:
return FORCE_PUSH_COMMITS_MESSAGE_TEMPLATE.format(
user_name=user_name,
url=url,
@ -107,16 +106,14 @@ def get_force_push_commits_event_message(user_name, url, branch_name, head):
head=head
)
def get_create_branch_event_message(user_name, url, branch_name):
# type: (Text, Text, Text) -> Text
def get_create_branch_event_message(user_name: Text, url: Text, branch_name: Text) -> Text:
return CREATE_BRANCH_MESSAGE_TEMPLATE.format(
user_name=user_name,
url=url,
branch_name=branch_name,
)
def get_remove_branch_event_message(user_name, branch_name):
# type: (Text, Text) -> Text
def get_remove_branch_event_message(user_name: Text, branch_name: Text) -> Text:
return REMOVE_BRANCH_MESSAGE_TEMPLATE.format(
user_name=user_name,
branch_name=branch_name,
@ -147,15 +144,18 @@ def get_pull_request_event_message(
main_message += '\n' + CONTENT_MESSAGE_TEMPLATE.format(message=message)
return main_message.rstrip()
def get_setup_webhook_message(integration, user_name=None):
# type: (Text, Optional[Text]) -> Text
def get_setup_webhook_message(integration: Text, user_name: Optional[Text]=None) -> Text:
content = SETUP_MESSAGE_TEMPLATE.format(integration=integration)
if user_name:
content += SETUP_MESSAGE_USER_PART.format(user_name=user_name)
return content
def get_issue_event_message(user_name, action, url, number=None, message=None, assignee=None):
# type: (Text, Text, Text, Optional[int], Optional[Text], Optional[Text]) -> Text
def get_issue_event_message(user_name: Text,
action: Text,
url: Text,
number: Optional[int]=None,
message: Optional[Text]=None,
assignee: Optional[Text]=None) -> Text:
return get_pull_request_event_message(
user_name,
action,
@ -166,8 +166,10 @@ def get_issue_event_message(user_name, action, url, number=None, message=None, a
type='Issue'
)
def get_push_tag_event_message(user_name, tag_name, tag_url=None, action='pushed'):
# type: (Text, Text, Optional[Text], Optional[Text]) -> Text
def get_push_tag_event_message(user_name: Text,
tag_name: Text,
tag_url: Optional[Text]=None,
action: Optional[Text]='pushed') -> Text:
if tag_url:
tag_part = TAG_WITH_URL_TEMPLATE.format(tag_name=tag_name, tag_url=tag_url)
else:
@ -178,8 +180,11 @@ def get_push_tag_event_message(user_name, tag_name, tag_url=None, action='pushed
tag=tag_part
)
def get_commits_comment_action_message(user_name, action, commit_url, sha, message=None):
# type: (Text, Text, Text, Text, Optional[Text]) -> Text
def get_commits_comment_action_message(user_name: Text,
action: Text,
commit_url: Text,
sha: Text,
message: Optional[Text]=None) -> Text:
content = COMMITS_COMMENT_MESSAGE_TEMPLATE.format(
user_name=user_name,
action=action,
@ -192,8 +197,7 @@ def get_commits_comment_action_message(user_name, action, commit_url, sha, messa
)
return content
def get_commits_content(commits_data, is_truncated=False):
# type: (List[Dict[str, Any]], Optional[bool]) -> Text
def get_commits_content(commits_data: List[Dict[str, Any]], is_truncated: Optional[bool]=False) -> Text:
commits_content = ''
for commit in commits_data[:COMMITS_LIMIT]:
commits_content += COMMIT_ROW_TEMPLATE.format(
@ -212,12 +216,10 @@ def get_commits_content(commits_data, is_truncated=False):
).replace(' ', ' ')
return commits_content.rstrip()
def get_short_sha(sha):
# type: (Text) -> Text
def get_short_sha(sha: Text) -> Text:
return sha[:7]
def get_all_committers(commits_data):
# type: (List[Dict[str, Any]]) -> List[Tuple[str, int]]
def get_all_committers(commits_data: List[Dict[str, Any]]) -> List[Tuple[str, int]]:
committers = defaultdict(int) # type: Dict[str, int]
for commit in commits_data: