ruff: Fix SIM117 Use a single `with` statement with multiple contexts.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2024-07-14 11:30:42 -07:00 committed by Tim Abbott
parent b0f144327d
commit b96feb34f6
47 changed files with 1380 additions and 1141 deletions

View File

@ -560,12 +560,15 @@ class RemoteBillingAuthenticationTest(RemoteRealmBillingTestCase):
) )
# Try the case where the identity dict is simultaneously expired. # Try the case where the identity dict is simultaneously expired.
with time_machine.travel( with (
now + timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 30), time_machine.travel(
tick=False, now + timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 30),
tick=False,
),
self.assertLogs("django.request", "ERROR") as m,
self.assertRaises(AssertionError),
): ):
with self.assertLogs("django.request", "ERROR") as m, self.assertRaises(AssertionError): self.client_get(final_url, subdomain="selfhosting")
self.client_get(final_url, subdomain="selfhosting")
# The django.request log should be a traceback, mentioning the relevant # The django.request log should be a traceback, mentioning the relevant
# exceptions that occurred. # exceptions that occurred.
self.assertIn( self.assertIn(

View File

@ -1415,9 +1415,11 @@ class StripeTest(StripeTestCase):
self.assertFalse(Customer.objects.filter(realm=user.realm).exists()) self.assertFalse(Customer.objects.filter(realm=user.realm).exists())
# Require free trial users to add a credit card. # Require free trial users to add a credit card.
with time_machine.travel(self.now, tick=False): with (
with self.assertLogs("corporate.stripe", "WARNING"): time_machine.travel(self.now, tick=False),
response = self.upgrade() self.assertLogs("corporate.stripe", "WARNING"),
):
response = self.upgrade()
self.assert_json_error( self.assert_json_error(
response, "Please add a credit card before starting your free trial." response, "Please add a credit card before starting your free trial."
) )
@ -1953,12 +1955,14 @@ class StripeTest(StripeTestCase):
initial_upgrade_request initial_upgrade_request
) )
# Change the seat count while the user is going through the upgrade flow # Change the seat count while the user is going through the upgrade flow
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=new_seat_count): with (
with patch( patch("corporate.lib.stripe.get_latest_seat_count", return_value=new_seat_count),
patch(
"corporate.lib.stripe.RealmBillingSession.get_initial_upgrade_context", "corporate.lib.stripe.RealmBillingSession.get_initial_upgrade_context",
return_value=(_, context_when_upgrade_page_is_rendered), return_value=(_, context_when_upgrade_page_is_rendered),
): ),
self.add_card_and_upgrade(hamlet) ):
self.add_card_and_upgrade(hamlet)
customer = Customer.objects.first() customer = Customer.objects.first()
assert customer is not None assert customer is not None
@ -2072,11 +2076,13 @@ class StripeTest(StripeTestCase):
hamlet = self.example_user("hamlet") hamlet = self.example_user("hamlet")
self.login_user(hamlet) self.login_user(hamlet)
self.local_upgrade(self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False) self.local_upgrade(self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False)
with self.assertLogs("corporate.stripe", "WARNING") as m: with (
with self.assertRaises(BillingError) as context: self.assertLogs("corporate.stripe", "WARNING") as m,
self.local_upgrade( self.assertRaises(BillingError) as context,
self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False ):
) self.local_upgrade(
self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False
)
self.assertEqual( self.assertEqual(
"subscribing with existing subscription", context.exception.error_description "subscribing with existing subscription", context.exception.error_description
) )
@ -2197,14 +2203,16 @@ class StripeTest(StripeTestCase):
else: else:
del_args = [] del_args = []
upgrade_params["licenses"] = licenses upgrade_params["licenses"] = licenses
with patch("corporate.lib.stripe.BillingSession.process_initial_upgrade"): with (
with patch( patch("corporate.lib.stripe.BillingSession.process_initial_upgrade"),
patch(
"corporate.lib.stripe.BillingSession.create_stripe_invoice_and_charge", "corporate.lib.stripe.BillingSession.create_stripe_invoice_and_charge",
return_value="fake_stripe_invoice_id", return_value="fake_stripe_invoice_id",
): ),
response = self.upgrade( ):
invoice=invoice, talk_to_stripe=False, del_args=del_args, **upgrade_params response = self.upgrade(
) invoice=invoice, talk_to_stripe=False, del_args=del_args, **upgrade_params
)
self.assert_json_success(response) self.assert_json_success(response)
# Autopay with licenses < seat count # Autopay with licenses < seat count
@ -2911,18 +2919,20 @@ class StripeTest(StripeTestCase):
assert plan is not None assert plan is not None
self.assertEqual(plan.licenses(), self.seat_count) self.assertEqual(plan.licenses(), self.seat_count)
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count)
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, response = self.client_billing_patch(
) "/billing/plan",
stripe_customer_id = Customer.objects.get(realm=user.realm).id {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE},
new_plan = get_current_plan_by_realm(user.realm) )
assert new_plan is not None stripe_customer_id = Customer.objects.get(realm=user.realm).id
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" new_plan = get_current_plan_by_realm(user.realm)
self.assertEqual(m.output[0], expected_log) assert new_plan is not None
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
plan.refresh_from_db() plan.refresh_from_db()
self.assertEqual(plan.licenses(), self.seat_count) self.assertEqual(plan.licenses(), self.seat_count)
self.assertEqual(plan.licenses_at_next_renewal(), None) self.assertEqual(plan.licenses_at_next_renewal(), None)
@ -3034,15 +3044,17 @@ class StripeTest(StripeTestCase):
new_plan = get_current_plan_by_realm(user.realm) new_plan = get_current_plan_by_realm(user.realm)
assert new_plan is not None assert new_plan is not None
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, response = self.client_billing_patch(
) "/billing/plan",
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}" {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE},
self.assertEqual(m.output[0], expected_log) )
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
monthly_plan.refresh_from_db() monthly_plan.refresh_from_db()
self.assertEqual(monthly_plan.status, CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE) self.assertEqual(monthly_plan.status, CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE)
with time_machine.travel(self.now, tick=False): with time_machine.travel(self.now, tick=False):
@ -3062,9 +3074,11 @@ class StripeTest(StripeTestCase):
(20, 20), (20, 20),
) )
with time_machine.travel(self.next_month, tick=False): with (
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25): time_machine.travel(self.next_month, tick=False),
billing_session.update_license_ledger_if_needed(self.next_month) patch("corporate.lib.stripe.get_latest_seat_count", return_value=25),
):
billing_session.update_license_ledger_if_needed(self.next_month)
self.assertEqual(LicenseLedger.objects.filter(plan=monthly_plan).count(), 2) self.assertEqual(LicenseLedger.objects.filter(plan=monthly_plan).count(), 2)
customer = get_customer_by_realm(user.realm) customer = get_customer_by_realm(user.realm)
assert customer is not None assert customer is not None
@ -3230,17 +3244,19 @@ class StripeTest(StripeTestCase):
stripe_customer_id = Customer.objects.get(realm=user.realm).id stripe_customer_id = Customer.objects.get(realm=user.realm).id
new_plan = get_current_plan_by_realm(user.realm) new_plan = get_current_plan_by_realm(user.realm)
assert new_plan is not None assert new_plan is not None
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, response = self.client_billing_patch(
) "/billing/plan",
self.assertEqual( {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE},
m.output[0], )
f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}", self.assertEqual(
) m.output[0],
self.assert_json_success(response) f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}",
)
self.assert_json_success(response)
monthly_plan.refresh_from_db() monthly_plan.refresh_from_db()
self.assertEqual(monthly_plan.status, CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE) self.assertEqual(monthly_plan.status, CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE)
with time_machine.travel(self.now, tick=False): with time_machine.travel(self.now, tick=False):
@ -3343,15 +3359,17 @@ class StripeTest(StripeTestCase):
assert new_plan is not None assert new_plan is not None
assert self.now is not None assert self.now is not None
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}, response = self.client_billing_patch(
) "/billing/plan",
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}" {"status": CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE},
self.assertEqual(m.output[0], expected_log) )
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
annual_plan.refresh_from_db() annual_plan.refresh_from_db()
self.assertEqual(annual_plan.status, CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE) self.assertEqual(annual_plan.status, CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE)
with time_machine.travel(self.now, tick=False): with time_machine.travel(self.now, tick=False):
@ -3375,9 +3393,11 @@ class StripeTest(StripeTestCase):
# additional licenses) but at the end of current billing cycle. # additional licenses) but at the end of current billing cycle.
self.assertEqual(annual_plan.next_invoice_date, self.next_month) self.assertEqual(annual_plan.next_invoice_date, self.next_month)
assert annual_plan.next_invoice_date is not None assert annual_plan.next_invoice_date is not None
with time_machine.travel(annual_plan.next_invoice_date, tick=False): with (
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25): time_machine.travel(annual_plan.next_invoice_date, tick=False),
billing_session.update_license_ledger_if_needed(annual_plan.next_invoice_date) patch("corporate.lib.stripe.get_latest_seat_count", return_value=25),
):
billing_session.update_license_ledger_if_needed(annual_plan.next_invoice_date)
annual_plan.refresh_from_db() annual_plan.refresh_from_db()
self.assertEqual(annual_plan.status, CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE) self.assertEqual(annual_plan.status, CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE)
@ -3430,9 +3450,11 @@ class StripeTest(StripeTestCase):
self.assertEqual(invoice_item2[key], value) self.assertEqual(invoice_item2[key], value)
# Check that we switch to monthly plan at the end of current billing cycle. # Check that we switch to monthly plan at the end of current billing cycle.
with time_machine.travel(self.next_year, tick=False): with (
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25): time_machine.travel(self.next_year, tick=False),
billing_session.update_license_ledger_if_needed(self.next_year) patch("corporate.lib.stripe.get_latest_seat_count", return_value=25),
):
billing_session.update_license_ledger_if_needed(self.next_year)
self.assertEqual(LicenseLedger.objects.filter(plan=annual_plan).count(), 3) self.assertEqual(LicenseLedger.objects.filter(plan=annual_plan).count(), 3)
customer = get_customer_by_realm(user.realm) customer = get_customer_by_realm(user.realm)
assert customer is not None assert customer is not None
@ -3513,30 +3535,34 @@ class StripeTest(StripeTestCase):
self.local_upgrade( self.local_upgrade(
self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False
) )
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, response = self.client_billing_patch(
) "/billing/plan",
stripe_customer_id = Customer.objects.get(realm=user.realm).id {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE},
new_plan = get_current_plan_by_realm(user.realm) )
assert new_plan is not None stripe_customer_id = Customer.objects.get(realm=user.realm).id
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" new_plan = get_current_plan_by_realm(user.realm)
self.assertEqual(m.output[0], expected_log) assert new_plan is not None
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
plan = CustomerPlan.objects.first() plan = CustomerPlan.objects.first()
assert plan is not None assert plan is not None
self.assertEqual(plan.status, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE) self.assertEqual(plan.status, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE)
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.ACTIVE}, response = self.client_billing_patch(
) "/billing/plan",
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.ACTIVE}" {"status": CustomerPlan.ACTIVE},
self.assertEqual(m.output[0], expected_log) )
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.ACTIVE}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
plan = CustomerPlan.objects.first() plan = CustomerPlan.objects.first()
assert plan is not None assert plan is not None
self.assertEqual(plan.status, CustomerPlan.ACTIVE) self.assertEqual(plan.status, CustomerPlan.ACTIVE)
@ -3587,55 +3613,54 @@ class StripeTest(StripeTestCase):
self.login_user(user) self.login_user(user)
free_trial_end_date = self.now + timedelta(days=60) free_trial_end_date = self.now + timedelta(days=60)
with self.settings(CLOUD_FREE_TRIAL_DAYS=60): with self.settings(CLOUD_FREE_TRIAL_DAYS=60), time_machine.travel(self.now, tick=False):
with time_machine.travel(self.now, tick=False): self.add_card_and_upgrade(user, schedule="monthly")
self.add_card_and_upgrade(user, schedule="monthly") plan = CustomerPlan.objects.get()
plan = CustomerPlan.objects.get() self.assertEqual(plan.next_invoice_date, free_trial_end_date)
self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD)
self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL)
self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL)
customer = get_customer_by_realm(user.realm) customer = get_customer_by_realm(user.realm)
assert customer is not None assert customer is not None
result = self.client_billing_patch( result = self.client_billing_patch(
"/billing/plan", "/billing/plan",
{ {
"status": CustomerPlan.FREE_TRIAL, "status": CustomerPlan.FREE_TRIAL,
"schedule": CustomerPlan.BILLING_SCHEDULE_ANNUAL, "schedule": CustomerPlan.BILLING_SCHEDULE_ANNUAL,
}, },
) )
self.assert_json_success(result) self.assert_json_success(result)
plan.refresh_from_db() plan.refresh_from_db()
self.assertEqual(plan.status, CustomerPlan.ENDED) self.assertEqual(plan.status, CustomerPlan.ENDED)
self.assertIsNone(plan.next_invoice_date) self.assertIsNone(plan.next_invoice_date)
new_plan = CustomerPlan.objects.get( new_plan = CustomerPlan.objects.get(
customer=customer, customer=customer,
automanage_licenses=True, automanage_licenses=True,
price_per_license=8000, price_per_license=8000,
fixed_price=None, fixed_price=None,
discount=None, discount=None,
billing_cycle_anchor=self.now, billing_cycle_anchor=self.now,
billing_schedule=CustomerPlan.BILLING_SCHEDULE_ANNUAL, billing_schedule=CustomerPlan.BILLING_SCHEDULE_ANNUAL,
next_invoice_date=free_trial_end_date, next_invoice_date=free_trial_end_date,
tier=CustomerPlan.TIER_CLOUD_STANDARD, tier=CustomerPlan.TIER_CLOUD_STANDARD,
status=CustomerPlan.FREE_TRIAL, status=CustomerPlan.FREE_TRIAL,
charge_automatically=True, charge_automatically=True,
) )
ledger_entry = LicenseLedger.objects.get( ledger_entry = LicenseLedger.objects.get(
plan=new_plan, plan=new_plan,
is_renewal=True, is_renewal=True,
event_time=self.now, event_time=self.now,
licenses=self.seat_count, licenses=self.seat_count,
licenses_at_next_renewal=self.seat_count, licenses_at_next_renewal=self.seat_count,
) )
self.assertEqual(new_plan.invoiced_through, ledger_entry) self.assertEqual(new_plan.invoiced_through, ledger_entry)
realm_audit_log = RealmAuditLog.objects.filter( realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN
).last() ).last()
assert realm_audit_log is not None assert realm_audit_log is not None
@mock_stripe() @mock_stripe()
def test_switch_now_free_trial_from_annual_to_monthly(self, *mocks: Mock) -> None: def test_switch_now_free_trial_from_annual_to_monthly(self, *mocks: Mock) -> None:
@ -3643,54 +3668,53 @@ class StripeTest(StripeTestCase):
self.login_user(user) self.login_user(user)
free_trial_end_date = self.now + timedelta(days=60) free_trial_end_date = self.now + timedelta(days=60)
with self.settings(CLOUD_FREE_TRIAL_DAYS=60): with self.settings(CLOUD_FREE_TRIAL_DAYS=60), time_machine.travel(self.now, tick=False):
with time_machine.travel(self.now, tick=False): self.add_card_and_upgrade(user, schedule="annual")
self.add_card_and_upgrade(user, schedule="annual") plan = CustomerPlan.objects.get()
plan = CustomerPlan.objects.get() self.assertEqual(plan.next_invoice_date, free_trial_end_date)
self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD)
self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL)
self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL)
customer = get_customer_by_realm(user.realm) customer = get_customer_by_realm(user.realm)
assert customer is not None assert customer is not None
result = self.client_billing_patch( result = self.client_billing_patch(
"/billing/plan", "/billing/plan",
{ {
"status": CustomerPlan.FREE_TRIAL, "status": CustomerPlan.FREE_TRIAL,
"schedule": CustomerPlan.BILLING_SCHEDULE_MONTHLY, "schedule": CustomerPlan.BILLING_SCHEDULE_MONTHLY,
}, },
) )
self.assert_json_success(result) self.assert_json_success(result)
plan.refresh_from_db() plan.refresh_from_db()
self.assertEqual(plan.status, CustomerPlan.ENDED) self.assertEqual(plan.status, CustomerPlan.ENDED)
self.assertIsNone(plan.next_invoice_date) self.assertIsNone(plan.next_invoice_date)
new_plan = CustomerPlan.objects.get( new_plan = CustomerPlan.objects.get(
customer=customer, customer=customer,
automanage_licenses=True, automanage_licenses=True,
price_per_license=800, price_per_license=800,
fixed_price=None, fixed_price=None,
discount=None, discount=None,
billing_cycle_anchor=self.now, billing_cycle_anchor=self.now,
billing_schedule=CustomerPlan.BILLING_SCHEDULE_MONTHLY, billing_schedule=CustomerPlan.BILLING_SCHEDULE_MONTHLY,
next_invoice_date=free_trial_end_date, next_invoice_date=free_trial_end_date,
tier=CustomerPlan.TIER_CLOUD_STANDARD, tier=CustomerPlan.TIER_CLOUD_STANDARD,
status=CustomerPlan.FREE_TRIAL, status=CustomerPlan.FREE_TRIAL,
charge_automatically=True, charge_automatically=True,
) )
ledger_entry = LicenseLedger.objects.get( ledger_entry = LicenseLedger.objects.get(
plan=new_plan, plan=new_plan,
is_renewal=True, is_renewal=True,
event_time=self.now, event_time=self.now,
licenses=self.seat_count, licenses=self.seat_count,
licenses_at_next_renewal=self.seat_count, licenses_at_next_renewal=self.seat_count,
) )
self.assertEqual(new_plan.invoiced_through, ledger_entry) self.assertEqual(new_plan.invoiced_through, ledger_entry)
realm_audit_log = RealmAuditLog.objects.filter( realm_audit_log = RealmAuditLog.objects.filter(
event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_ANNUAL_TO_MONTHLY_PLAN event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_ANNUAL_TO_MONTHLY_PLAN
).last() ).last()
assert realm_audit_log is not None assert realm_audit_log is not None
@mock_stripe() @mock_stripe()
def test_end_free_trial(self, *mocks: Mock) -> None: def test_end_free_trial(self, *mocks: Mock) -> None:
@ -3764,18 +3788,20 @@ class StripeTest(StripeTestCase):
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count)
# Schedule downgrade # Schedule downgrade
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}, response = self.client_billing_patch(
) "/billing/plan",
stripe_customer_id = Customer.objects.get(realm=user.realm).id {"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL},
new_plan = get_current_plan_by_realm(user.realm) )
assert new_plan is not None stripe_customer_id = Customer.objects.get(realm=user.realm).id
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}" new_plan = get_current_plan_by_realm(user.realm)
self.assertEqual(m.output[0], expected_log) assert new_plan is not None
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
plan.refresh_from_db() plan.refresh_from_db()
self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(plan.next_invoice_date, free_trial_end_date)
self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD)
@ -3874,18 +3900,20 @@ class StripeTest(StripeTestCase):
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count)
# Schedule downgrade # Schedule downgrade
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}, response = self.client_billing_patch(
) "/billing/plan",
stripe_customer_id = Customer.objects.get(realm=user.realm).id {"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL},
new_plan = get_current_plan_by_realm(user.realm) )
assert new_plan is not None stripe_customer_id = Customer.objects.get(realm=user.realm).id
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}" new_plan = get_current_plan_by_realm(user.realm)
self.assertEqual(m.output[0], expected_log) assert new_plan is not None
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
plan.refresh_from_db() plan.refresh_from_db()
self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(plan.next_invoice_date, free_trial_end_date)
self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD)
@ -3894,18 +3922,20 @@ class StripeTest(StripeTestCase):
self.assertEqual(plan.licenses_at_next_renewal(), None) self.assertEqual(plan.licenses_at_next_renewal(), None)
# Cancel downgrade # Cancel downgrade
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.FREE_TRIAL}, response = self.client_billing_patch(
) "/billing/plan",
stripe_customer_id = Customer.objects.get(realm=user.realm).id {"status": CustomerPlan.FREE_TRIAL},
new_plan = get_current_plan_by_realm(user.realm) )
assert new_plan is not None stripe_customer_id = Customer.objects.get(realm=user.realm).id
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.FREE_TRIAL}" new_plan = get_current_plan_by_realm(user.realm)
self.assertEqual(m.output[0], expected_log) assert new_plan is not None
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.FREE_TRIAL}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
plan.refresh_from_db() plan.refresh_from_db()
self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(plan.next_invoice_date, free_trial_end_date)
self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD)
@ -3937,11 +3967,11 @@ class StripeTest(StripeTestCase):
with ( with (
self.assertRaises(BillingError) as context, self.assertRaises(BillingError) as context,
self.assertLogs("corporate.stripe", "WARNING") as m, self.assertLogs("corporate.stripe", "WARNING") as m,
time_machine.travel(self.now, tick=False),
): ):
with time_machine.travel(self.now, tick=False): self.local_upgrade(
self.local_upgrade( self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False
self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False )
)
self.assertEqual( self.assertEqual(
m.output[0], m.output[0],
"WARNING:corporate.stripe:Upgrade of <Realm: zulip 2> (with stripe_customer_id: cus_123) failed because of existing active plan.", "WARNING:corporate.stripe:Upgrade of <Realm: zulip 2> (with stripe_customer_id: cus_123) failed because of existing active plan.",
@ -4242,17 +4272,19 @@ class StripeTest(StripeTestCase):
) )
self.login_user(self.example_user("hamlet")) self.login_user(self.example_user("hamlet"))
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
result = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, result = self.client_billing_patch(
) "/billing/plan",
self.assert_json_success(result) {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE},
self.assertRegex( )
m.output[0], self.assert_json_success(result)
r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 2", self.assertRegex(
) m.output[0],
r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 2",
)
with time_machine.travel(self.next_year, tick=False): with time_machine.travel(self.next_year, tick=False):
result = self.client_billing_patch( result = self.client_billing_patch(
@ -4270,17 +4302,19 @@ class StripeTest(StripeTestCase):
) )
self.login_user(self.example_user("hamlet")) self.login_user(self.example_user("hamlet"))
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now, tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
result = self.client_billing_patch( time_machine.travel(self.now, tick=False),
"/billing/plan", ):
{"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, result = self.client_billing_patch(
) "/billing/plan",
self.assert_json_success(result) {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE},
self.assertRegex( )
m.output[0], self.assert_json_success(result)
r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 4", self.assertRegex(
) m.output[0],
r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 4",
)
with time_machine.travel(self.next_month, tick=False): with time_machine.travel(self.next_month, tick=False):
result = self.client_billing_patch("/billing/plan", {}) result = self.client_billing_patch("/billing/plan", {})
@ -5602,11 +5636,13 @@ class LicenseLedgerTest(StripeTestCase):
self.assertEqual(plan.licenses(), self.seat_count + 3) self.assertEqual(plan.licenses(), self.seat_count + 3)
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count + 3) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count + 3)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): with (
with self.assertRaises(AssertionError): patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count),
billing_session.update_license_ledger_for_manual_plan( self.assertRaises(AssertionError),
plan, self.now, licenses=self.seat_count ):
) billing_session.update_license_ledger_for_manual_plan(
plan, self.now, licenses=self.seat_count
)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count):
billing_session.update_license_ledger_for_manual_plan( billing_session.update_license_ledger_for_manual_plan(
@ -5615,11 +5651,13 @@ class LicenseLedgerTest(StripeTestCase):
self.assertEqual(plan.licenses(), self.seat_count + 3) self.assertEqual(plan.licenses(), self.seat_count + 3)
self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): with (
with self.assertRaises(AssertionError): patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count),
billing_session.update_license_ledger_for_manual_plan( self.assertRaises(AssertionError),
plan, self.now, licenses_at_next_renewal=self.seat_count - 1 ):
) billing_session.update_license_ledger_for_manual_plan(
plan, self.now, licenses_at_next_renewal=self.seat_count - 1
)
with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count):
billing_session.update_license_ledger_for_manual_plan( billing_session.update_license_ledger_for_manual_plan(
@ -6614,11 +6652,13 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase):
# Same result even with free trial enabled for self hosted customers since we don't # Same result even with free trial enabled for self hosted customers since we don't
# offer free trial for business plan. # offer free trial for business plan.
with self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30): with (
with time_machine.travel(self.now, tick=False): self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30),
result = self.client_get( time_machine.travel(self.now, tick=False),
f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting" ):
) result = self.client_get(
f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting"
)
self.assert_in_success_response( self.assert_in_success_response(
[ [
@ -6631,11 +6671,10 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase):
) )
# Check that cloud free trials don't affect self hosted customers. # Check that cloud free trials don't affect self hosted customers.
with self.settings(CLOUD_FREE_TRIAL_DAYS=30): with self.settings(CLOUD_FREE_TRIAL_DAYS=30), time_machine.travel(self.now, tick=False):
with time_machine.travel(self.now, tick=False): result = self.client_get(
result = self.client_get( f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting"
f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting" )
)
self.assert_in_success_response( self.assert_in_success_response(
[ [
@ -8018,15 +8057,17 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase):
self.assertEqual(result["Location"], f"{billing_base_url}/billing/") self.assertEqual(result["Location"], f"{billing_base_url}/billing/")
# Downgrade # Downgrade
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now + timedelta(days=7), tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now + timedelta(days=7), tick=False),
"/billing/plan", ):
{"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, response = self.client_billing_patch(
) "/billing/plan",
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {business_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE},
self.assertEqual(m.output[0], expected_log) )
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {business_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
business_plan.refresh_from_db() business_plan.refresh_from_db()
self.assertEqual(business_plan.licenses_at_next_renewal(), None) self.assertEqual(business_plan.licenses_at_next_renewal(), None)
@ -8323,9 +8364,11 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase):
# Same result even with free trial enabled for self hosted customers since we don't # Same result even with free trial enabled for self hosted customers since we don't
# offer free trial for business plan. # offer free trial for business plan.
with self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30): with (
with time_machine.travel(self.now, tick=False): self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30),
result = self.client_get(f"{billing_base_url}/upgrade/", subdomain="selfhosting") time_machine.travel(self.now, tick=False),
):
result = self.client_get(f"{billing_base_url}/upgrade/", subdomain="selfhosting")
self.assert_in_success_response(["Add card", "Purchase Zulip Business"], result) self.assert_in_success_response(["Add card", "Purchase Zulip Business"], result)
@ -8390,18 +8433,20 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase):
self.assertEqual(result["Location"], f"{billing_base_url}/billing/") self.assertEqual(result["Location"], f"{billing_base_url}/billing/")
# Downgrade # Downgrade
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now + timedelta(days=7), tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now + timedelta(days=7), tick=False),
"/billing/plan", ):
{"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, response = self.client_billing_patch(
) "/billing/plan",
customer = Customer.objects.get(remote_server=self.remote_server) {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE},
new_plan = get_current_plan_by_customer(customer) )
assert new_plan is not None customer = Customer.objects.get(remote_server=self.remote_server)
expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" new_plan = get_current_plan_by_customer(customer)
self.assertEqual(m.output[0], expected_log) assert new_plan is not None
self.assert_json_success(response) expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}"
self.assertEqual(m.output[0], expected_log)
self.assert_json_success(response)
self.assertEqual(new_plan.licenses_at_next_renewal(), None) self.assertEqual(new_plan.licenses_at_next_renewal(), None)
@responses.activate @responses.activate
@ -8599,21 +8644,23 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase):
self.assertEqual(result["Location"], f"{billing_base_url}/billing/") self.assertEqual(result["Location"], f"{billing_base_url}/billing/")
# Downgrade # Downgrade
with self.assertLogs("corporate.stripe", "INFO") as m: with (
with time_machine.travel(self.now + timedelta(days=7), tick=False): self.assertLogs("corporate.stripe", "INFO") as m,
response = self.client_billing_patch( time_machine.travel(self.now + timedelta(days=7), tick=False),
"/billing/plan", ):
{"status": CustomerPlan.ACTIVE}, response = self.client_billing_patch(
) "/billing/plan",
self.assert_json_success(response) {"status": CustomerPlan.ACTIVE},
self.assertEqual( )
m.output[0], self.assert_json_success(response)
f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_customer_plan.id}, status: {CustomerPlan.ENDED}", self.assertEqual(
) m.output[0],
self.assertEqual( f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_customer_plan.id}, status: {CustomerPlan.ENDED}",
m.output[1], )
f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {customer_plan.id}, status: {CustomerPlan.ACTIVE}", self.assertEqual(
) m.output[1],
f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {customer_plan.id}, status: {CustomerPlan.ACTIVE}",
)
@responses.activate @responses.activate
@mock_stripe() @mock_stripe()

View File

@ -186,7 +186,6 @@ ignore = [
"SIM103", # Return the condition directly "SIM103", # Return the condition directly
"SIM108", # Use ternary operator `action = "[commented]" if action == "created" else f"{action} a [comment]"` instead of if-else-block "SIM108", # Use ternary operator `action = "[commented]" if action == "created" else f"{action} a [comment]"` instead of if-else-block
"SIM114", # Combine `if` branches using logical `or` operator "SIM114", # Combine `if` branches using logical `or` operator
"SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements
"SIM401", # Use `d.get(key, default)` instead of an `if` block "SIM401", # Use `d.get(key, default)` instead of an `if` block
"TCH001", # Move application import into a type-checking block "TCH001", # Move application import into a type-checking block
"TCH002", # Move third-party import into a type-checking block "TCH002", # Move third-party import into a type-checking block

View File

@ -66,9 +66,11 @@ def main() -> None:
args = parser.parse_args() args = parser.parse_args()
sns_topic_arn = get_ses_arn(session, args) sns_topic_arn = get_ses_arn(session, args)
with our_sqs_queue(session, sns_topic_arn) as (queue_arn, queue_url): with (
with our_sns_subscription(session, sns_topic_arn, queue_arn): our_sqs_queue(session, sns_topic_arn) as (queue_arn, queue_url),
print_messages(session, queue_url) our_sns_subscription(session, sns_topic_arn, queue_arn),
):
print_messages(session, queue_url)
def get_ses_arn(session: boto3.session.Session, args: argparse.Namespace) -> str: def get_ses_arn(session: boto3.session.Session, args: argparse.Namespace) -> str:

View File

@ -168,11 +168,13 @@ def get_failed_tests() -> list[str]:
def block_internet() -> Iterator[None]: def block_internet() -> Iterator[None]:
# Monkey-patching - responses library raises requests.ConnectionError when access to an unregistered URL # Monkey-patching - responses library raises requests.ConnectionError when access to an unregistered URL
# is attempted. We want to replace that with our own exception, so that it propagates all the way: # is attempted. We want to replace that with our own exception, so that it propagates all the way:
with mock.patch.object(responses, "ConnectionError", new=ZulipInternetBlockedError): with (
mock.patch.object(responses, "ConnectionError", new=ZulipInternetBlockedError),
# We'll run all tests in this context manager. It'll cause an error to be raised (see above comment), # We'll run all tests in this context manager. It'll cause an error to be raised (see above comment),
# if any code attempts to access the internet. # if any code attempts to access the internet.
with responses.RequestsMock(): responses.RequestsMock(),
yield ):
yield
class ZulipInternetBlockedError(Exception): class ZulipInternetBlockedError(Exception):

View File

@ -19,14 +19,13 @@ CACHE_FILE = os.path.join(CACHE_DIR, "requirements_hashes")
def print_diff(path_file1: str, path_file2: str) -> None: def print_diff(path_file1: str, path_file2: str) -> None:
with open(path_file1) as file1: with open(path_file1) as file1, open(path_file2) as file2:
with open(path_file2) as file2: diff = difflib.unified_diff(
diff = difflib.unified_diff( file1.readlines(),
file1.readlines(), file2.readlines(),
file2.readlines(), fromfile=path_file1,
fromfile=path_file1, tofile=path_file2,
tofile=path_file2, )
)
sys.stdout.writelines(diff) sys.stdout.writelines(diff)

View File

@ -1347,10 +1347,12 @@ def fetch_team_icons(
) )
resized_icon_output_path = os.path.join(output_dir, str(realm_id), "icon.png") resized_icon_output_path = os.path.join(output_dir, str(realm_id), "icon.png")
with open(resized_icon_output_path, "wb") as output_file: with (
with open(original_icon_output_path, "rb") as original_file: open(resized_icon_output_path, "wb") as output_file,
resized_data = resize_logo(original_file.read()) open(original_icon_output_path, "rb") as original_file,
output_file.write(resized_data) ):
resized_data = resize_logo(original_file.read())
output_file.write(resized_data)
records.append( records.append(
{ {
"realm_id": realm_id, "realm_id": realm_id,

View File

@ -28,9 +28,8 @@ def lockfile(filename: str, shared: bool = False) -> Iterator[None]:
If shared is True, use a LOCK_SH lock, otherwise LOCK_EX. If shared is True, use a LOCK_SH lock, otherwise LOCK_EX.
The file is given by name and will be created if it does not exist.""" The file is given by name and will be created if it does not exist."""
with open(filename, "w") as lock: with open(filename, "w") as lock, flock(lock, shared=shared):
with flock(lock, shared=shared): yield
yield
@contextmanager @contextmanager

View File

@ -548,15 +548,17 @@ def custom_email_sender(
rendered_input = render_markdown_path(plain_text_template_path.replace("templates/", "")) rendered_input = render_markdown_path(plain_text_template_path.replace("templates/", ""))
# And then extend it with our standard email headers. # And then extend it with our standard email headers.
with open(html_template_path, "w") as f: with (
with open(markdown_email_base_template_path) as base_template: open(html_template_path, "w") as f,
# We use an ugly string substitution here, because we want to: open(markdown_email_base_template_path) as base_template,
# 1. Only run Jinja once on the supplied content ):
# 2. Allow the supplied content to have jinja interpolation in it # We use an ugly string substitution here, because we want to:
# 3. Have that interpolation happen in the context of # 1. Only run Jinja once on the supplied content
# each individual email we send, so the contents can # 2. Allow the supplied content to have jinja interpolation in it
# vary user-to-user # 3. Have that interpolation happen in the context of
f.write(base_template.read().replace("{{ rendered_input }}", rendered_input)) # each individual email we send, so the contents can
# vary user-to-user
f.write(base_template.read().replace("{{ rendered_input }}", rendered_input))
with open(subject_path, "w") as f: with open(subject_path, "w") as f:
f.write(get_header(subject, parsed_email_template.get("subject"), "subject")) f.write(get_header(subject, parsed_email_template.get("subject"), "subject"))

View File

@ -2018,14 +2018,16 @@ class ZulipTestCase(ZulipTestCaseMixin, TestCase):
# Some code might call process_notification using keyword arguments, # Some code might call process_notification using keyword arguments,
# so mypy doesn't allow assigning lst.append to process_notification # so mypy doesn't allow assigning lst.append to process_notification
# So explicitly change parameter name to 'notice' to work around this problem # So explicitly change parameter name to 'notice' to work around this problem
with mock.patch("zerver.tornado.event_queue.process_notification", lst.append): with (
mock.patch("zerver.tornado.event_queue.process_notification", lst.append),
# Some `send_event` calls need to be executed only after the current transaction # Some `send_event` calls need to be executed only after the current transaction
# commits (using `on_commit` hooks). Because the transaction in Django tests never # commits (using `on_commit` hooks). Because the transaction in Django tests never
# commits (rather, gets rolled back after the test completes), such events would # commits (rather, gets rolled back after the test completes), such events would
# never be sent in tests, and we would be unable to verify them. Hence, we use # never be sent in tests, and we would be unable to verify them. Hence, we use
# this helper to make sure the `send_event` calls actually run. # this helper to make sure the `send_event` calls actually run.
with self.captureOnCommitCallbacks(execute=True): self.captureOnCommitCallbacks(execute=True),
yield lst ):
yield lst
self.assert_length(lst, expected_num_events) self.assert_length(lst, expected_num_events)

View File

@ -71,9 +71,11 @@ class MockLDAP(fakeldap.MockLDAP):
def stub_event_queue_user_events( def stub_event_queue_user_events(
event_queue_return: Any, user_events_return: Any event_queue_return: Any, user_events_return: Any
) -> Iterator[None]: ) -> Iterator[None]:
with mock.patch("zerver.lib.events.request_event_queue", return_value=event_queue_return): with (
with mock.patch("zerver.lib.events.get_user_events", return_value=user_events_return): mock.patch("zerver.lib.events.request_event_queue", return_value=event_queue_return),
yield mock.patch("zerver.lib.events.get_user_events", return_value=user_events_return),
):
yield
@contextmanager @contextmanager

View File

@ -186,9 +186,11 @@ def thumbnail_local_emoji(apps: StateApps) -> None:
) )
new_file_name = get_emoji_file_name("image/png", emoji.id) new_file_name = get_emoji_file_name("image/png", emoji.id)
try: try:
with open(f"{settings.DEPLOY_ROOT}/static/images/bad-emoji.png", "rb") as f: with (
with open(f"{base_path}/{new_file_name}", "wb") as new_f: open(f"{settings.DEPLOY_ROOT}/static/images/bad-emoji.png", "rb") as f,
new_f.write(f.read()) open(f"{base_path}/{new_file_name}", "wb") as new_f,
):
new_f.write(f.read())
emoji.deactivated = True emoji.deactivated = True
emoji.is_animated = False emoji.is_animated = False
emoji.file_name = new_file_name emoji.file_name = new_file_name

View File

@ -3415,14 +3415,16 @@ class AppleIdAuthBackendTest(AppleAuthMixin, SocialAuthBase):
def test_id_token_verification_failure(self) -> None: def test_id_token_verification_failure(self) -> None:
account_data_dict = self.get_account_data_dict(email=self.email, name=self.name) account_data_dict = self.get_account_data_dict(email=self.email, name=self.name)
with self.assertLogs(self.logger_string, level="INFO") as m: with (
with mock.patch("jwt.decode", side_effect=PyJWTError): self.assertLogs(self.logger_string, level="INFO") as m,
result = self.social_auth_test( mock.patch("jwt.decode", side_effect=PyJWTError),
account_data_dict, ):
expect_choose_email_screen=True, result = self.social_auth_test(
subdomain="zulip", account_data_dict,
is_signup=True, expect_choose_email_screen=True,
) subdomain="zulip",
is_signup=True,
)
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
self.assertEqual(result["Location"], "/login/") self.assertEqual(result["Location"], "/login/")
self.assertEqual( self.assertEqual(
@ -4583,9 +4585,11 @@ class GoogleAuthBackendTest(SocialAuthBase):
"redirect_to": next, "redirect_to": next,
} }
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
with mock.patch("zerver.views.auth.authenticate", return_value=user_profile): with (
with mock.patch("zerver.views.auth.do_login"): mock.patch("zerver.views.auth.authenticate", return_value=user_profile),
result = self.get_log_into_subdomain(data) mock.patch("zerver.views.auth.do_login"),
):
result = self.get_log_into_subdomain(data)
return result return result
res = test_redirect_to_next_url() res = test_redirect_to_next_url()
@ -5666,49 +5670,55 @@ class TestZulipRemoteUserBackend(DesktopFlowTestingLib, ZulipTestCase):
def test_login_failure_due_to_wrong_subdomain(self) -> None: def test_login_failure_due_to_wrong_subdomain(self) -> None:
email = self.example_email("hamlet") email = self.example_email("hamlet")
with self.settings( with (
AUTHENTICATION_BACKENDS=( self.settings(
"zproject.backends.ZulipRemoteUserBackend",
"zproject.backends.ZulipDummyBackend",
)
):
with mock.patch("zerver.views.auth.get_subdomain", return_value="acme"):
result = self.client_get(
"http://testserver:9080/accounts/login/sso/", REMOTE_USER=email
)
self.assertEqual(result.status_code, 200)
self.assert_logged_in_user_id(None)
self.assert_in_response("You need an invitation to join this organization.", result)
def test_login_failure_due_to_empty_subdomain(self) -> None:
email = self.example_email("hamlet")
with self.settings(
AUTHENTICATION_BACKENDS=(
"zproject.backends.ZulipRemoteUserBackend",
"zproject.backends.ZulipDummyBackend",
)
):
with mock.patch("zerver.views.auth.get_subdomain", return_value=""):
result = self.client_get(
"http://testserver:9080/accounts/login/sso/", REMOTE_USER=email
)
self.assertEqual(result.status_code, 200)
self.assert_logged_in_user_id(None)
self.assert_in_response("You need an invitation to join this organization.", result)
def test_login_success_under_subdomains(self) -> None:
user_profile = self.example_user("hamlet")
email = user_profile.delivery_email
with mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"):
with self.settings(
AUTHENTICATION_BACKENDS=( AUTHENTICATION_BACKENDS=(
"zproject.backends.ZulipRemoteUserBackend", "zproject.backends.ZulipRemoteUserBackend",
"zproject.backends.ZulipDummyBackend", "zproject.backends.ZulipDummyBackend",
) )
): ),
result = self.client_get("/accounts/login/sso/", REMOTE_USER=email) mock.patch("zerver.views.auth.get_subdomain", return_value="acme"),
self.assertEqual(result.status_code, 302) ):
self.assert_logged_in_user_id(user_profile.id) result = self.client_get(
"http://testserver:9080/accounts/login/sso/", REMOTE_USER=email
)
self.assertEqual(result.status_code, 200)
self.assert_logged_in_user_id(None)
self.assert_in_response("You need an invitation to join this organization.", result)
def test_login_failure_due_to_empty_subdomain(self) -> None:
email = self.example_email("hamlet")
with (
self.settings(
AUTHENTICATION_BACKENDS=(
"zproject.backends.ZulipRemoteUserBackend",
"zproject.backends.ZulipDummyBackend",
)
),
mock.patch("zerver.views.auth.get_subdomain", return_value=""),
):
result = self.client_get(
"http://testserver:9080/accounts/login/sso/", REMOTE_USER=email
)
self.assertEqual(result.status_code, 200)
self.assert_logged_in_user_id(None)
self.assert_in_response("You need an invitation to join this organization.", result)
def test_login_success_under_subdomains(self) -> None:
user_profile = self.example_user("hamlet")
email = user_profile.delivery_email
with (
mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"),
self.settings(
AUTHENTICATION_BACKENDS=(
"zproject.backends.ZulipRemoteUserBackend",
"zproject.backends.ZulipDummyBackend",
)
),
):
result = self.client_get("/accounts/login/sso/", REMOTE_USER=email)
self.assertEqual(result.status_code, 302)
self.assert_logged_in_user_id(user_profile.id)
@override_settings(SEND_LOGIN_EMAILS=True) @override_settings(SEND_LOGIN_EMAILS=True)
@override_settings( @override_settings(
@ -5974,30 +5984,34 @@ class TestJWTLogin(ZulipTestCase):
def test_login_failure_due_to_wrong_subdomain(self) -> None: def test_login_failure_due_to_wrong_subdomain(self) -> None:
payload = {"email": "hamlet@zulip.com"} payload = {"email": "hamlet@zulip.com"}
with self.settings(JWT_AUTH_KEYS={"acme": {"key": "key", "algorithms": ["HS256"]}}): with (
with mock.patch("zerver.views.auth.get_realm_from_request", return_value=None): self.settings(JWT_AUTH_KEYS={"acme": {"key": "key", "algorithms": ["HS256"]}}),
key = settings.JWT_AUTH_KEYS["acme"]["key"] mock.patch("zerver.views.auth.get_realm_from_request", return_value=None),
[algorithm] = settings.JWT_AUTH_KEYS["acme"]["algorithms"] ):
web_token = jwt.encode(payload, key, algorithm) key = settings.JWT_AUTH_KEYS["acme"]["key"]
[algorithm] = settings.JWT_AUTH_KEYS["acme"]["algorithms"]
web_token = jwt.encode(payload, key, algorithm)
data = {"token": web_token} data = {"token": web_token}
result = self.client_post("/accounts/login/jwt/", data) result = self.client_post("/accounts/login/jwt/", data)
self.assert_json_error_contains(result, "Invalid subdomain", 404) self.assert_json_error_contains(result, "Invalid subdomain", 404)
self.assert_logged_in_user_id(None) self.assert_logged_in_user_id(None)
def test_login_success_under_subdomains(self) -> None: def test_login_success_under_subdomains(self) -> None:
payload = {"email": "hamlet@zulip.com"} payload = {"email": "hamlet@zulip.com"}
with self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key", "algorithms": ["HS256"]}}): with (
with mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"): self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key", "algorithms": ["HS256"]}}),
key = settings.JWT_AUTH_KEYS["zulip"]["key"] mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"),
[algorithm] = settings.JWT_AUTH_KEYS["zulip"]["algorithms"] ):
web_token = jwt.encode(payload, key, algorithm) key = settings.JWT_AUTH_KEYS["zulip"]["key"]
[algorithm] = settings.JWT_AUTH_KEYS["zulip"]["algorithms"]
web_token = jwt.encode(payload, key, algorithm)
data = {"token": web_token} data = {"token": web_token}
result = self.client_post("/accounts/login/jwt/", data) result = self.client_post("/accounts/login/jwt/", data)
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
user_profile = self.example_user("hamlet") user_profile = self.example_user("hamlet")
self.assert_logged_in_user_id(user_profile.id) self.assert_logged_in_user_id(user_profile.id)
class DjangoToLDAPUsernameTests(ZulipTestCase): class DjangoToLDAPUsernameTests(ZulipTestCase):
@ -6046,9 +6060,8 @@ class DjangoToLDAPUsernameTests(ZulipTestCase):
self.backend.django_to_ldap_username("aaron@zulip.com"), self.ldap_username("aaron") self.backend.django_to_ldap_username("aaron@zulip.com"), self.ldap_username("aaron")
) )
with self.assertLogs(level="WARNING") as m: with self.assertLogs(level="WARNING") as m, self.assertRaises(NoMatchingLDAPUserError):
with self.assertRaises(NoMatchingLDAPUserError): self.backend.django_to_ldap_username("shared_email@zulip.com")
self.backend.django_to_ldap_username("shared_email@zulip.com")
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -6641,9 +6654,11 @@ class TestZulipLDAPUserPopulator(ZulipLDAPTestCase):
@override_settings(LDAP_EMAIL_ATTR="mail") @override_settings(LDAP_EMAIL_ATTR="mail")
def test_populate_user_returns_none(self) -> None: def test_populate_user_returns_none(self) -> None:
with mock.patch.object(ZulipLDAPUser, "populate_user", return_value=None): with (
with self.assertRaises(PopulateUserLDAPError): mock.patch.object(ZulipLDAPUser, "populate_user", return_value=None),
sync_user_from_ldap(self.example_user("hamlet"), mock.Mock()) self.assertRaises(PopulateUserLDAPError),
):
sync_user_from_ldap(self.example_user("hamlet"), mock.Mock())
def test_update_full_name(self) -> None: def test_update_full_name(self) -> None:
self.change_ldap_user_attr("hamlet", "cn", "New Name") self.change_ldap_user_attr("hamlet", "cn", "New Name")
@ -6823,17 +6838,19 @@ class TestZulipLDAPUserPopulator(ZulipLDAPTestCase):
self.change_ldap_user_attr("hamlet", "cn", "Second Hamlet") self.change_ldap_user_attr("hamlet", "cn", "Second Hamlet")
expected_call_args = [hamlet2, "Second Hamlet", None] expected_call_args = [hamlet2, "Second Hamlet", None]
with self.settings(AUTH_LDAP_USER_ATTR_MAP={"full_name": "cn"}): with (
with mock.patch("zerver.actions.user_settings.do_change_full_name") as f: self.settings(AUTH_LDAP_USER_ATTR_MAP={"full_name": "cn"}),
self.perform_ldap_sync(hamlet2) mock.patch("zerver.actions.user_settings.do_change_full_name") as f,
f.assert_called_once_with(*expected_call_args) ):
self.perform_ldap_sync(hamlet2)
f.assert_called_once_with(*expected_call_args)
# Get the updated model and make sure the full name is changed correctly: # Get the updated model and make sure the full name is changed correctly:
hamlet2 = get_user_by_delivery_email(email, test_realm) hamlet2 = get_user_by_delivery_email(email, test_realm)
self.assertEqual(hamlet2.full_name, "Second Hamlet") self.assertEqual(hamlet2.full_name, "Second Hamlet")
# Now get the original hamlet and make he still has his name unchanged: # Now get the original hamlet and make he still has his name unchanged:
hamlet = self.example_user("hamlet") hamlet = self.example_user("hamlet")
self.assertEqual(hamlet.full_name, "King Hamlet") self.assertEqual(hamlet.full_name, "King Hamlet")
def test_user_not_found_in_ldap(self) -> None: def test_user_not_found_in_ldap(self) -> None:
with self.settings( with self.settings(
@ -7038,16 +7055,18 @@ class TestZulipLDAPUserPopulator(ZulipLDAPTestCase):
}, },
], ],
] ]
with self.settings( with (
AUTH_LDAP_USER_ATTR_MAP={ self.settings(
"full_name": "cn", AUTH_LDAP_USER_ATTR_MAP={
"custom_profile_field__birthday": "birthDate", "full_name": "cn",
"custom_profile_field__phone_number": "homePhone", "custom_profile_field__birthday": "birthDate",
} "custom_profile_field__phone_number": "homePhone",
}
),
mock.patch("zproject.backends.do_update_user_custom_profile_data_if_changed") as f,
): ):
with mock.patch("zproject.backends.do_update_user_custom_profile_data_if_changed") as f: self.perform_ldap_sync(self.example_user("hamlet"))
self.perform_ldap_sync(self.example_user("hamlet")) f.assert_called_once_with(*expected_call_args)
f.assert_called_once_with(*expected_call_args)
def test_update_custom_profile_field_not_present_in_ldap(self) -> None: def test_update_custom_profile_field_not_present_in_ldap(self) -> None:
hamlet = self.example_user("hamlet") hamlet = self.example_user("hamlet")
@ -7489,14 +7508,16 @@ class JWTFetchAPIKeyTest(ZulipTestCase):
self.assert_json_error_contains(result, "Invalid subdomain", 404) self.assert_json_error_contains(result, "Invalid subdomain", 404)
def test_jwt_key_not_found_failure(self) -> None: def test_jwt_key_not_found_failure(self) -> None:
with self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key1", "algorithms": ["HS256"]}}): with (
with mock.patch( self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key1", "algorithms": ["HS256"]}}),
mock.patch(
"zerver.views.auth.get_realm_from_request", return_value=get_realm("zephyr") "zerver.views.auth.get_realm_from_request", return_value=get_realm("zephyr")
): ),
result = self.client_post("/api/v1/jwt/fetch_api_key") ):
self.assert_json_error_contains( result = self.client_post("/api/v1/jwt/fetch_api_key")
result, "JWT authentication is not enabled for this organization", 400 self.assert_json_error_contains(
) result, "JWT authentication is not enabled for this organization", 400
)
def test_missing_jwt_payload_failure(self) -> None: def test_missing_jwt_payload_failure(self) -> None:
with self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key1", "algorithms": ["HS256"]}}): with self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key1", "algorithms": ["HS256"]}}):
@ -7709,12 +7730,12 @@ class LDAPGroupSyncTest(ZulipTestCase):
), ),
self.assertLogs("django_auth_ldap", "WARN") as django_ldap_log, self.assertLogs("django_auth_ldap", "WARN") as django_ldap_log,
self.assertLogs("zulip.ldap", "DEBUG") as zulip_ldap_log, self.assertLogs("zulip.ldap", "DEBUG") as zulip_ldap_log,
): self.assertRaisesRegex(
with self.assertRaisesRegex(
ZulipLDAPError, ZulipLDAPError,
"search_s.*", "search_s.*",
): ),
sync_user_from_ldap(cordelia, mock.Mock()) ):
sync_user_from_ldap(cordelia, mock.Mock())
self.assertEqual( self.assertEqual(
zulip_ldap_log.output, zulip_ldap_log.output,

View File

@ -165,11 +165,11 @@ class DecoratorTestCase(ZulipTestCase):
# Start a valid request here # Start a valid request here
request = HostRequestMock() request = HostRequestMock()
request.POST["api_key"] = webhook_bot_api_key request.POST["api_key"] = webhook_bot_api_key
with self.assertLogs(level="WARNING") as m: with (
with self.assertRaisesRegex( self.assertLogs(level="WARNING") as m,
JsonableError, "Account is not associated with this subdomain" self.assertRaisesRegex(JsonableError, "Account is not associated with this subdomain"),
): ):
api_result = my_webhook(request) api_result = my_webhook(request)
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -181,12 +181,12 @@ class DecoratorTestCase(ZulipTestCase):
request = HostRequestMock() request = HostRequestMock()
request.POST["api_key"] = webhook_bot_api_key request.POST["api_key"] = webhook_bot_api_key
with self.assertLogs(level="WARNING") as m: with (
with self.assertRaisesRegex( self.assertLogs(level="WARNING") as m,
JsonableError, "Account is not associated with this subdomain" self.assertRaisesRegex(JsonableError, "Account is not associated with this subdomain"),
): ):
request.host = "acme." + settings.EXTERNAL_HOST request.host = "acme." + settings.EXTERNAL_HOST
api_result = my_webhook(request) api_result = my_webhook(request)
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -203,11 +203,13 @@ class DecoratorTestCase(ZulipTestCase):
request = HostRequestMock() request = HostRequestMock()
request.host = "zulip.testserver" request.host = "zulip.testserver"
request.POST["api_key"] = webhook_bot_api_key request.POST["api_key"] = webhook_bot_api_key
with self.assertLogs("zulip.zerver.webhooks", level="INFO") as log: with (
with self.assertRaisesRegex(Exception, "raised by webhook function"): self.assertLogs("zulip.zerver.webhooks", level="INFO") as log,
request._body = b"{}" self.assertRaisesRegex(Exception, "raised by webhook function"),
request.content_type = "application/json" ):
my_webhook_raises_exception(request) request._body = b"{}"
request.content_type = "application/json"
my_webhook_raises_exception(request)
# Test when content_type is not application/json; exception raised # Test when content_type is not application/json; exception raised
# in the webhook function should be re-raised # in the webhook function should be re-raised
@ -215,11 +217,13 @@ class DecoratorTestCase(ZulipTestCase):
request = HostRequestMock() request = HostRequestMock()
request.host = "zulip.testserver" request.host = "zulip.testserver"
request.POST["api_key"] = webhook_bot_api_key request.POST["api_key"] = webhook_bot_api_key
with self.assertLogs("zulip.zerver.webhooks", level="INFO") as log: with (
with self.assertRaisesRegex(Exception, "raised by webhook function"): self.assertLogs("zulip.zerver.webhooks", level="INFO") as log,
request._body = b"notjson" self.assertRaisesRegex(Exception, "raised by webhook function"),
request.content_type = "text/plain" ):
my_webhook_raises_exception(request) request._body = b"notjson"
request.content_type = "text/plain"
my_webhook_raises_exception(request)
# Test when content_type is application/json but request.body # Test when content_type is application/json but request.body
# is not valid JSON; invalid JSON should be logged and the # is not valid JSON; invalid JSON should be logged and the
@ -227,12 +231,14 @@ class DecoratorTestCase(ZulipTestCase):
request = HostRequestMock() request = HostRequestMock()
request.host = "zulip.testserver" request.host = "zulip.testserver"
request.POST["api_key"] = webhook_bot_api_key request.POST["api_key"] = webhook_bot_api_key
with self.assertLogs("zulip.zerver.webhooks", level="ERROR") as log: with (
with self.assertRaisesRegex(Exception, "raised by webhook function"): self.assertLogs("zulip.zerver.webhooks", level="ERROR") as log,
request._body = b"invalidjson" self.assertRaisesRegex(Exception, "raised by webhook function"),
request.content_type = "application/json" ):
request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value" request._body = b"invalidjson"
my_webhook_raises_exception(request) request.content_type = "application/json"
request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value"
my_webhook_raises_exception(request)
self.assertIn( self.assertIn(
self.logger_output("raised by webhook function\n", "error", "webhooks"), log.output[0] self.logger_output("raised by webhook function\n", "error", "webhooks"), log.output[0]
@ -245,12 +251,14 @@ class DecoratorTestCase(ZulipTestCase):
exception_msg = ( exception_msg = (
"The 'test_event' event isn't currently supported by the ClientName webhook; ignoring" "The 'test_event' event isn't currently supported by the ClientName webhook; ignoring"
) )
with self.assertLogs("zulip.zerver.webhooks.unsupported", level="ERROR") as log: with (
with self.assertRaisesRegex(UnsupportedWebhookEventTypeError, exception_msg): self.assertLogs("zulip.zerver.webhooks.unsupported", level="ERROR") as log,
request._body = b"invalidjson" self.assertRaisesRegex(UnsupportedWebhookEventTypeError, exception_msg),
request.content_type = "application/json" ):
request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value" request._body = b"invalidjson"
my_webhook_raises_exception_unsupported_event(request) request.content_type = "application/json"
request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value"
my_webhook_raises_exception_unsupported_event(request)
self.assertIn( self.assertIn(
self.logger_output(exception_msg, "error", "webhooks.unsupported"), log.output[0] self.logger_output(exception_msg, "error", "webhooks.unsupported"), log.output[0]
@ -259,9 +267,11 @@ class DecoratorTestCase(ZulipTestCase):
request = HostRequestMock() request = HostRequestMock()
request.host = "zulip.testserver" request.host = "zulip.testserver"
request.POST["api_key"] = webhook_bot_api_key request.POST["api_key"] = webhook_bot_api_key
with self.settings(RATE_LIMITING=True): with (
with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: self.settings(RATE_LIMITING=True),
api_result = orjson.loads(my_webhook(request).content).get("msg") mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock,
):
api_result = orjson.loads(my_webhook(request).content).get("msg")
# Verify rate limiting was attempted. # Verify rate limiting was attempted.
self.assertTrue(rate_limit_mock.called) self.assertTrue(rate_limit_mock.called)
@ -389,9 +399,11 @@ class DecoratorLoggingTestCase(ZulipTestCase):
request._body = b"{}" request._body = b"{}"
request.content_type = "text/plain" request.content_type = "text/plain"
with self.assertLogs("zulip.zerver.webhooks") as logger: with (
with self.assertRaisesRegex(Exception, "raised by webhook function"): self.assertLogs("zulip.zerver.webhooks") as logger,
my_webhook_raises_exception(request) self.assertRaisesRegex(Exception, "raised by webhook function"),
):
my_webhook_raises_exception(request)
self.assertIn("raised by webhook function", logger.output[0]) self.assertIn("raised by webhook function", logger.output[0])
@ -440,9 +452,11 @@ class DecoratorLoggingTestCase(ZulipTestCase):
request._body = b"{}" request._body = b"{}"
request.content_type = "application/json" request.content_type = "application/json"
with mock.patch("zerver.decorator.webhook_logger.exception") as mock_exception: with (
with self.assertRaisesRegex(Exception, "raised by a non-webhook view"): mock.patch("zerver.decorator.webhook_logger.exception") as mock_exception,
non_webhook_view_raises_exception(request) self.assertRaisesRegex(Exception, "raised by a non-webhook view"),
):
non_webhook_view_raises_exception(request)
self.assertFalse(mock_exception.called) self.assertFalse(mock_exception.called)
@ -964,15 +978,17 @@ class TestValidateApiKey(ZulipTestCase):
def test_valid_api_key_if_user_is_on_wrong_subdomain(self) -> None: def test_valid_api_key_if_user_is_on_wrong_subdomain(self) -> None:
with self.settings(RUNNING_INSIDE_TORNADO=False): with self.settings(RUNNING_INSIDE_TORNADO=False):
api_key = get_api_key(self.default_bot) api_key = get_api_key(self.default_bot)
with self.assertLogs(level="WARNING") as m: with (
with self.assertRaisesRegex( self.assertLogs(level="WARNING") as m,
self.assertRaisesRegex(
JsonableError, "Account is not associated with this subdomain" JsonableError, "Account is not associated with this subdomain"
): ),
validate_api_key( ):
HostRequestMock(host=settings.EXTERNAL_HOST), validate_api_key(
self.default_bot.email, HostRequestMock(host=settings.EXTERNAL_HOST),
api_key, self.default_bot.email,
) api_key,
)
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [
@ -982,15 +998,17 @@ class TestValidateApiKey(ZulipTestCase):
], ],
) )
with self.assertLogs(level="WARNING") as m: with (
with self.assertRaisesRegex( self.assertLogs(level="WARNING") as m,
self.assertRaisesRegex(
JsonableError, "Account is not associated with this subdomain" JsonableError, "Account is not associated with this subdomain"
): ),
validate_api_key( ):
HostRequestMock(host="acme." + settings.EXTERNAL_HOST), validate_api_key(
self.default_bot.email, HostRequestMock(host="acme." + settings.EXTERNAL_HOST),
api_key, self.default_bot.email,
) api_key,
)
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [

View File

@ -241,9 +241,8 @@ class TestDigestEmailMessages(ZulipTestCase):
digest_user_ids = [user.id for user in digest_users] digest_user_ids = [user.id for user in digest_users]
get_recent_topics.cache_clear() get_recent_topics.cache_clear()
with self.assert_database_query_count(16): with self.assert_database_query_count(16), self.assert_memcached_count(0):
with self.assert_memcached_count(0): bulk_handle_digest_email(digest_user_ids, cutoff)
bulk_handle_digest_email(digest_user_ids, cutoff)
self.assert_length(digest_users, mock_send_future_email.call_count) self.assert_length(digest_users, mock_send_future_email.call_count)
@ -441,9 +440,11 @@ class TestDigestEmailMessages(ZulipTestCase):
tuesday = self.tuesday() tuesday = self.tuesday()
cutoff = tuesday - timedelta(days=5) cutoff = tuesday - timedelta(days=5)
with time_machine.travel(tuesday, tick=False): with (
with mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock: time_machine.travel(tuesday, tick=False),
enqueue_emails(cutoff) mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock,
):
enqueue_emails(cutoff)
queue_mock.assert_not_called() queue_mock.assert_not_called()
@override_settings(SEND_DIGEST_EMAILS=True) @override_settings(SEND_DIGEST_EMAILS=True)
@ -453,9 +454,11 @@ class TestDigestEmailMessages(ZulipTestCase):
not_tuesday = datetime(year=2016, month=1, day=6, tzinfo=timezone.utc) not_tuesday = datetime(year=2016, month=1, day=6, tzinfo=timezone.utc)
cutoff = not_tuesday - timedelta(days=5) cutoff = not_tuesday - timedelta(days=5)
with time_machine.travel(not_tuesday, tick=False): with (
with mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock: time_machine.travel(not_tuesday, tick=False),
enqueue_emails(cutoff) mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock,
):
enqueue_emails(cutoff)
queue_mock.assert_not_called() queue_mock.assert_not_called()
@override_settings(SEND_DIGEST_EMAILS=True) @override_settings(SEND_DIGEST_EMAILS=True)

View File

@ -72,18 +72,20 @@ class TestEmbeddedBotMessaging(ZulipTestCase):
def test_embedded_bot_quit_exception(self) -> None: def test_embedded_bot_quit_exception(self) -> None:
assert self.bot_profile is not None assert self.bot_profile is not None
with patch( with (
"zulip_bots.bots.helloworld.helloworld.HelloWorldHandler.handle_message", patch(
side_effect=EmbeddedBotQuitError("I'm quitting!"), "zulip_bots.bots.helloworld.helloworld.HelloWorldHandler.handle_message",
side_effect=EmbeddedBotQuitError("I'm quitting!"),
),
self.assertLogs(level="WARNING") as m,
): ):
with self.assertLogs(level="WARNING") as m: self.send_stream_message(
self.send_stream_message( self.user_profile,
self.user_profile, "Denmark",
"Denmark", content=f"@**{self.bot_profile.full_name}** foo",
content=f"@**{self.bot_profile.full_name}** foo", topic_name="bar",
topic_name="bar", )
) self.assertEqual(m.output, ["WARNING:root:I'm quitting!"])
self.assertEqual(m.output, ["WARNING:root:I'm quitting!"])
class TestEmbeddedBotFailures(ZulipTestCase): class TestEmbeddedBotFailures(ZulipTestCase):

View File

@ -86,12 +86,14 @@ class EventsEndpointTest(ZulipTestCase):
test_event = dict(id=6, type=event_type, realm_emoji=empty_realm_emoji_dict) test_event = dict(id=6, type=event_type, realm_emoji=empty_realm_emoji_dict)
# Test that call is made to deal with a returning soft deactivated user. # Test that call is made to deal with a returning soft deactivated user.
with mock.patch("zerver.lib.events.reactivate_user_if_soft_deactivated") as fa: with (
with stub_event_queue_user_events(return_event_queue, return_user_events): mock.patch("zerver.lib.events.reactivate_user_if_soft_deactivated") as fa,
result = self.api_post( stub_event_queue_user_events(return_event_queue, return_user_events),
user, "/api/v1/register", dict(event_types=orjson.dumps([event_type]).decode()) ):
) result = self.api_post(
self.assertEqual(fa.call_count, 1) user, "/api/v1/register", dict(event_types=orjson.dumps([event_type]).decode())
)
self.assertEqual(fa.call_count, 1)
with stub_event_queue_user_events(return_event_queue, return_user_events): with stub_event_queue_user_events(return_event_queue, return_user_events):
result = self.api_post( result = self.api_post(
@ -1171,9 +1173,11 @@ class FetchQueriesTest(ZulipTestCase):
# count in production. # count in production.
realm = get_realm_with_settings(realm_id=user.realm_id) realm = get_realm_with_settings(realm_id=user.realm_id)
with self.assert_database_query_count(43): with (
with mock.patch("zerver.lib.events.always_want") as want_mock: self.assert_database_query_count(43),
fetch_initial_state_data(user, realm=realm) mock.patch("zerver.lib.events.always_want") as want_mock,
):
fetch_initial_state_data(user, realm=realm)
expected_counts = dict( expected_counts = dict(
alert_words=1, alert_words=1,

View File

@ -1742,17 +1742,19 @@ class NormalActionsTest(BaseAction):
cordelia.save() cordelia.save()
away_val = False away_val = False
with self.settings(CAN_ACCESS_ALL_USERS_GROUP_LIMITS_PRESENCE=True): with (
with self.verify_action(num_events=0, state_change_expected=False) as events: self.settings(CAN_ACCESS_ALL_USERS_GROUP_LIMITS_PRESENCE=True),
do_update_user_status( self.verify_action(num_events=0, state_change_expected=False) as events,
user_profile=cordelia, ):
away=away_val, do_update_user_status(
status_text="out to lunch", user_profile=cordelia,
emoji_name="car", away=away_val,
emoji_code="1f697", status_text="out to lunch",
reaction_type=UserStatus.UNICODE_EMOJI, emoji_name="car",
client_id=client.id, emoji_code="1f697",
) reaction_type=UserStatus.UNICODE_EMOJI,
client_id=client.id,
)
away_val = True away_val = True
with self.verify_action(num_events=1, state_change_expected=True) as events: with self.verify_action(num_events=1, state_change_expected=True) as events:
@ -2128,13 +2130,12 @@ class NormalActionsTest(BaseAction):
{"Google": False, "Email": False, "GitHub": True, "LDAP": False, "Dev": True}, {"Google": False, "Email": False, "GitHub": True, "LDAP": False, "Dev": True},
{"Google": False, "Email": True, "GitHub": True, "LDAP": True, "Dev": False}, {"Google": False, "Email": True, "GitHub": True, "LDAP": True, "Dev": False},
): ):
with fake_backends(): with fake_backends(), self.verify_action() as events:
with self.verify_action() as events: do_set_realm_authentication_methods(
do_set_realm_authentication_methods( self.user_profile.realm,
self.user_profile.realm, auth_method_dict,
auth_method_dict, acting_user=None,
acting_user=None, )
)
check_realm_update_dict("events[0]", events[0]) check_realm_update_dict("events[0]", events[0])
@ -2664,11 +2665,10 @@ class NormalActionsTest(BaseAction):
def test_realm_emoji_events(self) -> None: def test_realm_emoji_events(self) -> None:
author = self.example_user("iago") author = self.example_user("iago")
with get_test_image_file("img.png") as img_file: with get_test_image_file("img.png") as img_file, self.verify_action() as events:
with self.verify_action() as events: check_add_realm_emoji(
check_add_realm_emoji( self.user_profile.realm, "my_emoji", author, img_file, "image/png"
self.user_profile.realm, "my_emoji", author, img_file, "image/png" )
)
check_realm_emoji_update("events[0]", events[0]) check_realm_emoji_update("events[0]", events[0])
@ -3278,9 +3278,12 @@ class NormalActionsTest(BaseAction):
"zerver.lib.export.do_export_realm", "zerver.lib.export.do_export_realm",
return_value=create_dummy_file("test-export.tar.gz"), return_value=create_dummy_file("test-export.tar.gz"),
): ):
with stdout_suppressed(), self.assertLogs(level="INFO") as info_logs: with (
with self.verify_action(state_change_expected=True, num_events=3) as events: stdout_suppressed(),
self.client_post("/json/export/realm") self.assertLogs(level="INFO") as info_logs,
self.verify_action(state_change_expected=True, num_events=3) as events,
):
self.client_post("/json/export/realm")
self.assertTrue("INFO:root:Completed data export for zulip in" in info_logs.output[0]) self.assertTrue("INFO:root:Completed data export for zulip in" in info_logs.output[0])
# We get two realm_export events for this action, where the first # We get two realm_export events for this action, where the first
@ -3328,9 +3331,11 @@ class NormalActionsTest(BaseAction):
mock.patch("zerver.lib.export.do_export_realm", side_effect=Exception("Some failure")), mock.patch("zerver.lib.export.do_export_realm", side_effect=Exception("Some failure")),
self.assertLogs(level="ERROR") as error_log, self.assertLogs(level="ERROR") as error_log,
): ):
with stdout_suppressed(): with (
with self.verify_action(state_change_expected=False, num_events=2) as events: stdout_suppressed(),
self.client_post("/json/export/realm") self.verify_action(state_change_expected=False, num_events=2) as events,
):
self.client_post("/json/export/realm")
# Log is of following format: "ERROR:root:Data export for zulip failed after 0.004499673843383789" # Log is of following format: "ERROR:root:Data export for zulip failed after 0.004499673843383789"
# Where last floating number is time and will vary in each test hence the following assertion is # Where last floating number is time and will vary in each test hence the following assertion is

View File

@ -298,12 +298,14 @@ class RateLimitTests(ZulipTestCase):
# We need to reset the circuitbreaker before starting. We # We need to reset the circuitbreaker before starting. We
# patch the .opened property to be false, then call the # patch the .opened property to be false, then call the
# function, so it resets to closed. # function, so it resets to closed.
with mock.patch("builtins.open", mock.mock_open(read_data=orjson.dumps(["1.2.3.4"]))): with (
with mock.patch( mock.patch("builtins.open", mock.mock_open(read_data=orjson.dumps(["1.2.3.4"]))),
mock.patch(
"circuitbreaker.CircuitBreaker.opened", new_callable=mock.PropertyMock "circuitbreaker.CircuitBreaker.opened", new_callable=mock.PropertyMock
) as mock_opened: ) as mock_opened,
mock_opened.return_value = False ):
get_tor_ips() mock_opened.return_value = False
get_tor_ips()
# Having closed it, it's now cached. Clear the cache. # Having closed it, it's now cached. Clear the cache.
assert CircuitBreakerMonitor.get("get_tor_ips").closed assert CircuitBreakerMonitor.get("get_tor_ips").closed
@ -354,13 +356,15 @@ class RateLimitTests(ZulipTestCase):
# An empty list of IPs is treated as some error in parsing the # An empty list of IPs is treated as some error in parsing the
# input, and as such should not be cached; rate-limiting # input, and as such should not be cached; rate-limiting
# should work as normal, per-IP # should work as normal, per-IP
with self.tor_mock(read_data=[]) as tor_open: with (
with self.assertLogs("zerver.lib.rate_limiter", level="WARNING"): self.tor_mock(read_data=[]) as tor_open,
self.do_test_hit_ratelimits( self.assertLogs("zerver.lib.rate_limiter", level="WARNING"),
lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4") ):
) self.do_test_hit_ratelimits(
resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8") lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4")
self.assertNotEqual(resp.status_code, 429) )
resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8")
self.assertNotEqual(resp.status_code, 429)
# Was not cached, so tried to read twice before hitting the # Was not cached, so tried to read twice before hitting the
# circuit-breaker, and stopping trying # circuit-breaker, and stopping trying
@ -372,15 +376,17 @@ class RateLimitTests(ZulipTestCase):
for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]: for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]:
RateLimitedIPAddr(ip, domain="api_by_ip").clear_history() RateLimitedIPAddr(ip, domain="api_by_ip").clear_history()
with self.tor_mock(side_effect=FileNotFoundError("File not found")) as tor_open: with (
self.tor_mock(side_effect=FileNotFoundError("File not found")) as tor_open,
# If we cannot get a list of TOR exit nodes, then # If we cannot get a list of TOR exit nodes, then
# rate-limiting works as normal, per-IP # rate-limiting works as normal, per-IP
with self.assertLogs("zerver.lib.rate_limiter", level="WARNING") as log_mock: self.assertLogs("zerver.lib.rate_limiter", level="WARNING") as log_mock,
self.do_test_hit_ratelimits( ):
lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4") self.do_test_hit_ratelimits(
) lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4")
resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8") )
self.assertNotEqual(resp.status_code, 429) resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8")
self.assertNotEqual(resp.status_code, 429)
# Tries twice before hitting the circuit-breaker, and stopping trying # Tries twice before hitting the circuit-breaker, and stopping trying
tor_open.assert_has_calls( tor_open.assert_has_calls(

View File

@ -261,10 +261,12 @@ class HomeTest(ZulipTestCase):
self.client_post("/json/bots", bot_info) self.client_post("/json/bots", bot_info)
# Verify succeeds once logged-in # Verify succeeds once logged-in
with self.assert_database_query_count(54): with (
with patch("zerver.lib.cache.cache_set") as cache_mock: self.assert_database_query_count(54),
result = self._get_home_page(stream="Denmark") patch("zerver.lib.cache.cache_set") as cache_mock,
self.check_rendered_logged_in_app(result) ):
result = self._get_home_page(stream="Denmark")
self.check_rendered_logged_in_app(result)
self.assertEqual( self.assertEqual(
set(result["Cache-Control"].split(", ")), {"must-revalidate", "no-store", "no-cache"} set(result["Cache-Control"].split(", ")), {"must-revalidate", "no-store", "no-cache"}
) )
@ -312,10 +314,9 @@ class HomeTest(ZulipTestCase):
self.login("hamlet") self.login("hamlet")
# Verify succeeds once logged-in # Verify succeeds once logged-in
with queries_captured(): with queries_captured(), patch("zerver.lib.cache.cache_set"):
with patch("zerver.lib.cache.cache_set"): result = self._get_home_page(stream="Denmark")
result = self._get_home_page(stream="Denmark") self.check_rendered_logged_in_app(result)
self.check_rendered_logged_in_app(result)
page_params = self._get_page_params(result) page_params = self._get_page_params(result)
self.assertCountEqual(page_params, self.expected_page_params_keys) self.assertCountEqual(page_params, self.expected_page_params_keys)
@ -565,11 +566,13 @@ class HomeTest(ZulipTestCase):
def test_num_queries_for_realm_admin(self) -> None: def test_num_queries_for_realm_admin(self) -> None:
# Verify number of queries for Realm admin isn't much higher than for normal users. # Verify number of queries for Realm admin isn't much higher than for normal users.
self.login("iago") self.login("iago")
with self.assert_database_query_count(54): with (
with patch("zerver.lib.cache.cache_set") as cache_mock: self.assert_database_query_count(54),
result = self._get_home_page() patch("zerver.lib.cache.cache_set") as cache_mock,
self.check_rendered_logged_in_app(result) ):
self.assert_length(cache_mock.call_args_list, 7) result = self._get_home_page()
self.check_rendered_logged_in_app(result)
self.assert_length(cache_mock.call_args_list, 7)
def test_num_queries_with_streams(self) -> None: def test_num_queries_with_streams(self) -> None:
main_user = self.example_user("hamlet") main_user = self.example_user("hamlet")

View File

@ -2547,9 +2547,11 @@ class MultiuseInviteTest(ZulipTestCase):
email = self.nonreg_email("newuser") email = self.nonreg_email("newuser")
invite_link = "/join/invalid_key/" invite_link = "/join/invalid_key/"
with patch("zerver.views.registration.get_realm_from_request", return_value=self.realm): with (
with patch("zerver.views.registration.get_realm", return_value=self.realm): patch("zerver.views.registration.get_realm_from_request", return_value=self.realm),
self.check_user_able_to_register(email, invite_link) patch("zerver.views.registration.get_realm", return_value=self.realm),
):
self.check_user_able_to_register(email, invite_link)
def test_multiuse_link_with_specified_streams(self) -> None: def test_multiuse_link_with_specified_streams(self) -> None:
name1 = "newuser" name1 = "newuser"

View File

@ -438,11 +438,10 @@ class PreviewTestCase(ZulipTestCase):
self.create_mock_response(original_url) self.create_mock_response(original_url)
self.create_mock_response(edited_url) self.create_mock_response(edited_url)
with self.settings(TEST_SUITE=False): with self.settings(TEST_SUITE=False), self.assertLogs(level="INFO") as info_logs:
with self.assertLogs(level="INFO") as info_logs: # Run the queue processor. This will simulate the event for original_url being
# Run the queue processor. This will simulate the event for original_url being # processed after the message has been edited.
# processed after the message has been edited. FetchLinksEmbedData().consume(event)
FetchLinksEmbedData().consume(event)
self.assertTrue( self.assertTrue(
"INFO:root:Time spent on get_link_embed_data for http://test.org/: " "INFO:root:Time spent on get_link_embed_data for http://test.org/: "
in info_logs.output[0] in info_logs.output[0]
@ -457,17 +456,16 @@ class PreviewTestCase(ZulipTestCase):
self.assertTrue(responses.assert_call_count(edited_url, 0)) self.assertTrue(responses.assert_call_count(edited_url, 0))
with self.settings(TEST_SUITE=False): with self.settings(TEST_SUITE=False), self.assertLogs(level="INFO") as info_logs:
with self.assertLogs(level="INFO") as info_logs: # Now proceed with the original queue_json_publish and call the
# Now proceed with the original queue_json_publish and call the # up-to-date event for edited_url.
# up-to-date event for edited_url. queue_json_publish(*args, **kwargs)
queue_json_publish(*args, **kwargs) msg = Message.objects.select_related("sender").get(id=msg_id)
msg = Message.objects.select_related("sender").get(id=msg_id) assert msg.rendered_content is not None
assert msg.rendered_content is not None self.assertIn(
self.assertIn( f'<a href="{edited_url}" title="The Rock">The Rock</a>',
f'<a href="{edited_url}" title="The Rock">The Rock</a>', msg.rendered_content,
msg.rendered_content, )
)
self.assertTrue( self.assertTrue(
"INFO:root:Time spent on get_link_embed_data for http://edited.org/: " "INFO:root:Time spent on get_link_embed_data for http://edited.org/: "
in info_logs.output[0] in info_logs.output[0]
@ -503,11 +501,10 @@ class PreviewTestCase(ZulipTestCase):
# We do still fetch the URL, as we don't want to incur the # We do still fetch the URL, as we don't want to incur the
# cost of locking the row while we do the HTTP fetches. # cost of locking the row while we do the HTTP fetches.
self.create_mock_response(url) self.create_mock_response(url)
with self.settings(TEST_SUITE=False): with self.settings(TEST_SUITE=False), self.assertLogs(level="INFO") as info_logs:
with self.assertLogs(level="INFO") as info_logs: # Run the queue processor. This will simulate the event for original_url being
# Run the queue processor. This will simulate the event for original_url being # processed after the message has been deleted.
# processed after the message has been deleted. FetchLinksEmbedData().consume(event)
FetchLinksEmbedData().consume(event)
self.assertTrue( self.assertTrue(
"INFO:root:Time spent on get_link_embed_data for http://test.org/: " "INFO:root:Time spent on get_link_embed_data for http://test.org/: "
in info_logs.output[0] in info_logs.output[0]
@ -852,24 +849,26 @@ class PreviewTestCase(ZulipTestCase):
self.create_mock_response(url, body=ConnectionError()) self.create_mock_response(url, body=ConnectionError())
with mock.patch( with (
"zerver.lib.url_preview.preview.get_oembed_data", mock.patch(
side_effect=lambda *args, **kwargs: None, "zerver.lib.url_preview.preview.get_oembed_data",
): side_effect=lambda *args, **kwargs: None,
with mock.patch( ),
mock.patch(
"zerver.lib.url_preview.preview.valid_content_type", side_effect=lambda k: True "zerver.lib.url_preview.preview.valid_content_type", side_effect=lambda k: True
): ),
with self.settings(TEST_SUITE=False): self.settings(TEST_SUITE=False),
with self.assertLogs(level="INFO") as info_logs: ):
FetchLinksEmbedData().consume(event) with self.assertLogs(level="INFO") as info_logs:
self.assertTrue( FetchLinksEmbedData().consume(event)
"INFO:root:Time spent on get_link_embed_data for http://test.org/: " self.assertTrue(
in info_logs.output[0] "INFO:root:Time spent on get_link_embed_data for http://test.org/: "
) in info_logs.output[0]
)
# This did not get cached -- hence the lack of [0] on the cache_get # This did not get cached -- hence the lack of [0] on the cache_get
cached_data = cache_get(preview_url_cache_key(url)) cached_data = cache_get(preview_url_cache_key(url))
self.assertIsNone(cached_data) self.assertIsNone(cached_data)
msg.refresh_from_db() msg.refresh_from_db()
self.assertEqual( self.assertEqual(
@ -939,13 +938,15 @@ class PreviewTestCase(ZulipTestCase):
) )
self.create_mock_response(url) self.create_mock_response(url)
with self.settings(TEST_SUITE=False): with self.settings(TEST_SUITE=False):
with self.assertLogs(level="INFO") as info_logs: with (
with mock.patch( self.assertLogs(level="INFO") as info_logs,
mock.patch(
"zerver.lib.url_preview.preview.get_oembed_data", "zerver.lib.url_preview.preview.get_oembed_data",
lambda *args, **kwargs: mocked_data, lambda *args, **kwargs: mocked_data,
): ),
FetchLinksEmbedData().consume(event) ):
cached_data = cache_get(preview_url_cache_key(url))[0] FetchLinksEmbedData().consume(event)
cached_data = cache_get(preview_url_cache_key(url))[0]
self.assertTrue( self.assertTrue(
"INFO:root:Time spent on get_link_embed_data for http://test.org/: " "INFO:root:Time spent on get_link_embed_data for http://test.org/: "
in info_logs.output[0] in info_logs.output[0]
@ -979,12 +980,14 @@ class PreviewTestCase(ZulipTestCase):
) )
self.create_mock_response(url) self.create_mock_response(url)
with self.settings(TEST_SUITE=False): with self.settings(TEST_SUITE=False):
with self.assertLogs(level="INFO") as info_logs: with (
with mock.patch( self.assertLogs(level="INFO") as info_logs,
mock.patch(
"zerver.worker.embed_links.url_preview.get_link_embed_data", "zerver.worker.embed_links.url_preview.get_link_embed_data",
lambda *args, **kwargs: mocked_data, lambda *args, **kwargs: mocked_data,
): ),
FetchLinksEmbedData().consume(event) ):
FetchLinksEmbedData().consume(event)
self.assertTrue( self.assertTrue(
"INFO:root:Time spent on get_link_embed_data for https://www.youtube.com/watch?v=eSJTXC7Ixgg:" "INFO:root:Time spent on get_link_embed_data for https://www.youtube.com/watch?v=eSJTXC7Ixgg:"
in info_logs.output[0] in info_logs.output[0]
@ -1017,12 +1020,14 @@ class PreviewTestCase(ZulipTestCase):
) )
self.create_mock_response(url) self.create_mock_response(url)
with self.settings(TEST_SUITE=False): with self.settings(TEST_SUITE=False):
with self.assertLogs(level="INFO") as info_logs: with (
with mock.patch( self.assertLogs(level="INFO") as info_logs,
mock.patch(
"zerver.worker.embed_links.url_preview.get_link_embed_data", "zerver.worker.embed_links.url_preview.get_link_embed_data",
lambda *args, **kwargs: mocked_data, lambda *args, **kwargs: mocked_data,
): ),
FetchLinksEmbedData().consume(event) ):
FetchLinksEmbedData().consume(event)
self.assertTrue( self.assertTrue(
"INFO:root:Time spent on get_link_embed_data for [YouTube link](https://www.youtube.com/watch?v=eSJTXC7Ixgg):" "INFO:root:Time spent on get_link_embed_data for [YouTube link](https://www.youtube.com/watch?v=eSJTXC7Ixgg):"
in info_logs.output[0] in info_logs.output[0]

View File

@ -29,11 +29,13 @@ from zerver.models.users import get_user_profile_by_email
class TestCheckConfig(ZulipTestCase): class TestCheckConfig(ZulipTestCase):
def test_check_config(self) -> None: def test_check_config(self) -> None:
check_config() check_config()
with self.settings(REQUIRED_SETTINGS=[("asdf", "not asdf")]): with (
with self.assertRaisesRegex( self.settings(REQUIRED_SETTINGS=[("asdf", "not asdf")]),
self.assertRaisesRegex(
CommandError, "Error: You must set asdf in /etc/zulip/settings.py." CommandError, "Error: You must set asdf in /etc/zulip/settings.py."
): ),
check_config() ):
check_config()
@override_settings(WARN_NO_EMAIL=True) @override_settings(WARN_NO_EMAIL=True)
def test_check_send_email(self) -> None: def test_check_send_email(self) -> None:
@ -210,9 +212,8 @@ class TestCommandsCanStart(ZulipTestCase):
def test_management_commands_show_help(self) -> None: def test_management_commands_show_help(self) -> None:
with stdout_suppressed(): with stdout_suppressed():
for command in self.commands: for command in self.commands:
with self.subTest(management_command=command): with self.subTest(management_command=command), self.assertRaises(SystemExit):
with self.assertRaises(SystemExit): call_command(command, "--help")
call_command(command, "--help")
# zerver/management/commands/runtornado.py sets this to True; # zerver/management/commands/runtornado.py sets this to True;
# we need to reset it here. See #3685 for details. # we need to reset it here. See #3685 for details.
settings.RUNNING_INSIDE_TORNADO = False settings.RUNNING_INSIDE_TORNADO = False

View File

@ -1104,9 +1104,11 @@ class MarkdownTest(ZulipTestCase):
) )
def test_fetch_tweet_data_settings_validation(self) -> None: def test_fetch_tweet_data_settings_validation(self) -> None:
with self.settings(TEST_SUITE=False, TWITTER_CONSUMER_KEY=None): with (
with self.assertRaises(NotImplementedError): self.settings(TEST_SUITE=False, TWITTER_CONSUMER_KEY=None),
fetch_tweet_data("287977969287315459") self.assertRaises(NotImplementedError),
):
fetch_tweet_data("287977969287315459")
def test_content_has_emoji(self) -> None: def test_content_has_emoji(self) -> None:
self.assertFalse(content_has_emoji_syntax("boring")) self.assertFalse(content_has_emoji_syntax("boring"))
@ -1710,9 +1712,11 @@ class MarkdownTest(ZulipTestCase):
self.assertEqual(linkifiers_for_realm(realm.id), []) self.assertEqual(linkifiers_for_realm(realm.id), [])
# Verify that our in-memory cache avoids round trips. # Verify that our in-memory cache avoids round trips.
with self.assert_database_query_count(0, keep_cache_warm=True): with (
with self.assert_memcached_count(0): self.assert_database_query_count(0, keep_cache_warm=True),
self.assertEqual(linkifiers_for_realm(realm.id), []) self.assert_memcached_count(0),
):
self.assertEqual(linkifiers_for_realm(realm.id), [])
linkifier = RealmFilter(realm=realm, pattern=r"whatever", url_template="whatever") linkifier = RealmFilter(realm=realm, pattern=r"whatever", url_template="whatever")
linkifier.save() linkifier.save()
@ -1724,12 +1728,14 @@ class MarkdownTest(ZulipTestCase):
) )
# And the in-process cache works again. # And the in-process cache works again.
with self.assert_database_query_count(0, keep_cache_warm=True): with (
with self.assert_memcached_count(0): self.assert_database_query_count(0, keep_cache_warm=True),
self.assertEqual( self.assert_memcached_count(0),
linkifiers_for_realm(realm.id), ):
[{"id": linkifier.id, "pattern": "whatever", "url_template": "whatever"}], self.assertEqual(
) linkifiers_for_realm(realm.id),
[{"id": linkifier.id, "pattern": "whatever", "url_template": "whatever"}],
)
def test_alert_words(self) -> None: def test_alert_words(self) -> None:
user_profile = self.example_user("othello") user_profile = self.example_user("othello")
@ -3289,17 +3295,18 @@ class MarkdownApiTests(ZulipTestCase):
class MarkdownErrorTests(ZulipTestCase): class MarkdownErrorTests(ZulipTestCase):
def test_markdown_error_handling(self) -> None: def test_markdown_error_handling(self) -> None:
with self.simulated_markdown_failure(): with self.simulated_markdown_failure(), self.assertRaises(MarkdownRenderingError):
with self.assertRaises(MarkdownRenderingError): markdown_convert_wrapper("")
markdown_convert_wrapper("")
def test_send_message_errors(self) -> None: def test_send_message_errors(self) -> None:
message = "whatever" message = "whatever"
with self.simulated_markdown_failure(): with (
self.simulated_markdown_failure(),
# We don't use assertRaisesRegex because it seems to not # We don't use assertRaisesRegex because it seems to not
# handle i18n properly here on some systems. # handle i18n properly here on some systems.
with self.assertRaises(JsonableError): self.assertRaises(JsonableError),
self.send_stream_message(self.example_user("othello"), "Denmark", message) ):
self.send_stream_message(self.example_user("othello"), "Denmark", message)
@override_settings(MAX_MESSAGE_LENGTH=10) @override_settings(MAX_MESSAGE_LENGTH=10)
def test_ultra_long_rendering(self) -> None: def test_ultra_long_rendering(self) -> None:
@ -3310,9 +3317,9 @@ class MarkdownErrorTests(ZulipTestCase):
with ( with (
mock.patch("zerver.lib.markdown.unsafe_timeout", return_value=msg), mock.patch("zerver.lib.markdown.unsafe_timeout", return_value=msg),
mock.patch("zerver.lib.markdown.markdown_logger"), mock.patch("zerver.lib.markdown.markdown_logger"),
self.assertRaises(MarkdownRenderingError),
): ):
with self.assertRaises(MarkdownRenderingError): markdown_convert_wrapper(msg)
markdown_convert_wrapper(msg)
def test_curl_code_block_validation(self) -> None: def test_curl_code_block_validation(self) -> None:
processor = SimulatedFencedBlockPreprocessor(Markdown()) processor = SimulatedFencedBlockPreprocessor(Markdown())

View File

@ -301,12 +301,14 @@ class DeleteMessageTest(ZulipTestCase):
self.send_stream_message(hamlet, "Denmark") self.send_stream_message(hamlet, "Denmark")
message = self.get_last_message() message = self.get_last_message()
with self.capture_send_event_calls(expected_num_events=1): with (
with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: self.capture_send_event_calls(expected_num_events=1),
m.side_effect = AssertionError( mock.patch("zerver.tornado.django_api.queue_json_publish") as m,
"Events should be sent only after the transaction commits." ):
) m.side_effect = AssertionError(
do_delete_messages(hamlet.realm, [message]) "Events should be sent only after the transaction commits."
)
do_delete_messages(hamlet.realm, [message])
def test_delete_message_in_unsubscribed_private_stream(self) -> None: def test_delete_message_in_unsubscribed_private_stream(self) -> None:
hamlet = self.example_user("hamlet") hamlet = self.example_user("hamlet")

View File

@ -100,9 +100,11 @@ class EditMessageSideEffectsTest(ZulipTestCase):
content=content, content=content,
) )
with mock.patch("zerver.tornado.event_queue.maybe_enqueue_notifications") as m: with (
with self.captureOnCommitCallbacks(execute=True): mock.patch("zerver.tornado.event_queue.maybe_enqueue_notifications") as m,
result = self.client_patch(url, request) self.captureOnCommitCallbacks(execute=True),
):
result = self.client_patch(url, request)
cordelia = self.example_user("cordelia") cordelia = self.example_user("cordelia")
cordelia_calls = [ cordelia_calls = [

View File

@ -4203,14 +4203,13 @@ class GetOldMessagesTest(ZulipTestCase):
request = HostRequestMock(query_params, user_profile) request = HostRequestMock(query_params, user_profile)
first_visible_message_id = first_unread_message_id + 2 first_visible_message_id = first_unread_message_id + 2
with first_visible_id_as(first_visible_message_id): with first_visible_id_as(first_visible_message_id), queries_captured() as all_queries:
with queries_captured() as all_queries: get_messages_backend(
get_messages_backend( request,
request, user_profile,
user_profile, num_before=10,
num_before=10, num_after=10,
num_after=10, )
)
queries = [q for q in all_queries if "/* get_messages */" in q.sql] queries = [q for q in all_queries if "/* get_messages */" in q.sql]
self.assert_length(queries, 1) self.assert_length(queries, 1)

View File

@ -2118,9 +2118,11 @@ class StreamMessagesTest(ZulipTestCase):
self.subscribe(cordelia, "test_stream") self.subscribe(cordelia, "test_stream")
do_set_realm_property(cordelia.realm, "wildcard_mention_policy", 10, acting_user=None) do_set_realm_property(cordelia.realm, "wildcard_mention_policy", 10, acting_user=None)
content = "@**all** test wildcard mention" content = "@**all** test wildcard mention"
with mock.patch("zerver.lib.message.num_subscribers_for_stream_id", return_value=16): with (
with self.assertRaisesRegex(AssertionError, "Invalid wildcard mention policy"): mock.patch("zerver.lib.message.num_subscribers_for_stream_id", return_value=16),
self.send_stream_message(cordelia, "test_stream", content) self.assertRaisesRegex(AssertionError, "Invalid wildcard mention policy"),
):
self.send_stream_message(cordelia, "test_stream", content)
def test_user_group_mention_restrictions(self) -> None: def test_user_group_mention_restrictions(self) -> None:
iago = self.example_user("iago") iago = self.example_user("iago")

View File

@ -630,13 +630,15 @@ class TestOutgoingWebhookMessaging(ZulipTestCase):
"https://bot.example.com/", "https://bot.example.com/",
body=requests.exceptions.Timeout("Time is up!"), body=requests.exceptions.Timeout("Time is up!"),
) )
with mock.patch( with (
"zerver.lib.outgoing_webhook.fail_with_message", side_effect=wrapped mock.patch(
) as fail: "zerver.lib.outgoing_webhook.fail_with_message", side_effect=wrapped
with self.assertLogs(level="INFO") as logs: ) as fail,
self.send_stream_message( self.assertLogs(level="INFO") as logs,
bot_owner, "Denmark", content=f"@**{bot.full_name}** foo", topic_name="bar" ):
) self.send_stream_message(
bot_owner, "Denmark", content=f"@**{bot.full_name}** foo", topic_name="bar"
)
self.assert_length(logs.output, 5) self.assert_length(logs.output, 5)
fail.assert_called_once() fail.assert_called_once()

View File

@ -1103,31 +1103,33 @@ class PushBouncerNotificationTest(BouncerTestCase):
not_configured_warn_log, not_configured_warn_log,
) )
with mock.patch( with (
"zerver.lib.push_notifications.uses_notification_bouncer", return_value=True mock.patch(
"zerver.lib.push_notifications.uses_notification_bouncer", return_value=True
),
mock.patch("zerver.lib.remote_server.send_to_push_bouncer") as m,
): ):
with mock.patch("zerver.lib.remote_server.send_to_push_bouncer") as m: post_response = {
post_response = { "realms": {realm.uuid: {"can_push": True, "expected_end_timestamp": None}}
"realms": {realm.uuid: {"can_push": True, "expected_end_timestamp": None}} }
} get_response = {
get_response = { "last_realm_count_id": 0,
"last_realm_count_id": 0, "last_installation_count_id": 0,
"last_installation_count_id": 0, "last_realmauditlog_id": 0,
"last_realmauditlog_id": 0, }
}
def mock_send_to_push_bouncer_response(method: str, *args: Any) -> dict[str, Any]: def mock_send_to_push_bouncer_response(method: str, *args: Any) -> dict[str, Any]:
if method == "POST": if method == "POST":
return post_response return post_response
return get_response return get_response
m.side_effect = mock_send_to_push_bouncer_response m.side_effect = mock_send_to_push_bouncer_response
initialize_push_notifications() initialize_push_notifications()
realm = get_realm("zulip") realm = get_realm("zulip")
self.assertTrue(realm.push_notifications_enabled) self.assertTrue(realm.push_notifications_enabled)
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
@override_settings(PUSH_NOTIFICATION_BOUNCER_URL="https://push.zulip.org.example.com") @override_settings(PUSH_NOTIFICATION_BOUNCER_URL="https://push.zulip.org.example.com")
@responses.activate @responses.activate
@ -2340,84 +2342,90 @@ class AnalyticsBouncerTest(BouncerTestCase):
def test_realm_properties_after_send_analytics(self) -> None: def test_realm_properties_after_send_analytics(self) -> None:
self.add_mock_response() self.add_mock_response()
with mock.patch( with (
"corporate.lib.stripe.RemoteRealmBillingSession.get_customer", return_value=None mock.patch(
) as m: "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", return_value=None
with mock.patch( ) as m,
mock.patch(
"corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses",
return_value=10, return_value=10,
): ),
send_server_data_to_push_bouncer(consider_usage_statistics=False) ):
m.assert_called() send_server_data_to_push_bouncer(consider_usage_statistics=False)
realms = Realm.objects.all() m.assert_called()
for realm in realms: realms = Realm.objects.all()
self.assertEqual(realm.push_notifications_enabled, True) for realm in realms:
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) self.assertEqual(realm.push_notifications_enabled, True)
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
with mock.patch( with (
"zilencer.views.RemoteRealmBillingSession.get_customer", return_value=None mock.patch(
) as m: "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=None
with mock.patch( ) as m,
mock.patch(
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses",
return_value=11, return_value=11,
): ),
send_server_data_to_push_bouncer(consider_usage_statistics=False) ):
m.assert_called() send_server_data_to_push_bouncer(consider_usage_statistics=False)
realms = Realm.objects.all() m.assert_called()
for realm in realms: realms = Realm.objects.all()
self.assertEqual(realm.push_notifications_enabled, False) for realm in realms:
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) self.assertEqual(realm.push_notifications_enabled, False)
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
dummy_customer = mock.MagicMock() dummy_customer = mock.MagicMock()
with mock.patch( with (
"corporate.lib.stripe.RemoteRealmBillingSession.get_customer", mock.patch(
return_value=dummy_customer, "corporate.lib.stripe.RemoteRealmBillingSession.get_customer",
return_value=dummy_customer,
),
mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None) as m,
): ):
with mock.patch( send_server_data_to_push_bouncer(consider_usage_statistics=False)
"corporate.lib.stripe.get_current_plan_by_customer", return_value=None m.assert_called()
) as m: realms = Realm.objects.all()
send_server_data_to_push_bouncer(consider_usage_statistics=False) for realm in realms:
m.assert_called() self.assertEqual(realm.push_notifications_enabled, True)
realms = Realm.objects.all() self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
for realm in realms:
self.assertEqual(realm.push_notifications_enabled, True)
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
dummy_customer = mock.MagicMock() dummy_customer = mock.MagicMock()
with mock.patch( with (
"zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer mock.patch(
"zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer
),
mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None) as m,
mock.patch(
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses",
return_value=11,
),
): ):
with mock.patch( send_server_data_to_push_bouncer(consider_usage_statistics=False)
"corporate.lib.stripe.get_current_plan_by_customer", return_value=None m.assert_called()
) as m: realms = Realm.objects.all()
with mock.patch( for realm in realms:
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", self.assertEqual(realm.push_notifications_enabled, False)
return_value=11, self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
):
send_server_data_to_push_bouncer(consider_usage_statistics=False)
m.assert_called()
realms = Realm.objects.all()
for realm in realms:
self.assertEqual(realm.push_notifications_enabled, False)
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
RemoteRealm.objects.filter(server=self.server).update( RemoteRealm.objects.filter(server=self.server).update(
plan_type=RemoteRealm.PLAN_TYPE_COMMUNITY plan_type=RemoteRealm.PLAN_TYPE_COMMUNITY
) )
with mock.patch( with (
"zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer mock.patch(
"zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer
),
mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None),
mock.patch(
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses"
) as m,
): ):
with mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None): send_server_data_to_push_bouncer(consider_usage_statistics=False)
with mock.patch( m.assert_not_called()
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses" realms = Realm.objects.all()
) as m: for realm in realms:
send_server_data_to_push_bouncer(consider_usage_statistics=False) self.assertEqual(realm.push_notifications_enabled, True)
m.assert_not_called() self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
realms = Realm.objects.all()
for realm in realms:
self.assertEqual(realm.push_notifications_enabled, True)
self.assertEqual(realm.push_notifications_enabled_end_timestamp, None)
# Reset the plan type to test remaining cases. # Reset the plan type to test remaining cases.
RemoteRealm.objects.filter(server=self.server).update( RemoteRealm.objects.filter(server=self.server).update(
@ -2427,118 +2435,122 @@ class AnalyticsBouncerTest(BouncerTestCase):
dummy_customer_plan = mock.MagicMock() dummy_customer_plan = mock.MagicMock()
dummy_customer_plan.status = CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE dummy_customer_plan.status = CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE
dummy_date = datetime(year=2023, month=12, day=3, tzinfo=timezone.utc) dummy_date = datetime(year=2023, month=12, day=3, tzinfo=timezone.utc)
with mock.patch( with (
"corporate.lib.stripe.RemoteRealmBillingSession.get_customer", mock.patch(
return_value=dummy_customer, "corporate.lib.stripe.RemoteRealmBillingSession.get_customer",
): return_value=dummy_customer,
with mock.patch( ),
mock.patch(
"corporate.lib.stripe.get_current_plan_by_customer", "corporate.lib.stripe.get_current_plan_by_customer",
return_value=dummy_customer_plan, return_value=dummy_customer_plan,
): ),
with mock.patch( mock.patch(
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses",
return_value=11, return_value=11,
): ),
with ( mock.patch(
mock.patch( "corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle",
"corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle", return_value=dummy_date,
return_value=dummy_date, ) as m,
) as m, self.assertLogs("zulip.analytics", level="INFO") as info_log,
self.assertLogs("zulip.analytics", level="INFO") as info_log, ):
): send_server_data_to_push_bouncer(consider_usage_statistics=False)
send_server_data_to_push_bouncer(consider_usage_statistics=False) m.assert_called()
m.assert_called() realms = Realm.objects.all()
realms = Realm.objects.all() for realm in realms:
for realm in realms: self.assertEqual(realm.push_notifications_enabled, True)
self.assertEqual(realm.push_notifications_enabled, True) self.assertEqual(
self.assertEqual( realm.push_notifications_enabled_end_timestamp,
realm.push_notifications_enabled_end_timestamp, dummy_date,
dummy_date, )
) self.assertIn(
self.assertIn( "INFO:zulip.analytics:Reported 0 records",
"INFO:zulip.analytics:Reported 0 records", info_log.output[0],
info_log.output[0], )
)
with mock.patch( with (
"corporate.lib.stripe.RemoteRealmBillingSession.get_customer", mock.patch(
return_value=dummy_customer, "corporate.lib.stripe.RemoteRealmBillingSession.get_customer",
): return_value=dummy_customer,
with mock.patch( ),
mock.patch(
"corporate.lib.stripe.get_current_plan_by_customer", "corporate.lib.stripe.get_current_plan_by_customer",
return_value=dummy_customer_plan, return_value=dummy_customer_plan,
): ),
with mock.patch( mock.patch(
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses",
side_effect=MissingDataError, side_effect=MissingDataError,
): ),
with ( mock.patch(
mock.patch( "corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle",
"corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle", return_value=dummy_date,
return_value=dummy_date, ) as m,
) as m, self.assertLogs("zulip.analytics", level="INFO") as info_log,
self.assertLogs("zulip.analytics", level="INFO") as info_log, ):
): send_server_data_to_push_bouncer(consider_usage_statistics=False)
send_server_data_to_push_bouncer(consider_usage_statistics=False) m.assert_called()
m.assert_called() realms = Realm.objects.all()
realms = Realm.objects.all() for realm in realms:
for realm in realms: self.assertEqual(realm.push_notifications_enabled, True)
self.assertEqual(realm.push_notifications_enabled, True) self.assertEqual(
self.assertEqual( realm.push_notifications_enabled_end_timestamp,
realm.push_notifications_enabled_end_timestamp, dummy_date,
dummy_date, )
) self.assertIn(
self.assertIn( "INFO:zulip.analytics:Reported 0 records",
"INFO:zulip.analytics:Reported 0 records", info_log.output[0],
info_log.output[0], )
)
with mock.patch( with (
"corporate.lib.stripe.RemoteRealmBillingSession.get_customer", mock.patch(
return_value=dummy_customer, "corporate.lib.stripe.RemoteRealmBillingSession.get_customer",
): return_value=dummy_customer,
with mock.patch( ),
mock.patch(
"corporate.lib.stripe.get_current_plan_by_customer", "corporate.lib.stripe.get_current_plan_by_customer",
return_value=dummy_customer_plan, return_value=dummy_customer_plan,
): ),
with mock.patch( mock.patch(
"corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses",
return_value=10, return_value=10,
): ),
send_server_data_to_push_bouncer(consider_usage_statistics=False) ):
m.assert_called() send_server_data_to_push_bouncer(consider_usage_statistics=False)
realms = Realm.objects.all() m.assert_called()
for realm in realms: realms = Realm.objects.all()
self.assertEqual(realm.push_notifications_enabled, True) for realm in realms:
self.assertEqual( self.assertEqual(realm.push_notifications_enabled, True)
realm.push_notifications_enabled_end_timestamp, self.assertEqual(
None, realm.push_notifications_enabled_end_timestamp,
) None,
)
dummy_customer_plan = mock.MagicMock() dummy_customer_plan = mock.MagicMock()
dummy_customer_plan.status = CustomerPlan.ACTIVE dummy_customer_plan.status = CustomerPlan.ACTIVE
with mock.patch( with (
"corporate.lib.stripe.RemoteRealmBillingSession.get_customer", mock.patch(
return_value=dummy_customer, "corporate.lib.stripe.RemoteRealmBillingSession.get_customer",
): return_value=dummy_customer,
with mock.patch( ),
mock.patch(
"corporate.lib.stripe.get_current_plan_by_customer", "corporate.lib.stripe.get_current_plan_by_customer",
return_value=dummy_customer_plan, return_value=dummy_customer_plan,
): ),
with self.assertLogs("zulip.analytics", level="INFO") as info_log: self.assertLogs("zulip.analytics", level="INFO") as info_log,
send_server_data_to_push_bouncer(consider_usage_statistics=False) ):
m.assert_called() send_server_data_to_push_bouncer(consider_usage_statistics=False)
realms = Realm.objects.all() m.assert_called()
for realm in realms: realms = Realm.objects.all()
self.assertEqual(realm.push_notifications_enabled, True) for realm in realms:
self.assertEqual( self.assertEqual(realm.push_notifications_enabled, True)
realm.push_notifications_enabled_end_timestamp, self.assertEqual(
None, realm.push_notifications_enabled_end_timestamp,
) None,
self.assertIn( )
"INFO:zulip.analytics:Reported 0 records", self.assertIn(
info_log.output[0], "INFO:zulip.analytics:Reported 0 records",
) info_log.output[0],
)
# Remote realm is on an inactive plan. Remote server on active plan. # Remote realm is on an inactive plan. Remote server on active plan.
# ACTIVE plan takes precedence. # ACTIVE plan takes precedence.

View File

@ -377,13 +377,16 @@ class WorkerTest(ZulipTestCase):
# If called after `expected_scheduled_timestamp`, it should process all emails. # If called after `expected_scheduled_timestamp`, it should process all emails.
one_minute_overdue = expected_scheduled_timestamp + timedelta(seconds=60) one_minute_overdue = expected_scheduled_timestamp + timedelta(seconds=60)
with time_machine.travel(one_minute_overdue, tick=True): with (
with send_mock as sm, self.assertLogs(level="INFO") as info_logs: time_machine.travel(one_minute_overdue, tick=True),
has_timeout = advance() send_mock as sm,
self.assertTrue(has_timeout) self.assertLogs(level="INFO") as info_logs,
self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) ):
has_timeout = advance() has_timeout = advance()
self.assertFalse(has_timeout) self.assertTrue(has_timeout)
self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0)
has_timeout = advance()
self.assertFalse(has_timeout)
self.assertEqual( self.assertEqual(
[ [
@ -643,20 +646,22 @@ class WorkerTest(ZulipTestCase):
self.assertEqual(mock_mirror_email.call_count, 4) self.assertEqual(mock_mirror_email.call_count, 4)
# If RateLimiterLockingError is thrown, we rate-limit the new message: # If RateLimiterLockingError is thrown, we rate-limit the new message:
with patch( with (
"zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit", patch(
side_effect=RateLimiterLockingError, "zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit",
side_effect=RateLimiterLockingError,
),
self.assertLogs("zerver.lib.rate_limiter", "WARNING") as mock_warn,
): ):
with self.assertLogs("zerver.lib.rate_limiter", "WARNING") as mock_warn: fake_client.enqueue("email_mirror", data[0])
fake_client.enqueue("email_mirror", data[0]) worker.start()
worker.start() self.assertEqual(mock_mirror_email.call_count, 4)
self.assertEqual(mock_mirror_email.call_count, 4) self.assertEqual(
self.assertEqual( mock_warn.output,
mock_warn.output, [
[ "WARNING:zerver.lib.rate_limiter:Deadlock trying to incr_ratelimit for RateLimitedRealmMirror:zulip"
"WARNING:zerver.lib.rate_limiter:Deadlock trying to incr_ratelimit for RateLimitedRealmMirror:zulip" ],
], )
)
self.assertEqual( self.assertEqual(
warn_logs.output, warn_logs.output,
[ [

View File

@ -1054,12 +1054,14 @@ class ReactionAPIEventTest(EmojiReactionBase):
"emoji_code": "1f354", "emoji_code": "1f354",
"reaction_type": "unicode_emoji", "reaction_type": "unicode_emoji",
} }
with self.capture_send_event_calls(expected_num_events=1) as events: with (
with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: self.capture_send_event_calls(expected_num_events=1) as events,
m.side_effect = AssertionError( mock.patch("zerver.tornado.django_api.queue_json_publish") as m,
"Events should be sent only after the transaction commits!" ):
) m.side_effect = AssertionError(
self.api_post(reaction_sender, f"/api/v1/messages/{pm_id}/reactions", reaction_info) "Events should be sent only after the transaction commits!"
)
self.api_post(reaction_sender, f"/api/v1/messages/{pm_id}/reactions", reaction_info)
event = events[0]["event"] event = events[0]["event"]
event_user_ids = set(events[0]["users"]) event_user_ids = set(events[0]["users"])
@ -1137,9 +1139,11 @@ class ReactionAPIEventTest(EmojiReactionBase):
reaction_type="whatever", reaction_type="whatever",
) )
with self.capture_send_event_calls(expected_num_events=1): with (
with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: self.capture_send_event_calls(expected_num_events=1),
m.side_effect = AssertionError( mock.patch("zerver.tornado.django_api.queue_json_publish") as m,
"Events should be sent only after the transaction commits." ):
) m.side_effect = AssertionError(
notify_reaction_update(hamlet, message, reaction, "stuff") "Events should be sent only after the transaction commits."
)
notify_reaction_update(hamlet, message, reaction, "stuff")

View File

@ -95,13 +95,14 @@ class RealmTest(ZulipTestCase):
) )
def test_realm_creation_on_special_subdomains_disallowed(self) -> None: def test_realm_creation_on_special_subdomains_disallowed(self) -> None:
with self.settings(SOCIAL_AUTH_SUBDOMAIN="zulipauth"): with self.settings(SOCIAL_AUTH_SUBDOMAIN="zulipauth"), self.assertRaises(AssertionError):
with self.assertRaises(AssertionError): do_create_realm("zulipauth", "Test Realm")
do_create_realm("zulipauth", "Test Realm")
with self.settings(SELF_HOSTING_MANAGEMENT_SUBDOMAIN="zulipselfhosting"): with (
with self.assertRaises(AssertionError): self.settings(SELF_HOSTING_MANAGEMENT_SUBDOMAIN="zulipselfhosting"),
do_create_realm("zulipselfhosting", "Test Realm") self.assertRaises(AssertionError),
):
do_create_realm("zulipselfhosting", "Test Realm")
def test_permission_for_education_non_profit_organization(self) -> None: def test_permission_for_education_non_profit_organization(self) -> None:
realm = do_create_realm( realm = do_create_realm(

View File

@ -315,9 +315,8 @@ class RealmEmojiTest(ZulipTestCase):
def test_emoji_upload_file_size_error(self) -> None: def test_emoji_upload_file_size_error(self) -> None:
self.login("iago") self.login("iago")
with get_test_image_file("img.png") as fp: with get_test_image_file("img.png") as fp, self.settings(MAX_EMOJI_FILE_SIZE_MIB=0):
with self.settings(MAX_EMOJI_FILE_SIZE_MIB=0): result = self.client_post("/json/realm/emoji/my_emoji", {"file": fp})
result = self.client_post("/json/realm/emoji/my_emoji", {"file": fp})
self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB")
def test_emoji_upload_file_format_error(self) -> None: def test_emoji_upload_file_format_error(self) -> None:
@ -355,12 +354,14 @@ class RealmEmojiTest(ZulipTestCase):
def test_failed_file_upload(self) -> None: def test_failed_file_upload(self) -> None:
self.login("iago") self.login("iago")
with mock.patch( with (
"zerver.lib.upload.local.write_local_file", side_effect=BadImageError(msg="Broken") mock.patch(
"zerver.lib.upload.local.write_local_file", side_effect=BadImageError(msg="Broken")
),
get_test_image_file("img.png") as fp1,
): ):
with get_test_image_file("img.png") as fp1: emoji_data = {"f1": fp1}
emoji_data = {"f1": fp1} result = self.client_post("/json/realm/emoji/my_emoji", info=emoji_data)
result = self.client_post("/json/realm/emoji/my_emoji", info=emoji_data)
self.assert_json_error(result, "Broken") self.assert_json_error(result, "Broken")
def test_check_admin_realm_emoji(self) -> None: def test_check_admin_realm_emoji(self) -> None:

View File

@ -49,9 +49,9 @@ class RealmExportTest(ZulipTestCase):
self.settings(LOCAL_UPLOADS_DIR=None), self.settings(LOCAL_UPLOADS_DIR=None),
stdout_suppressed(), stdout_suppressed(),
self.assertLogs(level="INFO") as info_logs, self.assertLogs(level="INFO") as info_logs,
self.captureOnCommitCallbacks(execute=True),
): ):
with self.captureOnCommitCallbacks(execute=True): result = self.client_post("/json/export/realm")
result = self.client_post("/json/export/realm")
self.assertTrue("INFO:root:Completed data export for zulip in " in info_logs.output[0]) self.assertTrue("INFO:root:Completed data export for zulip in " in info_logs.output[0])
self.assert_json_success(result) self.assert_json_success(result)
self.assertFalse(os.path.exists(tarball_path)) self.assertFalse(os.path.exists(tarball_path))
@ -150,9 +150,12 @@ class RealmExportTest(ZulipTestCase):
with patch( with patch(
"zerver.lib.export.do_export_realm", side_effect=fake_export_realm "zerver.lib.export.do_export_realm", side_effect=fake_export_realm
) as mock_export: ) as mock_export:
with stdout_suppressed(), self.assertLogs(level="INFO") as info_logs: with (
with self.captureOnCommitCallbacks(execute=True): stdout_suppressed(),
result = self.client_post("/json/export/realm") self.assertLogs(level="INFO") as info_logs,
self.captureOnCommitCallbacks(execute=True),
):
result = self.client_post("/json/export/realm")
self.assertTrue("INFO:root:Completed data export for zulip in " in info_logs.output[0]) self.assertTrue("INFO:root:Completed data export for zulip in " in info_logs.output[0])
mock_export.assert_called_once() mock_export.assert_called_once()
data = self.assert_json_success(result) data = self.assert_json_success(result)
@ -208,12 +211,15 @@ class RealmExportTest(ZulipTestCase):
admin = self.example_user("iago") admin = self.example_user("iago")
self.login_user(admin) self.login_user(admin)
with patch( with (
"zerver.lib.export.do_export_realm", side_effect=Exception("failure") patch(
) as mock_export: "zerver.lib.export.do_export_realm", side_effect=Exception("failure")
with stdout_suppressed(), self.assertLogs(level="INFO") as info_logs: ) as mock_export,
with self.captureOnCommitCallbacks(execute=True): stdout_suppressed(),
result = self.client_post("/json/export/realm") self.assertLogs(level="INFO") as info_logs,
self.captureOnCommitCallbacks(execute=True),
):
result = self.client_post("/json/export/realm")
self.assertTrue( self.assertTrue(
info_logs.output[0].startswith("ERROR:root:Data export for zulip failed after ") info_logs.output[0].startswith("ERROR:root:Data export for zulip failed after ")
) )
@ -240,18 +246,20 @@ class RealmExportTest(ZulipTestCase):
# If the queue worker sees the same export-id again, it aborts # If the queue worker sees the same export-id again, it aborts
# instead of retrying # instead of retrying
with patch("zerver.lib.export.do_export_realm") as mock_export: with (
with self.assertLogs(level="INFO") as info_logs: patch("zerver.lib.export.do_export_realm") as mock_export,
queue_json_publish( self.assertLogs(level="INFO") as info_logs,
"deferred_work", ):
{ queue_json_publish(
"type": "realm_export", "deferred_work",
"time": 42, {
"realm_id": admin.realm.id, "type": "realm_export",
"user_profile_id": admin.id, "time": 42,
"id": export_id, "realm_id": admin.realm.id,
}, "user_profile_id": admin.id,
) "id": export_id,
},
)
mock_export.assert_not_called() mock_export.assert_not_called()
self.assertEqual( self.assertEqual(
info_logs.output, info_logs.output,

View File

@ -132,15 +132,17 @@ class TestSendEmail(ZulipTestCase):
for message, side_effect in errors.items(): for message, side_effect in errors.items():
with mock.patch.object(EmailBackend, "send_messages", side_effect=side_effect): with mock.patch.object(EmailBackend, "send_messages", side_effect=side_effect):
with self.assertLogs(logger=logger) as info_log: with (
with self.assertRaises(EmailNotDeliveredError): self.assertLogs(logger=logger) as info_log,
send_email( self.assertRaises(EmailNotDeliveredError),
"zerver/emails/password_reset", ):
to_emails=[hamlet.email], send_email(
from_name=from_name, "zerver/emails/password_reset",
from_address=FromAddress.NOREPLY, to_emails=[hamlet.email],
language="en", from_name=from_name,
) from_address=FromAddress.NOREPLY,
language="en",
)
self.assert_length(info_log.records, 2) self.assert_length(info_log.records, 2)
self.assertEqual( self.assertEqual(
info_log.output[0], info_log.output[0],
@ -151,15 +153,17 @@ class TestSendEmail(ZulipTestCase):
def test_send_email_config_error_logging(self) -> None: def test_send_email_config_error_logging(self) -> None:
hamlet = self.example_user("hamlet") hamlet = self.example_user("hamlet")
with self.settings(EMAIL_HOST_USER="test", EMAIL_HOST_PASSWORD=None): with (
with self.assertLogs(logger=logger, level="ERROR") as error_log: self.settings(EMAIL_HOST_USER="test", EMAIL_HOST_PASSWORD=None),
send_email( self.assertLogs(logger=logger, level="ERROR") as error_log,
"zerver/emails/password_reset", ):
to_emails=[hamlet.email], send_email(
from_name="From Name", "zerver/emails/password_reset",
from_address=FromAddress.NOREPLY, to_emails=[hamlet.email],
language="en", from_name="From Name",
) from_address=FromAddress.NOREPLY,
language="en",
)
self.assertEqual( self.assertEqual(
error_log.output, error_log.output,

View File

@ -1050,9 +1050,12 @@ class LoginTest(ZulipTestCase):
# seem to be any O(N) behavior. Some of the cache hits are related # seem to be any O(N) behavior. Some of the cache hits are related
# to sending messages, such as getting the welcome bot, looking up # to sending messages, such as getting the welcome bot, looking up
# the alert words for a realm, etc. # the alert words for a realm, etc.
with self.assert_database_query_count(94), self.assert_memcached_count(14): with (
with self.captureOnCommitCallbacks(execute=True): self.assert_database_query_count(94),
self.register(self.nonreg_email("test"), "test") self.assert_memcached_count(14),
self.captureOnCommitCallbacks(execute=True),
):
self.register(self.nonreg_email("test"), "test")
user_profile = self.nonreg_user("test") user_profile = self.nonreg_user("test")
self.assert_logged_in_user_id(user_profile.id) self.assert_logged_in_user_id(user_profile.id)
@ -2946,21 +2949,23 @@ class UserSignUpTest(ZulipTestCase):
return_data = kwargs.get("return_data", {}) return_data = kwargs.get("return_data", {})
return_data["invalid_subdomain"] = True return_data["invalid_subdomain"] = True
with patch("zerver.views.registration.authenticate", side_effect=invalid_subdomain): with (
with self.assertLogs(level="ERROR") as m: patch("zerver.views.registration.authenticate", side_effect=invalid_subdomain),
result = self.client_post( self.assertLogs(level="ERROR") as m,
"/accounts/register/", ):
{ result = self.client_post(
"password": password, "/accounts/register/",
"full_name": "New User", {
"key": find_key_by_email(email), "password": password,
"terms": True, "full_name": "New User",
}, "key": find_key_by_email(email),
) "terms": True,
self.assertEqual( },
m.output, )
["ERROR:root:Subdomain mismatch in registration zulip: newuser@zulip.com"], self.assertEqual(
) m.output,
["ERROR:root:Subdomain mismatch in registration zulip: newuser@zulip.com"],
)
self.assertEqual(result.status_code, 302) self.assertEqual(result.status_code, 302)
def test_signup_using_invalid_subdomain_preserves_state_of_form(self) -> None: def test_signup_using_invalid_subdomain_preserves_state_of_form(self) -> None:

View File

@ -273,9 +273,11 @@ class UserSoftDeactivationTests(ZulipTestCase):
).count() ).count()
self.assertEqual(0, received_count) self.assertEqual(0, received_count)
with self.settings(AUTO_CATCH_UP_SOFT_DEACTIVATED_USERS=False): with (
with self.assertLogs(logger_string, level="INFO") as m: self.settings(AUTO_CATCH_UP_SOFT_DEACTIVATED_USERS=False),
users_deactivated = do_auto_soft_deactivate_users(-1, realm) self.assertLogs(logger_string, level="INFO") as m,
):
users_deactivated = do_auto_soft_deactivate_users(-1, realm)
self.assertEqual( self.assertEqual(
m.output, m.output,
[ [

View File

@ -194,12 +194,14 @@ class TestBasics(ZulipTestCase):
hamlet = self.example_user("hamlet") hamlet = self.example_user("hamlet")
message_id = self.send_stream_message(hamlet, "Denmark") message_id = self.send_stream_message(hamlet, "Denmark")
with self.capture_send_event_calls(expected_num_events=1): with (
with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: self.capture_send_event_calls(expected_num_events=1),
m.side_effect = AssertionError( mock.patch("zerver.tornado.django_api.queue_json_publish") as m,
"Events should be sent only after the transaction commits." ):
) m.side_effect = AssertionError(
do_add_submessage(hamlet.realm, hamlet.id, message_id, "whatever", "whatever") "Events should be sent only after the transaction commits."
)
do_add_submessage(hamlet.realm, hamlet.id, message_id, "whatever", "whatever")
def test_fetch_message_containing_submessages(self) -> None: def test_fetch_message_containing_submessages(self) -> None:
cordelia = self.example_user("cordelia") cordelia = self.example_user("cordelia")

View File

@ -2607,16 +2607,18 @@ class StreamAdminTest(ZulipTestCase):
for user in other_sub_users: for user in other_sub_users:
self.subscribe(user, stream_name) self.subscribe(user, stream_name)
with self.assert_database_query_count(query_count): with (
with cache_tries_captured() as cache_tries: self.assert_database_query_count(query_count),
with self.captureOnCommitCallbacks(execute=True): cache_tries_captured() as cache_tries,
result = self.client_delete( self.captureOnCommitCallbacks(execute=True),
"/json/users/me/subscriptions", ):
{ result = self.client_delete(
"subscriptions": orjson.dumps([stream_name]).decode(), "/json/users/me/subscriptions",
"principals": orjson.dumps(principals).decode(), {
}, "subscriptions": orjson.dumps([stream_name]).decode(),
) "principals": orjson.dumps(principals).decode(),
},
)
if cache_count is not None: if cache_count is not None:
self.assert_length(cache_tries, cache_count) self.assert_length(cache_tries, cache_count)
@ -4744,13 +4746,15 @@ class SubscriptionAPITest(ZulipTestCase):
user2 = self.example_user("iago") user2 = self.example_user("iago")
realm = get_realm("zulip") realm = get_realm("zulip")
streams_to_sub = ["multi_user_stream"] streams_to_sub = ["multi_user_stream"]
with self.capture_send_event_calls(expected_num_events=5) as events: with (
with self.assert_database_query_count(38): self.capture_send_event_calls(expected_num_events=5) as events,
self.common_subscribe_to_streams( self.assert_database_query_count(38),
self.test_user, ):
streams_to_sub, self.common_subscribe_to_streams(
dict(principals=orjson.dumps([user1.id, user2.id]).decode()), self.test_user,
) streams_to_sub,
dict(principals=orjson.dumps([user1.id, user2.id]).decode()),
)
for ev in [x for x in events if x["event"]["type"] not in ("message", "stream")]: for ev in [x for x in events if x["event"]["type"] not in ("message", "stream")]:
if ev["event"]["op"] == "add": if ev["event"]["op"] == "add":
@ -4768,13 +4772,15 @@ class SubscriptionAPITest(ZulipTestCase):
self.assertEqual(num_subscribers_for_stream_id(stream.id), 2) self.assertEqual(num_subscribers_for_stream_id(stream.id), 2)
# Now add ourselves # Now add ourselves
with self.capture_send_event_calls(expected_num_events=2) as events: with (
with self.assert_database_query_count(14): self.capture_send_event_calls(expected_num_events=2) as events,
self.common_subscribe_to_streams( self.assert_database_query_count(14),
self.test_user, ):
streams_to_sub, self.common_subscribe_to_streams(
dict(principals=orjson.dumps([self.test_user.id]).decode()), self.test_user,
) streams_to_sub,
dict(principals=orjson.dumps([self.test_user.id]).decode()),
)
add_event, add_peer_event = events add_event, add_peer_event = events
self.assertEqual(add_event["event"]["type"], "subscription") self.assertEqual(add_event["event"]["type"], "subscription")
@ -5061,15 +5067,17 @@ class SubscriptionAPITest(ZulipTestCase):
# Sends 3 peer-remove events, 2 unsubscribe events # Sends 3 peer-remove events, 2 unsubscribe events
# and 2 stream delete events for private streams. # and 2 stream delete events for private streams.
with self.assert_database_query_count(16): with (
with self.assert_memcached_count(3): self.assert_database_query_count(16),
with self.capture_send_event_calls(expected_num_events=7) as events: self.assert_memcached_count(3),
bulk_remove_subscriptions( self.capture_send_event_calls(expected_num_events=7) as events,
realm, ):
[user1, user2], bulk_remove_subscriptions(
[stream1, stream2, stream3, private], realm,
acting_user=None, [user1, user2],
) [stream1, stream2, stream3, private],
acting_user=None,
)
peer_events = [e for e in events if e["event"].get("op") == "peer_remove"] peer_events = [e for e in events if e["event"].get("op") == "peer_remove"]
stream_delete_events = [ stream_delete_events = [
@ -5214,14 +5222,16 @@ class SubscriptionAPITest(ZulipTestCase):
# The only known O(N) behavior here is that we call # The only known O(N) behavior here is that we call
# principal_to_user_profile for each of our users, but it # principal_to_user_profile for each of our users, but it
# should be cached. # should be cached.
with self.assert_database_query_count(21): with (
with self.assert_memcached_count(3): self.assert_database_query_count(21),
with mock.patch("zerver.views.streams.send_messages_for_new_subscribers"): self.assert_memcached_count(3),
self.common_subscribe_to_streams( mock.patch("zerver.views.streams.send_messages_for_new_subscribers"),
desdemona, ):
streams, self.common_subscribe_to_streams(
dict(principals=orjson.dumps(test_user_ids).decode()), desdemona,
) streams,
dict(principals=orjson.dumps(test_user_ids).decode()),
)
def test_subscriptions_add_for_principal(self) -> None: def test_subscriptions_add_for_principal(self) -> None:
""" """

View File

@ -176,9 +176,11 @@ class TypingHappyPathTestDirectMessages(ZulipTestCase):
op="start", op="start",
) )
with self.assert_database_query_count(4): with (
with self.capture_send_event_calls(expected_num_events=1) as events: self.assert_database_query_count(4),
result = self.api_post(sender, "/api/v1/typing", params) self.capture_send_event_calls(expected_num_events=1) as events,
):
result = self.api_post(sender, "/api/v1/typing", params)
self.assert_json_success(result) self.assert_json_success(result)
self.assert_length(events, 1) self.assert_length(events, 1)
@ -212,9 +214,11 @@ class TypingHappyPathTestDirectMessages(ZulipTestCase):
op="start", op="start",
) )
with self.assert_database_query_count(5): with (
with self.capture_send_event_calls(expected_num_events=1) as events: self.assert_database_query_count(5),
result = self.api_post(sender, "/api/v1/typing", params) self.capture_send_event_calls(expected_num_events=1) as events,
):
result = self.api_post(sender, "/api/v1/typing", params)
self.assert_json_success(result) self.assert_json_success(result)
self.assert_length(events, 1) self.assert_length(events, 1)
@ -406,9 +410,11 @@ class TypingHappyPathTestStreams(ZulipTestCase):
topic=topic_name, topic=topic_name,
) )
with self.assert_database_query_count(6): with (
with self.capture_send_event_calls(expected_num_events=1) as events: self.assert_database_query_count(6),
result = self.api_post(sender, "/api/v1/typing", params) self.capture_send_event_calls(expected_num_events=1) as events,
):
result = self.api_post(sender, "/api/v1/typing", params)
self.assert_json_success(result) self.assert_json_success(result)
self.assert_length(events, 1) self.assert_length(events, 1)
@ -437,9 +443,11 @@ class TypingHappyPathTestStreams(ZulipTestCase):
topic=topic_name, topic=topic_name,
) )
with self.assert_database_query_count(6): with (
with self.capture_send_event_calls(expected_num_events=1) as events: self.assert_database_query_count(6),
result = self.api_post(sender, "/api/v1/typing", params) self.capture_send_event_calls(expected_num_events=1) as events,
):
result = self.api_post(sender, "/api/v1/typing", params)
self.assert_json_success(result) self.assert_json_success(result)
self.assert_length(events, 1) self.assert_length(events, 1)
@ -470,9 +478,11 @@ class TypingHappyPathTestStreams(ZulipTestCase):
topic=topic_name, topic=topic_name,
) )
with self.settings(MAX_STREAM_SIZE_FOR_TYPING_NOTIFICATIONS=5): with self.settings(MAX_STREAM_SIZE_FOR_TYPING_NOTIFICATIONS=5):
with self.assert_database_query_count(5): with (
with self.capture_send_event_calls(expected_num_events=0) as events: self.assert_database_query_count(5),
result = self.api_post(sender, "/api/v1/typing", params) self.capture_send_event_calls(expected_num_events=0) as events,
):
result = self.api_post(sender, "/api/v1/typing", params)
self.assert_json_success(result) self.assert_json_success(result)
self.assert_length(events, 0) self.assert_length(events, 0)
@ -501,9 +511,11 @@ class TypingHappyPathTestStreams(ZulipTestCase):
topic=topic_name, topic=topic_name,
) )
with self.assert_database_query_count(6): with (
with self.capture_send_event_calls(expected_num_events=1) as events: self.assert_database_query_count(6),
result = self.api_post(sender, "/api/v1/typing", params) self.capture_send_event_calls(expected_num_events=1) as events,
):
result = self.api_post(sender, "/api/v1/typing", params)
self.assert_json_success(result) self.assert_json_success(result)
self.assert_length(events, 1) self.assert_length(events, 1)

View File

@ -1390,9 +1390,11 @@ class AvatarTest(UploadSerializeMixin, ZulipTestCase):
def test_avatar_upload_file_size_error(self) -> None: def test_avatar_upload_file_size_error(self) -> None:
self.login("hamlet") self.login("hamlet")
with get_test_image_file(self.correct_files[0][0]) as fp: with (
with self.settings(MAX_AVATAR_FILE_SIZE_MIB=0): get_test_image_file(self.correct_files[0][0]) as fp,
result = self.client_post("/json/users/me/avatar", {"file": fp}) self.settings(MAX_AVATAR_FILE_SIZE_MIB=0),
):
result = self.client_post("/json/users/me/avatar", {"file": fp})
self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB")
@ -1537,9 +1539,11 @@ class RealmIconTest(UploadSerializeMixin, ZulipTestCase):
def test_realm_icon_upload_file_size_error(self) -> None: def test_realm_icon_upload_file_size_error(self) -> None:
self.login("iago") self.login("iago")
with get_test_image_file(self.correct_files[0][0]) as fp: with (
with self.settings(MAX_ICON_FILE_SIZE_MIB=0): get_test_image_file(self.correct_files[0][0]) as fp,
result = self.client_post("/json/realm/icon", {"file": fp}) self.settings(MAX_ICON_FILE_SIZE_MIB=0),
):
result = self.client_post("/json/realm/icon", {"file": fp})
self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB")
@ -1743,11 +1747,13 @@ class RealmLogoTest(UploadSerializeMixin, ZulipTestCase):
def test_logo_upload_file_size_error(self) -> None: def test_logo_upload_file_size_error(self) -> None:
self.login("iago") self.login("iago")
with get_test_image_file(self.correct_files[0][0]) as fp: with (
with self.settings(MAX_LOGO_FILE_SIZE_MIB=0): get_test_image_file(self.correct_files[0][0]) as fp,
result = self.client_post( self.settings(MAX_LOGO_FILE_SIZE_MIB=0),
"/json/realm/logo", {"file": fp, "night": orjson.dumps(self.night).decode()} ):
) result = self.client_post(
"/json/realm/logo", {"file": fp, "night": orjson.dumps(self.night).decode()}
)
self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB")
@ -1766,53 +1772,63 @@ class EmojiTest(UploadSerializeMixin, ZulipTestCase):
def test_non_image(self) -> None: def test_non_image(self) -> None:
"""Non-image is not resized""" """Non-image is not resized"""
self.login("iago") self.login("iago")
with get_test_image_file("text.txt") as f: with (
with patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock: get_test_image_file("text.txt") as f,
result = self.client_post("/json/realm/emoji/new", {"f1": f}) patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock,
self.assert_json_error(result, "Invalid image format") ):
resize_mock.assert_not_called() result = self.client_post("/json/realm/emoji/new", {"f1": f})
self.assert_json_error(result, "Invalid image format")
resize_mock.assert_not_called()
def test_upsupported_format(self) -> None: def test_upsupported_format(self) -> None:
"""Invalid format is not resized""" """Invalid format is not resized"""
self.login("iago") self.login("iago")
with get_test_image_file("img.bmp") as f: with (
with patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock: get_test_image_file("img.bmp") as f,
result = self.client_post("/json/realm/emoji/new", {"f1": f}) patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock,
self.assert_json_error(result, "Invalid image format") ):
resize_mock.assert_not_called() result = self.client_post("/json/realm/emoji/new", {"f1": f})
self.assert_json_error(result, "Invalid image format")
resize_mock.assert_not_called()
def test_upload_too_big_after_resize(self) -> None: def test_upload_too_big_after_resize(self) -> None:
"""Non-animated image is too big after resizing""" """Non-animated image is too big after resizing"""
self.login("iago") self.login("iago")
with get_test_image_file("img.png") as f: with (
with patch( get_test_image_file("img.png") as f,
patch(
"zerver.lib.upload.resize_emoji", return_value=(b"a" * (200 * 1024), None) "zerver.lib.upload.resize_emoji", return_value=(b"a" * (200 * 1024), None)
) as resize_mock: ) as resize_mock,
result = self.client_post("/json/realm/emoji/new", {"f1": f}) ):
self.assert_json_error(result, "Image size exceeds limit") result = self.client_post("/json/realm/emoji/new", {"f1": f})
resize_mock.assert_called_once() self.assert_json_error(result, "Image size exceeds limit")
resize_mock.assert_called_once()
def test_upload_big_after_animated_resize(self) -> None: def test_upload_big_after_animated_resize(self) -> None:
"""A big animated image is fine as long as the still is small""" """A big animated image is fine as long as the still is small"""
self.login("iago") self.login("iago")
with get_test_image_file("animated_img.gif") as f: with (
with patch( get_test_image_file("animated_img.gif") as f,
patch(
"zerver.lib.upload.resize_emoji", return_value=(b"a" * (200 * 1024), b"aaa") "zerver.lib.upload.resize_emoji", return_value=(b"a" * (200 * 1024), b"aaa")
) as resize_mock: ) as resize_mock,
result = self.client_post("/json/realm/emoji/new", {"f1": f}) ):
self.assert_json_success(result) result = self.client_post("/json/realm/emoji/new", {"f1": f})
resize_mock.assert_called_once() self.assert_json_success(result)
resize_mock.assert_called_once()
def test_upload_too_big_after_animated_resize_still(self) -> None: def test_upload_too_big_after_animated_resize_still(self) -> None:
"""Still of animated image is too big after resizing""" """Still of animated image is too big after resizing"""
self.login("iago") self.login("iago")
with get_test_image_file("animated_img.gif") as f: with (
with patch( get_test_image_file("animated_img.gif") as f,
patch(
"zerver.lib.upload.resize_emoji", return_value=(b"aaa", b"a" * (200 * 1024)) "zerver.lib.upload.resize_emoji", return_value=(b"aaa", b"a" * (200 * 1024))
) as resize_mock: ) as resize_mock,
result = self.client_post("/json/realm/emoji/new", {"f1": f}) ):
self.assert_json_error(result, "Image size exceeds limit") result = self.client_post("/json/realm/emoji/new", {"f1": f})
resize_mock.assert_called_once() self.assert_json_error(result, "Image size exceeds limit")
resize_mock.assert_called_once()
class SanitizeNameTests(ZulipTestCase): class SanitizeNameTests(ZulipTestCase):

View File

@ -1156,9 +1156,11 @@ class UserGroupAPITestCase(UserGroupTestCase):
munge = lambda obj: orjson.dumps(obj).decode() munge = lambda obj: orjson.dumps(obj).decode()
params = dict(add=munge(new_user_ids)) params = dict(add=munge(new_user_ids))
with mock.patch("zerver.views.user_groups.notify_for_user_group_subscription_changes"): with (
with self.assert_database_query_count(11): mock.patch("zerver.views.user_groups.notify_for_user_group_subscription_changes"),
result = self.client_post(f"/json/user_groups/{user_group.id}/members", info=params) self.assert_database_query_count(11),
):
result = self.client_post(f"/json/user_groups/{user_group.id}/members", info=params)
self.assert_json_success(result) self.assert_json_success(result)
with self.assert_database_query_count(1): with self.assert_database_query_count(1):

View File

@ -338,10 +338,12 @@ class MutedTopicsTests(ZulipTestCase):
mock_date_muted = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() mock_date_muted = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp()
with self.capture_send_event_calls(expected_num_events=2) as events: with (
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): self.capture_send_event_calls(expected_num_events=2) as events,
result = self.api_post(user, url, data) time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False),
self.assert_json_success(result) ):
result = self.api_post(user, url, data)
self.assert_json_success(result)
self.assertTrue( self.assertTrue(
topic_has_visibility_policy( topic_has_visibility_policy(
@ -404,10 +406,12 @@ class MutedTopicsTests(ZulipTestCase):
mock_date_mute_removed = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() mock_date_mute_removed = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp()
with self.capture_send_event_calls(expected_num_events=2) as events: with (
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): self.capture_send_event_calls(expected_num_events=2) as events,
result = self.api_post(user, url, data) time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False),
self.assert_json_success(result) ):
result = self.api_post(user, url, data)
self.assert_json_success(result)
self.assertFalse( self.assertFalse(
topic_has_visibility_policy( topic_has_visibility_policy(
@ -553,10 +557,12 @@ class UnmutedTopicsTests(ZulipTestCase):
mock_date_unmuted = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() mock_date_unmuted = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp()
with self.capture_send_event_calls(expected_num_events=2) as events: with (
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): self.capture_send_event_calls(expected_num_events=2) as events,
result = self.api_post(user, url, data) time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False),
self.assert_json_success(result) ):
result = self.api_post(user, url, data)
self.assert_json_success(result)
self.assertTrue( self.assertTrue(
topic_has_visibility_policy( topic_has_visibility_policy(
@ -619,10 +625,12 @@ class UnmutedTopicsTests(ZulipTestCase):
mock_date_unmute_removed = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() mock_date_unmute_removed = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp()
with self.capture_send_event_calls(expected_num_events=2) as events: with (
with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): self.capture_send_event_calls(expected_num_events=2) as events,
result = self.api_post(user, url, data) time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False),
self.assert_json_success(result) ):
result = self.api_post(user, url, data)
self.assert_json_success(result)
self.assertFalse( self.assertFalse(
topic_has_visibility_policy( topic_has_visibility_policy(

View File

@ -909,17 +909,19 @@ class QueryCountTest(ZulipTestCase):
prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com") prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com")
with self.assert_database_query_count(84): with (
with self.assert_memcached_count(19): self.assert_database_query_count(84),
with self.capture_send_event_calls(expected_num_events=10) as events: self.assert_memcached_count(19),
fred = do_create_user( self.capture_send_event_calls(expected_num_events=10) as events,
email="fred@zulip.com", ):
password="password", fred = do_create_user(
realm=realm, email="fred@zulip.com",
full_name="Fred Flintstone", password="password",
prereg_user=prereg_user, realm=realm,
acting_user=None, full_name="Fred Flintstone",
) prereg_user=prereg_user,
acting_user=None,
)
peer_add_events = [event for event in events if event["event"].get("op") == "peer_add"] peer_add_events = [event for event in events if event["event"].get("op") == "peer_add"]
@ -2404,9 +2406,8 @@ class GetProfileTest(ZulipTestCase):
""" """
realm = get_realm("zulip") realm = get_realm("zulip")
email = self.example_user("hamlet").email email = self.example_user("hamlet").email
with self.assert_database_query_count(1): with self.assert_database_query_count(1), simulated_empty_cache() as cache_queries:
with simulated_empty_cache() as cache_queries: user_profile = get_user(email, realm)
user_profile = get_user(email, realm)
self.assert_length(cache_queries, 1) self.assert_length(cache_queries, 1)
self.assertEqual(user_profile.email, email) self.assertEqual(user_profile.email, email)

View File

@ -210,11 +210,11 @@ Try again next time
def test_bad_payload(self) -> None: def test_bad_payload(self) -> None:
bad = ("foo", None, "bar") bad = ("foo", None, "bar")
with self.assertRaisesRegex(AssertionError, "Unable to handle Pivotal payload"): with (
with mock.patch( self.assertRaisesRegex(AssertionError, "Unable to handle Pivotal payload"),
"zerver.webhooks.pivotal.view.api_pivotal_webhook_v3", return_value=bad mock.patch("zerver.webhooks.pivotal.view.api_pivotal_webhook_v3", return_value=bad),
): ):
self.check_webhook("accepted", expect_topic="foo") self.check_webhook("accepted", expect_topic="foo")
def test_bad_request(self) -> None: def test_bad_request(self) -> None:
request = mock.MagicMock() request = mock.MagicMock()
@ -226,9 +226,11 @@ Try again next time
self.assertEqual(result[0], "#0: ") self.assertEqual(result[0], "#0: ")
bad = orjson.loads(self.get_body("bad_kind")) bad = orjson.loads(self.get_body("bad_kind"))
with self.assertRaisesRegex(UnsupportedWebhookEventTypeError, "'unknown_kind'.* supported"): with (
with mock.patch("zerver.webhooks.pivotal.view.orjson.loads", return_value=bad): self.assertRaisesRegex(UnsupportedWebhookEventTypeError, "'unknown_kind'.* supported"),
api_pivotal_webhook_v5(request, hamlet) mock.patch("zerver.webhooks.pivotal.view.orjson.loads", return_value=bad),
):
api_pivotal_webhook_v5(request, hamlet)
@override @override
def get_body(self, fixture_name: str) -> str: def get_body(self, fixture_name: str) -> str:

View File

@ -276,9 +276,8 @@ class QueueProcessingWorker(ABC):
fn = os.path.join(settings.QUEUE_ERROR_DIR, fname) fn = os.path.join(settings.QUEUE_ERROR_DIR, fname)
line = f"{time.asctime()}\t{orjson.dumps(events).decode()}\n" line = f"{time.asctime()}\t{orjson.dumps(events).decode()}\n"
lock_fn = fn + ".lock" lock_fn = fn + ".lock"
with lockfile(lock_fn): with lockfile(lock_fn), open(fn, "a") as f:
with open(fn, "a") as f: f.write(line)
f.write(line)
check_and_send_restart_signal() check_and_send_restart_signal()
def setup(self) -> None: def setup(self) -> None: