Skip to content

Utilities

Core utility functions and helpers.

utils

Attributes

EXPIRES_GRACE_SECONDS = 15 module-attribute

T = TypeVar('T') module-attribute

Classes

TwoLevelCache

A two-level caching system with soft and hard time-to-live (TTL) expiration.

This cache implements a two-tier caching mechanism to allow for background refresh of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback, which helps in reducing latency and maintaining data freshness.

Attributes:

Name Type Description
soft_cache TTLCache

A cache with a shorter TTL for quick access.

hard_cache TTLCache

A cache with a longer TTL as a fallback.

locks defaultdict

Thread-safe locks for each cache key.

futures dict

Stores ongoing asynchronous population tasks.

pool ThreadPoolExecutor

Thread pool for executing cache population tasks.

Parameters:

Name Type Description Default
soft_ttl int

Time-to-live in seconds for the soft cache.

required
hard_ttl int

Time-to-live in seconds for the hard cache.

required
max_workers int

Maximum number of workers in the thread pool.

10
max_items int

Maximum number of items in the cache.

1000000
Example

cache = TwoLevelCache(soft_ttl=60, hard_ttl=300) def populate_func(): ... return "cached_value" value = cache.get("key", populate_func)

Source code in diracx-core/src/diracx/core/utils.py
class TwoLevelCache:
    """A two-level caching system with soft and hard time-to-live (TTL) expiration.

    This cache implements a two-tier caching mechanism to allow for background refresh
    of cached values. It uses a soft TTL for quick access and a hard TTL as a fallback,
    which helps in reducing latency and maintaining data freshness.

    Attributes:
        soft_cache (TTLCache): A cache with a shorter TTL for quick access.
        hard_cache (TTLCache): A cache with a longer TTL as a fallback.
        locks (defaultdict): Thread-safe locks for each cache key.
        futures (dict): Stores ongoing asynchronous population tasks.
        pool (ThreadPoolExecutor): Thread pool for executing cache population tasks.

    Args:
        soft_ttl (int): Time-to-live in seconds for the soft cache.
        hard_ttl (int): Time-to-live in seconds for the hard cache.
        max_workers (int): Maximum number of workers in the thread pool.
        max_items (int): Maximum number of items in the cache.

    Example:
        >>> cache = TwoLevelCache(soft_ttl=60, hard_ttl=300)
        >>> def populate_func():
        ...     return "cached_value"
        >>> value = cache.get("key", populate_func)

    """

    def __init__(
        self,
        soft_ttl: int,
        hard_ttl: int,
        *,
        max_workers: int = 10,
        max_items: int = 1_000_000,
    ):
        """Initialize the TwoLevelCache with specified TTLs."""
        self.soft_cache: Cache = TTLCache(max_items, soft_ttl)
        self.hard_cache: Cache = TTLCache(max_items, hard_ttl)
        self.locks: defaultdict[str, threading.Lock] = defaultdict(threading.Lock)
        self.futures: dict[str, Future] = {}
        self.pool = ThreadPoolExecutor(max_workers=max_workers)

    def get(self, key: str, populate_func: Callable[[], T], blocking: bool = True) -> T:
        """Retrieve a value from the cache, populating it if necessary.

        This method first checks the soft cache for the key. If not found,
        it checks the hard cache while initiating a background refresh.
        If the key is not in either cache, it waits for the populate_func
        to complete and stores the result in both caches.

        Locks are used to ensure there is never more than one concurrent
        population task for a given key.

        Args:
            key (str): The cache key to retrieve or populate.
            populate_func (Callable[[], Any]): A function to call to populate the cache
                                               if the key is not found.

            blocking (bool): If True, wait for the cache to be populated if the key is not
                            found. If False, raise NotReadyError if the key is not ready.

        Returns:
            Any: The cached value associated with the key.

        Note:
            This method is thread-safe and handles concurrent requests for the same key.

        """
        if result := self.soft_cache.get(key):
            return result
        if self.locks[key].acquire(blocking=blocking):
            try:
                if key not in self.futures:
                    self.futures[key] = self.pool.submit(self._work, key, populate_func)
                if result := self.hard_cache.get(key):
                    # The soft cache will be updated by _work so we can fill the soft
                    # cache to avoid later requests needign to acquire the lock.
                    self.soft_cache[key] = result
                    return result
                future = self.futures[key]
            finally:
                self.locks[key].release()
            if blocking:
                # It is critical that ``future`` is waited for outside of the lock
                # as _work acquires the lock before filling the caches. This also
                # means we can guarantee that the future has not yet been removed
                # from the futures dict.
                wait([future])
                return self.hard_cache[key]

        # If the lock is not acquired we're in a non-blocking mode, try to get the
        # value from the hard cache. If it's not there, raise NotReadyError.
        if result := self.hard_cache.get(key):
            return result
        raise NotReadyError(f"Cache key {key} is not ready yet.")

    def _work(self, key: str, populate_func: Callable[[], Any]) -> None:
        """Internal method to execute the populate_func and update caches.

        This method is intended to be run in a separate thread. It calls the
        populate_func, stores the result in both caches, and cleans up the
        associated future.

        Args:
            key (str): The cache key to populate.
            populate_func (Callable[[], Any]): The function to call to get the value.

        Note:
            This method is not intended to be called directly by users of the class.

        """
        result = populate_func()
        with self.locks[key]:
            self.futures.pop(key)
            self.hard_cache[key] = result
            self.soft_cache[key] = result

    def clear(self):
        """Clear all caches and reset the thread pool."""
        self.pool.shutdown(wait=True)
        self.pool = ThreadPoolExecutor(max_workers=self.pool._max_workers)
        self.soft_cache.clear()
        self.hard_cache.clear()
        self.futures.clear()
        self.locks.clear()
Attributes
soft_cache = TTLCache(max_items, soft_ttl) instance-attribute
hard_cache = TTLCache(max_items, hard_ttl) instance-attribute
locks = defaultdict(threading.Lock) instance-attribute
futures = {} instance-attribute
pool = ThreadPoolExecutor(max_workers=max_workers) instance-attribute
Functions
get(key, populate_func, blocking=True)

Retrieve a value from the cache, populating it if necessary.

This method first checks the soft cache for the key. If not found, it checks the hard cache while initiating a background refresh. If the key is not in either cache, it waits for the populate_func to complete and stores the result in both caches.

Locks are used to ensure there is never more than one concurrent population task for a given key.

Parameters:

Name Type Description Default
key str

The cache key to retrieve or populate.

required
populate_func Callable[[], Any]

A function to call to populate the cache if the key is not found.

required
blocking bool

If True, wait for the cache to be populated if the key is not found. If False, raise NotReadyError if the key is not ready.

True

Returns:

Name Type Description
Any T

The cached value associated with the key.

Note

This method is thread-safe and handles concurrent requests for the same key.

Source code in diracx-core/src/diracx/core/utils.py
def get(self, key: str, populate_func: Callable[[], T], blocking: bool = True) -> T:
    """Retrieve a value from the cache, populating it if necessary.

    This method first checks the soft cache for the key. If not found,
    it checks the hard cache while initiating a background refresh.
    If the key is not in either cache, it waits for the populate_func
    to complete and stores the result in both caches.

    Locks are used to ensure there is never more than one concurrent
    population task for a given key.

    Args:
        key (str): The cache key to retrieve or populate.
        populate_func (Callable[[], Any]): A function to call to populate the cache
                                           if the key is not found.

        blocking (bool): If True, wait for the cache to be populated if the key is not
                        found. If False, raise NotReadyError if the key is not ready.

    Returns:
        Any: The cached value associated with the key.

    Note:
        This method is thread-safe and handles concurrent requests for the same key.

    """
    if result := self.soft_cache.get(key):
        return result
    if self.locks[key].acquire(blocking=blocking):
        try:
            if key not in self.futures:
                self.futures[key] = self.pool.submit(self._work, key, populate_func)
            if result := self.hard_cache.get(key):
                # The soft cache will be updated by _work so we can fill the soft
                # cache to avoid later requests needign to acquire the lock.
                self.soft_cache[key] = result
                return result
            future = self.futures[key]
        finally:
            self.locks[key].release()
        if blocking:
            # It is critical that ``future`` is waited for outside of the lock
            # as _work acquires the lock before filling the caches. This also
            # means we can guarantee that the future has not yet been removed
            # from the futures dict.
            wait([future])
            return self.hard_cache[key]

    # If the lock is not acquired we're in a non-blocking mode, try to get the
    # value from the hard cache. If it's not there, raise NotReadyError.
    if result := self.hard_cache.get(key):
        return result
    raise NotReadyError(f"Cache key {key} is not ready yet.")
clear()

Clear all caches and reset the thread pool.

Source code in diracx-core/src/diracx/core/utils.py
def clear(self):
    """Clear all caches and reset the thread pool."""
    self.pool.shutdown(wait=True)
    self.pool = ThreadPoolExecutor(max_workers=self.pool._max_workers)
    self.soft_cache.clear()
    self.hard_cache.clear()
    self.futures.clear()
    self.locks.clear()

Functions

recursive_merge(base, override)

recursive_merge(base: T, override: None) -> T
recursive_merge(base: None, override: T) -> T
recursive_merge(base: T, override: T) -> T

Recursively merge dictionaries; values in override take precedence.

  • If both base and override are dicts, merge keys recursively.
  • Otherwise, return override if it is not None; fallback to base.
Source code in diracx-core/src/diracx/core/utils.py
def recursive_merge(base: Any, override: Any) -> Any:
    """Recursively merge dictionaries; values in ``override`` take precedence.

    - If both ``base`` and ``override`` are dicts, merge keys recursively.
    - Otherwise, return ``override`` if it is not ``None``; fallback to ``base``.
    """
    if isinstance(base, dict) and isinstance(override, dict):
        merged: dict[str, Any] = {}
        for key, base_val in base.items():
            if key in override:
                merged[key] = recursive_merge(base_val, override[key])
            else:
                merged[key] = base_val
        for key, override_val in override.items():
            if key not in merged:
                merged[key] = override_val
        return merged
    return override if override is not None else base

dotenv_files_from_environment(prefix)

Get the sorted list of .env files to use for configuration.

Source code in diracx-core/src/diracx/core/utils.py
def dotenv_files_from_environment(prefix: str) -> list[str]:
    """Get the sorted list of .env files to use for configuration."""
    env_files = {}
    for key, value in os.environ.items():
        if match := re.fullmatch(rf"{prefix}(?:_(\d+))?", key):
            env_files[int(match.group(1) or -1)] = value
    return [v for _, v in sorted(env_files.items())]

serialize_credentials(token_response)

Serialize DiracX client credentials to a string.

This method is separated from write_credentials to allow for DIRAC to be able to serialize credentials for inclusion in the proxy file.

Source code in diracx-core/src/diracx/core/utils.py
def serialize_credentials(token_response: TokenResponse) -> str:
    """Serialize DiracX client credentials to a string.

    This method is separated from write_credentials to allow for DIRAC to be
    able to serialize credentials for inclusion in the proxy file.
    """
    expires = datetime.now(tz=timezone.utc) + timedelta(
        seconds=token_response.expires_in - EXPIRES_GRACE_SECONDS
    )
    credential_data = {
        "access_token": token_response.access_token,
        "refresh_token": token_response.refresh_token,
        "expires_on": int(datetime.timestamp(expires)),
    }
    return json.dumps(credential_data)

read_credentials(location)

Read credentials from a file.

Source code in diracx-core/src/diracx/core/utils.py
def read_credentials(location: Path) -> TokenResponse:
    """Read credentials from a file."""
    from diracx.core.preferences import get_diracx_preferences

    credentials_path = location or get_diracx_preferences().credentials_path
    try:
        with open(credentials_path, "r") as f:
            # Lock the file to prevent other processes from writing to it at the same time
            fcntl.flock(f, fcntl.LOCK_SH)
            # Read the credentials from the file
            try:
                credentials = json.load(f)
            finally:
                # Release the lock
                fcntl.flock(f, fcntl.LOCK_UN)
    except (FileNotFoundError, json.JSONDecodeError) as e:
        raise RuntimeError(f"Error reading credentials: {e}") from e

    return TokenResponse(
        access_token=credentials["access_token"],
        expires_in=credentials["expires_on"]
        - int(datetime.now(tz=timezone.utc).timestamp()),
        token_type="Bearer",  # noqa: S106
        refresh_token=credentials.get("refresh_token"),
    )

write_credentials(token_response, *, location=None)

Write credentials received in dirax_preferences.credentials_path.

Source code in diracx-core/src/diracx/core/utils.py
def write_credentials(token_response: TokenResponse, *, location: Path | None = None):
    """Write credentials received in dirax_preferences.credentials_path."""
    from diracx.core.preferences import get_diracx_preferences

    credentials_path = location or get_diracx_preferences().credentials_path
    credentials_path.parent.mkdir(parents=True, exist_ok=True)

    # Open a file and set the permissions to 0x600
    file_descriptor = os.open(
        path=credentials_path,
        flags=os.O_WRONLY | os.O_CREAT | os.O_TRUNC,
        mode=stat.S_IRUSR | stat.S_IWUSR,
    )

    with open(file_descriptor, "w") as f:
        # Lock the file to prevent other processes from writing to it at the same time
        fcntl.flock(f, fcntl.LOCK_EX)
        try:
            # Write the credentials to the file
            f.write(serialize_credentials(token_response))
            f.flush()
            os.fsync(f.fileno())
        finally:
            # Release the lock
            fcntl.flock(f, fcntl.LOCK_UN)

batched_async(iterable, n, *, strict=False) async

Yield successive n-sized chunks from an async iterable.

Parameters:

Name Type Description Default
iterable async iterable

The input async iterable to be batched.

required
n int

The size of each batch.

required
strict bool

If True, raises ValueError for incomplete batches.

False

Yields:

Name Type Description
tuple AsyncIterable[tuple[T, ...]]

A tuple containing the next n elements from the iterable.

Raises:

Type Description
ValueError

If strict is True and the last batch is not of size n.

Example

async for batch in batched(aiter("ABCDEFG"), 3): ... print(batch) ('A', 'B', 'C') ('D', 'E', 'F') ('G',) async for batch in batched(aiter("ABCDEFG"), 3, strict=True): ... print(batch) ValueError: batched(): incomplete batch

Source code in diracx-core/src/diracx/core/utils.py
async def batched_async(
    iterable: AsyncIterable[T], n: int, *, strict: bool = False
) -> AsyncIterable[tuple[T, ...]]:
    """Yield successive n-sized chunks from an async iterable.

    Args:
        iterable (async iterable): The input async iterable to be batched.
        n (int): The size of each batch.
        strict (bool): If True, raises ValueError for incomplete batches.

    Yields:
        tuple: A tuple containing the next n elements from the iterable.

    Raises:
        ValueError: If strict is True and the last batch is not of size n.

    Example:
        >>> async for batch in batched(aiter("ABCDEFG"), 3):
        ...     print(batch)
        ('A', 'B', 'C')
        ('D', 'E', 'F')
        ('G',)
        >>> async for batch in batched(aiter("ABCDEFG"), 3, strict=True):
        ...     print(batch)
        ValueError: batched(): incomplete batch

    """
    if n < 1:
        raise ValueError("n must be at least one")
    batch = []
    async for item in iterable:
        batch.append(item)
        if len(batch) == n:
            yield tuple(batch)
            batch = []
    if batch:
        if strict and len(batch) != n:
            raise ValueError("batched(): incomplete batch")
        yield tuple(batch)