from __future__ import annotations from collections.abc import Generator from contextlib import contextmanager from fastapi import Request from sqlalchemy import create_engine, text from sqlalchemy.engine import Engine from sqlalchemy.orm import Session, sessionmaker from sqlalchemy.pool import StaticPool engine: Engine | None = None SessionLocal: sessionmaker[Session] | None = None def configure_db( database_url: str, *, echo: bool = False, pool_pre_ping: bool = True, ) -> tuple[Engine, sessionmaker[Session]]: global engine, SessionLocal engine_kwargs = { "echo": echo, "pool_pre_ping": pool_pre_ping, "future": True, } if database_url.startswith("sqlite"): engine_kwargs["connect_args"] = {"check_same_thread": False} if database_url in {"sqlite://", "sqlite:///:memory:", "sqlite+pysqlite:///:memory:"}: engine_kwargs["poolclass"] = StaticPool engine = create_engine( database_url, **engine_kwargs, ) SessionLocal = sessionmaker(bind=engine, autoflush=False, expire_on_commit=False) return engine, SessionLocal def get_db(request: Request) -> Generator[Session, None, None]: factory = getattr(request.app.state, "db_sessionmaker", None) if factory is None: raise RuntimeError("database is not configured") db = factory() try: yield db finally: db.close() @contextmanager def session_scope() -> Generator[Session, None, None]: if SessionLocal is None: raise RuntimeError("database is not configured") db = SessionLocal() try: yield db db.commit() except Exception: db.rollback() raise finally: db.close() def ping_database(db: Session) -> None: db.execute(text("SELECT 1")) def reset_db() -> None: global engine, SessionLocal if engine is not None: engine.dispose() engine = None SessionLocal = None