billing: Create get_discount_for_realm function.

This commit is contained in:
Vishnu Ks 2019-03-06 12:01:56 +00:00 committed by Tim Abbott
parent 9671ed7bab
commit ca1276961d
2 changed files with 14 additions and 1 deletions

View File

@ -450,6 +450,12 @@ def invoice_plans_as_needed(event_time: datetime) -> None:
def attach_discount_to_realm(realm: Realm, discount: Decimal) -> None: def attach_discount_to_realm(realm: Realm, discount: Decimal) -> None:
Customer.objects.update_or_create(realm=realm, defaults={'default_discount': discount}) Customer.objects.update_or_create(realm=realm, defaults={'default_discount': discount})
def get_discount_for_realm(realm: Realm) -> Optional[Decimal]:
customer = Customer.objects.filter(realm=realm).first()
if customer is not None:
return customer.default_discount
return None
def process_downgrade(user: UserProfile) -> None: # nocoverage def process_downgrade(user: UserProfile) -> None: # nocoverage
pass pass

View File

@ -30,7 +30,7 @@ from corporate.lib.stripe import catch_stripe_errors, attach_discount_to_realm,
compute_plan_parameters, update_or_create_stripe_customer, \ compute_plan_parameters, update_or_create_stripe_customer, \
process_initial_upgrade, add_plan_renewal_to_license_ledger_if_needed, \ process_initial_upgrade, add_plan_renewal_to_license_ledger_if_needed, \
update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \ update_license_ledger_if_needed, update_license_ledger_for_automanaged_plan, \
invoice_plan, invoice_plans_as_needed invoice_plan, invoice_plans_as_needed, get_discount_for_realm
from corporate.models import Customer, CustomerPlan, LicenseLedger from corporate.models import Customer, CustomerPlan, LicenseLedger
CallableT = TypeVar('CallableT', bound=Callable[..., Any]) CallableT = TypeVar('CallableT', bound=Callable[..., Any])
@ -807,6 +807,13 @@ class StripeTest(StripeTestCase):
[item.amount for item in stripe_invoice.lines]) [item.amount for item in stripe_invoice.lines])
plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25)) plan = CustomerPlan.objects.get(price_per_license=6000, discount=Decimal(25))
def test_get_discount_for_realm(self) -> None:
user = self.example_user('hamlet')
self.assertEqual(get_discount_for_realm(user.realm), None)
attach_discount_to_realm(user.realm, Decimal(85))
self.assertEqual(get_discount_for_realm(user.realm), 85)
@mock_stripe() @mock_stripe()
def test_replace_payment_source(self, *mocks: Mock) -> None: def test_replace_payment_source(self, *mocks: Mock) -> None:
user = self.example_user("hamlet") user = self.example_user("hamlet")