typing: Use assertions for function arguments.

Utilize the assert_is_not_None helper to eliminate errors of
'Argument x to "Foo" has incompatible type "Optional[Bar]"...'
This commit is contained in:
PIG208 2021-07-25 22:31:12 +08:00 committed by Tim Abbott
parent 8a91d1c2b1
commit 7d1c475f69
9 changed files with 80 additions and 48 deletions

View File

@ -51,6 +51,7 @@ from zerver.lib.exceptions import InvitationError
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.timestamp import TimezoneNotUTCException, floor_to_day
from zerver.lib.topic import DB_TOPIC_NAME
from zerver.lib.utils import assert_is_not_none
from zerver.models import (
Client,
Huddle,
@ -1388,11 +1389,13 @@ class TestLoggingCountStats(AnalyticsTestCase):
assertInviteCountEquals(5)
# Revoking invite should not give you credit
do_revoke_user_invite(PreregistrationUser.objects.filter(realm=user.realm).first())
do_revoke_user_invite(
assert_is_not_none(PreregistrationUser.objects.filter(realm=user.realm).first())
)
assertInviteCountEquals(5)
# Resending invite should cost you
do_resend_user_invite_email(PreregistrationUser.objects.first())
do_resend_user_invite_email(assert_is_not_none(PreregistrationUser.objects.first()))
assertInviteCountEquals(6)
def test_messages_read_hour(self) -> None:
@ -1423,7 +1426,7 @@ class TestLoggingCountStats(AnalyticsTestCase):
self.send_stream_message(user1, stream.name)
self.send_stream_message(user1, stream.name)
do_mark_stream_messages_as_read(user2, stream.recipient_id)
do_mark_stream_messages_as_read(user2, assert_is_not_none(stream.recipient_id))
self.assertEqual(
3,
UserCount.objects.filter(property=read_count_property).aggregate(Sum("value"))[

View File

@ -28,6 +28,7 @@ from zerver.lib.actions import (
from zerver.lib.exceptions import JsonableError
from zerver.lib.realm_icon import realm_icon_url
from zerver.lib.subdomains import get_subdomain_from_hostname
from zerver.lib.utils import assert_is_not_none
from zerver.models import (
MultiuseInvite,
PreregistrationUser,
@ -119,24 +120,24 @@ def support(request: HttpRequest) -> HttpResponse:
if len(keys) != 2:
raise JsonableError(_("Invalid parameters"))
realm_id = request.POST.get("realm_id")
realm_id: str = assert_is_not_none(request.POST.get("realm_id"))
realm = Realm.objects.get(id=realm_id)
if request.POST.get("plan_type", None) is not None:
new_plan_type = int(request.POST.get("plan_type"))
new_plan_type = int(assert_is_not_none(request.POST.get("plan_type")))
current_plan_type = realm.plan_type
do_change_plan_type(realm, new_plan_type, acting_user=request.user)
msg = f"Plan type of {realm.string_id} changed from {get_plan_name(current_plan_type)} to {get_plan_name(new_plan_type)} "
context["success_message"] = msg
elif request.POST.get("discount", None) is not None:
new_discount = Decimal(request.POST.get("discount"))
new_discount = Decimal(assert_is_not_none(request.POST.get("discount")))
current_discount = get_discount_for_realm(realm) or 0
attach_discount_to_realm(realm, new_discount, acting_user=request.user)
context[
"success_message"
] = f"Discount of {realm.string_id} changed to {new_discount}% from {current_discount}%."
elif request.POST.get("new_subdomain", None) is not None:
new_subdomain = request.POST.get("new_subdomain")
new_subdomain: str = assert_is_not_none(request.POST.get("new_subdomain"))
old_subdomain = realm.string_id
try:
check_subdomain_available(new_subdomain)

View File

@ -5,7 +5,7 @@ import secrets
from datetime import datetime, timedelta
from decimal import Decimal
from functools import wraps
from typing import Callable, Dict, Generator, Optional, Tuple, TypeVar, cast
from typing import Any, Callable, Dict, Generator, Optional, Tuple, TypeVar, cast
import orjson
import stripe
@ -30,6 +30,7 @@ from zerver.lib.exceptions import JsonableError
from zerver.lib.logging_util import log_to_file
from zerver.lib.send_email import FromAddress, send_email_to_billing_admins_and_realm_owners
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none
from zerver.models import Realm, RealmAuditLog, UserProfile, get_system_bot
from zproject.config import get_secret
@ -516,7 +517,9 @@ def compute_plan_parameters(
if automanage_licenses:
next_invoice_date = add_months(billing_cycle_anchor, 1)
if free_trial:
period_end = billing_cycle_anchor + timedelta(days=settings.FREE_TRIAL_DAYS)
period_end = billing_cycle_anchor + timedelta(
days=assert_is_not_none(settings.FREE_TRIAL_DAYS)
)
next_invoice_date = period_end
return billing_cycle_anchor, next_invoice_date, period_end, price_per_license
@ -943,10 +946,12 @@ def estimate_annual_recurring_revenue_by_realm() -> Dict[str, int]: # nocoverag
def get_realms_to_default_discount_dict() -> Dict[str, Decimal]:
realms_to_default_discount = {}
realms_to_default_discount: Dict[str, Any] = {}
customers = Customer.objects.exclude(default_discount=None).exclude(default_discount=0)
for customer in customers:
realms_to_default_discount[customer.realm.string_id] = customer.default_discount
realms_to_default_discount[customer.realm.string_id] = assert_is_not_none(
customer.default_discount
)
return realms_to_default_discount

View File

@ -90,6 +90,7 @@ from zerver.lib.actions import (
)
from zerver.lib.test_classes import ZulipTestCase
from zerver.lib.timestamp import datetime_to_timestamp, timestamp_to_datetime
from zerver.lib.utils import assert_is_not_none
from zerver.models import (
Message,
Realm,
@ -532,7 +533,7 @@ class StripeTest(StripeTestCase):
# Check that we correctly created a Customer object in Stripe
stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id
assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
)
self.assertEqual(stripe_customer.default_source.id[:5], "card_")
self.assertTrue(stripe_customer_has_credit_card_as_default_source(stripe_customer))
@ -642,9 +643,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual(
orjson.loads(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
)
)["automanage_licenses"],
True,
)
@ -694,7 +697,7 @@ class StripeTest(StripeTestCase):
self.upgrade(invoice=True)
# Check that we correctly created a Customer in Stripe
stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id
assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
)
self.assertFalse(stripe_customer_has_credit_card_as_default_source(stripe_customer))
# It can take a second for Stripe to attach the source to the customer, and in
@ -781,9 +784,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual(
orjson.loads(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
)
)["automanage_licenses"],
False,
)
@ -834,7 +839,7 @@ class StripeTest(StripeTestCase):
self.upgrade()
stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id
assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
)
self.assertEqual(stripe_customer.default_source.id[:5], "card_")
self.assertEqual(stripe_customer.description, "zulip (Zulip Dev)")
@ -894,9 +899,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[3][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual(
orjson.loads(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
)
)["automanage_licenses"],
True,
)
@ -1040,7 +1047,7 @@ class StripeTest(StripeTestCase):
self.upgrade(invoice=True)
stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id
assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
)
self.assertEqual(stripe_customer.discount, None)
self.assertEqual(stripe_customer.email, user.delivery_email)
@ -1093,9 +1100,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(audit_log_entries[2][0], RealmAuditLog.REALM_PLAN_TYPE_CHANGED)
self.assertEqual(
orjson.loads(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
assert_is_not_none(
RealmAuditLog.objects.filter(event_type=RealmAuditLog.CUSTOMER_PLAN_CREATED)
.values_list("extra_data", flat=True)
.first()
)
)["automanage_licenses"],
False,
)
@ -1218,7 +1227,7 @@ class StripeTest(StripeTestCase):
self.upgrade()
customer = Customer.objects.first()
assert customer is not None
stripe_customer_id = customer.stripe_customer_id
stripe_customer_id: str = assert_is_not_none(customer.stripe_customer_id)
# Check that the Charge used the old quantity, not new_seat_count
[charge] = stripe.Charge.list(customer=stripe_customer_id)
self.assertEqual(8000 * self.seat_count, charge.amount)
@ -2037,9 +2046,10 @@ class StripeTest(StripeTestCase):
audit_log = RealmAuditLog.objects.get(
event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN
)
extra_data: str = assert_is_not_none(audit_log.extra_data)
self.assertEqual(audit_log.realm, user.realm)
self.assertEqual(orjson.loads(audit_log.extra_data)["monthly_plan_id"], monthly_plan.id)
self.assertEqual(orjson.loads(audit_log.extra_data)["annual_plan_id"], annual_plan.id)
self.assertEqual(orjson.loads(extra_data)["monthly_plan_id"], monthly_plan.id)
self.assertEqual(orjson.loads(extra_data)["annual_plan_id"], annual_plan.id)
invoice_plans_as_needed(self.next_month)
@ -2468,7 +2478,7 @@ class StripeTest(StripeTestCase):
self.assert_json_success(result)
invoice_plans_as_needed(self.next_year)
stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id
assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
)
[invoice, _] = stripe.Invoice.list(customer=stripe_customer.id)
invoice_params = {
@ -2518,7 +2528,7 @@ class StripeTest(StripeTestCase):
self.assert_json_success(result)
invoice_plans_as_needed(self.next_year + timedelta(days=365))
stripe_customer = stripe_get_customer(
Customer.objects.get(realm=user.realm).stripe_customer_id
assert_is_not_none(Customer.objects.get(realm=user.realm).stripe_customer_id)
)
[invoice, _, _] = stripe.Invoice.list(customer=stripe_customer.id)
invoice_params = {
@ -3460,7 +3470,7 @@ class InvoiceTest(StripeTestCase):
plan.invoicing_status = CustomerPlan.STARTED
plan.save(update_fields=["invoicing_status"])
with self.assertRaises(NotImplementedError):
invoice_plan(CustomerPlan.objects.first(), self.now)
invoice_plan(assert_is_not_none(CustomerPlan.objects.first()), self.now)
def test_invoice_plan_without_stripe_customer(self) -> None:
self.local_upgrade(self.seat_count, True, CustomerPlan.ANNUAL)

View File

@ -10,6 +10,7 @@ from django.utils.timezone import now as timezone_now
from sentry_sdk import capture_exception
from zerver.lib.logging_util import log_to_file
from zerver.lib.utils import assert_is_not_none
from zerver.models import (
Message,
Realm,
@ -63,12 +64,14 @@ def filter_by_subscription_history(
# check belongs in this inner loop, not the outer loop.
break
event_last_message_id = assert_is_not_none(log_entry.event_last_message_id)
if log_entry.event_type == RealmAuditLog.SUBSCRIPTION_DEACTIVATED:
# If the event shows the user was unsubscribed after
# event_last_message_id, we know they must have been
# subscribed immediately before the event.
for stream_message in stream_messages:
if stream_message["id"] <= log_entry.event_last_message_id:
if stream_message["id"] <= event_last_message_id:
store_user_message_to_insert(stream_message)
else:
break
@ -78,12 +81,12 @@ def filter_by_subscription_history(
):
initial_msg_count = len(stream_messages)
for i, stream_message in enumerate(stream_messages):
if stream_message["id"] > log_entry.event_last_message_id:
if stream_message["id"] > event_last_message_id:
stream_messages = stream_messages[i:]
break
final_msg_count = len(stream_messages)
if initial_msg_count == final_msg_count:
if stream_messages[-1]["id"] <= log_entry.event_last_message_id:
if stream_messages[-1]["id"] <= event_last_message_id:
stream_messages = []
else:
raise AssertionError(f"{log_entry.event_type} is not a subscription event.")
@ -172,7 +175,7 @@ def add_missing_messages(user_profile: UserProfile) -> None:
all_stream_subscription_logs: DefaultDict[int, List[RealmAuditLog]] = defaultdict(list)
for log in subscription_logs:
all_stream_subscription_logs[log.modified_stream_id].append(log)
all_stream_subscription_logs[assert_is_not_none(log.modified_stream_id)].append(log)
recipient_ids = []
for sub in all_stream_subs:

View File

@ -31,6 +31,7 @@ from PIL.Image import DecompressionBombError
from zerver.lib.avatar_hash import user_avatar_path
from zerver.lib.exceptions import ErrorCode, JsonableError
from zerver.lib.utils import assert_is_not_none
from zerver.models import Attachment, Message, Realm, RealmEmoji, UserProfile
DEFAULT_AVATAR_SIZE = 100
@ -729,7 +730,7 @@ class S3UploadBackend(ZulipUploadBackend):
def write_local_file(type: str, path: str, file_data: bytes) -> None:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path)
file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "wb") as f:
@ -737,13 +738,13 @@ def write_local_file(type: str, path: str, file_data: bytes) -> None:
def read_local_file(type: str, path: str) -> bytes:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path)
file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
with open(file_path, "rb") as f:
return f.read()
def delete_local_file(type: str, path: str) -> bool:
file_path = os.path.join(settings.LOCAL_UPLOADS_DIR, type, path)
file_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), type, path)
if os.path.isfile(file_path):
# This removes the file but the empty folders still remain.
os.remove(file_path)
@ -754,7 +755,7 @@ def delete_local_file(type: str, path: str) -> bool:
def get_local_file_path(path_id: str) -> Optional[str]:
local_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "files", path_id)
local_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "files", path_id)
if os.path.isfile(local_path):
return local_path
else:
@ -897,12 +898,16 @@ class LocalUploadBackend(ZulipUploadBackend):
file_path = user_avatar_path(user_profile)
output_path = os.path.join(
settings.LOCAL_UPLOADS_DIR, "avatars", file_path + file_extension
assert_is_not_none(settings.LOCAL_UPLOADS_DIR),
"avatars",
file_path + file_extension,
)
if os.path.isfile(output_path):
return
image_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", file_path + ".original")
image_path = os.path.join(
assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", file_path + ".original"
)
with open(image_path, "rb") as f:
image_data = f.read()
if is_medium:
@ -942,7 +947,7 @@ class LocalUploadBackend(ZulipUploadBackend):
secrets.token_urlsafe(18),
os.path.basename(tarball_path),
)
abs_path = os.path.join(settings.LOCAL_UPLOADS_DIR, "avatars", path)
abs_path = os.path.join(assert_is_not_none(settings.LOCAL_UPLOADS_DIR), "avatars", path)
os.makedirs(os.path.dirname(abs_path), exist_ok=True)
shutil.copy(tarball_path, abs_path)
public_url = realm.uri + "/user_avatars/" + path

View File

@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Set
from zerver.lib.cache import cache_with_key, get_muting_users_cache_key
from zerver.lib.timestamp import datetime_to_timestamp
from zerver.lib.utils import assert_is_not_none
from zerver.models import MutedUser, UserProfile
@ -14,7 +15,7 @@ def get_user_mutes(user_profile: UserProfile) -> List[Dict[str, int]]:
return [
{
"id": row["muted_user_id"],
"timestamp": datetime_to_timestamp(row["date_muted"]),
"timestamp": datetime_to_timestamp(assert_is_not_none(row["date_muted"])),
}
for row in rows
]

View File

@ -13,6 +13,7 @@ from zerver.lib.exceptions import JsonableError
from zerver.lib.export import get_realm_exports_serialized
from zerver.lib.queue import queue_json_publish
from zerver.lib.response import json_success
from zerver.lib.utils import assert_is_not_none
from zerver.models import RealmAuditLog, UserProfile
@ -90,7 +91,7 @@ def delete_realm_export(request: HttpRequest, user: UserProfile, export_id: int)
except RealmAuditLog.DoesNotExist:
raise JsonableError(_("Invalid data export ID"))
export_data = orjson.loads(audit_log_entry.extra_data)
export_data = orjson.loads(assert_is_not_none(audit_log_entry.extra_data))
if "deleted_timestamp" in export_data:
raise JsonableError(_("Export already deleted"))
do_delete_realm_export(user, audit_log_entry)

View File

@ -71,6 +71,7 @@ from zerver.lib.topic import (
messages_for_topic,
)
from zerver.lib.types import Validator
from zerver.lib.utils import assert_is_not_none
from zerver.lib.validator import (
check_bool,
check_capped_string,
@ -726,7 +727,9 @@ def get_topics_backend(
if is_web_public_query:
realm = get_valid_realm_from_request(request)
stream = access_web_public_stream(stream_id, realm)
result = get_topic_history_for_public_stream(recipient_id=stream.recipient_id)
result = get_topic_history_for_public_stream(
recipient_id=assert_is_not_none(stream.recipient_id)
)
else:
assert user_profile is not None
@ -753,7 +756,7 @@ def delete_in_topic(
) -> HttpResponse:
(stream, sub) = access_stream_by_id(user_profile, stream_id)
messages = messages_for_topic(stream.recipient_id, topic_name)
messages = messages_for_topic(assert_is_not_none(stream.recipient_id), topic_name)
if not stream.is_history_public_to_subscribers():
# Don't allow the user to delete messages that they don't have access to.
deletable_message_ids = UserMessage.objects.filter(