diff --git a/corporate/tests/test_remote_billing.py b/corporate/tests/test_remote_billing.py index b176b83a92..cb3afb8553 100644 --- a/corporate/tests/test_remote_billing.py +++ b/corporate/tests/test_remote_billing.py @@ -560,12 +560,15 @@ class RemoteBillingAuthenticationTest(RemoteRealmBillingTestCase): ) # Try the case where the identity dict is simultaneously expired. - with time_machine.travel( - now + timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 30), - tick=False, + with ( + time_machine.travel( + now + timedelta(seconds=REMOTE_BILLING_SESSION_VALIDITY_SECONDS + 30), + tick=False, + ), + self.assertLogs("django.request", "ERROR") as m, + self.assertRaises(AssertionError), ): - with self.assertLogs("django.request", "ERROR") as m, self.assertRaises(AssertionError): - self.client_get(final_url, subdomain="selfhosting") + self.client_get(final_url, subdomain="selfhosting") # The django.request log should be a traceback, mentioning the relevant # exceptions that occurred. self.assertIn( diff --git a/corporate/tests/test_stripe.py b/corporate/tests/test_stripe.py index 85b0826b05..958ecf0f7f 100644 --- a/corporate/tests/test_stripe.py +++ b/corporate/tests/test_stripe.py @@ -1415,9 +1415,11 @@ class StripeTest(StripeTestCase): self.assertFalse(Customer.objects.filter(realm=user.realm).exists()) # Require free trial users to add a credit card. - with time_machine.travel(self.now, tick=False): - with self.assertLogs("corporate.stripe", "WARNING"): - response = self.upgrade() + with ( + time_machine.travel(self.now, tick=False), + self.assertLogs("corporate.stripe", "WARNING"), + ): + response = self.upgrade() self.assert_json_error( response, "Please add a credit card before starting your free trial." ) @@ -1953,12 +1955,14 @@ class StripeTest(StripeTestCase): initial_upgrade_request ) # Change the seat count while the user is going through the upgrade flow - with patch("corporate.lib.stripe.get_latest_seat_count", return_value=new_seat_count): - with patch( + with ( + patch("corporate.lib.stripe.get_latest_seat_count", return_value=new_seat_count), + patch( "corporate.lib.stripe.RealmBillingSession.get_initial_upgrade_context", return_value=(_, context_when_upgrade_page_is_rendered), - ): - self.add_card_and_upgrade(hamlet) + ), + ): + self.add_card_and_upgrade(hamlet) customer = Customer.objects.first() assert customer is not None @@ -2072,11 +2076,13 @@ class StripeTest(StripeTestCase): hamlet = self.example_user("hamlet") self.login_user(hamlet) self.local_upgrade(self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False) - with self.assertLogs("corporate.stripe", "WARNING") as m: - with self.assertRaises(BillingError) as context: - self.local_upgrade( - self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False - ) + with ( + self.assertLogs("corporate.stripe", "WARNING") as m, + self.assertRaises(BillingError) as context, + ): + self.local_upgrade( + self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False + ) self.assertEqual( "subscribing with existing subscription", context.exception.error_description ) @@ -2197,14 +2203,16 @@ class StripeTest(StripeTestCase): else: del_args = [] upgrade_params["licenses"] = licenses - with patch("corporate.lib.stripe.BillingSession.process_initial_upgrade"): - with patch( + with ( + patch("corporate.lib.stripe.BillingSession.process_initial_upgrade"), + patch( "corporate.lib.stripe.BillingSession.create_stripe_invoice_and_charge", return_value="fake_stripe_invoice_id", - ): - response = self.upgrade( - invoice=invoice, talk_to_stripe=False, del_args=del_args, **upgrade_params - ) + ), + ): + response = self.upgrade( + invoice=invoice, talk_to_stripe=False, del_args=del_args, **upgrade_params + ) self.assert_json_success(response) # Autopay with licenses < seat count @@ -2911,18 +2919,20 @@ class StripeTest(StripeTestCase): assert plan is not None self.assertEqual(plan.licenses(), self.seat_count) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, - ) - stripe_customer_id = Customer.objects.get(realm=user.realm).id - new_plan = get_current_plan_by_realm(user.realm) - assert new_plan is not None - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, + ) + stripe_customer_id = Customer.objects.get(realm=user.realm).id + new_plan = get_current_plan_by_realm(user.realm) + assert new_plan is not None + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) plan.refresh_from_db() self.assertEqual(plan.licenses(), self.seat_count) self.assertEqual(plan.licenses_at_next_renewal(), None) @@ -3034,15 +3044,17 @@ class StripeTest(StripeTestCase): new_plan = get_current_plan_by_realm(user.realm) assert new_plan is not None - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, - ) - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, + ) + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) monthly_plan.refresh_from_db() self.assertEqual(monthly_plan.status, CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE) with time_machine.travel(self.now, tick=False): @@ -3062,9 +3074,11 @@ class StripeTest(StripeTestCase): (20, 20), ) - with time_machine.travel(self.next_month, tick=False): - with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25): - billing_session.update_license_ledger_if_needed(self.next_month) + with ( + time_machine.travel(self.next_month, tick=False), + patch("corporate.lib.stripe.get_latest_seat_count", return_value=25), + ): + billing_session.update_license_ledger_if_needed(self.next_month) self.assertEqual(LicenseLedger.objects.filter(plan=monthly_plan).count(), 2) customer = get_customer_by_realm(user.realm) assert customer is not None @@ -3230,17 +3244,19 @@ class StripeTest(StripeTestCase): stripe_customer_id = Customer.objects.get(realm=user.realm).id new_plan = get_current_plan_by_realm(user.realm) assert new_plan is not None - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, - ) - self.assertEqual( - m.output[0], - f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}", - ) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, + ) + self.assertEqual( + m.output[0], + f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}", + ) + self.assert_json_success(response) monthly_plan.refresh_from_db() self.assertEqual(monthly_plan.status, CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE) with time_machine.travel(self.now, tick=False): @@ -3343,15 +3359,17 @@ class StripeTest(StripeTestCase): assert new_plan is not None assert self.now is not None - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}, - ) - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}, + ) + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) annual_plan.refresh_from_db() self.assertEqual(annual_plan.status, CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE) with time_machine.travel(self.now, tick=False): @@ -3375,9 +3393,11 @@ class StripeTest(StripeTestCase): # additional licenses) but at the end of current billing cycle. self.assertEqual(annual_plan.next_invoice_date, self.next_month) assert annual_plan.next_invoice_date is not None - with time_machine.travel(annual_plan.next_invoice_date, tick=False): - with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25): - billing_session.update_license_ledger_if_needed(annual_plan.next_invoice_date) + with ( + time_machine.travel(annual_plan.next_invoice_date, tick=False), + patch("corporate.lib.stripe.get_latest_seat_count", return_value=25), + ): + billing_session.update_license_ledger_if_needed(annual_plan.next_invoice_date) annual_plan.refresh_from_db() self.assertEqual(annual_plan.status, CustomerPlan.SWITCH_TO_MONTHLY_AT_END_OF_CYCLE) @@ -3430,9 +3450,11 @@ class StripeTest(StripeTestCase): self.assertEqual(invoice_item2[key], value) # Check that we switch to monthly plan at the end of current billing cycle. - with time_machine.travel(self.next_year, tick=False): - with patch("corporate.lib.stripe.get_latest_seat_count", return_value=25): - billing_session.update_license_ledger_if_needed(self.next_year) + with ( + time_machine.travel(self.next_year, tick=False), + patch("corporate.lib.stripe.get_latest_seat_count", return_value=25), + ): + billing_session.update_license_ledger_if_needed(self.next_year) self.assertEqual(LicenseLedger.objects.filter(plan=annual_plan).count(), 3) customer = get_customer_by_realm(user.realm) assert customer is not None @@ -3513,30 +3535,34 @@ class StripeTest(StripeTestCase): self.local_upgrade( self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False ) - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, - ) - stripe_customer_id = Customer.objects.get(realm=user.realm).id - new_plan = get_current_plan_by_realm(user.realm) - assert new_plan is not None - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, + ) + stripe_customer_id = Customer.objects.get(realm=user.realm).id + new_plan = get_current_plan_by_realm(user.realm) + assert new_plan is not None + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) plan = CustomerPlan.objects.first() assert plan is not None self.assertEqual(plan.status, CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE) - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.ACTIVE}, - ) - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.ACTIVE}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.ACTIVE}, + ) + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.ACTIVE}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) plan = CustomerPlan.objects.first() assert plan is not None self.assertEqual(plan.status, CustomerPlan.ACTIVE) @@ -3587,55 +3613,54 @@ class StripeTest(StripeTestCase): self.login_user(user) free_trial_end_date = self.now + timedelta(days=60) - with self.settings(CLOUD_FREE_TRIAL_DAYS=60): - with time_machine.travel(self.now, tick=False): - self.add_card_and_upgrade(user, schedule="monthly") - plan = CustomerPlan.objects.get() - self.assertEqual(plan.next_invoice_date, free_trial_end_date) - self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) - self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL) + with self.settings(CLOUD_FREE_TRIAL_DAYS=60), time_machine.travel(self.now, tick=False): + self.add_card_and_upgrade(user, schedule="monthly") + plan = CustomerPlan.objects.get() + self.assertEqual(plan.next_invoice_date, free_trial_end_date) + self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) + self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL) - customer = get_customer_by_realm(user.realm) - assert customer is not None - result = self.client_billing_patch( - "/billing/plan", - { - "status": CustomerPlan.FREE_TRIAL, - "schedule": CustomerPlan.BILLING_SCHEDULE_ANNUAL, - }, - ) - self.assert_json_success(result) + customer = get_customer_by_realm(user.realm) + assert customer is not None + result = self.client_billing_patch( + "/billing/plan", + { + "status": CustomerPlan.FREE_TRIAL, + "schedule": CustomerPlan.BILLING_SCHEDULE_ANNUAL, + }, + ) + self.assert_json_success(result) - plan.refresh_from_db() - self.assertEqual(plan.status, CustomerPlan.ENDED) - self.assertIsNone(plan.next_invoice_date) + plan.refresh_from_db() + self.assertEqual(plan.status, CustomerPlan.ENDED) + self.assertIsNone(plan.next_invoice_date) - new_plan = CustomerPlan.objects.get( - customer=customer, - automanage_licenses=True, - price_per_license=8000, - fixed_price=None, - discount=None, - billing_cycle_anchor=self.now, - billing_schedule=CustomerPlan.BILLING_SCHEDULE_ANNUAL, - next_invoice_date=free_trial_end_date, - tier=CustomerPlan.TIER_CLOUD_STANDARD, - status=CustomerPlan.FREE_TRIAL, - charge_automatically=True, - ) - ledger_entry = LicenseLedger.objects.get( - plan=new_plan, - is_renewal=True, - event_time=self.now, - licenses=self.seat_count, - licenses_at_next_renewal=self.seat_count, - ) - self.assertEqual(new_plan.invoiced_through, ledger_entry) + new_plan = CustomerPlan.objects.get( + customer=customer, + automanage_licenses=True, + price_per_license=8000, + fixed_price=None, + discount=None, + billing_cycle_anchor=self.now, + billing_schedule=CustomerPlan.BILLING_SCHEDULE_ANNUAL, + next_invoice_date=free_trial_end_date, + tier=CustomerPlan.TIER_CLOUD_STANDARD, + status=CustomerPlan.FREE_TRIAL, + charge_automatically=True, + ) + ledger_entry = LicenseLedger.objects.get( + plan=new_plan, + is_renewal=True, + event_time=self.now, + licenses=self.seat_count, + licenses_at_next_renewal=self.seat_count, + ) + self.assertEqual(new_plan.invoiced_through, ledger_entry) - realm_audit_log = RealmAuditLog.objects.filter( - event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN - ).last() - assert realm_audit_log is not None + realm_audit_log = RealmAuditLog.objects.filter( + event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_MONTHLY_TO_ANNUAL_PLAN + ).last() + assert realm_audit_log is not None @mock_stripe() def test_switch_now_free_trial_from_annual_to_monthly(self, *mocks: Mock) -> None: @@ -3643,54 +3668,53 @@ class StripeTest(StripeTestCase): self.login_user(user) free_trial_end_date = self.now + timedelta(days=60) - with self.settings(CLOUD_FREE_TRIAL_DAYS=60): - with time_machine.travel(self.now, tick=False): - self.add_card_and_upgrade(user, schedule="annual") - plan = CustomerPlan.objects.get() - self.assertEqual(plan.next_invoice_date, free_trial_end_date) - self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) - self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL) + with self.settings(CLOUD_FREE_TRIAL_DAYS=60), time_machine.travel(self.now, tick=False): + self.add_card_and_upgrade(user, schedule="annual") + plan = CustomerPlan.objects.get() + self.assertEqual(plan.next_invoice_date, free_trial_end_date) + self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) + self.assertEqual(plan.status, CustomerPlan.FREE_TRIAL) - customer = get_customer_by_realm(user.realm) - assert customer is not None - result = self.client_billing_patch( - "/billing/plan", - { - "status": CustomerPlan.FREE_TRIAL, - "schedule": CustomerPlan.BILLING_SCHEDULE_MONTHLY, - }, - ) - self.assert_json_success(result) - plan.refresh_from_db() - self.assertEqual(plan.status, CustomerPlan.ENDED) - self.assertIsNone(plan.next_invoice_date) + customer = get_customer_by_realm(user.realm) + assert customer is not None + result = self.client_billing_patch( + "/billing/plan", + { + "status": CustomerPlan.FREE_TRIAL, + "schedule": CustomerPlan.BILLING_SCHEDULE_MONTHLY, + }, + ) + self.assert_json_success(result) + plan.refresh_from_db() + self.assertEqual(plan.status, CustomerPlan.ENDED) + self.assertIsNone(plan.next_invoice_date) - new_plan = CustomerPlan.objects.get( - customer=customer, - automanage_licenses=True, - price_per_license=800, - fixed_price=None, - discount=None, - billing_cycle_anchor=self.now, - billing_schedule=CustomerPlan.BILLING_SCHEDULE_MONTHLY, - next_invoice_date=free_trial_end_date, - tier=CustomerPlan.TIER_CLOUD_STANDARD, - status=CustomerPlan.FREE_TRIAL, - charge_automatically=True, - ) - ledger_entry = LicenseLedger.objects.get( - plan=new_plan, - is_renewal=True, - event_time=self.now, - licenses=self.seat_count, - licenses_at_next_renewal=self.seat_count, - ) - self.assertEqual(new_plan.invoiced_through, ledger_entry) + new_plan = CustomerPlan.objects.get( + customer=customer, + automanage_licenses=True, + price_per_license=800, + fixed_price=None, + discount=None, + billing_cycle_anchor=self.now, + billing_schedule=CustomerPlan.BILLING_SCHEDULE_MONTHLY, + next_invoice_date=free_trial_end_date, + tier=CustomerPlan.TIER_CLOUD_STANDARD, + status=CustomerPlan.FREE_TRIAL, + charge_automatically=True, + ) + ledger_entry = LicenseLedger.objects.get( + plan=new_plan, + is_renewal=True, + event_time=self.now, + licenses=self.seat_count, + licenses_at_next_renewal=self.seat_count, + ) + self.assertEqual(new_plan.invoiced_through, ledger_entry) - realm_audit_log = RealmAuditLog.objects.filter( - event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_ANNUAL_TO_MONTHLY_PLAN - ).last() - assert realm_audit_log is not None + realm_audit_log = RealmAuditLog.objects.filter( + event_type=RealmAuditLog.CUSTOMER_SWITCHED_FROM_ANNUAL_TO_MONTHLY_PLAN + ).last() + assert realm_audit_log is not None @mock_stripe() def test_end_free_trial(self, *mocks: Mock) -> None: @@ -3764,18 +3788,20 @@ class StripeTest(StripeTestCase): self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) # Schedule downgrade - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}, - ) - stripe_customer_id = Customer.objects.get(realm=user.realm).id - new_plan = get_current_plan_by_realm(user.realm) - assert new_plan is not None - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}, + ) + stripe_customer_id = Customer.objects.get(realm=user.realm).id + new_plan = get_current_plan_by_realm(user.realm) + assert new_plan is not None + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) plan.refresh_from_db() self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) @@ -3874,18 +3900,20 @@ class StripeTest(StripeTestCase): self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) # Schedule downgrade - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}, - ) - stripe_customer_id = Customer.objects.get(realm=user.realm).id - new_plan = get_current_plan_by_realm(user.realm) - assert new_plan is not None - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}, + ) + stripe_customer_id = Customer.objects.get(realm=user.realm).id + new_plan = get_current_plan_by_realm(user.realm) + assert new_plan is not None + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_FREE_TRIAL}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) plan.refresh_from_db() self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) @@ -3894,18 +3922,20 @@ class StripeTest(StripeTestCase): self.assertEqual(plan.licenses_at_next_renewal(), None) # Cancel downgrade - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.FREE_TRIAL}, - ) - stripe_customer_id = Customer.objects.get(realm=user.realm).id - new_plan = get_current_plan_by_realm(user.realm) - assert new_plan is not None - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.FREE_TRIAL}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.FREE_TRIAL}, + ) + stripe_customer_id = Customer.objects.get(realm=user.realm).id + new_plan = get_current_plan_by_realm(user.realm) + assert new_plan is not None + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {stripe_customer_id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.FREE_TRIAL}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) plan.refresh_from_db() self.assertEqual(plan.next_invoice_date, free_trial_end_date) self.assertEqual(get_realm("zulip").plan_type, Realm.PLAN_TYPE_STANDARD) @@ -3937,11 +3967,11 @@ class StripeTest(StripeTestCase): with ( self.assertRaises(BillingError) as context, self.assertLogs("corporate.stripe", "WARNING") as m, + time_machine.travel(self.now, tick=False), ): - with time_machine.travel(self.now, tick=False): - self.local_upgrade( - self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False - ) + self.local_upgrade( + self.seat_count, True, CustomerPlan.BILLING_SCHEDULE_ANNUAL, True, False + ) self.assertEqual( m.output[0], "WARNING:corporate.stripe:Upgrade of (with stripe_customer_id: cus_123) failed because of existing active plan.", @@ -4242,17 +4272,19 @@ class StripeTest(StripeTestCase): ) self.login_user(self.example_user("hamlet")) - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - result = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, - ) - self.assert_json_success(result) - self.assertRegex( - m.output[0], - r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 2", - ) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + result = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, + ) + self.assert_json_success(result) + self.assertRegex( + m.output[0], + r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 2", + ) with time_machine.travel(self.next_year, tick=False): result = self.client_billing_patch( @@ -4270,17 +4302,19 @@ class StripeTest(StripeTestCase): ) self.login_user(self.example_user("hamlet")) - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now, tick=False): - result = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, - ) - self.assert_json_success(result) - self.assertRegex( - m.output[0], - r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 4", - ) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now, tick=False), + ): + result = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.SWITCH_TO_ANNUAL_AT_END_OF_CYCLE}, + ) + self.assert_json_success(result) + self.assertRegex( + m.output[0], + r"INFO:corporate.stripe:Change plan status: Customer.id: \d*, CustomerPlan.id: \d*, status: 4", + ) with time_machine.travel(self.next_month, tick=False): result = self.client_billing_patch("/billing/plan", {}) @@ -5602,11 +5636,13 @@ class LicenseLedgerTest(StripeTestCase): self.assertEqual(plan.licenses(), self.seat_count + 3) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count + 3) - with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): - with self.assertRaises(AssertionError): - billing_session.update_license_ledger_for_manual_plan( - plan, self.now, licenses=self.seat_count - ) + with ( + patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count), + self.assertRaises(AssertionError), + ): + billing_session.update_license_ledger_for_manual_plan( + plan, self.now, licenses=self.seat_count + ) with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): billing_session.update_license_ledger_for_manual_plan( @@ -5615,11 +5651,13 @@ class LicenseLedgerTest(StripeTestCase): self.assertEqual(plan.licenses(), self.seat_count + 3) self.assertEqual(plan.licenses_at_next_renewal(), self.seat_count) - with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): - with self.assertRaises(AssertionError): - billing_session.update_license_ledger_for_manual_plan( - plan, self.now, licenses_at_next_renewal=self.seat_count - 1 - ) + with ( + patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count), + self.assertRaises(AssertionError), + ): + billing_session.update_license_ledger_for_manual_plan( + plan, self.now, licenses_at_next_renewal=self.seat_count - 1 + ) with patch("corporate.lib.stripe.get_latest_seat_count", return_value=self.seat_count): billing_session.update_license_ledger_for_manual_plan( @@ -6614,11 +6652,13 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase): # Same result even with free trial enabled for self hosted customers since we don't # offer free trial for business plan. - with self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30): - with time_machine.travel(self.now, tick=False): - result = self.client_get( - f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting" - ) + with ( + self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30), + time_machine.travel(self.now, tick=False), + ): + result = self.client_get( + f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting" + ) self.assert_in_success_response( [ @@ -6631,11 +6671,10 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase): ) # Check that cloud free trials don't affect self hosted customers. - with self.settings(CLOUD_FREE_TRIAL_DAYS=30): - with time_machine.travel(self.now, tick=False): - result = self.client_get( - f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting" - ) + with self.settings(CLOUD_FREE_TRIAL_DAYS=30), time_machine.travel(self.now, tick=False): + result = self.client_get( + f"{self.billing_session.billing_base_url}/upgrade/", subdomain="selfhosting" + ) self.assert_in_success_response( [ @@ -8018,15 +8057,17 @@ class TestRemoteRealmBillingFlow(StripeTestCase, RemoteRealmBillingTestCase): self.assertEqual(result["Location"], f"{billing_base_url}/billing/") # Downgrade - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now + timedelta(days=7), tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, - ) - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {business_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now + timedelta(days=7), tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, + ) + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {business_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) business_plan.refresh_from_db() self.assertEqual(business_plan.licenses_at_next_renewal(), None) @@ -8323,9 +8364,11 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase): # Same result even with free trial enabled for self hosted customers since we don't # offer free trial for business plan. - with self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30): - with time_machine.travel(self.now, tick=False): - result = self.client_get(f"{billing_base_url}/upgrade/", subdomain="selfhosting") + with ( + self.settings(SELF_HOSTING_FREE_TRIAL_DAYS=30), + time_machine.travel(self.now, tick=False), + ): + result = self.client_get(f"{billing_base_url}/upgrade/", subdomain="selfhosting") self.assert_in_success_response(["Add card", "Purchase Zulip Business"], result) @@ -8390,18 +8433,20 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase): self.assertEqual(result["Location"], f"{billing_base_url}/billing/") # Downgrade - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now + timedelta(days=7), tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, - ) - customer = Customer.objects.get(remote_server=self.remote_server) - new_plan = get_current_plan_by_customer(customer) - assert new_plan is not None - expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" - self.assertEqual(m.output[0], expected_log) - self.assert_json_success(response) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now + timedelta(days=7), tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}, + ) + customer = Customer.objects.get(remote_server=self.remote_server) + new_plan = get_current_plan_by_customer(customer) + assert new_plan is not None + expected_log = f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_plan.id}, status: {CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE}" + self.assertEqual(m.output[0], expected_log) + self.assert_json_success(response) self.assertEqual(new_plan.licenses_at_next_renewal(), None) @responses.activate @@ -8599,21 +8644,23 @@ class TestRemoteServerBillingFlow(StripeTestCase, RemoteServerTestCase): self.assertEqual(result["Location"], f"{billing_base_url}/billing/") # Downgrade - with self.assertLogs("corporate.stripe", "INFO") as m: - with time_machine.travel(self.now + timedelta(days=7), tick=False): - response = self.client_billing_patch( - "/billing/plan", - {"status": CustomerPlan.ACTIVE}, - ) - self.assert_json_success(response) - self.assertEqual( - m.output[0], - f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_customer_plan.id}, status: {CustomerPlan.ENDED}", - ) - self.assertEqual( - m.output[1], - f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {customer_plan.id}, status: {CustomerPlan.ACTIVE}", - ) + with ( + self.assertLogs("corporate.stripe", "INFO") as m, + time_machine.travel(self.now + timedelta(days=7), tick=False), + ): + response = self.client_billing_patch( + "/billing/plan", + {"status": CustomerPlan.ACTIVE}, + ) + self.assert_json_success(response) + self.assertEqual( + m.output[0], + f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {new_customer_plan.id}, status: {CustomerPlan.ENDED}", + ) + self.assertEqual( + m.output[1], + f"INFO:corporate.stripe:Change plan status: Customer.id: {customer.id}, CustomerPlan.id: {customer_plan.id}, status: {CustomerPlan.ACTIVE}", + ) @responses.activate @mock_stripe() diff --git a/pyproject.toml b/pyproject.toml index 0ff16923b0..479828fdd9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -186,7 +186,6 @@ ignore = [ "SIM103", # Return the condition directly "SIM108", # Use ternary operator `action = "[commented]" if action == "created" else f"{action} a [comment]"` instead of if-else-block "SIM114", # Combine `if` branches using logical `or` operator - "SIM117", # Use a single `with` statement with multiple contexts instead of nested `with` statements "SIM401", # Use `d.get(key, default)` instead of an `if` block "TCH001", # Move application import into a type-checking block "TCH002", # Move third-party import into a type-checking block diff --git a/tools/tail-ses b/tools/tail-ses index 6899dfbbea..645d1d630e 100755 --- a/tools/tail-ses +++ b/tools/tail-ses @@ -66,9 +66,11 @@ def main() -> None: args = parser.parse_args() sns_topic_arn = get_ses_arn(session, args) - with our_sqs_queue(session, sns_topic_arn) as (queue_arn, queue_url): - with our_sns_subscription(session, sns_topic_arn, queue_arn): - print_messages(session, queue_url) + with ( + our_sqs_queue(session, sns_topic_arn) as (queue_arn, queue_url), + our_sns_subscription(session, sns_topic_arn, queue_arn), + ): + print_messages(session, queue_url) def get_ses_arn(session: boto3.session.Session, args: argparse.Namespace) -> str: diff --git a/tools/test-backend b/tools/test-backend index b2ea702b47..1b3d521a2e 100755 --- a/tools/test-backend +++ b/tools/test-backend @@ -168,11 +168,13 @@ def get_failed_tests() -> list[str]: def block_internet() -> Iterator[None]: # Monkey-patching - responses library raises requests.ConnectionError when access to an unregistered URL # is attempted. We want to replace that with our own exception, so that it propagates all the way: - with mock.patch.object(responses, "ConnectionError", new=ZulipInternetBlockedError): + with ( + mock.patch.object(responses, "ConnectionError", new=ZulipInternetBlockedError), # We'll run all tests in this context manager. It'll cause an error to be raised (see above comment), # if any code attempts to access the internet. - with responses.RequestsMock(): - yield + responses.RequestsMock(), + ): + yield class ZulipInternetBlockedError(Exception): diff --git a/tools/test-locked-requirements b/tools/test-locked-requirements index 8377ab723c..64ccbf6662 100755 --- a/tools/test-locked-requirements +++ b/tools/test-locked-requirements @@ -19,14 +19,13 @@ CACHE_FILE = os.path.join(CACHE_DIR, "requirements_hashes") def print_diff(path_file1: str, path_file2: str) -> None: - with open(path_file1) as file1: - with open(path_file2) as file2: - diff = difflib.unified_diff( - file1.readlines(), - file2.readlines(), - fromfile=path_file1, - tofile=path_file2, - ) + with open(path_file1) as file1, open(path_file2) as file2: + diff = difflib.unified_diff( + file1.readlines(), + file2.readlines(), + fromfile=path_file1, + tofile=path_file2, + ) sys.stdout.writelines(diff) diff --git a/zerver/data_import/slack.py b/zerver/data_import/slack.py index d88eecbd19..da1e0dfd1e 100644 --- a/zerver/data_import/slack.py +++ b/zerver/data_import/slack.py @@ -1347,10 +1347,12 @@ def fetch_team_icons( ) resized_icon_output_path = os.path.join(output_dir, str(realm_id), "icon.png") - with open(resized_icon_output_path, "wb") as output_file: - with open(original_icon_output_path, "rb") as original_file: - resized_data = resize_logo(original_file.read()) - output_file.write(resized_data) + with ( + open(resized_icon_output_path, "wb") as output_file, + open(original_icon_output_path, "rb") as original_file, + ): + resized_data = resize_logo(original_file.read()) + output_file.write(resized_data) records.append( { "realm_id": realm_id, diff --git a/zerver/lib/context_managers.py b/zerver/lib/context_managers.py index cc159bc537..bbe99ae2ed 100644 --- a/zerver/lib/context_managers.py +++ b/zerver/lib/context_managers.py @@ -28,9 +28,8 @@ def lockfile(filename: str, shared: bool = False) -> Iterator[None]: If shared is True, use a LOCK_SH lock, otherwise LOCK_EX. The file is given by name and will be created if it does not exist.""" - with open(filename, "w") as lock: - with flock(lock, shared=shared): - yield + with open(filename, "w") as lock, flock(lock, shared=shared): + yield @contextmanager diff --git a/zerver/lib/send_email.py b/zerver/lib/send_email.py index 5ee054b1ab..e1f512aa58 100644 --- a/zerver/lib/send_email.py +++ b/zerver/lib/send_email.py @@ -548,15 +548,17 @@ def custom_email_sender( rendered_input = render_markdown_path(plain_text_template_path.replace("templates/", "")) # And then extend it with our standard email headers. - with open(html_template_path, "w") as f: - with open(markdown_email_base_template_path) as base_template: - # We use an ugly string substitution here, because we want to: - # 1. Only run Jinja once on the supplied content - # 2. Allow the supplied content to have jinja interpolation in it - # 3. Have that interpolation happen in the context of - # each individual email we send, so the contents can - # vary user-to-user - f.write(base_template.read().replace("{{ rendered_input }}", rendered_input)) + with ( + open(html_template_path, "w") as f, + open(markdown_email_base_template_path) as base_template, + ): + # We use an ugly string substitution here, because we want to: + # 1. Only run Jinja once on the supplied content + # 2. Allow the supplied content to have jinja interpolation in it + # 3. Have that interpolation happen in the context of + # each individual email we send, so the contents can + # vary user-to-user + f.write(base_template.read().replace("{{ rendered_input }}", rendered_input)) with open(subject_path, "w") as f: f.write(get_header(subject, parsed_email_template.get("subject"), "subject")) diff --git a/zerver/lib/test_classes.py b/zerver/lib/test_classes.py index 57d0f3c854..ca86b1fc7d 100644 --- a/zerver/lib/test_classes.py +++ b/zerver/lib/test_classes.py @@ -2018,14 +2018,16 @@ class ZulipTestCase(ZulipTestCaseMixin, TestCase): # Some code might call process_notification using keyword arguments, # so mypy doesn't allow assigning lst.append to process_notification # So explicitly change parameter name to 'notice' to work around this problem - with mock.patch("zerver.tornado.event_queue.process_notification", lst.append): + with ( + mock.patch("zerver.tornado.event_queue.process_notification", lst.append), # Some `send_event` calls need to be executed only after the current transaction # commits (using `on_commit` hooks). Because the transaction in Django tests never # commits (rather, gets rolled back after the test completes), such events would # never be sent in tests, and we would be unable to verify them. Hence, we use # this helper to make sure the `send_event` calls actually run. - with self.captureOnCommitCallbacks(execute=True): - yield lst + self.captureOnCommitCallbacks(execute=True), + ): + yield lst self.assert_length(lst, expected_num_events) diff --git a/zerver/lib/test_helpers.py b/zerver/lib/test_helpers.py index 1824f6ff55..6abf572d9e 100644 --- a/zerver/lib/test_helpers.py +++ b/zerver/lib/test_helpers.py @@ -71,9 +71,11 @@ class MockLDAP(fakeldap.MockLDAP): def stub_event_queue_user_events( event_queue_return: Any, user_events_return: Any ) -> Iterator[None]: - with mock.patch("zerver.lib.events.request_event_queue", return_value=event_queue_return): - with mock.patch("zerver.lib.events.get_user_events", return_value=user_events_return): - yield + with ( + mock.patch("zerver.lib.events.request_event_queue", return_value=event_queue_return), + mock.patch("zerver.lib.events.get_user_events", return_value=user_events_return), + ): + yield @contextmanager diff --git a/zerver/migrations/0553_copy_emoji_images.py b/zerver/migrations/0553_copy_emoji_images.py index ccfa296dc3..0d4762c1ee 100644 --- a/zerver/migrations/0553_copy_emoji_images.py +++ b/zerver/migrations/0553_copy_emoji_images.py @@ -186,9 +186,11 @@ def thumbnail_local_emoji(apps: StateApps) -> None: ) new_file_name = get_emoji_file_name("image/png", emoji.id) try: - with open(f"{settings.DEPLOY_ROOT}/static/images/bad-emoji.png", "rb") as f: - with open(f"{base_path}/{new_file_name}", "wb") as new_f: - new_f.write(f.read()) + with ( + open(f"{settings.DEPLOY_ROOT}/static/images/bad-emoji.png", "rb") as f, + open(f"{base_path}/{new_file_name}", "wb") as new_f, + ): + new_f.write(f.read()) emoji.deactivated = True emoji.is_animated = False emoji.file_name = new_file_name diff --git a/zerver/tests/test_auth_backends.py b/zerver/tests/test_auth_backends.py index 0818914bae..4449ee7af6 100644 --- a/zerver/tests/test_auth_backends.py +++ b/zerver/tests/test_auth_backends.py @@ -3415,14 +3415,16 @@ class AppleIdAuthBackendTest(AppleAuthMixin, SocialAuthBase): def test_id_token_verification_failure(self) -> None: account_data_dict = self.get_account_data_dict(email=self.email, name=self.name) - with self.assertLogs(self.logger_string, level="INFO") as m: - with mock.patch("jwt.decode", side_effect=PyJWTError): - result = self.social_auth_test( - account_data_dict, - expect_choose_email_screen=True, - subdomain="zulip", - is_signup=True, - ) + with ( + self.assertLogs(self.logger_string, level="INFO") as m, + mock.patch("jwt.decode", side_effect=PyJWTError), + ): + result = self.social_auth_test( + account_data_dict, + expect_choose_email_screen=True, + subdomain="zulip", + is_signup=True, + ) self.assertEqual(result.status_code, 302) self.assertEqual(result["Location"], "/login/") self.assertEqual( @@ -4583,9 +4585,11 @@ class GoogleAuthBackendTest(SocialAuthBase): "redirect_to": next, } user_profile = self.example_user("hamlet") - with mock.patch("zerver.views.auth.authenticate", return_value=user_profile): - with mock.patch("zerver.views.auth.do_login"): - result = self.get_log_into_subdomain(data) + with ( + mock.patch("zerver.views.auth.authenticate", return_value=user_profile), + mock.patch("zerver.views.auth.do_login"), + ): + result = self.get_log_into_subdomain(data) return result res = test_redirect_to_next_url() @@ -5666,49 +5670,55 @@ class TestZulipRemoteUserBackend(DesktopFlowTestingLib, ZulipTestCase): def test_login_failure_due_to_wrong_subdomain(self) -> None: email = self.example_email("hamlet") - with self.settings( - AUTHENTICATION_BACKENDS=( - "zproject.backends.ZulipRemoteUserBackend", - "zproject.backends.ZulipDummyBackend", - ) - ): - with mock.patch("zerver.views.auth.get_subdomain", return_value="acme"): - result = self.client_get( - "http://testserver:9080/accounts/login/sso/", REMOTE_USER=email - ) - self.assertEqual(result.status_code, 200) - self.assert_logged_in_user_id(None) - self.assert_in_response("You need an invitation to join this organization.", result) - - def test_login_failure_due_to_empty_subdomain(self) -> None: - email = self.example_email("hamlet") - with self.settings( - AUTHENTICATION_BACKENDS=( - "zproject.backends.ZulipRemoteUserBackend", - "zproject.backends.ZulipDummyBackend", - ) - ): - with mock.patch("zerver.views.auth.get_subdomain", return_value=""): - result = self.client_get( - "http://testserver:9080/accounts/login/sso/", REMOTE_USER=email - ) - self.assertEqual(result.status_code, 200) - self.assert_logged_in_user_id(None) - self.assert_in_response("You need an invitation to join this organization.", result) - - def test_login_success_under_subdomains(self) -> None: - user_profile = self.example_user("hamlet") - email = user_profile.delivery_email - with mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"): - with self.settings( + with ( + self.settings( AUTHENTICATION_BACKENDS=( "zproject.backends.ZulipRemoteUserBackend", "zproject.backends.ZulipDummyBackend", ) - ): - result = self.client_get("/accounts/login/sso/", REMOTE_USER=email) - self.assertEqual(result.status_code, 302) - self.assert_logged_in_user_id(user_profile.id) + ), + mock.patch("zerver.views.auth.get_subdomain", return_value="acme"), + ): + result = self.client_get( + "http://testserver:9080/accounts/login/sso/", REMOTE_USER=email + ) + self.assertEqual(result.status_code, 200) + self.assert_logged_in_user_id(None) + self.assert_in_response("You need an invitation to join this organization.", result) + + def test_login_failure_due_to_empty_subdomain(self) -> None: + email = self.example_email("hamlet") + with ( + self.settings( + AUTHENTICATION_BACKENDS=( + "zproject.backends.ZulipRemoteUserBackend", + "zproject.backends.ZulipDummyBackend", + ) + ), + mock.patch("zerver.views.auth.get_subdomain", return_value=""), + ): + result = self.client_get( + "http://testserver:9080/accounts/login/sso/", REMOTE_USER=email + ) + self.assertEqual(result.status_code, 200) + self.assert_logged_in_user_id(None) + self.assert_in_response("You need an invitation to join this organization.", result) + + def test_login_success_under_subdomains(self) -> None: + user_profile = self.example_user("hamlet") + email = user_profile.delivery_email + with ( + mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"), + self.settings( + AUTHENTICATION_BACKENDS=( + "zproject.backends.ZulipRemoteUserBackend", + "zproject.backends.ZulipDummyBackend", + ) + ), + ): + result = self.client_get("/accounts/login/sso/", REMOTE_USER=email) + self.assertEqual(result.status_code, 302) + self.assert_logged_in_user_id(user_profile.id) @override_settings(SEND_LOGIN_EMAILS=True) @override_settings( @@ -5974,30 +5984,34 @@ class TestJWTLogin(ZulipTestCase): def test_login_failure_due_to_wrong_subdomain(self) -> None: payload = {"email": "hamlet@zulip.com"} - with self.settings(JWT_AUTH_KEYS={"acme": {"key": "key", "algorithms": ["HS256"]}}): - with mock.patch("zerver.views.auth.get_realm_from_request", return_value=None): - key = settings.JWT_AUTH_KEYS["acme"]["key"] - [algorithm] = settings.JWT_AUTH_KEYS["acme"]["algorithms"] - web_token = jwt.encode(payload, key, algorithm) + with ( + self.settings(JWT_AUTH_KEYS={"acme": {"key": "key", "algorithms": ["HS256"]}}), + mock.patch("zerver.views.auth.get_realm_from_request", return_value=None), + ): + key = settings.JWT_AUTH_KEYS["acme"]["key"] + [algorithm] = settings.JWT_AUTH_KEYS["acme"]["algorithms"] + web_token = jwt.encode(payload, key, algorithm) - data = {"token": web_token} - result = self.client_post("/accounts/login/jwt/", data) - self.assert_json_error_contains(result, "Invalid subdomain", 404) - self.assert_logged_in_user_id(None) + data = {"token": web_token} + result = self.client_post("/accounts/login/jwt/", data) + self.assert_json_error_contains(result, "Invalid subdomain", 404) + self.assert_logged_in_user_id(None) def test_login_success_under_subdomains(self) -> None: payload = {"email": "hamlet@zulip.com"} - with self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key", "algorithms": ["HS256"]}}): - with mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"): - key = settings.JWT_AUTH_KEYS["zulip"]["key"] - [algorithm] = settings.JWT_AUTH_KEYS["zulip"]["algorithms"] - web_token = jwt.encode(payload, key, algorithm) + with ( + self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key", "algorithms": ["HS256"]}}), + mock.patch("zerver.views.auth.get_subdomain", return_value="zulip"), + ): + key = settings.JWT_AUTH_KEYS["zulip"]["key"] + [algorithm] = settings.JWT_AUTH_KEYS["zulip"]["algorithms"] + web_token = jwt.encode(payload, key, algorithm) - data = {"token": web_token} - result = self.client_post("/accounts/login/jwt/", data) - self.assertEqual(result.status_code, 302) - user_profile = self.example_user("hamlet") - self.assert_logged_in_user_id(user_profile.id) + data = {"token": web_token} + result = self.client_post("/accounts/login/jwt/", data) + self.assertEqual(result.status_code, 302) + user_profile = self.example_user("hamlet") + self.assert_logged_in_user_id(user_profile.id) class DjangoToLDAPUsernameTests(ZulipTestCase): @@ -6046,9 +6060,8 @@ class DjangoToLDAPUsernameTests(ZulipTestCase): self.backend.django_to_ldap_username("aaron@zulip.com"), self.ldap_username("aaron") ) - with self.assertLogs(level="WARNING") as m: - with self.assertRaises(NoMatchingLDAPUserError): - self.backend.django_to_ldap_username("shared_email@zulip.com") + with self.assertLogs(level="WARNING") as m, self.assertRaises(NoMatchingLDAPUserError): + self.backend.django_to_ldap_username("shared_email@zulip.com") self.assertEqual( m.output, [ @@ -6641,9 +6654,11 @@ class TestZulipLDAPUserPopulator(ZulipLDAPTestCase): @override_settings(LDAP_EMAIL_ATTR="mail") def test_populate_user_returns_none(self) -> None: - with mock.patch.object(ZulipLDAPUser, "populate_user", return_value=None): - with self.assertRaises(PopulateUserLDAPError): - sync_user_from_ldap(self.example_user("hamlet"), mock.Mock()) + with ( + mock.patch.object(ZulipLDAPUser, "populate_user", return_value=None), + self.assertRaises(PopulateUserLDAPError), + ): + sync_user_from_ldap(self.example_user("hamlet"), mock.Mock()) def test_update_full_name(self) -> None: self.change_ldap_user_attr("hamlet", "cn", "New Name") @@ -6823,17 +6838,19 @@ class TestZulipLDAPUserPopulator(ZulipLDAPTestCase): self.change_ldap_user_attr("hamlet", "cn", "Second Hamlet") expected_call_args = [hamlet2, "Second Hamlet", None] - with self.settings(AUTH_LDAP_USER_ATTR_MAP={"full_name": "cn"}): - with mock.patch("zerver.actions.user_settings.do_change_full_name") as f: - self.perform_ldap_sync(hamlet2) - f.assert_called_once_with(*expected_call_args) + with ( + self.settings(AUTH_LDAP_USER_ATTR_MAP={"full_name": "cn"}), + mock.patch("zerver.actions.user_settings.do_change_full_name") as f, + ): + self.perform_ldap_sync(hamlet2) + f.assert_called_once_with(*expected_call_args) - # Get the updated model and make sure the full name is changed correctly: - hamlet2 = get_user_by_delivery_email(email, test_realm) - self.assertEqual(hamlet2.full_name, "Second Hamlet") - # Now get the original hamlet and make he still has his name unchanged: - hamlet = self.example_user("hamlet") - self.assertEqual(hamlet.full_name, "King Hamlet") + # Get the updated model and make sure the full name is changed correctly: + hamlet2 = get_user_by_delivery_email(email, test_realm) + self.assertEqual(hamlet2.full_name, "Second Hamlet") + # Now get the original hamlet and make he still has his name unchanged: + hamlet = self.example_user("hamlet") + self.assertEqual(hamlet.full_name, "King Hamlet") def test_user_not_found_in_ldap(self) -> None: with self.settings( @@ -7038,16 +7055,18 @@ class TestZulipLDAPUserPopulator(ZulipLDAPTestCase): }, ], ] - with self.settings( - AUTH_LDAP_USER_ATTR_MAP={ - "full_name": "cn", - "custom_profile_field__birthday": "birthDate", - "custom_profile_field__phone_number": "homePhone", - } + with ( + self.settings( + AUTH_LDAP_USER_ATTR_MAP={ + "full_name": "cn", + "custom_profile_field__birthday": "birthDate", + "custom_profile_field__phone_number": "homePhone", + } + ), + mock.patch("zproject.backends.do_update_user_custom_profile_data_if_changed") as f, ): - with mock.patch("zproject.backends.do_update_user_custom_profile_data_if_changed") as f: - self.perform_ldap_sync(self.example_user("hamlet")) - f.assert_called_once_with(*expected_call_args) + self.perform_ldap_sync(self.example_user("hamlet")) + f.assert_called_once_with(*expected_call_args) def test_update_custom_profile_field_not_present_in_ldap(self) -> None: hamlet = self.example_user("hamlet") @@ -7489,14 +7508,16 @@ class JWTFetchAPIKeyTest(ZulipTestCase): self.assert_json_error_contains(result, "Invalid subdomain", 404) def test_jwt_key_not_found_failure(self) -> None: - with self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key1", "algorithms": ["HS256"]}}): - with mock.patch( + with ( + self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key1", "algorithms": ["HS256"]}}), + mock.patch( "zerver.views.auth.get_realm_from_request", return_value=get_realm("zephyr") - ): - result = self.client_post("/api/v1/jwt/fetch_api_key") - self.assert_json_error_contains( - result, "JWT authentication is not enabled for this organization", 400 - ) + ), + ): + result = self.client_post("/api/v1/jwt/fetch_api_key") + self.assert_json_error_contains( + result, "JWT authentication is not enabled for this organization", 400 + ) def test_missing_jwt_payload_failure(self) -> None: with self.settings(JWT_AUTH_KEYS={"zulip": {"key": "key1", "algorithms": ["HS256"]}}): @@ -7709,12 +7730,12 @@ class LDAPGroupSyncTest(ZulipTestCase): ), self.assertLogs("django_auth_ldap", "WARN") as django_ldap_log, self.assertLogs("zulip.ldap", "DEBUG") as zulip_ldap_log, - ): - with self.assertRaisesRegex( + self.assertRaisesRegex( ZulipLDAPError, "search_s.*", - ): - sync_user_from_ldap(cordelia, mock.Mock()) + ), + ): + sync_user_from_ldap(cordelia, mock.Mock()) self.assertEqual( zulip_ldap_log.output, diff --git a/zerver/tests/test_decorators.py b/zerver/tests/test_decorators.py index 5201014ac7..73f6204d6d 100644 --- a/zerver/tests/test_decorators.py +++ b/zerver/tests/test_decorators.py @@ -165,11 +165,11 @@ class DecoratorTestCase(ZulipTestCase): # Start a valid request here request = HostRequestMock() request.POST["api_key"] = webhook_bot_api_key - with self.assertLogs(level="WARNING") as m: - with self.assertRaisesRegex( - JsonableError, "Account is not associated with this subdomain" - ): - api_result = my_webhook(request) + with ( + self.assertLogs(level="WARNING") as m, + self.assertRaisesRegex(JsonableError, "Account is not associated with this subdomain"), + ): + api_result = my_webhook(request) self.assertEqual( m.output, [ @@ -181,12 +181,12 @@ class DecoratorTestCase(ZulipTestCase): request = HostRequestMock() request.POST["api_key"] = webhook_bot_api_key - with self.assertLogs(level="WARNING") as m: - with self.assertRaisesRegex( - JsonableError, "Account is not associated with this subdomain" - ): - request.host = "acme." + settings.EXTERNAL_HOST - api_result = my_webhook(request) + with ( + self.assertLogs(level="WARNING") as m, + self.assertRaisesRegex(JsonableError, "Account is not associated with this subdomain"), + ): + request.host = "acme." + settings.EXTERNAL_HOST + api_result = my_webhook(request) self.assertEqual( m.output, [ @@ -203,11 +203,13 @@ class DecoratorTestCase(ZulipTestCase): request = HostRequestMock() request.host = "zulip.testserver" request.POST["api_key"] = webhook_bot_api_key - with self.assertLogs("zulip.zerver.webhooks", level="INFO") as log: - with self.assertRaisesRegex(Exception, "raised by webhook function"): - request._body = b"{}" - request.content_type = "application/json" - my_webhook_raises_exception(request) + with ( + self.assertLogs("zulip.zerver.webhooks", level="INFO") as log, + self.assertRaisesRegex(Exception, "raised by webhook function"), + ): + request._body = b"{}" + request.content_type = "application/json" + my_webhook_raises_exception(request) # Test when content_type is not application/json; exception raised # in the webhook function should be re-raised @@ -215,11 +217,13 @@ class DecoratorTestCase(ZulipTestCase): request = HostRequestMock() request.host = "zulip.testserver" request.POST["api_key"] = webhook_bot_api_key - with self.assertLogs("zulip.zerver.webhooks", level="INFO") as log: - with self.assertRaisesRegex(Exception, "raised by webhook function"): - request._body = b"notjson" - request.content_type = "text/plain" - my_webhook_raises_exception(request) + with ( + self.assertLogs("zulip.zerver.webhooks", level="INFO") as log, + self.assertRaisesRegex(Exception, "raised by webhook function"), + ): + request._body = b"notjson" + request.content_type = "text/plain" + my_webhook_raises_exception(request) # Test when content_type is application/json but request.body # is not valid JSON; invalid JSON should be logged and the @@ -227,12 +231,14 @@ class DecoratorTestCase(ZulipTestCase): request = HostRequestMock() request.host = "zulip.testserver" request.POST["api_key"] = webhook_bot_api_key - with self.assertLogs("zulip.zerver.webhooks", level="ERROR") as log: - with self.assertRaisesRegex(Exception, "raised by webhook function"): - request._body = b"invalidjson" - request.content_type = "application/json" - request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value" - my_webhook_raises_exception(request) + with ( + self.assertLogs("zulip.zerver.webhooks", level="ERROR") as log, + self.assertRaisesRegex(Exception, "raised by webhook function"), + ): + request._body = b"invalidjson" + request.content_type = "application/json" + request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value" + my_webhook_raises_exception(request) self.assertIn( self.logger_output("raised by webhook function\n", "error", "webhooks"), log.output[0] @@ -245,12 +251,14 @@ class DecoratorTestCase(ZulipTestCase): exception_msg = ( "The 'test_event' event isn't currently supported by the ClientName webhook; ignoring" ) - with self.assertLogs("zulip.zerver.webhooks.unsupported", level="ERROR") as log: - with self.assertRaisesRegex(UnsupportedWebhookEventTypeError, exception_msg): - request._body = b"invalidjson" - request.content_type = "application/json" - request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value" - my_webhook_raises_exception_unsupported_event(request) + with ( + self.assertLogs("zulip.zerver.webhooks.unsupported", level="ERROR") as log, + self.assertRaisesRegex(UnsupportedWebhookEventTypeError, exception_msg), + ): + request._body = b"invalidjson" + request.content_type = "application/json" + request.META["HTTP_X_CUSTOM_HEADER"] = "custom_value" + my_webhook_raises_exception_unsupported_event(request) self.assertIn( self.logger_output(exception_msg, "error", "webhooks.unsupported"), log.output[0] @@ -259,9 +267,11 @@ class DecoratorTestCase(ZulipTestCase): request = HostRequestMock() request.host = "zulip.testserver" request.POST["api_key"] = webhook_bot_api_key - with self.settings(RATE_LIMITING=True): - with mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock: - api_result = orjson.loads(my_webhook(request).content).get("msg") + with ( + self.settings(RATE_LIMITING=True), + mock.patch("zerver.decorator.rate_limit_user") as rate_limit_mock, + ): + api_result = orjson.loads(my_webhook(request).content).get("msg") # Verify rate limiting was attempted. self.assertTrue(rate_limit_mock.called) @@ -389,9 +399,11 @@ class DecoratorLoggingTestCase(ZulipTestCase): request._body = b"{}" request.content_type = "text/plain" - with self.assertLogs("zulip.zerver.webhooks") as logger: - with self.assertRaisesRegex(Exception, "raised by webhook function"): - my_webhook_raises_exception(request) + with ( + self.assertLogs("zulip.zerver.webhooks") as logger, + self.assertRaisesRegex(Exception, "raised by webhook function"), + ): + my_webhook_raises_exception(request) self.assertIn("raised by webhook function", logger.output[0]) @@ -440,9 +452,11 @@ class DecoratorLoggingTestCase(ZulipTestCase): request._body = b"{}" request.content_type = "application/json" - with mock.patch("zerver.decorator.webhook_logger.exception") as mock_exception: - with self.assertRaisesRegex(Exception, "raised by a non-webhook view"): - non_webhook_view_raises_exception(request) + with ( + mock.patch("zerver.decorator.webhook_logger.exception") as mock_exception, + self.assertRaisesRegex(Exception, "raised by a non-webhook view"), + ): + non_webhook_view_raises_exception(request) self.assertFalse(mock_exception.called) @@ -964,15 +978,17 @@ class TestValidateApiKey(ZulipTestCase): def test_valid_api_key_if_user_is_on_wrong_subdomain(self) -> None: with self.settings(RUNNING_INSIDE_TORNADO=False): api_key = get_api_key(self.default_bot) - with self.assertLogs(level="WARNING") as m: - with self.assertRaisesRegex( + with ( + self.assertLogs(level="WARNING") as m, + self.assertRaisesRegex( JsonableError, "Account is not associated with this subdomain" - ): - validate_api_key( - HostRequestMock(host=settings.EXTERNAL_HOST), - self.default_bot.email, - api_key, - ) + ), + ): + validate_api_key( + HostRequestMock(host=settings.EXTERNAL_HOST), + self.default_bot.email, + api_key, + ) self.assertEqual( m.output, [ @@ -982,15 +998,17 @@ class TestValidateApiKey(ZulipTestCase): ], ) - with self.assertLogs(level="WARNING") as m: - with self.assertRaisesRegex( + with ( + self.assertLogs(level="WARNING") as m, + self.assertRaisesRegex( JsonableError, "Account is not associated with this subdomain" - ): - validate_api_key( - HostRequestMock(host="acme." + settings.EXTERNAL_HOST), - self.default_bot.email, - api_key, - ) + ), + ): + validate_api_key( + HostRequestMock(host="acme." + settings.EXTERNAL_HOST), + self.default_bot.email, + api_key, + ) self.assertEqual( m.output, [ diff --git a/zerver/tests/test_digest.py b/zerver/tests/test_digest.py index 0532f1b621..e74f15cf93 100644 --- a/zerver/tests/test_digest.py +++ b/zerver/tests/test_digest.py @@ -241,9 +241,8 @@ class TestDigestEmailMessages(ZulipTestCase): digest_user_ids = [user.id for user in digest_users] get_recent_topics.cache_clear() - with self.assert_database_query_count(16): - with self.assert_memcached_count(0): - bulk_handle_digest_email(digest_user_ids, cutoff) + with self.assert_database_query_count(16), self.assert_memcached_count(0): + bulk_handle_digest_email(digest_user_ids, cutoff) self.assert_length(digest_users, mock_send_future_email.call_count) @@ -441,9 +440,11 @@ class TestDigestEmailMessages(ZulipTestCase): tuesday = self.tuesday() cutoff = tuesday - timedelta(days=5) - with time_machine.travel(tuesday, tick=False): - with mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock: - enqueue_emails(cutoff) + with ( + time_machine.travel(tuesday, tick=False), + mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock, + ): + enqueue_emails(cutoff) queue_mock.assert_not_called() @override_settings(SEND_DIGEST_EMAILS=True) @@ -453,9 +454,11 @@ class TestDigestEmailMessages(ZulipTestCase): not_tuesday = datetime(year=2016, month=1, day=6, tzinfo=timezone.utc) cutoff = not_tuesday - timedelta(days=5) - with time_machine.travel(not_tuesday, tick=False): - with mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock: - enqueue_emails(cutoff) + with ( + time_machine.travel(not_tuesday, tick=False), + mock.patch("zerver.lib.digest.queue_digest_user_ids") as queue_mock, + ): + enqueue_emails(cutoff) queue_mock.assert_not_called() @override_settings(SEND_DIGEST_EMAILS=True) diff --git a/zerver/tests/test_embedded_bot_system.py b/zerver/tests/test_embedded_bot_system.py index 2277ff2ce3..0d02595aa5 100644 --- a/zerver/tests/test_embedded_bot_system.py +++ b/zerver/tests/test_embedded_bot_system.py @@ -72,18 +72,20 @@ class TestEmbeddedBotMessaging(ZulipTestCase): def test_embedded_bot_quit_exception(self) -> None: assert self.bot_profile is not None - with patch( - "zulip_bots.bots.helloworld.helloworld.HelloWorldHandler.handle_message", - side_effect=EmbeddedBotQuitError("I'm quitting!"), + with ( + patch( + "zulip_bots.bots.helloworld.helloworld.HelloWorldHandler.handle_message", + side_effect=EmbeddedBotQuitError("I'm quitting!"), + ), + self.assertLogs(level="WARNING") as m, ): - with self.assertLogs(level="WARNING") as m: - self.send_stream_message( - self.user_profile, - "Denmark", - content=f"@**{self.bot_profile.full_name}** foo", - topic_name="bar", - ) - self.assertEqual(m.output, ["WARNING:root:I'm quitting!"]) + self.send_stream_message( + self.user_profile, + "Denmark", + content=f"@**{self.bot_profile.full_name}** foo", + topic_name="bar", + ) + self.assertEqual(m.output, ["WARNING:root:I'm quitting!"]) class TestEmbeddedBotFailures(ZulipTestCase): diff --git a/zerver/tests/test_event_system.py b/zerver/tests/test_event_system.py index 5039bf0474..2d770963a1 100644 --- a/zerver/tests/test_event_system.py +++ b/zerver/tests/test_event_system.py @@ -86,12 +86,14 @@ class EventsEndpointTest(ZulipTestCase): test_event = dict(id=6, type=event_type, realm_emoji=empty_realm_emoji_dict) # Test that call is made to deal with a returning soft deactivated user. - with mock.patch("zerver.lib.events.reactivate_user_if_soft_deactivated") as fa: - with stub_event_queue_user_events(return_event_queue, return_user_events): - result = self.api_post( - user, "/api/v1/register", dict(event_types=orjson.dumps([event_type]).decode()) - ) - self.assertEqual(fa.call_count, 1) + with ( + mock.patch("zerver.lib.events.reactivate_user_if_soft_deactivated") as fa, + stub_event_queue_user_events(return_event_queue, return_user_events), + ): + result = self.api_post( + user, "/api/v1/register", dict(event_types=orjson.dumps([event_type]).decode()) + ) + self.assertEqual(fa.call_count, 1) with stub_event_queue_user_events(return_event_queue, return_user_events): result = self.api_post( @@ -1171,9 +1173,11 @@ class FetchQueriesTest(ZulipTestCase): # count in production. realm = get_realm_with_settings(realm_id=user.realm_id) - with self.assert_database_query_count(43): - with mock.patch("zerver.lib.events.always_want") as want_mock: - fetch_initial_state_data(user, realm=realm) + with ( + self.assert_database_query_count(43), + mock.patch("zerver.lib.events.always_want") as want_mock, + ): + fetch_initial_state_data(user, realm=realm) expected_counts = dict( alert_words=1, diff --git a/zerver/tests/test_events.py b/zerver/tests/test_events.py index 863b381b15..5e7306890e 100644 --- a/zerver/tests/test_events.py +++ b/zerver/tests/test_events.py @@ -1742,17 +1742,19 @@ class NormalActionsTest(BaseAction): cordelia.save() away_val = False - with self.settings(CAN_ACCESS_ALL_USERS_GROUP_LIMITS_PRESENCE=True): - with self.verify_action(num_events=0, state_change_expected=False) as events: - do_update_user_status( - user_profile=cordelia, - away=away_val, - status_text="out to lunch", - emoji_name="car", - emoji_code="1f697", - reaction_type=UserStatus.UNICODE_EMOJI, - client_id=client.id, - ) + with ( + self.settings(CAN_ACCESS_ALL_USERS_GROUP_LIMITS_PRESENCE=True), + self.verify_action(num_events=0, state_change_expected=False) as events, + ): + do_update_user_status( + user_profile=cordelia, + away=away_val, + status_text="out to lunch", + emoji_name="car", + emoji_code="1f697", + reaction_type=UserStatus.UNICODE_EMOJI, + client_id=client.id, + ) away_val = True with self.verify_action(num_events=1, state_change_expected=True) as events: @@ -2128,13 +2130,12 @@ class NormalActionsTest(BaseAction): {"Google": False, "Email": False, "GitHub": True, "LDAP": False, "Dev": True}, {"Google": False, "Email": True, "GitHub": True, "LDAP": True, "Dev": False}, ): - with fake_backends(): - with self.verify_action() as events: - do_set_realm_authentication_methods( - self.user_profile.realm, - auth_method_dict, - acting_user=None, - ) + with fake_backends(), self.verify_action() as events: + do_set_realm_authentication_methods( + self.user_profile.realm, + auth_method_dict, + acting_user=None, + ) check_realm_update_dict("events[0]", events[0]) @@ -2664,11 +2665,10 @@ class NormalActionsTest(BaseAction): def test_realm_emoji_events(self) -> None: author = self.example_user("iago") - with get_test_image_file("img.png") as img_file: - with self.verify_action() as events: - check_add_realm_emoji( - self.user_profile.realm, "my_emoji", author, img_file, "image/png" - ) + with get_test_image_file("img.png") as img_file, self.verify_action() as events: + check_add_realm_emoji( + self.user_profile.realm, "my_emoji", author, img_file, "image/png" + ) check_realm_emoji_update("events[0]", events[0]) @@ -3278,9 +3278,12 @@ class NormalActionsTest(BaseAction): "zerver.lib.export.do_export_realm", return_value=create_dummy_file("test-export.tar.gz"), ): - with stdout_suppressed(), self.assertLogs(level="INFO") as info_logs: - with self.verify_action(state_change_expected=True, num_events=3) as events: - self.client_post("/json/export/realm") + with ( + stdout_suppressed(), + self.assertLogs(level="INFO") as info_logs, + self.verify_action(state_change_expected=True, num_events=3) as events, + ): + self.client_post("/json/export/realm") self.assertTrue("INFO:root:Completed data export for zulip in" in info_logs.output[0]) # We get two realm_export events for this action, where the first @@ -3328,9 +3331,11 @@ class NormalActionsTest(BaseAction): mock.patch("zerver.lib.export.do_export_realm", side_effect=Exception("Some failure")), self.assertLogs(level="ERROR") as error_log, ): - with stdout_suppressed(): - with self.verify_action(state_change_expected=False, num_events=2) as events: - self.client_post("/json/export/realm") + with ( + stdout_suppressed(), + self.verify_action(state_change_expected=False, num_events=2) as events, + ): + self.client_post("/json/export/realm") # Log is of following format: "ERROR:root:Data export for zulip failed after 0.004499673843383789" # Where last floating number is time and will vary in each test hence the following assertion is diff --git a/zerver/tests/test_external.py b/zerver/tests/test_external.py index cba84caf10..d0a94c8ed2 100644 --- a/zerver/tests/test_external.py +++ b/zerver/tests/test_external.py @@ -298,12 +298,14 @@ class RateLimitTests(ZulipTestCase): # We need to reset the circuitbreaker before starting. We # patch the .opened property to be false, then call the # function, so it resets to closed. - with mock.patch("builtins.open", mock.mock_open(read_data=orjson.dumps(["1.2.3.4"]))): - with mock.patch( + with ( + mock.patch("builtins.open", mock.mock_open(read_data=orjson.dumps(["1.2.3.4"]))), + mock.patch( "circuitbreaker.CircuitBreaker.opened", new_callable=mock.PropertyMock - ) as mock_opened: - mock_opened.return_value = False - get_tor_ips() + ) as mock_opened, + ): + mock_opened.return_value = False + get_tor_ips() # Having closed it, it's now cached. Clear the cache. assert CircuitBreakerMonitor.get("get_tor_ips").closed @@ -354,13 +356,15 @@ class RateLimitTests(ZulipTestCase): # An empty list of IPs is treated as some error in parsing the # input, and as such should not be cached; rate-limiting # should work as normal, per-IP - with self.tor_mock(read_data=[]) as tor_open: - with self.assertLogs("zerver.lib.rate_limiter", level="WARNING"): - self.do_test_hit_ratelimits( - lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4") - ) - resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8") - self.assertNotEqual(resp.status_code, 429) + with ( + self.tor_mock(read_data=[]) as tor_open, + self.assertLogs("zerver.lib.rate_limiter", level="WARNING"), + ): + self.do_test_hit_ratelimits( + lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4") + ) + resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8") + self.assertNotEqual(resp.status_code, 429) # Was not cached, so tried to read twice before hitting the # circuit-breaker, and stopping trying @@ -372,15 +376,17 @@ class RateLimitTests(ZulipTestCase): for ip in ["1.2.3.4", "5.6.7.8", "tor-exit-node"]: RateLimitedIPAddr(ip, domain="api_by_ip").clear_history() - with self.tor_mock(side_effect=FileNotFoundError("File not found")) as tor_open: + with ( + self.tor_mock(side_effect=FileNotFoundError("File not found")) as tor_open, # If we cannot get a list of TOR exit nodes, then # rate-limiting works as normal, per-IP - with self.assertLogs("zerver.lib.rate_limiter", level="WARNING") as log_mock: - self.do_test_hit_ratelimits( - lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4") - ) - resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8") - self.assertNotEqual(resp.status_code, 429) + self.assertLogs("zerver.lib.rate_limiter", level="WARNING") as log_mock, + ): + self.do_test_hit_ratelimits( + lambda: self.send_unauthed_api_request(REMOTE_ADDR="1.2.3.4") + ) + resp = self.send_unauthed_api_request(REMOTE_ADDR="5.6.7.8") + self.assertNotEqual(resp.status_code, 429) # Tries twice before hitting the circuit-breaker, and stopping trying tor_open.assert_has_calls( diff --git a/zerver/tests/test_home.py b/zerver/tests/test_home.py index 49de27421e..d2bbdd6a0b 100644 --- a/zerver/tests/test_home.py +++ b/zerver/tests/test_home.py @@ -261,10 +261,12 @@ class HomeTest(ZulipTestCase): self.client_post("/json/bots", bot_info) # Verify succeeds once logged-in - with self.assert_database_query_count(54): - with patch("zerver.lib.cache.cache_set") as cache_mock: - result = self._get_home_page(stream="Denmark") - self.check_rendered_logged_in_app(result) + with ( + self.assert_database_query_count(54), + patch("zerver.lib.cache.cache_set") as cache_mock, + ): + result = self._get_home_page(stream="Denmark") + self.check_rendered_logged_in_app(result) self.assertEqual( set(result["Cache-Control"].split(", ")), {"must-revalidate", "no-store", "no-cache"} ) @@ -312,10 +314,9 @@ class HomeTest(ZulipTestCase): self.login("hamlet") # Verify succeeds once logged-in - with queries_captured(): - with patch("zerver.lib.cache.cache_set"): - result = self._get_home_page(stream="Denmark") - self.check_rendered_logged_in_app(result) + with queries_captured(), patch("zerver.lib.cache.cache_set"): + result = self._get_home_page(stream="Denmark") + self.check_rendered_logged_in_app(result) page_params = self._get_page_params(result) self.assertCountEqual(page_params, self.expected_page_params_keys) @@ -565,11 +566,13 @@ class HomeTest(ZulipTestCase): def test_num_queries_for_realm_admin(self) -> None: # Verify number of queries for Realm admin isn't much higher than for normal users. self.login("iago") - with self.assert_database_query_count(54): - with patch("zerver.lib.cache.cache_set") as cache_mock: - result = self._get_home_page() - self.check_rendered_logged_in_app(result) - self.assert_length(cache_mock.call_args_list, 7) + with ( + self.assert_database_query_count(54), + patch("zerver.lib.cache.cache_set") as cache_mock, + ): + result = self._get_home_page() + self.check_rendered_logged_in_app(result) + self.assert_length(cache_mock.call_args_list, 7) def test_num_queries_with_streams(self) -> None: main_user = self.example_user("hamlet") diff --git a/zerver/tests/test_invite.py b/zerver/tests/test_invite.py index 61559ed40e..5b2417ed97 100644 --- a/zerver/tests/test_invite.py +++ b/zerver/tests/test_invite.py @@ -2547,9 +2547,11 @@ class MultiuseInviteTest(ZulipTestCase): email = self.nonreg_email("newuser") invite_link = "/join/invalid_key/" - with patch("zerver.views.registration.get_realm_from_request", return_value=self.realm): - with patch("zerver.views.registration.get_realm", return_value=self.realm): - self.check_user_able_to_register(email, invite_link) + with ( + patch("zerver.views.registration.get_realm_from_request", return_value=self.realm), + patch("zerver.views.registration.get_realm", return_value=self.realm), + ): + self.check_user_able_to_register(email, invite_link) def test_multiuse_link_with_specified_streams(self) -> None: name1 = "newuser" diff --git a/zerver/tests/test_link_embed.py b/zerver/tests/test_link_embed.py index 7cce457086..2a472882ac 100644 --- a/zerver/tests/test_link_embed.py +++ b/zerver/tests/test_link_embed.py @@ -438,11 +438,10 @@ class PreviewTestCase(ZulipTestCase): self.create_mock_response(original_url) self.create_mock_response(edited_url) - with self.settings(TEST_SUITE=False): - with self.assertLogs(level="INFO") as info_logs: - # Run the queue processor. This will simulate the event for original_url being - # processed after the message has been edited. - FetchLinksEmbedData().consume(event) + with self.settings(TEST_SUITE=False), self.assertLogs(level="INFO") as info_logs: + # Run the queue processor. This will simulate the event for original_url being + # processed after the message has been edited. + FetchLinksEmbedData().consume(event) self.assertTrue( "INFO:root:Time spent on get_link_embed_data for http://test.org/: " in info_logs.output[0] @@ -457,17 +456,16 @@ class PreviewTestCase(ZulipTestCase): self.assertTrue(responses.assert_call_count(edited_url, 0)) - with self.settings(TEST_SUITE=False): - with self.assertLogs(level="INFO") as info_logs: - # Now proceed with the original queue_json_publish and call the - # up-to-date event for edited_url. - queue_json_publish(*args, **kwargs) - msg = Message.objects.select_related("sender").get(id=msg_id) - assert msg.rendered_content is not None - self.assertIn( - f'The Rock', - msg.rendered_content, - ) + with self.settings(TEST_SUITE=False), self.assertLogs(level="INFO") as info_logs: + # Now proceed with the original queue_json_publish and call the + # up-to-date event for edited_url. + queue_json_publish(*args, **kwargs) + msg = Message.objects.select_related("sender").get(id=msg_id) + assert msg.rendered_content is not None + self.assertIn( + f'The Rock', + msg.rendered_content, + ) self.assertTrue( "INFO:root:Time spent on get_link_embed_data for http://edited.org/: " in info_logs.output[0] @@ -503,11 +501,10 @@ class PreviewTestCase(ZulipTestCase): # We do still fetch the URL, as we don't want to incur the # cost of locking the row while we do the HTTP fetches. self.create_mock_response(url) - with self.settings(TEST_SUITE=False): - with self.assertLogs(level="INFO") as info_logs: - # Run the queue processor. This will simulate the event for original_url being - # processed after the message has been deleted. - FetchLinksEmbedData().consume(event) + with self.settings(TEST_SUITE=False), self.assertLogs(level="INFO") as info_logs: + # Run the queue processor. This will simulate the event for original_url being + # processed after the message has been deleted. + FetchLinksEmbedData().consume(event) self.assertTrue( "INFO:root:Time spent on get_link_embed_data for http://test.org/: " in info_logs.output[0] @@ -852,24 +849,26 @@ class PreviewTestCase(ZulipTestCase): self.create_mock_response(url, body=ConnectionError()) - with mock.patch( - "zerver.lib.url_preview.preview.get_oembed_data", - side_effect=lambda *args, **kwargs: None, - ): - with mock.patch( + with ( + mock.patch( + "zerver.lib.url_preview.preview.get_oembed_data", + side_effect=lambda *args, **kwargs: None, + ), + mock.patch( "zerver.lib.url_preview.preview.valid_content_type", side_effect=lambda k: True - ): - with self.settings(TEST_SUITE=False): - with self.assertLogs(level="INFO") as info_logs: - FetchLinksEmbedData().consume(event) - self.assertTrue( - "INFO:root:Time spent on get_link_embed_data for http://test.org/: " - in info_logs.output[0] - ) + ), + self.settings(TEST_SUITE=False), + ): + with self.assertLogs(level="INFO") as info_logs: + FetchLinksEmbedData().consume(event) + self.assertTrue( + "INFO:root:Time spent on get_link_embed_data for http://test.org/: " + in info_logs.output[0] + ) - # This did not get cached -- hence the lack of [0] on the cache_get - cached_data = cache_get(preview_url_cache_key(url)) - self.assertIsNone(cached_data) + # This did not get cached -- hence the lack of [0] on the cache_get + cached_data = cache_get(preview_url_cache_key(url)) + self.assertIsNone(cached_data) msg.refresh_from_db() self.assertEqual( @@ -939,13 +938,15 @@ class PreviewTestCase(ZulipTestCase): ) self.create_mock_response(url) with self.settings(TEST_SUITE=False): - with self.assertLogs(level="INFO") as info_logs: - with mock.patch( + with ( + self.assertLogs(level="INFO") as info_logs, + mock.patch( "zerver.lib.url_preview.preview.get_oembed_data", lambda *args, **kwargs: mocked_data, - ): - FetchLinksEmbedData().consume(event) - cached_data = cache_get(preview_url_cache_key(url))[0] + ), + ): + FetchLinksEmbedData().consume(event) + cached_data = cache_get(preview_url_cache_key(url))[0] self.assertTrue( "INFO:root:Time spent on get_link_embed_data for http://test.org/: " in info_logs.output[0] @@ -979,12 +980,14 @@ class PreviewTestCase(ZulipTestCase): ) self.create_mock_response(url) with self.settings(TEST_SUITE=False): - with self.assertLogs(level="INFO") as info_logs: - with mock.patch( + with ( + self.assertLogs(level="INFO") as info_logs, + mock.patch( "zerver.worker.embed_links.url_preview.get_link_embed_data", lambda *args, **kwargs: mocked_data, - ): - FetchLinksEmbedData().consume(event) + ), + ): + FetchLinksEmbedData().consume(event) self.assertTrue( "INFO:root:Time spent on get_link_embed_data for https://www.youtube.com/watch?v=eSJTXC7Ixgg:" in info_logs.output[0] @@ -1017,12 +1020,14 @@ class PreviewTestCase(ZulipTestCase): ) self.create_mock_response(url) with self.settings(TEST_SUITE=False): - with self.assertLogs(level="INFO") as info_logs: - with mock.patch( + with ( + self.assertLogs(level="INFO") as info_logs, + mock.patch( "zerver.worker.embed_links.url_preview.get_link_embed_data", lambda *args, **kwargs: mocked_data, - ): - FetchLinksEmbedData().consume(event) + ), + ): + FetchLinksEmbedData().consume(event) self.assertTrue( "INFO:root:Time spent on get_link_embed_data for [YouTube link](https://www.youtube.com/watch?v=eSJTXC7Ixgg):" in info_logs.output[0] diff --git a/zerver/tests/test_management_commands.py b/zerver/tests/test_management_commands.py index a52689f39e..3218979073 100644 --- a/zerver/tests/test_management_commands.py +++ b/zerver/tests/test_management_commands.py @@ -29,11 +29,13 @@ from zerver.models.users import get_user_profile_by_email class TestCheckConfig(ZulipTestCase): def test_check_config(self) -> None: check_config() - with self.settings(REQUIRED_SETTINGS=[("asdf", "not asdf")]): - with self.assertRaisesRegex( + with ( + self.settings(REQUIRED_SETTINGS=[("asdf", "not asdf")]), + self.assertRaisesRegex( CommandError, "Error: You must set asdf in /etc/zulip/settings.py." - ): - check_config() + ), + ): + check_config() @override_settings(WARN_NO_EMAIL=True) def test_check_send_email(self) -> None: @@ -210,9 +212,8 @@ class TestCommandsCanStart(ZulipTestCase): def test_management_commands_show_help(self) -> None: with stdout_suppressed(): for command in self.commands: - with self.subTest(management_command=command): - with self.assertRaises(SystemExit): - call_command(command, "--help") + with self.subTest(management_command=command), self.assertRaises(SystemExit): + call_command(command, "--help") # zerver/management/commands/runtornado.py sets this to True; # we need to reset it here. See #3685 for details. settings.RUNNING_INSIDE_TORNADO = False diff --git a/zerver/tests/test_markdown.py b/zerver/tests/test_markdown.py index 39688ed029..9351711eff 100644 --- a/zerver/tests/test_markdown.py +++ b/zerver/tests/test_markdown.py @@ -1104,9 +1104,11 @@ class MarkdownTest(ZulipTestCase): ) def test_fetch_tweet_data_settings_validation(self) -> None: - with self.settings(TEST_SUITE=False, TWITTER_CONSUMER_KEY=None): - with self.assertRaises(NotImplementedError): - fetch_tweet_data("287977969287315459") + with ( + self.settings(TEST_SUITE=False, TWITTER_CONSUMER_KEY=None), + self.assertRaises(NotImplementedError), + ): + fetch_tweet_data("287977969287315459") def test_content_has_emoji(self) -> None: self.assertFalse(content_has_emoji_syntax("boring")) @@ -1710,9 +1712,11 @@ class MarkdownTest(ZulipTestCase): self.assertEqual(linkifiers_for_realm(realm.id), []) # Verify that our in-memory cache avoids round trips. - with self.assert_database_query_count(0, keep_cache_warm=True): - with self.assert_memcached_count(0): - self.assertEqual(linkifiers_for_realm(realm.id), []) + with ( + self.assert_database_query_count(0, keep_cache_warm=True), + self.assert_memcached_count(0), + ): + self.assertEqual(linkifiers_for_realm(realm.id), []) linkifier = RealmFilter(realm=realm, pattern=r"whatever", url_template="whatever") linkifier.save() @@ -1724,12 +1728,14 @@ class MarkdownTest(ZulipTestCase): ) # And the in-process cache works again. - with self.assert_database_query_count(0, keep_cache_warm=True): - with self.assert_memcached_count(0): - self.assertEqual( - linkifiers_for_realm(realm.id), - [{"id": linkifier.id, "pattern": "whatever", "url_template": "whatever"}], - ) + with ( + self.assert_database_query_count(0, keep_cache_warm=True), + self.assert_memcached_count(0), + ): + self.assertEqual( + linkifiers_for_realm(realm.id), + [{"id": linkifier.id, "pattern": "whatever", "url_template": "whatever"}], + ) def test_alert_words(self) -> None: user_profile = self.example_user("othello") @@ -3289,17 +3295,18 @@ class MarkdownApiTests(ZulipTestCase): class MarkdownErrorTests(ZulipTestCase): def test_markdown_error_handling(self) -> None: - with self.simulated_markdown_failure(): - with self.assertRaises(MarkdownRenderingError): - markdown_convert_wrapper("") + with self.simulated_markdown_failure(), self.assertRaises(MarkdownRenderingError): + markdown_convert_wrapper("") def test_send_message_errors(self) -> None: message = "whatever" - with self.simulated_markdown_failure(): + with ( + self.simulated_markdown_failure(), # We don't use assertRaisesRegex because it seems to not # handle i18n properly here on some systems. - with self.assertRaises(JsonableError): - self.send_stream_message(self.example_user("othello"), "Denmark", message) + self.assertRaises(JsonableError), + ): + self.send_stream_message(self.example_user("othello"), "Denmark", message) @override_settings(MAX_MESSAGE_LENGTH=10) def test_ultra_long_rendering(self) -> None: @@ -3310,9 +3317,9 @@ class MarkdownErrorTests(ZulipTestCase): with ( mock.patch("zerver.lib.markdown.unsafe_timeout", return_value=msg), mock.patch("zerver.lib.markdown.markdown_logger"), + self.assertRaises(MarkdownRenderingError), ): - with self.assertRaises(MarkdownRenderingError): - markdown_convert_wrapper(msg) + markdown_convert_wrapper(msg) def test_curl_code_block_validation(self) -> None: processor = SimulatedFencedBlockPreprocessor(Markdown()) diff --git a/zerver/tests/test_message_delete.py b/zerver/tests/test_message_delete.py index 7626ac8c65..89b2ffd5c9 100644 --- a/zerver/tests/test_message_delete.py +++ b/zerver/tests/test_message_delete.py @@ -301,12 +301,14 @@ class DeleteMessageTest(ZulipTestCase): self.send_stream_message(hamlet, "Denmark") message = self.get_last_message() - with self.capture_send_event_calls(expected_num_events=1): - with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: - m.side_effect = AssertionError( - "Events should be sent only after the transaction commits." - ) - do_delete_messages(hamlet.realm, [message]) + with ( + self.capture_send_event_calls(expected_num_events=1), + mock.patch("zerver.tornado.django_api.queue_json_publish") as m, + ): + m.side_effect = AssertionError( + "Events should be sent only after the transaction commits." + ) + do_delete_messages(hamlet.realm, [message]) def test_delete_message_in_unsubscribed_private_stream(self) -> None: hamlet = self.example_user("hamlet") diff --git a/zerver/tests/test_message_edit_notifications.py b/zerver/tests/test_message_edit_notifications.py index 03bdec64be..3841513e1b 100644 --- a/zerver/tests/test_message_edit_notifications.py +++ b/zerver/tests/test_message_edit_notifications.py @@ -100,9 +100,11 @@ class EditMessageSideEffectsTest(ZulipTestCase): content=content, ) - with mock.patch("zerver.tornado.event_queue.maybe_enqueue_notifications") as m: - with self.captureOnCommitCallbacks(execute=True): - result = self.client_patch(url, request) + with ( + mock.patch("zerver.tornado.event_queue.maybe_enqueue_notifications") as m, + self.captureOnCommitCallbacks(execute=True), + ): + result = self.client_patch(url, request) cordelia = self.example_user("cordelia") cordelia_calls = [ diff --git a/zerver/tests/test_message_fetch.py b/zerver/tests/test_message_fetch.py index d66742554c..bd80a5dd5b 100644 --- a/zerver/tests/test_message_fetch.py +++ b/zerver/tests/test_message_fetch.py @@ -4203,14 +4203,13 @@ class GetOldMessagesTest(ZulipTestCase): request = HostRequestMock(query_params, user_profile) first_visible_message_id = first_unread_message_id + 2 - with first_visible_id_as(first_visible_message_id): - with queries_captured() as all_queries: - get_messages_backend( - request, - user_profile, - num_before=10, - num_after=10, - ) + with first_visible_id_as(first_visible_message_id), queries_captured() as all_queries: + get_messages_backend( + request, + user_profile, + num_before=10, + num_after=10, + ) queries = [q for q in all_queries if "/* get_messages */" in q.sql] self.assert_length(queries, 1) diff --git a/zerver/tests/test_message_send.py b/zerver/tests/test_message_send.py index ad60af2d66..b34fe6e44d 100644 --- a/zerver/tests/test_message_send.py +++ b/zerver/tests/test_message_send.py @@ -2118,9 +2118,11 @@ class StreamMessagesTest(ZulipTestCase): self.subscribe(cordelia, "test_stream") do_set_realm_property(cordelia.realm, "wildcard_mention_policy", 10, acting_user=None) content = "@**all** test wildcard mention" - with mock.patch("zerver.lib.message.num_subscribers_for_stream_id", return_value=16): - with self.assertRaisesRegex(AssertionError, "Invalid wildcard mention policy"): - self.send_stream_message(cordelia, "test_stream", content) + with ( + mock.patch("zerver.lib.message.num_subscribers_for_stream_id", return_value=16), + self.assertRaisesRegex(AssertionError, "Invalid wildcard mention policy"), + ): + self.send_stream_message(cordelia, "test_stream", content) def test_user_group_mention_restrictions(self) -> None: iago = self.example_user("iago") diff --git a/zerver/tests/test_outgoing_webhook_system.py b/zerver/tests/test_outgoing_webhook_system.py index 5b0eef331d..8cd98e2d02 100644 --- a/zerver/tests/test_outgoing_webhook_system.py +++ b/zerver/tests/test_outgoing_webhook_system.py @@ -630,13 +630,15 @@ class TestOutgoingWebhookMessaging(ZulipTestCase): "https://bot.example.com/", body=requests.exceptions.Timeout("Time is up!"), ) - with mock.patch( - "zerver.lib.outgoing_webhook.fail_with_message", side_effect=wrapped - ) as fail: - with self.assertLogs(level="INFO") as logs: - self.send_stream_message( - bot_owner, "Denmark", content=f"@**{bot.full_name}** foo", topic_name="bar" - ) + with ( + mock.patch( + "zerver.lib.outgoing_webhook.fail_with_message", side_effect=wrapped + ) as fail, + self.assertLogs(level="INFO") as logs, + ): + self.send_stream_message( + bot_owner, "Denmark", content=f"@**{bot.full_name}** foo", topic_name="bar" + ) self.assert_length(logs.output, 5) fail.assert_called_once() diff --git a/zerver/tests/test_push_notifications.py b/zerver/tests/test_push_notifications.py index 1548d6db50..5af28f4e78 100644 --- a/zerver/tests/test_push_notifications.py +++ b/zerver/tests/test_push_notifications.py @@ -1103,31 +1103,33 @@ class PushBouncerNotificationTest(BouncerTestCase): not_configured_warn_log, ) - with mock.patch( - "zerver.lib.push_notifications.uses_notification_bouncer", return_value=True + with ( + mock.patch( + "zerver.lib.push_notifications.uses_notification_bouncer", return_value=True + ), + mock.patch("zerver.lib.remote_server.send_to_push_bouncer") as m, ): - with mock.patch("zerver.lib.remote_server.send_to_push_bouncer") as m: - post_response = { - "realms": {realm.uuid: {"can_push": True, "expected_end_timestamp": None}} - } - get_response = { - "last_realm_count_id": 0, - "last_installation_count_id": 0, - "last_realmauditlog_id": 0, - } + post_response = { + "realms": {realm.uuid: {"can_push": True, "expected_end_timestamp": None}} + } + get_response = { + "last_realm_count_id": 0, + "last_installation_count_id": 0, + "last_realmauditlog_id": 0, + } - def mock_send_to_push_bouncer_response(method: str, *args: Any) -> dict[str, Any]: - if method == "POST": - return post_response - return get_response + def mock_send_to_push_bouncer_response(method: str, *args: Any) -> dict[str, Any]: + if method == "POST": + return post_response + return get_response - m.side_effect = mock_send_to_push_bouncer_response + m.side_effect = mock_send_to_push_bouncer_response - initialize_push_notifications() + initialize_push_notifications() - realm = get_realm("zulip") - self.assertTrue(realm.push_notifications_enabled) - self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) + realm = get_realm("zulip") + self.assertTrue(realm.push_notifications_enabled) + self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) @override_settings(PUSH_NOTIFICATION_BOUNCER_URL="https://push.zulip.org.example.com") @responses.activate @@ -2340,84 +2342,90 @@ class AnalyticsBouncerTest(BouncerTestCase): def test_realm_properties_after_send_analytics(self) -> None: self.add_mock_response() - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", return_value=None - ) as m: - with mock.patch( + with ( + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", return_value=None + ) as m, + mock.patch( "corporate.lib.stripe.RemoteServerBillingSession.current_count_for_billed_licenses", return_value=10, - ): - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, True) - self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) + ), + ): + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) - with mock.patch( - "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=None - ) as m: - with mock.patch( + with ( + mock.patch( + "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=None + ) as m, + mock.patch( "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", return_value=11, - ): - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, False) - self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) + ), + ): + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, False) + self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) dummy_customer = mock.MagicMock() - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", - return_value=dummy_customer, + with ( + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, + ), + mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None) as m, ): - with mock.patch( - "corporate.lib.stripe.get_current_plan_by_customer", return_value=None - ) as m: - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, True) - self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) dummy_customer = mock.MagicMock() - with mock.patch( - "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer + with ( + mock.patch( + "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer + ), + mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None) as m, + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", + return_value=11, + ), ): - with mock.patch( - "corporate.lib.stripe.get_current_plan_by_customer", return_value=None - ) as m: - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", - return_value=11, - ): - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, False) - self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, False) + self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) RemoteRealm.objects.filter(server=self.server).update( plan_type=RemoteRealm.PLAN_TYPE_COMMUNITY ) - with mock.patch( - "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer + with ( + mock.patch( + "zilencer.views.RemoteRealmBillingSession.get_customer", return_value=dummy_customer + ), + mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None), + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses" + ) as m, ): - with mock.patch("corporate.lib.stripe.get_current_plan_by_customer", return_value=None): - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses" - ) as m: - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_not_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, True) - self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_not_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual(realm.push_notifications_enabled_end_timestamp, None) # Reset the plan type to test remaining cases. RemoteRealm.objects.filter(server=self.server).update( @@ -2427,118 +2435,122 @@ class AnalyticsBouncerTest(BouncerTestCase): dummy_customer_plan = mock.MagicMock() dummy_customer_plan.status = CustomerPlan.DOWNGRADE_AT_END_OF_CYCLE dummy_date = datetime(year=2023, month=12, day=3, tzinfo=timezone.utc) - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", - return_value=dummy_customer, - ): - with mock.patch( + with ( + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, + ), + mock.patch( "corporate.lib.stripe.get_current_plan_by_customer", return_value=dummy_customer_plan, - ): - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", - return_value=11, - ): - with ( - mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle", - return_value=dummy_date, - ) as m, - self.assertLogs("zulip.analytics", level="INFO") as info_log, - ): - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, True) - self.assertEqual( - realm.push_notifications_enabled_end_timestamp, - dummy_date, - ) - self.assertIn( - "INFO:zulip.analytics:Reported 0 records", - info_log.output[0], - ) + ), + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", + return_value=11, + ), + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle", + return_value=dummy_date, + ) as m, + self.assertLogs("zulip.analytics", level="INFO") as info_log, + ): + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual( + realm.push_notifications_enabled_end_timestamp, + dummy_date, + ) + self.assertIn( + "INFO:zulip.analytics:Reported 0 records", + info_log.output[0], + ) - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", - return_value=dummy_customer, - ): - with mock.patch( + with ( + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, + ), + mock.patch( "corporate.lib.stripe.get_current_plan_by_customer", return_value=dummy_customer_plan, - ): - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", - side_effect=MissingDataError, - ): - with ( - mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle", - return_value=dummy_date, - ) as m, - self.assertLogs("zulip.analytics", level="INFO") as info_log, - ): - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, True) - self.assertEqual( - realm.push_notifications_enabled_end_timestamp, - dummy_date, - ) - self.assertIn( - "INFO:zulip.analytics:Reported 0 records", - info_log.output[0], - ) + ), + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", + side_effect=MissingDataError, + ), + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_next_billing_cycle", + return_value=dummy_date, + ) as m, + self.assertLogs("zulip.analytics", level="INFO") as info_log, + ): + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual( + realm.push_notifications_enabled_end_timestamp, + dummy_date, + ) + self.assertIn( + "INFO:zulip.analytics:Reported 0 records", + info_log.output[0], + ) - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", - return_value=dummy_customer, - ): - with mock.patch( + with ( + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, + ), + mock.patch( "corporate.lib.stripe.get_current_plan_by_customer", return_value=dummy_customer_plan, - ): - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", - return_value=10, - ): - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, True) - self.assertEqual( - realm.push_notifications_enabled_end_timestamp, - None, - ) + ), + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.current_count_for_billed_licenses", + return_value=10, + ), + ): + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual( + realm.push_notifications_enabled_end_timestamp, + None, + ) dummy_customer_plan = mock.MagicMock() dummy_customer_plan.status = CustomerPlan.ACTIVE - with mock.patch( - "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", - return_value=dummy_customer, - ): - with mock.patch( + with ( + mock.patch( + "corporate.lib.stripe.RemoteRealmBillingSession.get_customer", + return_value=dummy_customer, + ), + mock.patch( "corporate.lib.stripe.get_current_plan_by_customer", return_value=dummy_customer_plan, - ): - with self.assertLogs("zulip.analytics", level="INFO") as info_log: - send_server_data_to_push_bouncer(consider_usage_statistics=False) - m.assert_called() - realms = Realm.objects.all() - for realm in realms: - self.assertEqual(realm.push_notifications_enabled, True) - self.assertEqual( - realm.push_notifications_enabled_end_timestamp, - None, - ) - self.assertIn( - "INFO:zulip.analytics:Reported 0 records", - info_log.output[0], - ) + ), + self.assertLogs("zulip.analytics", level="INFO") as info_log, + ): + send_server_data_to_push_bouncer(consider_usage_statistics=False) + m.assert_called() + realms = Realm.objects.all() + for realm in realms: + self.assertEqual(realm.push_notifications_enabled, True) + self.assertEqual( + realm.push_notifications_enabled_end_timestamp, + None, + ) + self.assertIn( + "INFO:zulip.analytics:Reported 0 records", + info_log.output[0], + ) # Remote realm is on an inactive plan. Remote server on active plan. # ACTIVE plan takes precedence. diff --git a/zerver/tests/test_queue_worker.py b/zerver/tests/test_queue_worker.py index 2cd9b03a73..a5cde9ab2e 100644 --- a/zerver/tests/test_queue_worker.py +++ b/zerver/tests/test_queue_worker.py @@ -377,13 +377,16 @@ class WorkerTest(ZulipTestCase): # If called after `expected_scheduled_timestamp`, it should process all emails. one_minute_overdue = expected_scheduled_timestamp + timedelta(seconds=60) - with time_machine.travel(one_minute_overdue, tick=True): - with send_mock as sm, self.assertLogs(level="INFO") as info_logs: - has_timeout = advance() - self.assertTrue(has_timeout) - self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) - has_timeout = advance() - self.assertFalse(has_timeout) + with ( + time_machine.travel(one_minute_overdue, tick=True), + send_mock as sm, + self.assertLogs(level="INFO") as info_logs, + ): + has_timeout = advance() + self.assertTrue(has_timeout) + self.assertEqual(ScheduledMessageNotificationEmail.objects.count(), 0) + has_timeout = advance() + self.assertFalse(has_timeout) self.assertEqual( [ @@ -643,20 +646,22 @@ class WorkerTest(ZulipTestCase): self.assertEqual(mock_mirror_email.call_count, 4) # If RateLimiterLockingError is thrown, we rate-limit the new message: - with patch( - "zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit", - side_effect=RateLimiterLockingError, + with ( + patch( + "zerver.lib.rate_limiter.RedisRateLimiterBackend.incr_ratelimit", + side_effect=RateLimiterLockingError, + ), + self.assertLogs("zerver.lib.rate_limiter", "WARNING") as mock_warn, ): - with self.assertLogs("zerver.lib.rate_limiter", "WARNING") as mock_warn: - fake_client.enqueue("email_mirror", data[0]) - worker.start() - self.assertEqual(mock_mirror_email.call_count, 4) - self.assertEqual( - mock_warn.output, - [ - "WARNING:zerver.lib.rate_limiter:Deadlock trying to incr_ratelimit for RateLimitedRealmMirror:zulip" - ], - ) + fake_client.enqueue("email_mirror", data[0]) + worker.start() + self.assertEqual(mock_mirror_email.call_count, 4) + self.assertEqual( + mock_warn.output, + [ + "WARNING:zerver.lib.rate_limiter:Deadlock trying to incr_ratelimit for RateLimitedRealmMirror:zulip" + ], + ) self.assertEqual( warn_logs.output, [ diff --git a/zerver/tests/test_reactions.py b/zerver/tests/test_reactions.py index 44677bfbfc..f7afc23796 100644 --- a/zerver/tests/test_reactions.py +++ b/zerver/tests/test_reactions.py @@ -1054,12 +1054,14 @@ class ReactionAPIEventTest(EmojiReactionBase): "emoji_code": "1f354", "reaction_type": "unicode_emoji", } - with self.capture_send_event_calls(expected_num_events=1) as events: - with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: - m.side_effect = AssertionError( - "Events should be sent only after the transaction commits!" - ) - self.api_post(reaction_sender, f"/api/v1/messages/{pm_id}/reactions", reaction_info) + with ( + self.capture_send_event_calls(expected_num_events=1) as events, + mock.patch("zerver.tornado.django_api.queue_json_publish") as m, + ): + m.side_effect = AssertionError( + "Events should be sent only after the transaction commits!" + ) + self.api_post(reaction_sender, f"/api/v1/messages/{pm_id}/reactions", reaction_info) event = events[0]["event"] event_user_ids = set(events[0]["users"]) @@ -1137,9 +1139,11 @@ class ReactionAPIEventTest(EmojiReactionBase): reaction_type="whatever", ) - with self.capture_send_event_calls(expected_num_events=1): - with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: - m.side_effect = AssertionError( - "Events should be sent only after the transaction commits." - ) - notify_reaction_update(hamlet, message, reaction, "stuff") + with ( + self.capture_send_event_calls(expected_num_events=1), + mock.patch("zerver.tornado.django_api.queue_json_publish") as m, + ): + m.side_effect = AssertionError( + "Events should be sent only after the transaction commits." + ) + notify_reaction_update(hamlet, message, reaction, "stuff") diff --git a/zerver/tests/test_realm.py b/zerver/tests/test_realm.py index 3e91409a4c..85698c32bf 100644 --- a/zerver/tests/test_realm.py +++ b/zerver/tests/test_realm.py @@ -95,13 +95,14 @@ class RealmTest(ZulipTestCase): ) def test_realm_creation_on_special_subdomains_disallowed(self) -> None: - with self.settings(SOCIAL_AUTH_SUBDOMAIN="zulipauth"): - with self.assertRaises(AssertionError): - do_create_realm("zulipauth", "Test Realm") + with self.settings(SOCIAL_AUTH_SUBDOMAIN="zulipauth"), self.assertRaises(AssertionError): + do_create_realm("zulipauth", "Test Realm") - with self.settings(SELF_HOSTING_MANAGEMENT_SUBDOMAIN="zulipselfhosting"): - with self.assertRaises(AssertionError): - do_create_realm("zulipselfhosting", "Test Realm") + with ( + self.settings(SELF_HOSTING_MANAGEMENT_SUBDOMAIN="zulipselfhosting"), + self.assertRaises(AssertionError), + ): + do_create_realm("zulipselfhosting", "Test Realm") def test_permission_for_education_non_profit_organization(self) -> None: realm = do_create_realm( diff --git a/zerver/tests/test_realm_emoji.py b/zerver/tests/test_realm_emoji.py index 7f757c65d4..6f5a4dae63 100644 --- a/zerver/tests/test_realm_emoji.py +++ b/zerver/tests/test_realm_emoji.py @@ -315,9 +315,8 @@ class RealmEmojiTest(ZulipTestCase): def test_emoji_upload_file_size_error(self) -> None: self.login("iago") - with get_test_image_file("img.png") as fp: - with self.settings(MAX_EMOJI_FILE_SIZE_MIB=0): - result = self.client_post("/json/realm/emoji/my_emoji", {"file": fp}) + with get_test_image_file("img.png") as fp, self.settings(MAX_EMOJI_FILE_SIZE_MIB=0): + result = self.client_post("/json/realm/emoji/my_emoji", {"file": fp}) self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") def test_emoji_upload_file_format_error(self) -> None: @@ -355,12 +354,14 @@ class RealmEmojiTest(ZulipTestCase): def test_failed_file_upload(self) -> None: self.login("iago") - with mock.patch( - "zerver.lib.upload.local.write_local_file", side_effect=BadImageError(msg="Broken") + with ( + mock.patch( + "zerver.lib.upload.local.write_local_file", side_effect=BadImageError(msg="Broken") + ), + get_test_image_file("img.png") as fp1, ): - with get_test_image_file("img.png") as fp1: - emoji_data = {"f1": fp1} - result = self.client_post("/json/realm/emoji/my_emoji", info=emoji_data) + emoji_data = {"f1": fp1} + result = self.client_post("/json/realm/emoji/my_emoji", info=emoji_data) self.assert_json_error(result, "Broken") def test_check_admin_realm_emoji(self) -> None: diff --git a/zerver/tests/test_realm_export.py b/zerver/tests/test_realm_export.py index eb4889d3c2..f7ac043e32 100644 --- a/zerver/tests/test_realm_export.py +++ b/zerver/tests/test_realm_export.py @@ -49,9 +49,9 @@ class RealmExportTest(ZulipTestCase): self.settings(LOCAL_UPLOADS_DIR=None), stdout_suppressed(), self.assertLogs(level="INFO") as info_logs, + self.captureOnCommitCallbacks(execute=True), ): - with self.captureOnCommitCallbacks(execute=True): - result = self.client_post("/json/export/realm") + result = self.client_post("/json/export/realm") self.assertTrue("INFO:root:Completed data export for zulip in " in info_logs.output[0]) self.assert_json_success(result) self.assertFalse(os.path.exists(tarball_path)) @@ -150,9 +150,12 @@ class RealmExportTest(ZulipTestCase): with patch( "zerver.lib.export.do_export_realm", side_effect=fake_export_realm ) as mock_export: - with stdout_suppressed(), self.assertLogs(level="INFO") as info_logs: - with self.captureOnCommitCallbacks(execute=True): - result = self.client_post("/json/export/realm") + with ( + stdout_suppressed(), + self.assertLogs(level="INFO") as info_logs, + self.captureOnCommitCallbacks(execute=True), + ): + result = self.client_post("/json/export/realm") self.assertTrue("INFO:root:Completed data export for zulip in " in info_logs.output[0]) mock_export.assert_called_once() data = self.assert_json_success(result) @@ -208,12 +211,15 @@ class RealmExportTest(ZulipTestCase): admin = self.example_user("iago") self.login_user(admin) - with patch( - "zerver.lib.export.do_export_realm", side_effect=Exception("failure") - ) as mock_export: - with stdout_suppressed(), self.assertLogs(level="INFO") as info_logs: - with self.captureOnCommitCallbacks(execute=True): - result = self.client_post("/json/export/realm") + with ( + patch( + "zerver.lib.export.do_export_realm", side_effect=Exception("failure") + ) as mock_export, + stdout_suppressed(), + self.assertLogs(level="INFO") as info_logs, + self.captureOnCommitCallbacks(execute=True), + ): + result = self.client_post("/json/export/realm") self.assertTrue( info_logs.output[0].startswith("ERROR:root:Data export for zulip failed after ") ) @@ -240,18 +246,20 @@ class RealmExportTest(ZulipTestCase): # If the queue worker sees the same export-id again, it aborts # instead of retrying - with patch("zerver.lib.export.do_export_realm") as mock_export: - with self.assertLogs(level="INFO") as info_logs: - queue_json_publish( - "deferred_work", - { - "type": "realm_export", - "time": 42, - "realm_id": admin.realm.id, - "user_profile_id": admin.id, - "id": export_id, - }, - ) + with ( + patch("zerver.lib.export.do_export_realm") as mock_export, + self.assertLogs(level="INFO") as info_logs, + ): + queue_json_publish( + "deferred_work", + { + "type": "realm_export", + "time": 42, + "realm_id": admin.realm.id, + "user_profile_id": admin.id, + "id": export_id, + }, + ) mock_export.assert_not_called() self.assertEqual( info_logs.output, diff --git a/zerver/tests/test_send_email.py b/zerver/tests/test_send_email.py index d274a3e70a..dc0ac5b6e4 100644 --- a/zerver/tests/test_send_email.py +++ b/zerver/tests/test_send_email.py @@ -132,15 +132,17 @@ class TestSendEmail(ZulipTestCase): for message, side_effect in errors.items(): with mock.patch.object(EmailBackend, "send_messages", side_effect=side_effect): - with self.assertLogs(logger=logger) as info_log: - with self.assertRaises(EmailNotDeliveredError): - send_email( - "zerver/emails/password_reset", - to_emails=[hamlet.email], - from_name=from_name, - from_address=FromAddress.NOREPLY, - language="en", - ) + with ( + self.assertLogs(logger=logger) as info_log, + self.assertRaises(EmailNotDeliveredError), + ): + send_email( + "zerver/emails/password_reset", + to_emails=[hamlet.email], + from_name=from_name, + from_address=FromAddress.NOREPLY, + language="en", + ) self.assert_length(info_log.records, 2) self.assertEqual( info_log.output[0], @@ -151,15 +153,17 @@ class TestSendEmail(ZulipTestCase): def test_send_email_config_error_logging(self) -> None: hamlet = self.example_user("hamlet") - with self.settings(EMAIL_HOST_USER="test", EMAIL_HOST_PASSWORD=None): - with self.assertLogs(logger=logger, level="ERROR") as error_log: - send_email( - "zerver/emails/password_reset", - to_emails=[hamlet.email], - from_name="From Name", - from_address=FromAddress.NOREPLY, - language="en", - ) + with ( + self.settings(EMAIL_HOST_USER="test", EMAIL_HOST_PASSWORD=None), + self.assertLogs(logger=logger, level="ERROR") as error_log, + ): + send_email( + "zerver/emails/password_reset", + to_emails=[hamlet.email], + from_name="From Name", + from_address=FromAddress.NOREPLY, + language="en", + ) self.assertEqual( error_log.output, diff --git a/zerver/tests/test_signup.py b/zerver/tests/test_signup.py index bd0e9cedf1..eb37841701 100644 --- a/zerver/tests/test_signup.py +++ b/zerver/tests/test_signup.py @@ -1050,9 +1050,12 @@ class LoginTest(ZulipTestCase): # seem to be any O(N) behavior. Some of the cache hits are related # to sending messages, such as getting the welcome bot, looking up # the alert words for a realm, etc. - with self.assert_database_query_count(94), self.assert_memcached_count(14): - with self.captureOnCommitCallbacks(execute=True): - self.register(self.nonreg_email("test"), "test") + with ( + self.assert_database_query_count(94), + self.assert_memcached_count(14), + self.captureOnCommitCallbacks(execute=True), + ): + self.register(self.nonreg_email("test"), "test") user_profile = self.nonreg_user("test") self.assert_logged_in_user_id(user_profile.id) @@ -2946,21 +2949,23 @@ class UserSignUpTest(ZulipTestCase): return_data = kwargs.get("return_data", {}) return_data["invalid_subdomain"] = True - with patch("zerver.views.registration.authenticate", side_effect=invalid_subdomain): - with self.assertLogs(level="ERROR") as m: - result = self.client_post( - "/accounts/register/", - { - "password": password, - "full_name": "New User", - "key": find_key_by_email(email), - "terms": True, - }, - ) - self.assertEqual( - m.output, - ["ERROR:root:Subdomain mismatch in registration zulip: newuser@zulip.com"], - ) + with ( + patch("zerver.views.registration.authenticate", side_effect=invalid_subdomain), + self.assertLogs(level="ERROR") as m, + ): + result = self.client_post( + "/accounts/register/", + { + "password": password, + "full_name": "New User", + "key": find_key_by_email(email), + "terms": True, + }, + ) + self.assertEqual( + m.output, + ["ERROR:root:Subdomain mismatch in registration zulip: newuser@zulip.com"], + ) self.assertEqual(result.status_code, 302) def test_signup_using_invalid_subdomain_preserves_state_of_form(self) -> None: diff --git a/zerver/tests/test_soft_deactivation.py b/zerver/tests/test_soft_deactivation.py index 62f764e747..30eb79ffca 100644 --- a/zerver/tests/test_soft_deactivation.py +++ b/zerver/tests/test_soft_deactivation.py @@ -273,9 +273,11 @@ class UserSoftDeactivationTests(ZulipTestCase): ).count() self.assertEqual(0, received_count) - with self.settings(AUTO_CATCH_UP_SOFT_DEACTIVATED_USERS=False): - with self.assertLogs(logger_string, level="INFO") as m: - users_deactivated = do_auto_soft_deactivate_users(-1, realm) + with ( + self.settings(AUTO_CATCH_UP_SOFT_DEACTIVATED_USERS=False), + self.assertLogs(logger_string, level="INFO") as m, + ): + users_deactivated = do_auto_soft_deactivate_users(-1, realm) self.assertEqual( m.output, [ diff --git a/zerver/tests/test_submessage.py b/zerver/tests/test_submessage.py index 274eea29c5..dbdeee263d 100644 --- a/zerver/tests/test_submessage.py +++ b/zerver/tests/test_submessage.py @@ -194,12 +194,14 @@ class TestBasics(ZulipTestCase): hamlet = self.example_user("hamlet") message_id = self.send_stream_message(hamlet, "Denmark") - with self.capture_send_event_calls(expected_num_events=1): - with mock.patch("zerver.tornado.django_api.queue_json_publish") as m: - m.side_effect = AssertionError( - "Events should be sent only after the transaction commits." - ) - do_add_submessage(hamlet.realm, hamlet.id, message_id, "whatever", "whatever") + with ( + self.capture_send_event_calls(expected_num_events=1), + mock.patch("zerver.tornado.django_api.queue_json_publish") as m, + ): + m.side_effect = AssertionError( + "Events should be sent only after the transaction commits." + ) + do_add_submessage(hamlet.realm, hamlet.id, message_id, "whatever", "whatever") def test_fetch_message_containing_submessages(self) -> None: cordelia = self.example_user("cordelia") diff --git a/zerver/tests/test_subs.py b/zerver/tests/test_subs.py index d0f75a3723..faaed28a1a 100644 --- a/zerver/tests/test_subs.py +++ b/zerver/tests/test_subs.py @@ -2607,16 +2607,18 @@ class StreamAdminTest(ZulipTestCase): for user in other_sub_users: self.subscribe(user, stream_name) - with self.assert_database_query_count(query_count): - with cache_tries_captured() as cache_tries: - with self.captureOnCommitCallbacks(execute=True): - result = self.client_delete( - "/json/users/me/subscriptions", - { - "subscriptions": orjson.dumps([stream_name]).decode(), - "principals": orjson.dumps(principals).decode(), - }, - ) + with ( + self.assert_database_query_count(query_count), + cache_tries_captured() as cache_tries, + self.captureOnCommitCallbacks(execute=True), + ): + result = self.client_delete( + "/json/users/me/subscriptions", + { + "subscriptions": orjson.dumps([stream_name]).decode(), + "principals": orjson.dumps(principals).decode(), + }, + ) if cache_count is not None: self.assert_length(cache_tries, cache_count) @@ -4744,13 +4746,15 @@ class SubscriptionAPITest(ZulipTestCase): user2 = self.example_user("iago") realm = get_realm("zulip") streams_to_sub = ["multi_user_stream"] - with self.capture_send_event_calls(expected_num_events=5) as events: - with self.assert_database_query_count(38): - self.common_subscribe_to_streams( - self.test_user, - streams_to_sub, - dict(principals=orjson.dumps([user1.id, user2.id]).decode()), - ) + with ( + self.capture_send_event_calls(expected_num_events=5) as events, + self.assert_database_query_count(38), + ): + self.common_subscribe_to_streams( + self.test_user, + streams_to_sub, + dict(principals=orjson.dumps([user1.id, user2.id]).decode()), + ) for ev in [x for x in events if x["event"]["type"] not in ("message", "stream")]: if ev["event"]["op"] == "add": @@ -4768,13 +4772,15 @@ class SubscriptionAPITest(ZulipTestCase): self.assertEqual(num_subscribers_for_stream_id(stream.id), 2) # Now add ourselves - with self.capture_send_event_calls(expected_num_events=2) as events: - with self.assert_database_query_count(14): - self.common_subscribe_to_streams( - self.test_user, - streams_to_sub, - dict(principals=orjson.dumps([self.test_user.id]).decode()), - ) + with ( + self.capture_send_event_calls(expected_num_events=2) as events, + self.assert_database_query_count(14), + ): + self.common_subscribe_to_streams( + self.test_user, + streams_to_sub, + dict(principals=orjson.dumps([self.test_user.id]).decode()), + ) add_event, add_peer_event = events self.assertEqual(add_event["event"]["type"], "subscription") @@ -5061,15 +5067,17 @@ class SubscriptionAPITest(ZulipTestCase): # Sends 3 peer-remove events, 2 unsubscribe events # and 2 stream delete events for private streams. - with self.assert_database_query_count(16): - with self.assert_memcached_count(3): - with self.capture_send_event_calls(expected_num_events=7) as events: - bulk_remove_subscriptions( - realm, - [user1, user2], - [stream1, stream2, stream3, private], - acting_user=None, - ) + with ( + self.assert_database_query_count(16), + self.assert_memcached_count(3), + self.capture_send_event_calls(expected_num_events=7) as events, + ): + bulk_remove_subscriptions( + realm, + [user1, user2], + [stream1, stream2, stream3, private], + acting_user=None, + ) peer_events = [e for e in events if e["event"].get("op") == "peer_remove"] stream_delete_events = [ @@ -5214,14 +5222,16 @@ class SubscriptionAPITest(ZulipTestCase): # The only known O(N) behavior here is that we call # principal_to_user_profile for each of our users, but it # should be cached. - with self.assert_database_query_count(21): - with self.assert_memcached_count(3): - with mock.patch("zerver.views.streams.send_messages_for_new_subscribers"): - self.common_subscribe_to_streams( - desdemona, - streams, - dict(principals=orjson.dumps(test_user_ids).decode()), - ) + with ( + self.assert_database_query_count(21), + self.assert_memcached_count(3), + mock.patch("zerver.views.streams.send_messages_for_new_subscribers"), + ): + self.common_subscribe_to_streams( + desdemona, + streams, + dict(principals=orjson.dumps(test_user_ids).decode()), + ) def test_subscriptions_add_for_principal(self) -> None: """ diff --git a/zerver/tests/test_typing.py b/zerver/tests/test_typing.py index 4526653939..b7cae1ecde 100644 --- a/zerver/tests/test_typing.py +++ b/zerver/tests/test_typing.py @@ -176,9 +176,11 @@ class TypingHappyPathTestDirectMessages(ZulipTestCase): op="start", ) - with self.assert_database_query_count(4): - with self.capture_send_event_calls(expected_num_events=1) as events: - result = self.api_post(sender, "/api/v1/typing", params) + with ( + self.assert_database_query_count(4), + self.capture_send_event_calls(expected_num_events=1) as events, + ): + result = self.api_post(sender, "/api/v1/typing", params) self.assert_json_success(result) self.assert_length(events, 1) @@ -212,9 +214,11 @@ class TypingHappyPathTestDirectMessages(ZulipTestCase): op="start", ) - with self.assert_database_query_count(5): - with self.capture_send_event_calls(expected_num_events=1) as events: - result = self.api_post(sender, "/api/v1/typing", params) + with ( + self.assert_database_query_count(5), + self.capture_send_event_calls(expected_num_events=1) as events, + ): + result = self.api_post(sender, "/api/v1/typing", params) self.assert_json_success(result) self.assert_length(events, 1) @@ -406,9 +410,11 @@ class TypingHappyPathTestStreams(ZulipTestCase): topic=topic_name, ) - with self.assert_database_query_count(6): - with self.capture_send_event_calls(expected_num_events=1) as events: - result = self.api_post(sender, "/api/v1/typing", params) + with ( + self.assert_database_query_count(6), + self.capture_send_event_calls(expected_num_events=1) as events, + ): + result = self.api_post(sender, "/api/v1/typing", params) self.assert_json_success(result) self.assert_length(events, 1) @@ -437,9 +443,11 @@ class TypingHappyPathTestStreams(ZulipTestCase): topic=topic_name, ) - with self.assert_database_query_count(6): - with self.capture_send_event_calls(expected_num_events=1) as events: - result = self.api_post(sender, "/api/v1/typing", params) + with ( + self.assert_database_query_count(6), + self.capture_send_event_calls(expected_num_events=1) as events, + ): + result = self.api_post(sender, "/api/v1/typing", params) self.assert_json_success(result) self.assert_length(events, 1) @@ -470,9 +478,11 @@ class TypingHappyPathTestStreams(ZulipTestCase): topic=topic_name, ) with self.settings(MAX_STREAM_SIZE_FOR_TYPING_NOTIFICATIONS=5): - with self.assert_database_query_count(5): - with self.capture_send_event_calls(expected_num_events=0) as events: - result = self.api_post(sender, "/api/v1/typing", params) + with ( + self.assert_database_query_count(5), + self.capture_send_event_calls(expected_num_events=0) as events, + ): + result = self.api_post(sender, "/api/v1/typing", params) self.assert_json_success(result) self.assert_length(events, 0) @@ -501,9 +511,11 @@ class TypingHappyPathTestStreams(ZulipTestCase): topic=topic_name, ) - with self.assert_database_query_count(6): - with self.capture_send_event_calls(expected_num_events=1) as events: - result = self.api_post(sender, "/api/v1/typing", params) + with ( + self.assert_database_query_count(6), + self.capture_send_event_calls(expected_num_events=1) as events, + ): + result = self.api_post(sender, "/api/v1/typing", params) self.assert_json_success(result) self.assert_length(events, 1) diff --git a/zerver/tests/test_upload.py b/zerver/tests/test_upload.py index 042fa7506c..01202c55f5 100644 --- a/zerver/tests/test_upload.py +++ b/zerver/tests/test_upload.py @@ -1390,9 +1390,11 @@ class AvatarTest(UploadSerializeMixin, ZulipTestCase): def test_avatar_upload_file_size_error(self) -> None: self.login("hamlet") - with get_test_image_file(self.correct_files[0][0]) as fp: - with self.settings(MAX_AVATAR_FILE_SIZE_MIB=0): - result = self.client_post("/json/users/me/avatar", {"file": fp}) + with ( + get_test_image_file(self.correct_files[0][0]) as fp, + self.settings(MAX_AVATAR_FILE_SIZE_MIB=0), + ): + result = self.client_post("/json/users/me/avatar", {"file": fp}) self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") @@ -1537,9 +1539,11 @@ class RealmIconTest(UploadSerializeMixin, ZulipTestCase): def test_realm_icon_upload_file_size_error(self) -> None: self.login("iago") - with get_test_image_file(self.correct_files[0][0]) as fp: - with self.settings(MAX_ICON_FILE_SIZE_MIB=0): - result = self.client_post("/json/realm/icon", {"file": fp}) + with ( + get_test_image_file(self.correct_files[0][0]) as fp, + self.settings(MAX_ICON_FILE_SIZE_MIB=0), + ): + result = self.client_post("/json/realm/icon", {"file": fp}) self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") @@ -1743,11 +1747,13 @@ class RealmLogoTest(UploadSerializeMixin, ZulipTestCase): def test_logo_upload_file_size_error(self) -> None: self.login("iago") - with get_test_image_file(self.correct_files[0][0]) as fp: - with self.settings(MAX_LOGO_FILE_SIZE_MIB=0): - result = self.client_post( - "/json/realm/logo", {"file": fp, "night": orjson.dumps(self.night).decode()} - ) + with ( + get_test_image_file(self.correct_files[0][0]) as fp, + self.settings(MAX_LOGO_FILE_SIZE_MIB=0), + ): + result = self.client_post( + "/json/realm/logo", {"file": fp, "night": orjson.dumps(self.night).decode()} + ) self.assert_json_error(result, "Uploaded file is larger than the allowed limit of 0 MiB") @@ -1766,53 +1772,63 @@ class EmojiTest(UploadSerializeMixin, ZulipTestCase): def test_non_image(self) -> None: """Non-image is not resized""" self.login("iago") - with get_test_image_file("text.txt") as f: - with patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock: - result = self.client_post("/json/realm/emoji/new", {"f1": f}) - self.assert_json_error(result, "Invalid image format") - resize_mock.assert_not_called() + with ( + get_test_image_file("text.txt") as f, + patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock, + ): + result = self.client_post("/json/realm/emoji/new", {"f1": f}) + self.assert_json_error(result, "Invalid image format") + resize_mock.assert_not_called() def test_upsupported_format(self) -> None: """Invalid format is not resized""" self.login("iago") - with get_test_image_file("img.bmp") as f: - with patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock: - result = self.client_post("/json/realm/emoji/new", {"f1": f}) - self.assert_json_error(result, "Invalid image format") - resize_mock.assert_not_called() + with ( + get_test_image_file("img.bmp") as f, + patch("zerver.lib.upload.resize_emoji", return_value=(b"a", None)) as resize_mock, + ): + result = self.client_post("/json/realm/emoji/new", {"f1": f}) + self.assert_json_error(result, "Invalid image format") + resize_mock.assert_not_called() def test_upload_too_big_after_resize(self) -> None: """Non-animated image is too big after resizing""" self.login("iago") - with get_test_image_file("img.png") as f: - with patch( + with ( + get_test_image_file("img.png") as f, + patch( "zerver.lib.upload.resize_emoji", return_value=(b"a" * (200 * 1024), None) - ) as resize_mock: - result = self.client_post("/json/realm/emoji/new", {"f1": f}) - self.assert_json_error(result, "Image size exceeds limit") - resize_mock.assert_called_once() + ) as resize_mock, + ): + result = self.client_post("/json/realm/emoji/new", {"f1": f}) + self.assert_json_error(result, "Image size exceeds limit") + resize_mock.assert_called_once() def test_upload_big_after_animated_resize(self) -> None: """A big animated image is fine as long as the still is small""" self.login("iago") - with get_test_image_file("animated_img.gif") as f: - with patch( + with ( + get_test_image_file("animated_img.gif") as f, + patch( "zerver.lib.upload.resize_emoji", return_value=(b"a" * (200 * 1024), b"aaa") - ) as resize_mock: - result = self.client_post("/json/realm/emoji/new", {"f1": f}) - self.assert_json_success(result) - resize_mock.assert_called_once() + ) as resize_mock, + ): + result = self.client_post("/json/realm/emoji/new", {"f1": f}) + self.assert_json_success(result) + resize_mock.assert_called_once() def test_upload_too_big_after_animated_resize_still(self) -> None: """Still of animated image is too big after resizing""" self.login("iago") - with get_test_image_file("animated_img.gif") as f: - with patch( + with ( + get_test_image_file("animated_img.gif") as f, + patch( "zerver.lib.upload.resize_emoji", return_value=(b"aaa", b"a" * (200 * 1024)) - ) as resize_mock: - result = self.client_post("/json/realm/emoji/new", {"f1": f}) - self.assert_json_error(result, "Image size exceeds limit") - resize_mock.assert_called_once() + ) as resize_mock, + ): + result = self.client_post("/json/realm/emoji/new", {"f1": f}) + self.assert_json_error(result, "Image size exceeds limit") + resize_mock.assert_called_once() class SanitizeNameTests(ZulipTestCase): diff --git a/zerver/tests/test_user_groups.py b/zerver/tests/test_user_groups.py index a603bccbc0..1168541bce 100644 --- a/zerver/tests/test_user_groups.py +++ b/zerver/tests/test_user_groups.py @@ -1156,9 +1156,11 @@ class UserGroupAPITestCase(UserGroupTestCase): munge = lambda obj: orjson.dumps(obj).decode() params = dict(add=munge(new_user_ids)) - with mock.patch("zerver.views.user_groups.notify_for_user_group_subscription_changes"): - with self.assert_database_query_count(11): - result = self.client_post(f"/json/user_groups/{user_group.id}/members", info=params) + with ( + mock.patch("zerver.views.user_groups.notify_for_user_group_subscription_changes"), + self.assert_database_query_count(11), + ): + result = self.client_post(f"/json/user_groups/{user_group.id}/members", info=params) self.assert_json_success(result) with self.assert_database_query_count(1): diff --git a/zerver/tests/test_user_topics.py b/zerver/tests/test_user_topics.py index 6e27da51ab..364662c414 100644 --- a/zerver/tests/test_user_topics.py +++ b/zerver/tests/test_user_topics.py @@ -338,10 +338,12 @@ class MutedTopicsTests(ZulipTestCase): mock_date_muted = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() - with self.capture_send_event_calls(expected_num_events=2) as events: - with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): - result = self.api_post(user, url, data) - self.assert_json_success(result) + with ( + self.capture_send_event_calls(expected_num_events=2) as events, + time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False), + ): + result = self.api_post(user, url, data) + self.assert_json_success(result) self.assertTrue( topic_has_visibility_policy( @@ -404,10 +406,12 @@ class MutedTopicsTests(ZulipTestCase): mock_date_mute_removed = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() - with self.capture_send_event_calls(expected_num_events=2) as events: - with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): - result = self.api_post(user, url, data) - self.assert_json_success(result) + with ( + self.capture_send_event_calls(expected_num_events=2) as events, + time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False), + ): + result = self.api_post(user, url, data) + self.assert_json_success(result) self.assertFalse( topic_has_visibility_policy( @@ -553,10 +557,12 @@ class UnmutedTopicsTests(ZulipTestCase): mock_date_unmuted = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() - with self.capture_send_event_calls(expected_num_events=2) as events: - with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): - result = self.api_post(user, url, data) - self.assert_json_success(result) + with ( + self.capture_send_event_calls(expected_num_events=2) as events, + time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False), + ): + result = self.api_post(user, url, data) + self.assert_json_success(result) self.assertTrue( topic_has_visibility_policy( @@ -619,10 +625,12 @@ class UnmutedTopicsTests(ZulipTestCase): mock_date_unmute_removed = datetime(2020, 1, 1, tzinfo=timezone.utc).timestamp() - with self.capture_send_event_calls(expected_num_events=2) as events: - with time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False): - result = self.api_post(user, url, data) - self.assert_json_success(result) + with ( + self.capture_send_event_calls(expected_num_events=2) as events, + time_machine.travel(datetime(2020, 1, 1, tzinfo=timezone.utc), tick=False), + ): + result = self.api_post(user, url, data) + self.assert_json_success(result) self.assertFalse( topic_has_visibility_policy( diff --git a/zerver/tests/test_users.py b/zerver/tests/test_users.py index 22badd5f02..c828a72f0e 100644 --- a/zerver/tests/test_users.py +++ b/zerver/tests/test_users.py @@ -909,17 +909,19 @@ class QueryCountTest(ZulipTestCase): prereg_user = PreregistrationUser.objects.get(email="fred@zulip.com") - with self.assert_database_query_count(84): - with self.assert_memcached_count(19): - with self.capture_send_event_calls(expected_num_events=10) as events: - fred = do_create_user( - email="fred@zulip.com", - password="password", - realm=realm, - full_name="Fred Flintstone", - prereg_user=prereg_user, - acting_user=None, - ) + with ( + self.assert_database_query_count(84), + self.assert_memcached_count(19), + self.capture_send_event_calls(expected_num_events=10) as events, + ): + fred = do_create_user( + email="fred@zulip.com", + password="password", + realm=realm, + full_name="Fred Flintstone", + prereg_user=prereg_user, + acting_user=None, + ) peer_add_events = [event for event in events if event["event"].get("op") == "peer_add"] @@ -2404,9 +2406,8 @@ class GetProfileTest(ZulipTestCase): """ realm = get_realm("zulip") email = self.example_user("hamlet").email - with self.assert_database_query_count(1): - with simulated_empty_cache() as cache_queries: - user_profile = get_user(email, realm) + with self.assert_database_query_count(1), simulated_empty_cache() as cache_queries: + user_profile = get_user(email, realm) self.assert_length(cache_queries, 1) self.assertEqual(user_profile.email, email) diff --git a/zerver/webhooks/pivotal/tests.py b/zerver/webhooks/pivotal/tests.py index 6dfffd6002..b78d218ffd 100644 --- a/zerver/webhooks/pivotal/tests.py +++ b/zerver/webhooks/pivotal/tests.py @@ -210,11 +210,11 @@ Try again next time def test_bad_payload(self) -> None: bad = ("foo", None, "bar") - with self.assertRaisesRegex(AssertionError, "Unable to handle Pivotal payload"): - with mock.patch( - "zerver.webhooks.pivotal.view.api_pivotal_webhook_v3", return_value=bad - ): - self.check_webhook("accepted", expect_topic="foo") + with ( + self.assertRaisesRegex(AssertionError, "Unable to handle Pivotal payload"), + mock.patch("zerver.webhooks.pivotal.view.api_pivotal_webhook_v3", return_value=bad), + ): + self.check_webhook("accepted", expect_topic="foo") def test_bad_request(self) -> None: request = mock.MagicMock() @@ -226,9 +226,11 @@ Try again next time self.assertEqual(result[0], "#0: ") bad = orjson.loads(self.get_body("bad_kind")) - with self.assertRaisesRegex(UnsupportedWebhookEventTypeError, "'unknown_kind'.* supported"): - with mock.patch("zerver.webhooks.pivotal.view.orjson.loads", return_value=bad): - api_pivotal_webhook_v5(request, hamlet) + with ( + self.assertRaisesRegex(UnsupportedWebhookEventTypeError, "'unknown_kind'.* supported"), + mock.patch("zerver.webhooks.pivotal.view.orjson.loads", return_value=bad), + ): + api_pivotal_webhook_v5(request, hamlet) @override def get_body(self, fixture_name: str) -> str: diff --git a/zerver/worker/base.py b/zerver/worker/base.py index 665fb86c33..5d95a0d917 100644 --- a/zerver/worker/base.py +++ b/zerver/worker/base.py @@ -276,9 +276,8 @@ class QueueProcessingWorker(ABC): fn = os.path.join(settings.QUEUE_ERROR_DIR, fname) line = f"{time.asctime()}\t{orjson.dumps(events).decode()}\n" lock_fn = fn + ".lock" - with lockfile(lock_fn): - with open(fn, "a") as f: - f.write(line) + with lockfile(lock_fn), open(fn, "a") as f: + f.write(line) check_and_send_restart_signal() def setup(self) -> None: