From 7d1c475f69ebd01430244f0b09b04012bb687fda Mon Sep 17 00:00:00 2001 From: PIG208 <359101898@qq.com> Date: Sun, 25 Jul 2021 22:31:12 +0800 Subject: [PATCH] typing: Use assertions for function arguments. Utilize the assert_is_not_None helper to eliminate errors of 'Argument x to "Foo" has incompatible type "Optional[Bar]"...' --- analytics/tests/test_counts.py | 9 ++++-- analytics/views/support.py | 9 +++--- corporate/lib/stripe.py | 13 +++++--- corporate/tests/test_stripe.py | 54 +++++++++++++++++++-------------- zerver/lib/soft_deactivation.py | 11 ++++--- zerver/lib/upload.py | 19 +++++++----- zerver/lib/user_mutes.py | 3 +- zerver/views/realm_export.py | 3 +- zerver/views/streams.py | 7 +++-- 9 files changed, 80 insertions(+), 48 deletions(-) diff --git a/analytics/tests/test_counts.py b/analytics/tests/test_counts.py index 3184dc7ddc..e31565a8a1 100644 --- a/analytics/tests/test_counts.py +++ b/analytics/tests/test_counts.py @@ -51,6 +51,7 @@ from zerver.lib.exceptions import InvitationError from zerver.lib.test_classes import ZulipTestCase from zerver.lib.timestamp import TimezoneNotUTCException, floor_to_day from zerver.lib.topic import DB_TOPIC_NAME +from zerver.lib.utils import assert_is_not_none from zerver.models import ( Client, Huddle, @@ -1388,11 +1389,13 @@ class TestLoggingCountStats(AnalyticsTestCase): assertInviteCountEquals(5) # Revoking invite should not give you credit - do_revoke_user_invite(PreregistrationUser.objects.filter(realm=user.realm).first()) + do_revoke_user_invite( + assert_is_not_none(PreregistrationUser.objects.filter(realm=user.realm).first()) + ) assertInviteCountEquals(5) # Resending invite should cost you - do_resend_user_invite_email(PreregistrationUser.objects.first()) + do_resend_user_invite_email(assert_is_not_none(PreregistrationUser.objects.first())) assertInviteCountEquals(6) def test_messages_read_hour(self) -> None: @@ -1423,7 +1426,7 @@ class TestLoggingCountStats(AnalyticsTestCase): self.send_stream_message(user1, stream.name) self.send_stream_message(user1, stream.name) - do_mark_stream_messages_as_read(user2, stream.recipient_id) + do_mark_stream_messages_as_read(user2, assert_is_not_none(stream.recipient_id)) self.assertEqual( 3, UserCount.objects.filter(property=read_count_property).aggregate(Sum("value"))[ diff --git a/analytics/views/support.py b/analytics/views/support.py index 0847d41f31..a88c693696 100644 --- a/analytics/views/support.py +++ b/analytics/views/support.py @@ -28,6 +28,7 @@ from zerver.lib.actions import ( from zerver.lib.exceptions import JsonableError from zerver.lib.realm_icon import realm_icon_url from zerver.lib.subdomains import get_subdomain_from_hostname +from zerver.lib.utils import assert_is_not_none from zerver.models import ( MultiuseInvite, PreregistrationUser, @@ -119,24 +120,24 @@ def support(request: HttpRequest) -> HttpResponse: if len(keys) != 2: raise JsonableError(_("Invalid parameters")) - realm_id = request.POST.get("realm_id") + realm_id: str = assert_is_not_none(request.POST.get("realm_id")) realm = Realm.objects.get(id=realm_id) if request.POST.get("plan_type", None) is not None: - new_plan_type = int(request.POST.get("plan_type")) + new_plan_type = int(assert_is_not_none(request.POST.get("plan_type"))) current_plan_type = realm.plan_type do_change_plan_type(realm, new_plan_type, acting_user=request.user) msg = f"Plan type of {realm.string_id} changed from {get_plan_name(current_plan_type)} to {get_plan_name(new_plan_type)} " context["success_message"] = msg elif request.POST.get("discount", None) is not None: - new_discount = Decimal(request.POST.get("discount")) + new_discount = Decimal(assert_is_not_none(request.POST.get("discount"))) current_discount = get_discount_for_realm(realm) or 0 attach_discount_to_realm(realm, new_discount, acting_user=request.user) context[ "success_message" ] = f"Discount of {realm.string_id} changed to {new_discount}% from {current_discount}%." elif request.POST.get("new_subdomain", None) is not None: - new_subdomain = request.POST.get("new_subdomain") + new_subdomain: str = assert_is_not_none(request.POST.get("new_subdomain")) old_subdomain = realm.string_id try: check_subdomain_available(new_subdomain) diff --git a/corporate/lib/stripe.py b/corporate/lib/stripe.py index 83836166dc..087078d0f5 100644 --- a/corporate/lib/stripe.py +++ b/corporate/lib/stripe.py @@ -5,7 +5,7 @@ import secrets from datetime import datetime, timedelta from decimal import Decimal from functools import wraps -from typing import Callable, Dict, Generator, Optional, Tuple, TypeVar, cast +from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, cast import orjson import stripe @@ -30,6 +30,7 @@ from zerver.lib.exceptions import JsonableError from zerver.lib.logging_util import log_to_file from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime +from zerver.lib.utils import assert_is_not_none from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot from zproject.config import get_secret @@ -516,7 +517,9 @@ def compute_plan_parameters( if automanage_licenses: next_invoice_date = add_months(billing_cycle_anchor, 1) if free_trial: - period_end = billing_cycle_anchor + timedelta(days=settings.FREE_TRIAL_DAYS) + period_end = billing_cycle_anchor + timedelta( + days=assert_is_not_none(settings.FREE_TRIAL_DAYS) + ) next_invoice_date = period_end return billing_cycle_anchor, next_invoice_date, period_end, price_per_license @@ -943,10 +946,12 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag def get_realms_to_default_discount_dict() -> Dict[str, Decimal]: - realms_to_default_discount = {} + realms_to_default_discount: Dict[str, Any] = {} customers = Customer.objects.exclude(default_discount=None).exclude(default_discount=0) for customer in customers: - realms_to_default_discount[customer.realm.string_id] = customer.default_discount + realms_to_default_discount[customer.realm.string_id] = assert_is_not_none( + customer.default_discount + ) return realms_to_default_discount diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 63f50139ca..0d63c04e53 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -90,6 +90,7 @@ from zerver.lib.actions import ( ) from zerver.lib.test_classes import ZulipTestCase from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime +from zerver.lib.utils import assert_is_not_none from zerver.models import ( Message, Realm, @@ -532,7 +533,7 @@ class StripeTest(StripeTestCase): # Check that we correctly created a Customer object in Stripe stripe_customer = stripe_get_customer( - Customer.objects.get(realm=user.realm).stripe_customer_id + assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id) ) self.assertEqual(stripe_customer.default_source.id[:5], "card_") self.assertTrue(stripe_customer_has_credit_card_as_default_source(stripe_customer)) @@ -642,9 +643,11 @@ class StripeTest(StripeTestCase): self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual( orjson.loads( - RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) - .values_list("extra_data", flat=True) - .first() + assert_is_not_none( + RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) + .values_list("extra_data", flat=True) + .first() + ) )["automanage_licenses"], True, ) @@ -694,7 +697,7 @@ class StripeTest(StripeTestCase): self.upgrade(invoice=True) # Check that we correctly created a Customer in Stripe stripe_customer = stripe_get_customer( - Customer.objects.get(realm=user.realm).stripe_customer_id + assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id) ) self.assertFalse(stripe_customer_has_credit_card_as_default_source(stripe_customer)) # It can take a second for Stripe to attach the source to the customer, and in @@ -781,9 +784,11 @@ class StripeTest(StripeTestCase): self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual( orjson.loads( - RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) - .values_list("extra_data", flat=True) - .first() + assert_is_not_none( + RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) + .values_list("extra_data", flat=True) + .first() + ) )["automanage_licenses"], False, ) @@ -834,7 +839,7 @@ class StripeTest(StripeTestCase): self.upgrade() stripe_customer = stripe_get_customer( - Customer.objects.get(realm=user.realm).stripe_customer_id + assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id) ) self.assertEqual(stripe_customer.default_source.id[:5], "card_") self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)") @@ -894,9 +899,11 @@ class StripeTest(StripeTestCase): self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual( orjson.loads( - RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) - .values_list("extra_data", flat=True) - .first() + assert_is_not_none( + RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) + .values_list("extra_data", flat=True) + .first() + ) )["automanage_licenses"], True, ) @@ -1040,7 +1047,7 @@ class StripeTest(StripeTestCase): self.upgrade(invoice=True) stripe_customer = stripe_get_customer( - Customer.objects.get(realm=user.realm).stripe_customer_id + assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id) ) self.assertEqual(stripe_customer.discount, None) self.assertEqual(stripe_customer.email, user.delivery_email) @@ -1093,9 +1100,11 @@ class StripeTest(StripeTestCase): self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED) self.assertEqual( orjson.loads( - RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) - .values_list("extra_data", flat=True) - .first() + assert_is_not_none( + RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED) + .values_list("extra_data", flat=True) + .first() + ) )["automanage_licenses"], False, ) @@ -1218,7 +1227,7 @@ class StripeTest(StripeTestCase): self.upgrade() customer = Customer.objects.first() assert customer is not None - stripe_customer_id = customer.stripe_customer_id + stripe_customer_id: str = assert_is_not_none(customer.stripe_customer_id) # Check that the Charge used the old quantity, not new_seat_count [charge] = stripe.Charge.list(customer=stripe_customer_id) self.assertEqual(8000 * self.seat_count, charge.amount) @@ -2037,9 +2046,10 @@ class StripeTest(StripeTestCase): audit_log = RealmAuditLog.objects.get( event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN ) + extra_data: str = assert_is_not_none(audit_log.extra_data) self.assertEqual(audit_log.realm, user.realm) - self.assertEqual(orjson.loads(audit_log.extra_data)["monthly_plan_id"], monthly_plan.id) - self.assertEqual(orjson.loads(audit_log.extra_data)["annual_plan_id"], annual_plan.id) + self.assertEqual(orjson.loads(extra_data)["monthly_plan_id"], monthly_plan.id) + self.assertEqual(orjson.loads(extra_data)["annual_plan_id"], annual_plan.id) invoice_plans_as_needed(self.next_month) @@ -2468,7 +2478,7 @@ class StripeTest(StripeTestCase): self.assert_json_success(result) invoice_plans_as_needed(self.next_year) stripe_customer = stripe_get_customer( - Customer.objects.get(realm=user.realm).stripe_customer_id + assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id) ) [invoice, _] = stripe.Invoice.list(customer=stripe_customer.id) invoice_params = { @@ -2518,7 +2528,7 @@ class StripeTest(StripeTestCase): self.assert_json_success(result) invoice_plans_as_needed(self.next_year + timedelta(days=365)) stripe_customer = stripe_get_customer( - Customer.objects.get(realm=user.realm).stripe_customer_id + assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id) ) [invoice, _, _] = stripe.Invoice.list(customer=stripe_customer.id) invoice_params = { @@ -3460,7 +3470,7 @@ class InvoiceTest(StripeTestCase): plan.invoicing_status = CustomerPlan.STARTED plan.save(update_fields=["invoicing_status"]) with self.assertRaises(NotImplementedError): - invoice_plan(CustomerPlan.objects.first(), self.now) + invoice_plan(assert_is_not_none(CustomerPlan.objects.first()), self.now) def test_invoice_plan_without_stripe_customer(self) -> None: self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL) diff --git a/zerver/lib/soft_deactivation.py b/zerver/lib/soft_deactivation.py index 733b6c5f23..8a724dd5f4 100644 --- a/zerver/lib/soft_deactivation.py +++ b/zerver/lib/soft_deactivation.py @@ -10,6 +10,7 @@ from django.utils.timezone import now as timezone_now from sentry_sdk import capture_exception from zerver.lib.logging_util import log_to_file +from zerver.lib.utils import assert_is_not_none from zerver.models import ( Message, Realm, @@ -63,12 +64,14 @@ def filter_by_subscription_history( # check belongs in this inner loop, not the outer loop. break + event_last_message_id = assert_is_not_none(log_entry.event_last_message_id) + if log_entry.event_type == RealmAuditLog.SUBSCRIPTION_DEACTIVATED: # If the event shows the user was unsubscribed after # event_last_message_id, we know they must have been # subscribed immediately before the event. for stream_message in stream_messages: - if stream_message["id"] <= log_entry.event_last_message_id: + if stream_message["id"] <= event_last_message_id: store_user_message_to_insert(stream_message) else: break @@ -78,12 +81,12 @@ def filter_by_subscription_history( ): initial_msg_count = len(stream_messages) for i, stream_message in enumerate(stream_messages): - if stream_message["id"] > log_entry.event_last_message_id: + if stream_message["id"] > event_last_message_id: stream_messages = stream_messages[i:] break final_msg_count = len(stream_messages) if initial_msg_count == final_msg_count: - if stream_messages[-1]["id"] <= log_entry.event_last_message_id: + if stream_messages[-1]["id"] <= event_last_message_id: stream_messages = [] else: raise AssertionError(f"{log_entry.event_type} is not a subscription event.") @@ -172,7 +175,7 @@ def add_missing_messages(user_profile: UserProfile) -> None: all_stream_subscription_logs: DefaultDict[int, List[RealmAuditLog]] = defaultdict(list) for log in subscription_logs: - all_stream_subscription_logs[log.modified_stream_id].append(log) + all_stream_subscription_logs[assert_is_not_none(log.modified_stream_id)].append(log) recipient_ids = [] for sub in all_stream_subs: diff --git a/zerver/lib/upload.py b/zerver/lib/upload.py index 601f2b08e5..88122486c9 100644 --- a/zerver/lib/upload.py +++ b/zerver/lib/upload.py @@ -31,6 +31,7 @@ from PIL.Image import DecompressionBombError from zerver.lib.avatar_hash import user_avatar_path from zerver.lib.exceptions import ErrorCode, JsonableError +from zerver.lib.utils import assert_is_not_none from zerver.models import Attachment, Message, Realm, RealmEmoji, UserProfile DEFAULT_AVATAR_SIZE = 100 @@ -729,7 +730,7 @@ class S3UploadBackend(ZulipUploadBackend): def write_local_file(type: str, path: str, file_data: bytes) -> None: - file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path) + file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path) os.makedirs(os.path.dirname(file_path), exist_ok=True) with open(file_path, "wb") as f: @@ -737,13 +738,13 @@ def write_local_file(type: str, path: str, file_data: bytes) -> None: def read_local_file(type: str, path: str) -> bytes: - file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path) + file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path) with open(file_path, "rb") as f: return f.read() def delete_local_file(type: str, path: str) -> bool: - file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path) + file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path) if os.path.isfile(file_path): # This removes the file but the empty folders still remain. os.remove(file_path) @@ -754,7 +755,7 @@ def delete_local_file(type: str, path: str) -> bool: def get_local_file_path(path_id: str) -> Optional[str]: - local_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "files", path_id) + local_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "files", path_id) if os.path.isfile(local_path): return local_path else: @@ -897,12 +898,16 @@ class LocalUploadBackend(ZulipUploadBackend): file_path = user_avatar_path(user_profile) output_path = os.path.join( - settings.LOCAL_UPLOADS_DIR, "avatars", file_path + file_extension + assert_is_not_none(settings.LOCAL_UPLOADS_DIR), + "avatars", + file_path + file_extension, ) if os.path.isfile(output_path): return - image_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", file_path + ".original") + image_path = os.path.join( + assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", file_path + ".original" + ) with open(image_path, "rb") as f: image_data = f.read() if is_medium: @@ -942,7 +947,7 @@ class LocalUploadBackend(ZulipUploadBackend): secrets.token_urlsafe(18), os.path.basename(tarball_path), ) - abs_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", path) + abs_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", path) os.makedirs(os.path.dirname(abs_path), exist_ok=True) shutil.copy(tarball_path, abs_path) public_url = realm.uri + "/user_avatars/" + path diff --git a/zerver/lib/user_mutes.py b/zerver/lib/user_mutes.py index 71389db704..07548e08fc 100644 --- a/zerver/lib/user_mutes.py +++ b/zerver/lib/user_mutes.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Set from zerver.lib.cache import cache_with_key, get_muting_users_cache_key from zerver.lib.timestamp import datetime_to_timestamp +from zerver.lib.utils import assert_is_not_none from zerver.models import MutedUser, UserProfile @@ -14,7 +15,7 @@ def get_user_mutes(user_profile: UserProfile) -> List[Dict[str, int]]: return [ { "id": row["muted_user_id"], - "timestamp": datetime_to_timestamp(row["date_muted"]), + "timestamp": datetime_to_timestamp(assert_is_not_none(row["date_muted"])), } for row in rows ] diff --git a/zerver/views/realm_export.py b/zerver/views/realm_export.py index 2f95dba0c6..50118c31fa 100644 --- a/zerver/views/realm_export.py +++ b/zerver/views/realm_export.py @@ -13,6 +13,7 @@ from zerver.lib.exceptions import JsonableError from zerver.lib.export import get_realm_exports_serialized from zerver.lib.queue import queue_json_publish from zerver.lib.response import json_success +from zerver.lib.utils import assert_is_not_none from zerver.models import RealmAuditLog, UserProfile @@ -90,7 +91,7 @@ def delete_realm_export(request: HttpRequest, user: UserProfile, export_id: int) except RealmAuditLog.DoesNotExist: raise JsonableError(_("Invalid data export ID")) - export_data = orjson.loads(audit_log_entry.extra_data) + export_data = orjson.loads(assert_is_not_none(audit_log_entry.extra_data)) if "deleted_timestamp" in export_data: raise JsonableError(_("Export already deleted")) do_delete_realm_export(user, audit_log_entry) diff --git a/zerver/views/streams.py b/zerver/views/streams.py index a5b11b8f3c..9a1725bc43 100644 --- a/zerver/views/streams.py +++ b/zerver/views/streams.py @@ -71,6 +71,7 @@ from zerver.lib.topic import ( messages_for_topic, ) from zerver.lib.types import Validator +from zerver.lib.utils import assert_is_not_none from zerver.lib.validator import ( check_bool, check_capped_string, @@ -726,7 +727,9 @@ def get_topics_backend( if is_web_public_query: realm = get_valid_realm_from_request(request) stream = access_web_public_stream(stream_id, realm) - result = get_topic_history_for_public_stream(recipient_id=stream.recipient_id) + result = get_topic_history_for_public_stream( + recipient_id=assert_is_not_none(stream.recipient_id) + ) else: assert user_profile is not None @@ -753,7 +756,7 @@ def delete_in_topic( ) -> HttpResponse: (stream, sub) = access_stream_by_id(user_profile, stream_id) - messages = messages_for_topic(stream.recipient_id, topic_name) + messages = messages_for_topic(assert_is_not_none(stream.recipient_id), topic_name) if not stream.is_history_public_to_subscribers(): # Don't allow the user to delete messages that they don't have access to. deletable_message_ids = UserMessage.objects.filter(