mirror of https://github.com/zulip/zulip.git
db: Fix types to accept psycopg2.sql.Composable queries, avoid Any.
Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
d0b40cd7a3
commit
cebac3f35a
|
@ -1,18 +1,21 @@
|
|||
import time
|
||||
from psycopg2.extensions import cursor, connection
|
||||
from psycopg2.sql import Composable
|
||||
|
||||
from typing import Callable, Optional, Iterable, Any, Dict, List, Union, TypeVar, \
|
||||
Mapping
|
||||
Mapping, Sequence
|
||||
|
||||
CursorObj = TypeVar('CursorObj', bound=cursor)
|
||||
ParamsT = Union[Iterable[Any], Mapping[str, Any]]
|
||||
Query = Union[str, Composable]
|
||||
Params = Union[Sequence[object], Mapping[str, object]]
|
||||
ParamsT = TypeVar('ParamsT')
|
||||
|
||||
# Similar to the tracking done in Django's CursorDebugWrapper, but done at the
|
||||
# psycopg2 cursor level so it works with SQLAlchemy.
|
||||
def wrapper_execute(self: CursorObj,
|
||||
action: Callable[[str, Optional[ParamsT]], CursorObj],
|
||||
sql: str,
|
||||
params: Optional[ParamsT]=()) -> CursorObj:
|
||||
action: Callable[[Query, ParamsT], CursorObj],
|
||||
sql: Query,
|
||||
params: ParamsT) -> CursorObj:
|
||||
start = time.time()
|
||||
try:
|
||||
return action(sql, params)
|
||||
|
@ -26,12 +29,12 @@ def wrapper_execute(self: CursorObj,
|
|||
class TimeTrackingCursor(cursor):
|
||||
"""A psycopg2 cursor class that tracks the time spent executing queries."""
|
||||
|
||||
def execute(self, query: str,
|
||||
vars: Optional[ParamsT]=None) -> 'TimeTrackingCursor':
|
||||
def execute(self, query: Query,
|
||||
vars: Optional[Params]=None) -> 'TimeTrackingCursor':
|
||||
return wrapper_execute(self, super().execute, query, vars)
|
||||
|
||||
def executemany(self, query: str,
|
||||
vars: Iterable[Any]) -> 'TimeTrackingCursor':
|
||||
def executemany(self, query: Query,
|
||||
vars: Iterable[Params]) -> 'TimeTrackingCursor':
|
||||
return wrapper_execute(self, super().executemany, query, vars)
|
||||
|
||||
class TimeTrackingConnection(connection):
|
||||
|
|
|
@ -17,7 +17,7 @@ from zerver.lib.actions import do_set_realm_property
|
|||
from zerver.lib.upload import S3UploadBackend, LocalUploadBackend
|
||||
from zerver.lib.avatar import avatar_url
|
||||
from zerver.lib.cache import get_cache_backend
|
||||
from zerver.lib.db import TimeTrackingCursor
|
||||
from zerver.lib.db import Params, ParamsT, Query, TimeTrackingCursor
|
||||
from zerver.lib import cache
|
||||
from zerver.tornado import event_queue
|
||||
from zerver.tornado.handlers import allocate_handler_id
|
||||
|
@ -147,9 +147,9 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
|
|||
queries: List[Dict[str, Union[str, bytes]]] = []
|
||||
|
||||
def wrapper_execute(self: TimeTrackingCursor,
|
||||
action: Callable[[str, Iterable[Any]], None],
|
||||
sql: str,
|
||||
params: Iterable[Any]=()) -> None:
|
||||
action: Callable[[str, ParamsT], None],
|
||||
sql: Query,
|
||||
params: ParamsT) -> None:
|
||||
cache = get_cache_backend(None)
|
||||
cache.clear()
|
||||
start = time.time()
|
||||
|
@ -158,7 +158,7 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
|
|||
finally:
|
||||
stop = time.time()
|
||||
duration = stop - start
|
||||
if include_savepoints or ('SAVEPOINT' not in sql):
|
||||
if include_savepoints or not isinstance(sql, str) or 'SAVEPOINT' not in sql:
|
||||
queries.append({
|
||||
'sql': self.mogrify(sql, params).decode('utf-8'),
|
||||
'time': "%.3f" % (duration,),
|
||||
|
@ -167,13 +167,13 @@ def queries_captured(include_savepoints: Optional[bool]=False) -> Generator[
|
|||
old_execute = TimeTrackingCursor.execute
|
||||
old_executemany = TimeTrackingCursor.executemany
|
||||
|
||||
def cursor_execute(self: TimeTrackingCursor, sql: str,
|
||||
params: Iterable[Any]=()) -> None:
|
||||
def cursor_execute(self: TimeTrackingCursor, sql: Query,
|
||||
params: Optional[Params]=None) -> None:
|
||||
return wrapper_execute(self, super(TimeTrackingCursor, self).execute, sql, params)
|
||||
TimeTrackingCursor.execute = cursor_execute # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
|
||||
|
||||
def cursor_executemany(self: TimeTrackingCursor, sql: str,
|
||||
params: Iterable[Any]=()) -> None:
|
||||
def cursor_executemany(self: TimeTrackingCursor, sql: Query,
|
||||
params: Iterable[Params]) -> None:
|
||||
return wrapper_execute(self, super(TimeTrackingCursor, self).executemany, sql, params) # nocoverage -- doesn't actually get used in tests
|
||||
TimeTrackingCursor.executemany = cursor_executemany # type: ignore[assignment] # https://github.com/JukkaL/mypy/issues/1167
|
||||
|
||||
|
|
Loading…
Reference in New Issue