python: Guard against default value mutation with read-only types.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-06-12 20:24:42 -07:00 committed by Tim Abbott
parent edf411718c
commit 0d6c771baf
17 changed files with 108 additions and 59 deletions

View File

@ -2,7 +2,7 @@ import logging
import time
from collections import OrderedDict, defaultdict
from datetime import datetime, timedelta
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
from typing import Callable, Dict, Optional, Sequence, Tuple, Type, Union
from django.conf import settings
from django.db import connection
@ -70,7 +70,7 @@ class LoggingCountStat(CountStat):
class DependentCountStat(CountStat):
def __init__(self, property: str, data_collector: 'DataCollector', frequency: str,
interval: Optional[timedelta]=None, dependencies: List[str]=[]) -> None:
interval: Optional[timedelta] = None, dependencies: Sequence[str] = []) -> None:
CountStat.__init__(self, property, data_collector, frequency, interval=interval)
self.dependencies = dependencies

View File

@ -6,7 +6,7 @@ import urllib
from collections import defaultdict
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Type, Union
import pytz
from django.conf import settings
@ -411,7 +411,7 @@ def get_time_series_by_subgroup(stat: CountStat,
eastern_tz = pytz.timezone('US/Eastern')
def make_table(title: str, cols: List[str], rows: List[Any], has_row_class: bool=False) -> str:
def make_table(title: str, cols: Sequence[str], rows: Sequence[Any], has_row_class: bool = False) -> str:
if not has_row_class:
def fix_row(row: Any) -> Dict[str, Any]:
@ -818,8 +818,8 @@ def sent_messages_report(realm: str) -> str:
return make_table(title, cols, rows)
def ad_hoc_queries() -> List[Dict[str, str]]:
def get_page(query: Composable, cols: List[str], title: str,
totals_columns: List[int]=[]) -> Dict[str, str]:
def get_page(query: Composable, cols: Sequence[str], title: str,
totals_columns: Sequence[int]=[]) -> Dict[str, str]:
cursor = connection.cursor()
cursor.execute(query)
rows = cursor.fetchall()

View File

@ -6,7 +6,7 @@ import sys
from datetime import datetime, timedelta, timezone
from decimal import Decimal
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, cast
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TypeVar, cast
from unittest.mock import Mock, patch
import responses
@ -141,7 +141,7 @@ def delete_fixture_data(decorated_function: CallableT) -> None: # nocoverage
os.remove(fixture_file)
def normalize_fixture_data(decorated_function: CallableT,
tested_timestamp_fields: List[str]=[]) -> None: # nocoverage
tested_timestamp_fields: Sequence[str] = []) -> None: # nocoverage
# stripe ids are all of the form cus_D7OT2jf5YAtZQ2
id_lengths = [
('cus', 14), ('sub', 14), ('si', 14), ('sli', 14), ('req', 14), ('tok', 24), ('card', 24),
@ -208,7 +208,7 @@ MOCKED_STRIPE_FUNCTION_NAMES = [f"stripe.{name}" for name in [
"Token.create",
]]
def mock_stripe(tested_timestamp_fields: List[str]=[],
def mock_stripe(tested_timestamp_fields: Sequence[str]=[],
generate: Optional[bool]=None) -> Callable[[CallableT], CallableT]:
def _mock_stripe(decorated_function: CallableT) -> CallableT:
generate_fixture = generate
@ -299,7 +299,7 @@ class StripeTestCase(ZulipTestCase):
return match.group(1) if match else None
def upgrade(self, invoice: bool=False, talk_to_stripe: bool=True,
realm: Optional[Realm]=None, del_args: List[str]=[],
realm: Optional[Realm]=None, del_args: Sequence[str]=[],
**kwargs: Any) -> HttpResponse:
host_args = {}
if realm is not None: # nocoverage: TODO
@ -982,8 +982,8 @@ class StripeTest(StripeTestCase):
def test_check_upgrade_parameters(self) -> None:
# Tests all the error paths except 'not enough licenses'
def check_error(error_description: str, upgrade_params: Dict[str, Any],
del_args: List[str]=[]) -> None:
def check_error(error_description: str, upgrade_params: Mapping[str, Any],
del_args: Sequence[str] = []) -> None:
response = self.upgrade(talk_to_stripe=False, del_args=del_args, **upgrade_params)
self.assert_json_error_contains(response, "Something went wrong. Please contact")
self.assertEqual(ujson.loads(response.content)['error_description'], error_description)

View File

@ -377,8 +377,8 @@ def os_families() -> Set[str]:
distro_info = parse_os_release()
return {distro_info["ID"], *distro_info.get("ID_LIKE", "").split()}
def files_and_string_digest(filenames: List[str],
extra_strings: List[str]) -> str:
def files_and_string_digest(filenames: Sequence[str],
extra_strings: Sequence[str]) -> str:
# see is_digest_obsolete for more context
sha1sum = hashlib.sha1()
for fn in filenames:
@ -391,8 +391,8 @@ def files_and_string_digest(filenames: List[str],
return sha1sum.hexdigest()
def is_digest_obsolete(hash_name: str,
filenames: List[str],
extra_strings: List[str]=[]) -> bool:
filenames: Sequence[str],
extra_strings: Sequence[str] = []) -> bool:
'''
In order to determine if we need to run some
process, we calculate a digest of the important
@ -425,8 +425,8 @@ def is_digest_obsolete(hash_name: str,
return new_hash != old_hash
def write_new_digest(hash_name: str,
filenames: List[str],
extra_strings: List[str]=[]) -> None:
filenames: Sequence[str],
extra_strings: Sequence[str] = []) -> None:
hash_path = os.path.join(get_dev_uuid_var_path(), hash_name)
new_hash = files_and_string_digest(filenames, extra_strings)
with open(hash_path, 'w') as f:

View File

@ -74,3 +74,39 @@ rules:
- pattern: psycopg2.sql.SQL(... .format(...))
severity: ERROR
message: "Do not write a SQL injection vulnerability please"
- id: mutable-default-type
languages: [python]
pattern-either:
- pattern: |
def $F(..., $A: typing.List[...] = [...], ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Optional[typing.List[...]] = [...], ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.List[...] = zerver.lib.request.REQ(..., default=[...], ...), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Optional[typing.List[...]] = zerver.lib.request.REQ(..., default=[...], ...), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Dict[...] = {}, ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Optional[typing.Dict[...]] = {}, ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Dict[...] = zerver.lib.request.REQ(..., default={}, ...), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Optional[typing.Dict[...]] = zerver.lib.request.REQ(..., default={}, ...), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Set[...] = set(), ...) -> ...:
...
- pattern: |
def $F(..., $A: typing.Optional[typing.Set[...]] = set(), ...) -> ...:
...
severity: ERROR
message: "Guard mutable default with read-only type (Sequence, Mapping, AbstractSet)"

View File

@ -289,9 +289,9 @@ def build_recipient(type_id: int, recipient_id: int, type: int) -> ZerverFieldsT
recipient_dict = model_to_dict(recipient)
return recipient_dict
def build_recipients(zerver_userprofile: List[ZerverFieldsT],
zerver_stream: List[ZerverFieldsT],
zerver_huddle: List[ZerverFieldsT]=[]) -> List[ZerverFieldsT]:
def build_recipients(zerver_userprofile: Iterable[ZerverFieldsT],
zerver_stream: Iterable[ZerverFieldsT],
zerver_huddle: Iterable[ZerverFieldsT] = []) -> List[ZerverFieldsT]:
'''
As of this writing, we only use this in the HipChat
conversion. The Slack and Gitter conversions do it more

View File

@ -420,8 +420,8 @@ def add_new_user_history(user_profile: UserProfile, streams: Iterable[Stream]) -
# * subscribe the user to newsletter if newsletter_data is specified
def process_new_human_user(user_profile: UserProfile,
prereg_user: Optional[PreregistrationUser]=None,
newsletter_data: Optional[Dict[str, str]]=None,
default_stream_groups: List[DefaultStreamGroup]=[],
newsletter_data: Optional[Mapping[str, str]]=None,
default_stream_groups: Sequence[DefaultStreamGroup]=[],
realm_creation: bool=False) -> None:
mit_beta_user = user_profile.realm.is_zephyr_mirror_realm
if prereg_user is not None:
@ -565,7 +565,7 @@ def do_create_user(email: str, password: Optional[str], realm: Realm, full_name:
default_all_public_streams: Optional[bool]=None,
prereg_user: Optional[PreregistrationUser]=None,
newsletter_data: Optional[Dict[str, str]]=None,
default_stream_groups: List[DefaultStreamGroup]=[],
default_stream_groups: Sequence[DefaultStreamGroup]=[],
source_profile: Optional[UserProfile]=None,
realm_creation: bool=False) -> UserProfile:
@ -1342,7 +1342,7 @@ def do_schedule_messages(messages: Sequence[Mapping[str, Any]]) -> List[int]:
def do_send_messages(messages_maybe_none: Sequence[Optional[MutableMapping[str, Any]]],
email_gateway: bool=False,
mark_as_read: List[int]=[]) -> List[int]:
mark_as_read: Sequence[int]=[]) -> List[int]:
"""See
https://zulip.readthedocs.io/en/latest/subsystems/sending-messages.html
for high-level documentation on this subsystem.
@ -1614,12 +1614,12 @@ class UserMessageLite:
return UserMessage.flags_list_for_flags(self.flags)
def create_user_messages(message: Message,
um_eligible_user_ids: Set[int],
long_term_idle_user_ids: Set[int],
stream_push_user_ids: Set[int],
stream_email_user_ids: Set[int],
mentioned_user_ids: Set[int],
mark_as_read: List[int]=[]) -> List[UserMessageLite]:
um_eligible_user_ids: AbstractSet[int],
long_term_idle_user_ids: AbstractSet[int],
stream_push_user_ids: AbstractSet[int],
stream_email_user_ids: AbstractSet[int],
mentioned_user_ids: AbstractSet[int],
mark_as_read: Sequence[int] = []) -> List[UserMessageLite]:
ums_to_create = []
for user_profile_id in um_eligible_user_ids:
um = UserMessageLite(
@ -5140,7 +5140,7 @@ def do_get_user_invites(user_profile: UserProfile) -> List[Dict[str, Any]]:
return invites
def do_create_multiuse_invite_link(referred_by: UserProfile, invited_as: int,
streams: List[Stream]=[]) -> str:
streams: Sequence[Stream] = []) -> str:
realm = referred_by.realm
invite = MultiuseInvite.objects.create(realm=realm, referred_by=referred_by)
if streams:

View File

@ -6,7 +6,7 @@ import tempfile
import urllib
from contextlib import contextmanager
from email.utils import parseaddr
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, cast
from typing import Any, Dict, Iterable, Iterator, List, Optional, Sequence, Set, Tuple, Union, cast
from unittest import mock
import ujson
@ -407,7 +407,7 @@ class ZulipTestCase(TestCase):
realm_subdomain: str="zuliptest",
from_confirmation: str='', full_name: Optional[str]=None,
timezone: str='', realm_in_root_domain: Optional[str]=None,
default_stream_groups: List[str]=[],
default_stream_groups: Sequence[str]=[],
source_realm: str='',
key: Optional[str]=None, **kwargs: Any) -> HttpResponse:
"""

View File

@ -1,7 +1,7 @@
import re
import unicodedata
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from django.conf import settings
from django.db.models.query import QuerySet
@ -180,7 +180,7 @@ def bulk_get_users(emails: List[str], realm: Optional[Realm],
id_fetcher=user_to_email,
)
def user_ids_to_users(user_ids: List[int], realm: Realm) -> List[UserProfile]:
def user_ids_to_users(user_ids: Sequence[int], realm: Realm) -> List[UserProfile]:
# TODO: Consider adding a flag to control whether deactivated
# users should be included.

View File

@ -5,7 +5,7 @@ import json
import re
import time
import urllib
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple
from unittest import mock
import jwt
@ -3037,7 +3037,7 @@ class FetchAuthBackends(ZulipTestCase):
raise AssertionError(error)
def test_get_server_settings(self) -> None:
def check_result(result: HttpResponse, extra_fields: List[Tuple[str, Validator]]=[]) -> None:
def check_result(result: HttpResponse, extra_fields: Sequence[Tuple[str, Validator]] = []) -> None:
authentication_methods_list = [
('password', check_bool),
]
@ -3058,7 +3058,8 @@ class FetchAuthBackends(ZulipTestCase):
('push_notifications_enabled', check_bool),
('msg', check_string),
('result', check_string),
] + extra_fields)
*extra_fields,
])
self.assert_on_error(checker("data", result.json()))
result = self.client_get("/api/v1/server_settings", subdomain="", HTTP_USER_AGENT="")

View File

@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List
from typing import Any, Dict, Sequence
from unittest import mock
from urllib.parse import urlsplit
@ -34,8 +34,8 @@ class DocPageTest(ZulipTestCase):
print("ERROR: {}".format(content.get('msg')))
print()
def _test(self, url: str, expected_content: str, extra_strings: List[str]=[],
landing_missing_strings: List[str]=[], landing_page: bool=True,
def _test(self, url: str, expected_content: str, extra_strings: Sequence[str]=[],
landing_missing_strings: Sequence[str]=[], landing_page: bool=True,
doc_html_str: bool=False) -> None:
# Test the URL on the "zephyr" subdomain

View File

@ -4,7 +4,7 @@ import smtplib
import time
import urllib
from email.utils import parseaddr
from typing import Any, List, Optional
from typing import Any, List, Optional, Sequence
from unittest.mock import MagicMock, patch
import ujson
@ -772,7 +772,7 @@ class InviteUserBase(ZulipTestCase):
tokenized_no_reply_email = parseaddr(outbox[0].from_email)[1]
self.assertTrue(re.search(self.TOKENIZED_NOREPLY_REGEX, tokenized_no_reply_email))
def invite(self, invitee_emails: str, stream_names: List[str], body: str='',
def invite(self, invitee_emails: str, stream_names: Sequence[str], body: str='',
invite_as: int=1) -> HttpResponse:
"""
Invites the specified users to Zulip with the specified streams.
@ -3623,7 +3623,7 @@ class UserSignUpTest(InviteUserBase):
# Name comes from the POST request, not LDAP
self.assertEqual(user_profile.full_name, 'Non-LDAP Full Name')
def ldap_invite_and_signup_as(self, invite_as: int, streams: List[str]=['Denmark']) -> None:
def ldap_invite_and_signup_as(self, invite_as: int, streams: Sequence[str] = ['Denmark']) -> None:
self.init_default_ldap_database()
ldap_user_attr_map = {'full_name': 'cn'}

View File

@ -1,4 +1,4 @@
from typing import Any, Dict, List
from typing import Any, Mapping, Sequence
from unittest import mock
from django.test import TestCase
@ -17,8 +17,8 @@ class SubdomainsTest(TestCase):
def test(expected: str, host: str, *, plusport: bool=True,
external_host: str='example.org',
realm_hosts: Dict[str, str]={},
root_aliases: List[str]=[]) -> None:
realm_hosts: Mapping[str, str]={},
root_aliases: Sequence[str]=[]) -> None:
with self.settings(EXTERNAL_HOST=external_host,
REALM_HOSTS=realm_hosts,
ROOT_SUBDOMAIN_ALIASES=root_aliases):

View File

@ -1,5 +1,5 @@
import re
from typing import List, Set
from typing import List, Sequence, Set
from django.http import HttpRequest, HttpResponse
from django.utils.translation import ugettext as _
@ -129,7 +129,7 @@ def resend_user_invite_email(request: HttpRequest, user_profile: UserProfile,
def generate_multiuse_invite_backend(
request: HttpRequest, user_profile: UserProfile,
invite_as: int=REQ(validator=check_int, default=PreregistrationUser.INVITE_AS['MEMBER']),
stream_ids: List[int]=REQ(validator=check_list(check_int), default=[])) -> HttpResponse:
stream_ids: Sequence[int]=REQ(validator=check_list(check_int), default=[])) -> HttpResponse:
streams = []
for stream_id in stream_ids:
try:

View File

@ -1,5 +1,17 @@
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
Optional,
Sequence,
Set,
Tuple,
Union,
)
import ujson
from django.conf import settings
@ -371,7 +383,7 @@ def add_subscriptions_backend(
Stream.STREAM_POST_POLICY_TYPES), default=Stream.STREAM_POST_POLICY_EVERYONE),
history_public_to_subscribers: Optional[bool]=REQ(validator=check_bool, default=None),
announce: bool=REQ(validator=check_bool, default=False),
principals: List[Union[str, int]]=REQ(validator=check_variable_type([
principals: Sequence[Union[str, int]]=REQ(validator=check_variable_type([
check_list(check_string), check_list(check_int)]), default=[]),
authorization_errors_fatal: bool=REQ(validator=check_bool, default=True),
) -> HttpResponse:

View File

@ -1,4 +1,4 @@
from typing import List
from typing import List, Sequence
from django.http import HttpRequest, HttpResponse
from django.utils.translation import ugettext as _
@ -31,7 +31,7 @@ from zerver.views.streams import FuncKwargPair, compose_views
@has_request_variables
def add_user_group(request: HttpRequest, user_profile: UserProfile,
name: str=REQ(),
members: List[int]=REQ(validator=check_list(check_int), default=[]),
members: Sequence[int]=REQ(validator=check_list(check_int), default=[]),
description: str=REQ()) -> HttpResponse:
user_profiles = user_ids_to_users(members, user_profile.realm)
check_add_user_group(user_profile.realm, name, user_profiles, description)
@ -74,8 +74,8 @@ def delete_user_group(request: HttpRequest, user_profile: UserProfile,
@has_request_variables
def update_user_group_backend(request: HttpRequest, user_profile: UserProfile,
user_group_id: int=REQ(validator=check_int),
delete: List[int]=REQ(validator=check_list(check_int), default=[]),
add: List[int]=REQ(validator=check_list(check_int), default=[]),
delete: Sequence[int]=REQ(validator=check_list(check_int), default=[]),
add: Sequence[int]=REQ(validator=check_list(check_int), default=[]),
) -> HttpResponse:
if not add and not delete:
return json_error(_('Nothing to do. Specify at least one of "add" or "delete".'))

View File

@ -1,6 +1,6 @@
# Webhooks for external integrations.
import time
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, Optional, Sequence, Tuple
from django.http import HttpRequest, HttpResponse
@ -49,7 +49,7 @@ def topic_and_body(payload: Dict[str, Any]) -> Tuple[str, str]:
topic = customer_id
body = None
def update_string(blacklist: List[str]=[]) -> str:
def update_string(blacklist: Sequence[str] = []) -> str:
assert('previous_attributes' in payload['data'])
previous_attributes = payload['data']['previous_attributes']
for attribute in blacklist:
@ -60,7 +60,7 @@ def topic_and_body(payload: Dict[str, Any]) -> Tuple[str, str]:
' is now ' + stringify(object_[attribute])
for attribute in sorted(previous_attributes.keys()))
def default_body(update_blacklist: List[str]=[]) -> str:
def default_body(update_blacklist: Sequence[str] = []) -> str:
body = '{resource} {verbed}'.format(
resource=linkified_id(object_['id']), verbed=event.replace('_', ' '))
if event == 'updated':