From a688e753dee6eb82b7be2f2333b2bdbecea62f85 Mon Sep 17 00:00:00 2001 From: Anders Kaseorg Date: Wed, 15 Nov 2023 13:25:00 -0800 Subject: [PATCH] test_helpers: Fix logging in cursor_executemany mock. Signed-off-by: Anders Kaseorg --- zerver/lib/db.py | 4 ++-- zerver/lib/test_helpers.py | 40 +++++++++++++++++++++----------------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/zerver/lib/db.py b/zerver/lib/db.py index b5a0eac3a0..df547dba3d 100644 --- a/zerver/lib/db.py +++ b/zerver/lib/db.py @@ -38,8 +38,8 @@ class TimeTrackingCursor(cursor): wrapper_execute(self, super().execute, query, vars) @override - def executemany(self, query: Query, vars: Iterable[Params]) -> None: # nocoverage - wrapper_execute(self, super().executemany, query, vars) + def executemany(self, query: Query, vars_list: Iterable[Params]) -> None: # nocoverage + wrapper_execute(self, super().executemany, query, vars_list) CursorT = TypeVar("CursorT", bound=cursor) diff --git a/zerver/lib/test_helpers.py b/zerver/lib/test_helpers.py index e3a8910d8f..234abea103 100644 --- a/zerver/lib/test_helpers.py +++ b/zerver/lib/test_helpers.py @@ -1,4 +1,5 @@ import collections +import itertools import os import re import sys @@ -45,7 +46,7 @@ from zerver.actions.user_settings import do_change_user_setting from zerver.lib import cache from zerver.lib.avatar import avatar_url from zerver.lib.cache import get_cache_backend -from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor +from zerver.lib.db import Params, Query, TimeTrackingCursor from zerver.lib.integrations import WEBHOOK_INTEGRATIONS from zerver.lib.per_request_cache import flush_per_request_caches from zerver.lib.rate_limiter import RateLimitedIPAddr, rules @@ -150,35 +151,38 @@ def queries_captured( queries: List[CapturedQuery] = [] - def wrapper_execute( - self: TimeTrackingCursor, - action: Callable[[Query, ParamsT], None], - sql: Query, - params: ParamsT, - ) -> None: + def cursor_execute(self: TimeTrackingCursor, sql: Query, vars: Optional[Params] = None) -> None: start = time.time() try: - return action(sql, params) + return super(TimeTrackingCursor, self).execute(sql, vars) finally: stop = time.time() duration = stop - start if include_savepoints or not isinstance(sql, str) or "SAVEPOINT" not in sql: queries.append( CapturedQuery( - sql=self.mogrify(sql, params).decode(), + sql=self.mogrify(sql, vars).decode(), time=f"{duration:.3f}", ) ) - def cursor_execute( - self: TimeTrackingCursor, sql: Query, params: Optional[Params] = None - ) -> None: - return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params) - - def cursor_executemany(self: TimeTrackingCursor, sql: Query, params: Iterable[Params]) -> None: - return wrapper_execute( - self, super(TimeTrackingCursor, self).executemany, sql, params - ) # nocoverage -- doesn't actually get used in tests + def cursor_executemany( + self: TimeTrackingCursor, sql: Query, vars_list: Iterable[Params] + ) -> None: # nocoverage -- doesn't actually get used in tests + vars_list, vars_list1 = itertools.tee(vars_list) + start = time.time() + try: + return super(TimeTrackingCursor, self).executemany(sql, vars_list) + finally: + stop = time.time() + duration = stop - start + queries.extend( + CapturedQuery( + sql=self.mogrify(sql, vars).decode(), + time=f"{duration:.3f}", + ) + for vars in vars_list1 + ) if not keep_cache_warm: cache = get_cache_backend(None)