db: Fix types to accept psycopg2.sql.Composable queries, avoid Any.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2020-05-03 17:36:15 -07:00 committed by Tim Abbott
parent d0b40cd7a3
commit cebac3f35a
2 changed files with 21 additions and 18 deletions

View File

@ -1,18 +1,21 @@
import time
from psycopg2.extensions import cursor, connection
from psycopg2.sql import Composable
from typing import Callable, Optional, Iterable, Any, Dict, List, Union, TypeVar, \
Mapping
Mapping, Sequence
CursorObj = TypeVar('CursorObj', bound=cursor)
ParamsT = Union[Iterable[Any], Mapping[str, Any]]
Query = Union[str, Composable]
Params = Union[Sequence[object], Mapping[str, object]]
ParamsT = TypeVar('ParamsT')
# Similar to the tracking done in Django's CursorDebugWrapper, but done at the
# psycopg2 cursor level so it works with SQLAlchemy.
def wrapper_execute(self: CursorObj,
action: Callable[[str, Optional[ParamsT]], CursorObj],
sql: str,
params: Optional[ParamsT]=()) -> CursorObj:
action: Callable[[Query, ParamsT], CursorObj],
sql: Query,
params: ParamsT) -> CursorObj:
start = time.time()
try:
return action(sql, params)
@ -26,12 +29,12 @@ def wrapper_execute(self: CursorObj,
class TimeTrackingCursor(cursor):
"""A psycopg2 cursor class that tracks the time spent executing queries."""
def execute(self, query: str,
vars: Optional[ParamsT]=None) -> 'TimeTrackingCursor':
def execute(self, query: Query,
vars: Optional[Params]=None) -> 'TimeTrackingCursor':
return wrapper_execute(self, super().execute, query, vars)
def executemany(self, query: str,
vars: Iterable[Any]) -> 'TimeTrackingCursor':
def executemany(self, query: Query,
vars: Iterable[Params]) -> 'TimeTrackingCursor':
return wrapper_execute(self, super().executemany, query, vars)
class TimeTrackingConnection(connection):

View File

@ -17,7 +17,7 @@ from zerver.lib.actions import do_set_realm_property
from zerver.lib.upload import S3UploadBackend, LocalUploadBackend
from zerver.lib.avatar import avatar_url
from zerver.lib.cache import get_cache_backend
from zerver.lib.db import TimeTrackingCursor
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
from zerver.lib import cache
from zerver.tornado import event_queue
from zerver.tornado.handlers import allocate_handler_id
@ -147,9 +147,9 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
queries: List[Dict[str, Union[str, bytes]]] = []
def wrapper_execute(self: TimeTrackingCursor,
action: Callable[[str, Iterable[Any]], None],
sql: str,
params: Iterable[Any]=()) -> None:
action: Callable[[str, ParamsT], None],
sql: Query,
params: ParamsT) -> None:
cache = get_cache_backend(None)
cache.clear()
start = time.time()
@ -158,7 +158,7 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
finally:
stop = time.time()
duration = stop - start
if include_savepoints or ('SAVEPOINT' not in sql):
if include_savepoints or not isinstance(sql, str) or 'SAVEPOINT' not in sql:
queries.append({
'sql': self.mogrify(sql, params).decode('utf-8'),
'time': "%.3f" % (duration,),
@ -167,13 +167,13 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
old_execute = TimeTrackingCursor.execute
old_executemany = TimeTrackingCursor.executemany
def cursor_execute(self: TimeTrackingCursor, sql: str,
params: Iterable[Any]=()) -> None:
def cursor_execute(self: TimeTrackingCursor, sql: Query,
params: Optional[Params]=None) -> None:
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
TimeTrackingCursor.execute = cursor_execute # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
def cursor_executemany(self: TimeTrackingCursor, sql: str,
params: Iterable[Any]=()) -> None:
def cursor_executemany(self: TimeTrackingCursor, sql: Query,
params: Iterable[Params]) -> None:
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # nocoverage -- doesn't actually get used in tests
TimeTrackingCursor.executemany = cursor_executemany # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167