sqlalchemy_utils: Make get_sqlalchemy_connection a context manager.

Although our NonClosingPool prevents the SQLAlchemy connection from
closing the underlying Django connection, we still want to properly
dispose of the associated SQLAlchemy structures.

Signed-off-by: Anders Kaseorg <anders@zulip.com>
This commit is contained in:
Anders Kaseorg 2022-02-09 19:59:48 -08:00 committed by Tim Abbott
parent 8e5ae4e829
commit 29330c180a
4 changed files with 72 additions and 69 deletions

View File

@ -1,7 +1,9 @@
from typing import Any, Optional
from contextlib import contextmanager
from typing import Iterator, Optional
import sqlalchemy
from django.db import connection
from sqlalchemy.engine import Connection, Engine
from zerver.lib.db import TimeTrackingConnection
@ -26,10 +28,11 @@ class NonClosingPool(sqlalchemy.pool.NullPool):
)
sqlalchemy_engine: Optional[Any] = None
sqlalchemy_engine: Optional[Engine] = None
def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:
@contextmanager
def get_sqlalchemy_connection() -> Iterator[Connection]:
global sqlalchemy_engine
if sqlalchemy_engine is None:
@ -43,6 +46,5 @@ def get_sqlalchemy_connection() -> sqlalchemy.engine.base.Connection:
poolclass=NonClosingPool,
pool_reset_on_return=None,
)
sa_connection = sqlalchemy_engine.connect()
sa_connection.execution_options(autocommit=False)
return sa_connection
with sqlalchemy_engine.connect().execution_options(autocommit=False) as sa_connection:
yield sa_connection

View File

@ -430,8 +430,8 @@ class Runner(DiscoverRunner):
# We have to do the next line to avoid flaky scenarios where we
# run a single test and getting an SA connection causes data from
# a Django connection to be rolled back mid-test.
get_sqlalchemy_connection()
result = self.run_suite(suite)
with get_sqlalchemy_connection():
result = self.run_suite(suite)
self.teardown_test_environment()
failed = self.suite_result(suite, result)
if not failed:

View File

@ -63,13 +63,15 @@ from zerver.views.message_fetch import (
def get_sqlalchemy_sql(query: ClauseElement) -> str:
dialect = get_sqlalchemy_connection().dialect
with get_sqlalchemy_connection() as conn:
dialect = conn.dialect
comp = query.compile(dialect=dialect)
return str(comp)
def get_sqlalchemy_query_params(query: ClauseElement) -> Dict[str, object]:
dialect = get_sqlalchemy_connection().dialect
with get_sqlalchemy_connection() as conn:
dialect = conn.dialect
comp = query.compile(dialect=dialect)
return comp.params
@ -3015,15 +3017,14 @@ class GetOldMessagesTest(ZulipTestCase):
extra_message_id = self.send_stream_message(cordelia, "England")
self.send_personal_message(cordelia, hamlet)
sa_conn = get_sqlalchemy_connection()
user_profile = hamlet
anchor = find_first_unread_anchor(
sa_conn=sa_conn,
user_profile=user_profile,
narrow=[],
)
with get_sqlalchemy_connection() as sa_conn:
anchor = find_first_unread_anchor(
sa_conn=sa_conn,
user_profile=user_profile,
narrow=[],
)
self.assertEqual(anchor, first_message_id)
# With the same data setup, we now want to test that a reasonable

View File

@ -1048,44 +1048,45 @@ def get_messages_backend(
assert log_data is not None
log_data["extra"] = "[{}]".format(",".join(verbose_operators))
sa_conn = get_sqlalchemy_connection()
with get_sqlalchemy_connection() as sa_conn:
if anchor is None:
# `anchor=None` corresponds to the anchor="first_unread" parameter.
anchor = find_first_unread_anchor(
sa_conn,
user_profile,
narrow,
)
if anchor is None:
# `anchor=None` corresponds to the anchor="first_unread" parameter.
anchor = find_first_unread_anchor(
sa_conn,
user_profile,
narrow,
anchored_to_left = anchor == 0
# Set value that will be used to short circuit the after_query
# altogether and avoid needless conditions in the before_query.
anchored_to_right = anchor >= LARGER_THAN_MAX_MESSAGE_ID
if anchored_to_right:
num_after = 0
first_visible_message_id = get_first_visible_message_id(realm)
query = limit_query_to_range(
query=query,
num_before=num_before,
num_after=num_after,
anchor=anchor,
anchored_to_left=anchored_to_left,
anchored_to_right=anchored_to_right,
id_col=inner_msg_id_col,
first_visible_message_id=first_visible_message_id,
)
anchored_to_left = anchor == 0
# Set value that will be used to short circuit the after_query
# altogether and avoid needless conditions in the before_query.
anchored_to_right = anchor >= LARGER_THAN_MAX_MESSAGE_ID
if anchored_to_right:
num_after = 0
first_visible_message_id = get_first_visible_message_id(realm)
query = limit_query_to_range(
query=query,
num_before=num_before,
num_after=num_after,
anchor=anchor,
anchored_to_left=anchored_to_left,
anchored_to_right=anchored_to_right,
id_col=inner_msg_id_col,
first_visible_message_id=first_visible_message_id,
)
main_query = query.subquery()
query = (
select(*main_query.c).select_from(main_query).order_by(column("message_id", Integer).asc())
)
# This is a hack to tag the query we use for testing
query = query.prefix_with("/* get_messages */")
rows = list(sa_conn.execute(query).fetchall())
main_query = query.subquery()
query = (
select(*main_query.c)
.select_from(main_query)
.order_by(column("message_id", Integer).asc())
)
# This is a hack to tag the query we use for testing
query = query.prefix_with("/* get_messages */")
rows = list(sa_conn.execute(query).fetchall())
query_info = post_process_limited_query(
rows=rows,
@ -1357,23 +1358,22 @@ def messages_in_narrow_backend(
for term in narrow:
query = builder.add_term(query, term)
sa_conn = get_sqlalchemy_connection()
search_fields = {}
for row in sa_conn.execute(query).fetchall():
message_id = row._mapping["message_id"]
topic_name = row._mapping[DB_TOPIC_NAME]
rendered_content = row._mapping["rendered_content"]
if "content_matches" in row._mapping:
content_matches = row._mapping["content_matches"]
topic_matches = row._mapping["topic_matches"]
else:
content_matches = topic_matches = []
search_fields[str(message_id)] = get_search_fields(
rendered_content,
topic_name,
content_matches,
topic_matches,
)
with get_sqlalchemy_connection() as sa_conn:
for row in sa_conn.execute(query).fetchall():
message_id = row._mapping["message_id"]
topic_name = row._mapping[DB_TOPIC_NAME]
rendered_content = row._mapping["rendered_content"]
if "content_matches" in row._mapping:
content_matches = row._mapping["content_matches"]
topic_matches = row._mapping["topic_matches"]
else:
content_matches = topic_matches = []
search_fields[str(message_id)] = get_search_fields(
rendered_content,
topic_name,
content_matches,
topic_matches,
)
return json_success(request, data={"messages": search_fields})