mirror of https://github.com/zulip/zulip.git
sqlalchemy_utils: Make get_sqlalchemy_connection a context manager.
Although our NonClosingPool prevents the SQLAlchemy connection from closing the underlying Django connection, we still want to properly dispose of the associated SQLAlchemy structures. Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
parent
8e5ae4e829
commit
29330c180a
|
@ -1,7 +1,9 @@
|
|||
from typing import Any, Optional
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, Optional
|
||||
|
||||
import sqlalchemy
|
||||
from django.db import connection
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
|
||||
from zerver.lib.db import TimeTrackingConnection
|
||||
|
||||
|
@ -26,10 +28,11 @@ class NonClosingPool(sqlalchemy.pool.NullPool):
|
|||
)
|
||||
|
||||
|
||||
sqlalchemy_engine: Optional[Any] = None
|
||||
sqlalchemy_engine: Optional[Engine] = None
|
||||
|
||||
|
||||
def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:
|
||||
@contextmanager
|
||||
def get_sqlalchemy_connection() -> Iterator[Connection]:
|
||||
global sqlalchemy_engine
|
||||
if sqlalchemy_engine is None:
|
||||
|
||||
|
@ -43,6 +46,5 @@ def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:
|
|||
poolclass=NonClosingPool,
|
||||
pool_reset_on_return=None,
|
||||
)
|
||||
sa_connection = sqlalchemy_engine.connect()
|
||||
sa_connection.execution_options(autocommit=False)
|
||||
return sa_connection
|
||||
with sqlalchemy_engine.connect().execution_options(autocommit=False) as sa_connection:
|
||||
yield sa_connection
|
||||
|
|
|
@ -430,8 +430,8 @@ class Runner(DiscoverRunner):
|
|||
# We have to do the next line to avoid flaky scenarios where we
|
||||
# run a single test and getting an SA connection causes data from
|
||||
# a Django connection to be rolled back mid-test.
|
||||
get_sqlalchemy_connection()
|
||||
result = self.run_suite(suite)
|
||||
with get_sqlalchemy_connection():
|
||||
result = self.run_suite(suite)
|
||||
self.teardown_test_environment()
|
||||
failed = self.suite_result(suite, result)
|
||||
if not failed:
|
||||
|
|
|
@ -63,13 +63,15 @@ from zerver.views.message_fetch import (
|
|||
|
||||
|
||||
def get_sqlalchemy_sql(query: ClauseElement) -> str:
|
||||
dialect = get_sqlalchemy_connection().dialect
|
||||
with get_sqlalchemy_connection() as conn:
|
||||
dialect = conn.dialect
|
||||
comp = query.compile(dialect=dialect)
|
||||
return str(comp)
|
||||
|
||||
|
||||
def get_sqlalchemy_query_params(query: ClauseElement) -> Dict[str, object]:
|
||||
dialect = get_sqlalchemy_connection().dialect
|
||||
with get_sqlalchemy_connection() as conn:
|
||||
dialect = conn.dialect
|
||||
comp = query.compile(dialect=dialect)
|
||||
return comp.params
|
||||
|
||||
|
@ -3015,15 +3017,14 @@ class GetOldMessagesTest(ZulipTestCase):
|
|||
extra_message_id = self.send_stream_message(cordelia, "England")
|
||||
self.send_personal_message(cordelia, hamlet)
|
||||
|
||||
sa_conn = get_sqlalchemy_connection()
|
||||
|
||||
user_profile = hamlet
|
||||
|
||||
anchor = find_first_unread_anchor(
|
||||
sa_conn=sa_conn,
|
||||
user_profile=user_profile,
|
||||
narrow=[],
|
||||
)
|
||||
with get_sqlalchemy_connection() as sa_conn:
|
||||
anchor = find_first_unread_anchor(
|
||||
sa_conn=sa_conn,
|
||||
user_profile=user_profile,
|
||||
narrow=[],
|
||||
)
|
||||
self.assertEqual(anchor, first_message_id)
|
||||
|
||||
# With the same data setup, we now want to test that a reasonable
|
||||
|
|
|
@ -1048,44 +1048,45 @@ def get_messages_backend(
|
|||
assert log_data is not None
|
||||
log_data["extra"] = "[{}]".format(",".join(verbose_operators))
|
||||
|
||||
sa_conn = get_sqlalchemy_connection()
|
||||
with get_sqlalchemy_connection() as sa_conn:
|
||||
if anchor is None:
|
||||
# `anchor=None` corresponds to the anchor="first_unread" parameter.
|
||||
anchor = find_first_unread_anchor(
|
||||
sa_conn,
|
||||
user_profile,
|
||||
narrow,
|
||||
)
|
||||
|
||||
if anchor is None:
|
||||
# `anchor=None` corresponds to the anchor="first_unread" parameter.
|
||||
anchor = find_first_unread_anchor(
|
||||
sa_conn,
|
||||
user_profile,
|
||||
narrow,
|
||||
anchored_to_left = anchor == 0
|
||||
|
||||
# Set value that will be used to short circuit the after_query
|
||||
# altogether and avoid needless conditions in the before_query.
|
||||
anchored_to_right = anchor >= LARGER_THAN_MAX_MESSAGE_ID
|
||||
if anchored_to_right:
|
||||
num_after = 0
|
||||
|
||||
first_visible_message_id = get_first_visible_message_id(realm)
|
||||
|
||||
query = limit_query_to_range(
|
||||
query=query,
|
||||
num_before=num_before,
|
||||
num_after=num_after,
|
||||
anchor=anchor,
|
||||
anchored_to_left=anchored_to_left,
|
||||
anchored_to_right=anchored_to_right,
|
||||
id_col=inner_msg_id_col,
|
||||
first_visible_message_id=first_visible_message_id,
|
||||
)
|
||||
|
||||
anchored_to_left = anchor == 0
|
||||
|
||||
# Set value that will be used to short circuit the after_query
|
||||
# altogether and avoid needless conditions in the before_query.
|
||||
anchored_to_right = anchor >= LARGER_THAN_MAX_MESSAGE_ID
|
||||
if anchored_to_right:
|
||||
num_after = 0
|
||||
|
||||
first_visible_message_id = get_first_visible_message_id(realm)
|
||||
|
||||
query = limit_query_to_range(
|
||||
query=query,
|
||||
num_before=num_before,
|
||||
num_after=num_after,
|
||||
anchor=anchor,
|
||||
anchored_to_left=anchored_to_left,
|
||||
anchored_to_right=anchored_to_right,
|
||||
id_col=inner_msg_id_col,
|
||||
first_visible_message_id=first_visible_message_id,
|
||||
)
|
||||
|
||||
main_query = query.subquery()
|
||||
query = (
|
||||
select(*main_query.c).select_from(main_query).order_by(column("message_id", Integer).asc())
|
||||
)
|
||||
# This is a hack to tag the query we use for testing
|
||||
query = query.prefix_with("/* get_messages */")
|
||||
rows = list(sa_conn.execute(query).fetchall())
|
||||
main_query = query.subquery()
|
||||
query = (
|
||||
select(*main_query.c)
|
||||
.select_from(main_query)
|
||||
.order_by(column("message_id", Integer).asc())
|
||||
)
|
||||
# This is a hack to tag the query we use for testing
|
||||
query = query.prefix_with("/* get_messages */")
|
||||
rows = list(sa_conn.execute(query).fetchall())
|
||||
|
||||
query_info = post_process_limited_query(
|
||||
rows=rows,
|
||||
|
@ -1357,23 +1358,22 @@ def messages_in_narrow_backend(
|
|||
for term in narrow:
|
||||
query = builder.add_term(query, term)
|
||||
|
||||
sa_conn = get_sqlalchemy_connection()
|
||||
|
||||
search_fields = {}
|
||||
for row in sa_conn.execute(query).fetchall():
|
||||
message_id = row._mapping["message_id"]
|
||||
topic_name = row._mapping[DB_TOPIC_NAME]
|
||||
rendered_content = row._mapping["rendered_content"]
|
||||
if "content_matches" in row._mapping:
|
||||
content_matches = row._mapping["content_matches"]
|
||||
topic_matches = row._mapping["topic_matches"]
|
||||
else:
|
||||
content_matches = topic_matches = []
|
||||
search_fields[str(message_id)] = get_search_fields(
|
||||
rendered_content,
|
||||
topic_name,
|
||||
content_matches,
|
||||
topic_matches,
|
||||
)
|
||||
with get_sqlalchemy_connection() as sa_conn:
|
||||
for row in sa_conn.execute(query).fetchall():
|
||||
message_id = row._mapping["message_id"]
|
||||
topic_name = row._mapping[DB_TOPIC_NAME]
|
||||
rendered_content = row._mapping["rendered_content"]
|
||||
if "content_matches" in row._mapping:
|
||||
content_matches = row._mapping["content_matches"]
|
||||
topic_matches = row._mapping["topic_matches"]
|
||||
else:
|
||||
content_matches = topic_matches = []
|
||||
search_fields[str(message_id)] = get_search_fields(
|
||||
rendered_content,
|
||||
topic_name,
|
||||
content_matches,
|
||||
topic_matches,
|
||||
)
|
||||
|
||||
return json_success(request, data={"messages": search_fields})
|
||||
|
|
Loading…
Reference in New Issue