Skip to content

SQL Utilities

Utilities for SQL database operations.

Base Classes

base

Attributes

logger = logging.getLogger(__name__) module-attribute

Classes

SQLDBError

Bases: Exception

Source code in diracx-db/src/diracx/db/sql/utils/base.py
class SQLDBError(Exception):
    pass

SQLDBUnavailableError

Bases: DBUnavailableError, SQLDBError

Used whenever we encounter a problem with the B connection.

Source code in diracx-db/src/diracx/db/sql/utils/base.py
class SQLDBUnavailableError(DBUnavailableError, SQLDBError):
    """Used whenever we encounter a problem with the B connection."""

BaseSQLDB

This should be the base class of all the SQL DiracX DBs.

The details covered here should be handled automatically by the service and task machinery of DiracX and this documentation exists for informational purposes.

The available databases are discovered by calling BaseSQLDB.available_urls. This method returns a mapping of database names to connection URLs. The available databases are determined by the diracx.dbs.sql entrypoint in the pyproject.toml file and the connection URLs are taken from the environment variables of the form DIRACX_DB_URL_<db-name>.

If extensions to DiracX are being used, there can be multiple implementations of the same database. To list the available implementations use BaseSQLDB.available_implementations(db_name). The first entry in this list will be the preferred implementation and it can be initialized by calling it's __init__ function with a URL previously obtained from BaseSQLDB.available_urls.

To control the lifetime of the SQLAlchemy engine used for connecting to the database, which includes the connection pool, the BaseSQLDB.engine_context asynchronous context manager should be entered. When inside this context manager, the engine can be accessed with BaseSQLDB.engine.

Upon entering, the DB class can then be used as an asynchronous context manager to enter transactions. If an exception is raised the transaction is rolled back automatically. If the inner context exits peacefully, the transaction is committed automatically. When inside this context manager, the DB connection can be accessed with BaseSQLDB.conn.

For example:

db_name = ...
url = BaseSQLDB.available_urls()[db_name]
MyDBClass = BaseSQLDB.available_implementations(db_name)[0]

db = MyDBClass(url)
async with db.engine_context():
    async with db:
        # Do something in the first transaction
        # Commit will be called automatically

    async with db:
        # This transaction will be rolled back due to the exception
        raise Exception(...)
Source code in diracx-db/src/diracx/db/sql/utils/base.py
class BaseSQLDB(metaclass=ABCMeta):
    """This should be the base class of all the SQL DiracX DBs.

    The details covered here should be handled automatically by the service and
    task machinery of DiracX and this documentation exists for informational
    purposes.

    The available databases are discovered by calling `BaseSQLDB.available_urls`.
    This method returns a mapping of database names to connection URLs. The
    available databases are determined by the `diracx.dbs.sql` entrypoint in the
    `pyproject.toml` file and the connection URLs are taken from the environment
    variables of the form `DIRACX_DB_URL_<db-name>`.

    If extensions to DiracX are being used, there can be multiple implementations
    of the same database. To list the available implementations use
    `BaseSQLDB.available_implementations(db_name)`. The first entry in this list
    will be the preferred implementation and it can be initialized by calling
    it's `__init__` function with a URL previously obtained from
    `BaseSQLDB.available_urls`.

    To control the lifetime of the SQLAlchemy engine used for connecting to the
    database, which includes the connection pool, the `BaseSQLDB.engine_context`
    asynchronous context manager should be entered. When inside this context
    manager, the engine can be accessed with `BaseSQLDB.engine`.

    Upon entering, the DB class can then be used as an asynchronous context
    manager to enter transactions. If an exception is raised the transaction is
    rolled back automatically. If the inner context exits peacefully, the
    transaction is committed automatically. When inside this context manager,
    the DB connection can be accessed with `BaseSQLDB.conn`.

    For example:

    ```python
    db_name = ...
    url = BaseSQLDB.available_urls()[db_name]
    MyDBClass = BaseSQLDB.available_implementations(db_name)[0]

    db = MyDBClass(url)
    async with db.engine_context():
        async with db:
            # Do something in the first transaction
            # Commit will be called automatically

        async with db:
            # This transaction will be rolled back due to the exception
            raise Exception(...)
    ```
    """

    # engine: AsyncEngine
    # TODO: Make metadata an abstract property
    metadata: MetaData

    def __init__(self, db_url: str) -> None:
        # We use a ContextVar to make sure that self._conn
        # is specific to each context, and avoid parallel
        # route executions to overlap
        self._conn: ContextVar[AsyncConnection | None] = ContextVar(
            "_conn", default=None
        )
        self._db_url = db_url
        self._engine: AsyncEngine | None = None

    @classmethod
    def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
        """Return the available implementations of the DB in reverse priority order."""
        db_classes: list[type[BaseSQLDB]] = [
            entry_point.load()
            for entry_point in select_from_extension(
                group="diracx.dbs.sql", name=db_name
            )
        ]
        if not db_classes:
            raise NotImplementedError(f"Could not find any matches for {db_name=}")
        return db_classes

    @classmethod
    def available_urls(cls) -> dict[str, str]:
        """Return a dict of available database urls.

        The list of available URLs is determined by environment variables
        prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
        """
        db_urls: dict[str, str] = {}
        for entry_point in select_from_extension(group="diracx.dbs.sql"):
            db_name = entry_point.name
            var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
            if var_name in os.environ:
                try:
                    db_url = os.environ[var_name]
                    if db_url == "sqlite+aiosqlite:///:memory:":
                        db_urls[db_name] = db_url
                    # pydantic does not allow for underscore in scheme
                    # so we do a special case
                    elif "_" in db_url.split(":")[0]:
                        # Validate the URL with a fake schema, and then store
                        # the original one
                        scheme_id = db_url.find(":")
                        fake_url = (
                            db_url[:scheme_id].replace("_", "-") + db_url[scheme_id:]
                        )
                        TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
                        db_urls[db_name] = db_url

                    else:
                        db_urls[db_name] = str(
                            TypeAdapter(SqlalchemyDsn).validate_python(db_url)
                        )
                except Exception:
                    logger.error("Error loading URL for %s", db_name)
                    raise
        return db_urls

    @classmethod
    async def post_create(cls, conn: AsyncConnection) -> None:
        """Execute actions after the schema has been created."""
        return

    @classmethod
    def transaction(cls) -> Self:
        raise NotImplementedError("This should never be called")

    @property
    def engine(self) -> AsyncEngine:
        """The engine to use for database operations.

        It is normally not necessary to use the engine directly, unless you are
        doing something special, like writing a test fixture that gives you a db.

        Requires that the engine_context has been entered.
        """
        assert self._engine is not None, "engine_context must be entered"
        return self._engine

    @contextlib.asynccontextmanager
    async def engine_context(self) -> AsyncIterator[None]:
        """Context manage to manage the engine lifecycle.

        This is called once at the application startup (see ``lifetime_functions``).
        """
        assert self._engine is None, "engine_context cannot be nested"

        # Set the pool_recycle to 30mn
        # That should prevent the problem of MySQL expiring connection
        # after 60mn by default
        engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
        self._engine = engine
        try:
            yield
        finally:
            self._engine = None
            await engine.dispose()

    @property
    def conn(self) -> AsyncConnection:
        if self._conn.get() is None:
            raise RuntimeError(f"{self.__class__} was used before entering")
        return cast(AsyncConnection, self._conn.get())

    async def __aenter__(self) -> Self:
        """Create a connection.

        This is called by the Dependency mechanism (see ``db_transaction``),
        It will create a new connection/transaction for each route call.
        """
        assert self._conn.get() is None, "BaseSQLDB context cannot be nested"
        try:
            self._conn.set(await self.engine.connect().__aenter__())
        except Exception as e:
            raise SQLDBUnavailableError(
                f"Cannot connect to {self.__class__.__name__}"
            ) from e

        return self

    async def __aexit__(self, exc_type, exc, tb):
        """This is called when exiting a route.

        If there was no exception, the changes in the DB are committed.
        Otherwise, they are rolled back.
        """
        if exc_type is None:
            await self._conn.get().commit()
        await self._conn.get().__aexit__(exc_type, exc, tb)
        self._conn.set(None)

    async def ping(self):
        """Check whether the connection to the DB is still working.

        We could enable the ``pre_ping`` in the engine, but this would be ran at
        every query.
        """
        try:
            await self.conn.scalar(select(1))
        except OperationalError as e:
            raise SQLDBUnavailableError("Cannot ping the DB") from e

    async def _search(
        self,
        table: Any,
        parameters: list[str] | None,
        search: list[SearchSpec],
        sorts: list[SortSpec],
        *,
        distinct: bool = False,
        per_page: int = 100,
        page: int | None = None,
    ) -> tuple[int, list[dict[str, Any]]]:
        """Search for elements in a table."""
        # Find which columns to select
        columns = _get_columns(table.__table__, parameters)

        stmt = select(*columns)

        stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
        stmt = apply_sort_constraints(table.__table__.columns.__getitem__, stmt, sorts)

        if distinct:
            stmt = stmt.distinct()

        # Calculate total count before applying pagination
        total_count_subquery = stmt.alias()
        total_count_stmt = select(func.count()).select_from(total_count_subquery)
        total = (await self.conn.execute(total_count_stmt)).scalar_one()

        # Apply pagination
        if page is not None:
            if page < 1:
                raise InvalidQueryError("Page must be a positive integer")
            if per_page < 1:
                raise InvalidQueryError("Per page must be a positive integer")
            stmt = stmt.offset((page - 1) * per_page).limit(per_page)

        # Execute the query
        return total, [
            dict(row._mapping) async for row in (await self.conn.stream(stmt))
        ]

    async def _summary(
        self, table: Any, group_by: list[str], search: list[SearchSpec]
    ) -> list[dict[str, str | int]]:
        """Get a summary of the elements of a table."""
        columns = _get_columns(table.__table__, group_by)

        pk_columns = list(table.__table__.primary_key.columns)
        if not pk_columns:
            raise ValueError(
                "Model has no primary key and no count_column was provided."
            )
        count_col = pk_columns[0]

        stmt = select(*columns, func.count(count_col).label("count"))
        stmt = apply_search_filters(table.__table__.columns.__getitem__, stmt, search)
        stmt = stmt.group_by(*columns)

        # Execute the query
        return [
            dict(row._mapping)
            async for row in (await self.conn.stream(stmt))
            if row.count > 0  # type: ignore
        ]
Attributes
metadata instance-attribute
engine property

The engine to use for database operations.

It is normally not necessary to use the engine directly, unless you are doing something special, like writing a test fixture that gives you a db.

Requires that the engine_context has been entered.

conn property
Functions
available_implementations(db_name) classmethod

Return the available implementations of the DB in reverse priority order.

Source code in diracx-db/src/diracx/db/sql/utils/base.py
@classmethod
def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]:
    """Return the available implementations of the DB in reverse priority order."""
    db_classes: list[type[BaseSQLDB]] = [
        entry_point.load()
        for entry_point in select_from_extension(
            group="diracx.dbs.sql", name=db_name
        )
    ]
    if not db_classes:
        raise NotImplementedError(f"Could not find any matches for {db_name=}")
    return db_classes
available_urls() classmethod

Return a dict of available database urls.

The list of available URLs is determined by environment variables prefixed with DIRACX_DB_URL_{DB_NAME}.

Source code in diracx-db/src/diracx/db/sql/utils/base.py
@classmethod
def available_urls(cls) -> dict[str, str]:
    """Return a dict of available database urls.

    The list of available URLs is determined by environment variables
    prefixed with ``DIRACX_DB_URL_{DB_NAME}``.
    """
    db_urls: dict[str, str] = {}
    for entry_point in select_from_extension(group="diracx.dbs.sql"):
        db_name = entry_point.name
        var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}"
        if var_name in os.environ:
            try:
                db_url = os.environ[var_name]
                if db_url == "sqlite+aiosqlite:///:memory:":
                    db_urls[db_name] = db_url
                # pydantic does not allow for underscore in scheme
                # so we do a special case
                elif "_" in db_url.split(":")[0]:
                    # Validate the URL with a fake schema, and then store
                    # the original one
                    scheme_id = db_url.find(":")
                    fake_url = (
                        db_url[:scheme_id].replace("_", "-") + db_url[scheme_id:]
                    )
                    TypeAdapter(SqlalchemyDsn).validate_python(fake_url)
                    db_urls[db_name] = db_url

                else:
                    db_urls[db_name] = str(
                        TypeAdapter(SqlalchemyDsn).validate_python(db_url)
                    )
            except Exception:
                logger.error("Error loading URL for %s", db_name)
                raise
    return db_urls
post_create(conn) async classmethod

Execute actions after the schema has been created.

Source code in diracx-db/src/diracx/db/sql/utils/base.py
@classmethod
async def post_create(cls, conn: AsyncConnection) -> None:
    """Execute actions after the schema has been created."""
    return
transaction() classmethod
Source code in diracx-db/src/diracx/db/sql/utils/base.py
@classmethod
def transaction(cls) -> Self:
    raise NotImplementedError("This should never be called")
engine_context() async

Context manage to manage the engine lifecycle.

This is called once at the application startup (see lifetime_functions).

Source code in diracx-db/src/diracx/db/sql/utils/base.py
@contextlib.asynccontextmanager
async def engine_context(self) -> AsyncIterator[None]:
    """Context manage to manage the engine lifecycle.

    This is called once at the application startup (see ``lifetime_functions``).
    """
    assert self._engine is None, "engine_context cannot be nested"

    # Set the pool_recycle to 30mn
    # That should prevent the problem of MySQL expiring connection
    # after 60mn by default
    engine = create_async_engine(self._db_url, pool_recycle=60 * 30)
    self._engine = engine
    try:
        yield
    finally:
        self._engine = None
        await engine.dispose()
ping() async

Check whether the connection to the DB is still working.

We could enable the pre_ping in the engine, but this would be ran at every query.

Source code in diracx-db/src/diracx/db/sql/utils/base.py
async def ping(self):
    """Check whether the connection to the DB is still working.

    We could enable the ``pre_ping`` in the engine, but this would be ran at
    every query.
    """
    try:
        await self.conn.scalar(select(1))
    except OperationalError as e:
        raise SQLDBUnavailableError("Cannot ping the DB") from e

Functions

find_time_resolution(value)

Source code in diracx-db/src/diracx/db/sql/utils/base.py
def find_time_resolution(value):
    if isinstance(value, datetime):
        return None, value
    if match := re.fullmatch(
        r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{1,6}Z?)?)?)?)?)?)?", value
    ):
        if match.group(6):
            precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6"
        elif match.group(5):
            precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5"
        elif match.group(3):
            precision, pattern = "HOUR", r"\1-\2-\3 \4"
        elif match.group(2):
            precision, pattern = "DAY", r"\1-\2-\3"
        elif match.group(1):
            precision, pattern = "MONTH", r"\1-\2"
        else:
            precision, pattern = "YEAR", r"\1"
        return (
            precision,
            re.sub(
                r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{1,6})?Z?$",
                pattern,
                value,
            ),
        )

    raise InvalidQueryError(f"Cannot parse {value=}")

apply_search_filters(column_mapping, stmt, search)

Source code in diracx-db/src/diracx/db/sql/utils/base.py
def apply_search_filters(column_mapping, stmt, search):
    for query in search:
        try:
            column = column_mapping(query["parameter"])
        except KeyError as e:
            raise InvalidQueryError(f"Unknown column {query['parameter']}") from e

        if isinstance(column.type, (DateTime, SmarterDateTime)):
            if "value" in query and isinstance(query["value"], str):
                resolution, value = find_time_resolution(query["value"])
                if resolution:
                    column = date_trunc(column, time_resolution=resolution)
                query["value"] = value

            if query.get("values"):
                resolutions, values = zip(
                    *map(find_time_resolution, query.get("values"))
                )
                if len(set(resolutions)) != 1:
                    raise InvalidQueryError(
                        f"Cannot mix different time resolutions in {query=}"
                    )
                if resolution := resolutions[0]:
                    column = date_trunc(column, time_resolution=resolution)
                query["values"] = values

        if query["operator"] == "eq":
            expr = column == query["value"]
        elif query["operator"] == "neq":
            expr = column != query["value"]
        elif query["operator"] == "gt":
            expr = column > query["value"]
        elif query["operator"] == "lt":
            expr = column < query["value"]
        elif query["operator"] == "in":
            expr = column.in_(query["values"])
        elif query["operator"] == "not in":
            expr = column.notin_(query["values"])
        elif query["operator"] in "like":
            expr = column.like(query["value"])
        elif query["operator"] in "ilike":
            expr = column.ilike(query["value"])
        elif query["operator"] == "not like":
            expr = column.not_like(query["value"])
        elif query["operator"] == "regex":
            # We check the regex validity here
            try:
                re.compile(query["value"])
            except re.error as e:
                raise InvalidQueryError(f"Invalid regex {query['value']}") from e
            expr = column.regexp_match(query["value"])
        else:
            raise InvalidQueryError(f"Unknown filter {query=}")
        stmt = stmt.where(expr)
    return stmt

apply_sort_constraints(column_mapping, stmt, sorts)

Source code in diracx-db/src/diracx/db/sql/utils/base.py
def apply_sort_constraints(column_mapping, stmt, sorts):
    sort_columns = []
    for sort in sorts or []:
        try:
            column = column_mapping(sort["parameter"])
        except KeyError as e:
            raise InvalidQueryError(
                f"Cannot sort by {sort['parameter']}: unknown column"
            ) from e
        sorted_column = None
        if sort["direction"] == SortDirection.ASC:
            sorted_column = column.asc()
        elif sort["direction"] == SortDirection.DESC:
            sorted_column = column.desc()
        else:
            raise InvalidQueryError(f"Unknown sort {sort['direction']=}")
        sort_columns.append(sorted_column)
    if sort_columns:
        stmt = stmt.order_by(*sort_columns)
    return stmt

uuid7_to_datetime(uuid)

Convert a UUIDv7 to a datetime.

Source code in diracx-db/src/diracx/db/sql/utils/base.py
def uuid7_to_datetime(uuid: UUID | StdUUID | str) -> datetime:
    """Convert a UUIDv7 to a datetime."""
    if isinstance(uuid, StdUUID):
        # Convert stdlib UUID to uuid_utils.UUID
        uuid = UUID(str(uuid))
    elif not isinstance(uuid, UUID):
        # Convert string or other types to uuid_utils.UUID
        uuid = UUID(uuid)
    if uuid.version != 7:
        raise ValueError(f"UUID {uuid} is not a UUIDv7")
    return datetime.fromtimestamp(uuid.timestamp / 1000.0, tz=timezone.utc)

uuid7_from_datetime(dt, *, randomize=True)

Generate a UUIDv7 corresponding to the given datetime.

If randomize is True, the standard uuid7 function is used resulting in the lowest 62-bits being random. If randomize is False, the UUIDv7 will be the lowest possible UUIDv7 for the given datetime.

Source code in diracx-db/src/diracx/db/sql/utils/base.py
def uuid7_from_datetime(dt: datetime, *, randomize: bool = True) -> UUID:
    """Generate a UUIDv7 corresponding to the given datetime.

    If randomize is True, the standard uuid7 function is used resulting in the
    lowest 62-bits being random. If randomize is False, the UUIDv7 will be the
    lowest possible UUIDv7 for the given datetime.
    """
    timestamp = dt.timestamp()
    if randomize:
        uuid = uuid7(int(timestamp), int((timestamp % 1) * 1e9))
    else:
        time_high = int(timestamp * 1000) >> 16
        time_low = int(timestamp * 1000) & 0xFFFF
        uuid = UUID.from_fields((time_high, time_low, 0x7000, 0x80, 0, 0))
    return uuid

Functions

functions

Classes

utcnow

Bases: FunctionElement

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
class utcnow(expression.FunctionElement):  # noqa: N801
    type: TypeEngine = DateTime()
    inherit_cache: bool = True
Attributes
type = DateTime() class-attribute instance-attribute
inherit_cache = True class-attribute instance-attribute

date_trunc

Bases: FunctionElement

Sqlalchemy function to truncate a date to a given resolution.

Primarily used to be able to query for a specific resolution of a date e.g.

select * from table where date_trunc('day', date_column) = '2021-01-01'
select * from table where date_trunc('year', date_column) = '2021'
select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
Source code in diracx-db/src/diracx/db/sql/utils/functions.py
class date_trunc(expression.FunctionElement):  # noqa: N801
    """Sqlalchemy function to truncate a date to a given resolution.

    Primarily used to be able to query for a specific resolution of a date e.g.

        select * from table where date_trunc('day', date_column) = '2021-01-01'
        select * from table where date_trunc('year', date_column) = '2021'
        select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00'
    """

    type = DateTime()
    # Cache does not work as intended with time resolution values, so we disable it
    inherit_cache = False

    def __init__(self, *args, time_resolution, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self._time_resolution = time_resolution
Attributes
type = DateTime() class-attribute instance-attribute
inherit_cache = False class-attribute instance-attribute

days_since

Bases: FunctionElement

Sqlalchemy function to get the number of days since a given date.

Primarily used to be able to query for a specific resolution of a date e.g.

select * from table where days_since(date_column) = 0
select * from table where days_since(date_column) = 1
Source code in diracx-db/src/diracx/db/sql/utils/functions.py
class days_since(expression.FunctionElement):  # noqa: N801
    """Sqlalchemy function to get the number of days since a given date.

    Primarily used to be able to query for a specific resolution of a date e.g.

        select * from table where days_since(date_column) = 0
        select * from table where days_since(date_column) = 1
    """

    type = DateTime()
    inherit_cache = False

    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
Attributes
type = DateTime() class-attribute instance-attribute
inherit_cache = False class-attribute instance-attribute

Functions

pg_utcnow(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(utcnow, "postgresql")
def pg_utcnow(element, compiler, **kw) -> str:
    return "TIMEZONE('utc', CURRENT_TIMESTAMP)"

ms_utcnow(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(utcnow, "mssql")
def ms_utcnow(element, compiler, **kw) -> str:
    return "GETUTCDATE()"

mysql_utcnow(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(utcnow, "mysql")
def mysql_utcnow(element, compiler, **kw) -> str:
    return "(UTC_TIMESTAMP)"

sqlite_utcnow(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(utcnow, "sqlite")
def sqlite_utcnow(element, compiler, **kw) -> str:
    return "DATETIME('now')"

pg_date_trunc(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(date_trunc, "postgresql")
def pg_date_trunc(element, compiler, **kw):
    res = {
        "SECOND": "second",
        "MINUTE": "minute",
        "HOUR": "hour",
        "DAY": "day",
        "MONTH": "month",
        "YEAR": "year",
    }[element._time_resolution]
    return f"date_trunc('{res}', {compiler.process(element.clauses)})"

mysql_date_trunc(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(date_trunc, "mysql")
def mysql_date_trunc(element, compiler, **kw):
    pattern = {
        "SECOND": "%Y-%m-%d %H:%i:%S",
        "MINUTE": "%Y-%m-%d %H:%i",
        "HOUR": "%Y-%m-%d %H",
        "DAY": "%Y-%m-%d",
        "MONTH": "%Y-%m",
        "YEAR": "%Y",
    }[element._time_resolution]

    (dt_col,) = list(element.clauses)
    return compiler.process(func.date_format(dt_col, pattern))

sqlite_date_trunc(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(date_trunc, "sqlite")
def sqlite_date_trunc(element, compiler, **kw):
    pattern = {
        "SECOND": "%Y-%m-%d %H:%M:%S",
        "MINUTE": "%Y-%m-%d %H:%M",
        "HOUR": "%Y-%m-%d %H",
        "DAY": "%Y-%m-%d",
        "MONTH": "%Y-%m",
        "YEAR": "%Y",
    }[element._time_resolution]
    (dt_col,) = list(element.clauses)
    return compiler.process(
        func.strftime(
            pattern,
            dt_col,
        )
    )

pg_days_since(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(days_since, "postgresql")
def pg_days_since(element, compiler, **kw):
    return f"EXTRACT(DAY FROM (now() - {compiler.process(element.clauses)}))"

mysql_days_since(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(days_since, "mysql")
def mysql_days_since(element, compiler, **kw):
    return f"DATEDIFF(NOW(), {compiler.process(element.clauses)})"

sqlite_days_since(element, compiler, **kw)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
@compiles(days_since, "sqlite")
def sqlite_days_since(element, compiler, **kw):
    return f"julianday('now') - julianday({compiler.process(element.clauses)})"

substract_date(**kwargs)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
def substract_date(**kwargs: float) -> datetime:
    return datetime.now(tz=timezone.utc) - timedelta(**kwargs)

hash(code)

Source code in diracx-db/src/diracx/db/sql/utils/functions.py
def hash(code: str):
    return hashlib.sha256(code.encode()).hexdigest()

Types

types

Attributes

Column = partial(RawColumn, nullable=False) module-attribute

NullColumn = partial(RawColumn, nullable=True) module-attribute

DateNowColumn = partial(Column, type_=(DateTime(timezone=True)), server_default=(utcnow())) module-attribute

Classes

EnumBackedBool

Bases: TypeDecorator

Maps a EnumBackedBool() column to True/False in Python.

Source code in diracx-db/src/diracx/db/sql/utils/types.py
class EnumBackedBool(types.TypeDecorator):
    """Maps a ``EnumBackedBool()`` column to True/False in Python."""

    impl = types.Enum("True", "False", name="enum_backed_bool")
    cache_ok = True

    def process_bind_param(self, value, dialect) -> str:
        if value is True:
            return "True"
        elif value is False:
            return "False"
        else:
            raise NotImplementedError(value, dialect)

    def process_result_value(self, value, dialect) -> bool:
        if value == "True":
            return True
        elif value == "False":
            return False
        else:
            raise NotImplementedError(f"Unknown {value=}")
Attributes
impl = types.Enum('True', 'False', name='enum_backed_bool') class-attribute instance-attribute
cache_ok = True class-attribute instance-attribute
Functions
process_bind_param(value, dialect)
Source code in diracx-db/src/diracx/db/sql/utils/types.py
def process_bind_param(self, value, dialect) -> str:
    if value is True:
        return "True"
    elif value is False:
        return "False"
    else:
        raise NotImplementedError(value, dialect)
process_result_value(value, dialect)
Source code in diracx-db/src/diracx/db/sql/utils/types.py
def process_result_value(self, value, dialect) -> bool:
    if value == "True":
        return True
    elif value == "False":
        return False
    else:
        raise NotImplementedError(f"Unknown {value=}")

SmarterDateTime

Bases: TypeDecorator

A DateTime type that also accepts ISO8601 strings.

Takes into account converting timezone aware datetime objects into naive form and back when needed.

Source code in diracx-db/src/diracx/db/sql/utils/types.py
class SmarterDateTime(types.TypeDecorator):
    """A DateTime type that also accepts ISO8601 strings.

    Takes into account converting timezone aware datetime objects into
    naive form and back when needed.

    """

    impl = DateTime()
    cache_ok = True

    def __init__(
        self,
        stored_tz: ZoneInfo | None = ZoneInfo("UTC"),
        returned_tz: ZoneInfo = ZoneInfo("UTC"),
        stored_naive_sqlite=True,
        stored_naive_mysql=True,
        stored_naive_postgres=False,  # Forces timezone-awareness
    ):
        self._stored_naive_dialect = {
            "sqlite": stored_naive_sqlite,
            "mysql": stored_naive_mysql,
            "postgres": stored_naive_postgres,
        }
        self._stored_tz: ZoneInfo | None = stored_tz  # None = Local timezone
        self._returned_tz: ZoneInfo = returned_tz

    def _stored_naive(self, dialect):
        if dialect.name not in self._stored_naive_dialect:
            raise NotImplementedError(dialect.name)
        return self._stored_naive_dialect.get(dialect.name)

    def process_bind_param(self, value, dialect):
        if value is None:
            return None

        if isinstance(value, str):
            try:
                value: datetime = datetime.fromisoformat(value)
            except ValueError as err:
                raise ValueError(f"Unable to parse datetime string: {value}") from err

        if not isinstance(value, datetime):
            raise ValueError(f"Expected datetime or ISO8601 string, but got {value!r}")

        if not value.tzinfo:
            raise ValueError(
                f"Provided timestamp {value=} has no tzinfo -"
                " this is problematic and may cause inconsistencies in stored timestamps.\n"
                " Please always work with tz-aware datetimes / attach tzinfo to your datetime objects:"
                " e.g. datetime.now(tz=timezone.utc) or use datetime_obj.astimezone() with no arguments if you need to "
                "attach the local timezone to a local naive timestamp."
            )

        # Check that we need to convert the timezone to match self._stored_tz timezone:
        if self._stored_naive(dialect):
            # if self._stored_tz is None, we use our local/system timezone.
            stored_tz = self._stored_tz

            # astimezone converts to the stored timezone (local timezone if None)
            # replace strips the TZ info --> naive datetime object
            value = value.astimezone(tz=stored_tz).replace(tzinfo=None)

        return value

    def process_result_value(self, value, dialect):
        if value is None:
            return None
        if not isinstance(value, datetime):
            raise NotImplementedError(f"{value=} not a datetime object")

        if self._stored_naive(dialect):
            # Here we add back the tzinfo to the naive timestamp
            # from the DB to make it aware again.
            if value.tzinfo is None:
                # we are definitely given a naive timestamp, so handle it.
                # add back the timezone info if stored_tz is set
                if self._stored_tz:
                    value = value.replace(tzinfo=self._stored_tz)
                else:
                    # if stored as a local time, add back the system timezone info...
                    value = value.astimezone()
            else:
                raise ValueError(
                    f"stored_naive is True for {dialect.name=}, but the database engine returned "
                    "a tz-aware datetime. You need to check the SQLAlchemy model is consistent with the DB schema."
                )

        # finally, convert the datetime according to the "returned_tz"
        value = value.astimezone(self._returned_tz)

        # phew...
        return value
Attributes
impl = DateTime() class-attribute instance-attribute
cache_ok = True class-attribute instance-attribute
Functions
process_bind_param(value, dialect)
Source code in diracx-db/src/diracx/db/sql/utils/types.py
def process_bind_param(self, value, dialect):
    if value is None:
        return None

    if isinstance(value, str):
        try:
            value: datetime = datetime.fromisoformat(value)
        except ValueError as err:
            raise ValueError(f"Unable to parse datetime string: {value}") from err

    if not isinstance(value, datetime):
        raise ValueError(f"Expected datetime or ISO8601 string, but got {value!r}")

    if not value.tzinfo:
        raise ValueError(
            f"Provided timestamp {value=} has no tzinfo -"
            " this is problematic and may cause inconsistencies in stored timestamps.\n"
            " Please always work with tz-aware datetimes / attach tzinfo to your datetime objects:"
            " e.g. datetime.now(tz=timezone.utc) or use datetime_obj.astimezone() with no arguments if you need to "
            "attach the local timezone to a local naive timestamp."
        )

    # Check that we need to convert the timezone to match self._stored_tz timezone:
    if self._stored_naive(dialect):
        # if self._stored_tz is None, we use our local/system timezone.
        stored_tz = self._stored_tz

        # astimezone converts to the stored timezone (local timezone if None)
        # replace strips the TZ info --> naive datetime object
        value = value.astimezone(tz=stored_tz).replace(tzinfo=None)

    return value
process_result_value(value, dialect)
Source code in diracx-db/src/diracx/db/sql/utils/types.py
def process_result_value(self, value, dialect):
    if value is None:
        return None
    if not isinstance(value, datetime):
        raise NotImplementedError(f"{value=} not a datetime object")

    if self._stored_naive(dialect):
        # Here we add back the tzinfo to the naive timestamp
        # from the DB to make it aware again.
        if value.tzinfo is None:
            # we are definitely given a naive timestamp, so handle it.
            # add back the timezone info if stored_tz is set
            if self._stored_tz:
                value = value.replace(tzinfo=self._stored_tz)
            else:
                # if stored as a local time, add back the system timezone info...
                value = value.astimezone()
        else:
            raise ValueError(
                f"stored_naive is True for {dialect.name=}, but the database engine returned "
                "a tz-aware datetime. You need to check the SQLAlchemy model is consistent with the DB schema."
            )

    # finally, convert the datetime according to the "returned_tz"
    value = value.astimezone(self._returned_tz)

    # phew...
    return value

Functions

EnumColumn(name, enum_type, **kwargs)

Source code in diracx-db/src/diracx/db/sql/utils/types.py
def EnumColumn(name, enum_type, **kwargs):  # noqa: N802
    return Column(name, Enum(enum_type, native_enum=False, length=16), **kwargs)