diff --git a/zerver/lib/test_runner.py b/zerver/lib/test_runner.py index c3e567c1bd..af966f4131 100644 --- a/zerver/lib/test_runner.py +++ b/zerver/lib/test_runner.py @@ -5,6 +5,7 @@ from typing import Any, Callable, Iterable, List, Optional, Set, Tuple from django.test import TestCase from django.test.runner import DiscoverRunner from django.test.signals import template_rendered +from unittest import loader # type: ignore # Mypy cannot pick this up. from zerver.lib.cache import bounce_key_prefix_for_testing from zerver.lib.sqlalchemy_utils import get_sqlalchemy_connection @@ -19,6 +20,9 @@ import time import traceback import unittest +if False: + from unittest.result import TextTestResult + def slow(slowness_reason): # type: (str) -> Callable[[Callable], Callable] ''' @@ -146,7 +150,28 @@ def run_test(test): test._post_teardown() return failed +class TestSuite(unittest.TestSuite): + def run(self, result, debug=False): + # type: (TextTestResult, Optional[bool]) -> TextTestResult + for test in self: # type: ignore # Mypy cannot recognize this but this is correct. Taken from unittest. + result.startTest(test) + # The attributes __unittest_skip__ and __unittest_skip_why__ are undocumented + if hasattr(test, '__unittest_skip__') and test.__unittest_skip__: # type: ignore + print('Skipping', full_test_name(test), "(%s)" % (test.__unittest_skip_why__,)) # type: ignore + elif run_test(test): + if result.failfast: + break + result.stopTest(test) + + return result + +class TestLoader(loader.TestLoader): + suiteClass = TestSuite + class Runner(DiscoverRunner): + test_suite = TestSuite + test_loader = TestLoader() + def __init__(self, *args, **kwargs): # type: (*Any, **Any) -> None DiscoverRunner.__init__(self, *args, **kwargs) @@ -174,19 +199,6 @@ class Runner(DiscoverRunner): # type: () -> Set[str] return self.shallow_tested_templates - def run_suite(self, suite, **kwargs): - # type: (Iterable[TestCase], **Any) -> bool - failed = False - for test in suite: - # The attributes __unittest_skip__ and __unittest_skip_why__ are undocumented - if hasattr(test, '__unittest_skip__') and test.__unittest_skip__: - print('Skipping', full_test_name(test), "(%s)" % (test.__unittest_skip_why__,)) - elif run_test(test): - failed = True - if self.failfast: - return failed - return failed - def run_tests(self, test_labels, extra_tests=None, full_suite=False, **kwargs): # type: (List[str], Optional[List[TestCase]], bool, **Any) -> bool @@ -207,8 +219,9 @@ class Runner(DiscoverRunner): # run a single test and getting an SA connection causes data from # a Django connection to be rolled back mid-test. get_sqlalchemy_connection() - failed = self.run_suite(suite) + result = self.run_suite(suite) self.teardown_test_environment() + failed = self.suite_result(suite, result) if not failed: write_instrumentation_reports(full_suite=full_suite) return failed