2022-02-10 04:59:48 +01:00
|
|
|
from contextlib import contextmanager
|
|
|
|
from typing import Iterator, Optional
|
2018-05-15 22:35:23 +02:00
|
|
|
|
2020-06-11 00:54:34 +02:00
|
|
|
import sqlalchemy
|
2016-07-19 08:12:35 +02:00
|
|
|
from django.db import connection
|
2022-02-10 04:59:48 +01:00
|
|
|
from sqlalchemy.engine import Connection, Engine
|
2023-10-12 19:43:45 +02:00
|
|
|
from typing_extensions import override
|
2020-06-11 00:54:34 +02:00
|
|
|
|
2016-09-10 21:08:37 +02:00
|
|
|
from zerver.lib.db import TimeTrackingConnection
|
2016-07-19 08:12:35 +02:00
|
|
|
|
|
|
|
|
|
|
|
# This is a Pool that doesn't close connections. Therefore it can be used with
|
|
|
|
# existing Django database connections.
|
|
|
|
class NonClosingPool(sqlalchemy.pool.NullPool):
|
2023-10-12 19:43:45 +02:00
|
|
|
@override
|
2017-11-05 11:15:10 +01:00
|
|
|
def status(self) -> str:
|
2016-07-19 08:12:35 +02:00
|
|
|
return "NonClosingPool"
|
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def _do_return_conn(self, conn: sqlalchemy.engine.base.Connection) -> None:
|
2016-07-19 08:12:35 +02:00
|
|
|
pass
|
|
|
|
|
2021-02-12 08:19:30 +01:00
|
|
|
|
2022-02-10 04:59:48 +01:00
|
|
|
sqlalchemy_engine: Optional[Engine] = None
|
2021-02-12 08:19:30 +01:00
|
|
|
|
|
|
|
|
2022-02-10 04:59:48 +01:00
|
|
|
@contextmanager
|
|
|
|
def get_sqlalchemy_connection() -> Iterator[Connection]:
|
2016-07-19 08:12:35 +02:00
|
|
|
global sqlalchemy_engine
|
|
|
|
if sqlalchemy_engine is None:
|
2021-02-12 08:19:30 +01:00
|
|
|
|
2017-11-05 11:15:10 +01:00
|
|
|
def get_dj_conn() -> TimeTrackingConnection:
|
2016-07-19 08:12:35 +02:00
|
|
|
connection.ensure_connection()
|
|
|
|
return connection.connection
|
2021-02-12 08:19:30 +01:00
|
|
|
|
|
|
|
sqlalchemy_engine = sqlalchemy.create_engine(
|
2021-02-12 08:20:45 +01:00
|
|
|
"postgresql://",
|
2021-02-12 08:19:30 +01:00
|
|
|
creator=get_dj_conn,
|
|
|
|
poolclass=NonClosingPool,
|
2021-08-21 01:07:28 +02:00
|
|
|
pool_reset_on_return=None,
|
2021-02-12 08:19:30 +01:00
|
|
|
)
|
2022-02-10 04:59:48 +01:00
|
|
|
with sqlalchemy_engine.connect().execution_options(autocommit=False) as sa_connection:
|
|
|
|
yield sa_connection
|