Skip to content

Auth Logic

Authentication and authorization business logic.

Token Management

token

Token endpoint implementation.

Attributes

BASE_64_URL_SAFE_PATTERN = '(?:[A-Za-z0-9\\-_]{4})*(?:[A-Za-z0-9\\-_]{2}==|[A-Za-z0-9\\-_]{3}=)?' module-attribute

LEGACY_EXCHANGE_PATTERN = f'Bearer diracx:legacy:({BASE_64_URL_SAFE_PATTERN})' module-attribute

Classes

Functions

get_oidc_token(grant_type, client_id, auth_db, config, settings, available_properties, device_code=None, code=None, redirect_uri=None, code_verifier=None, refresh_token=None) async

Token endpoint to retrieve the token at the end of a flow.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def get_oidc_token(
    grant_type: GrantType,
    client_id: str,
    auth_db: AuthDB,
    config: Config,
    settings: AuthSettings,
    available_properties: set[SecurityProperty],
    device_code: str | None = None,
    code: str | None = None,
    redirect_uri: str | None = None,
    code_verifier: str | None = None,
    refresh_token: str | None = None,
) -> tuple[AccessTokenPayload, RefreshTokenPayload | None]:
    """Token endpoint to retrieve the token at the end of a flow."""
    legacy_exchange = False
    include_refresh_token = True
    refresh_token_expire_minutes = None

    if grant_type == GrantType.device_code:
        assert device_code is not None
        oidc_token_info, scope = await get_oidc_token_info_from_device_flow(
            device_code, client_id, auth_db, settings
        )

    elif grant_type == GrantType.authorization_code:
        assert code is not None
        assert code_verifier is not None
        oidc_token_info, scope = await get_oidc_token_info_from_authorization_flow(
            code, client_id, redirect_uri, code_verifier, auth_db, settings
        )

    elif grant_type == GrantType.refresh_token:
        assert refresh_token is not None
        (
            oidc_token_info,
            scope,
            legacy_exchange,
            refresh_token_expire_minutes,
            include_refresh_token,
        ) = await get_oidc_token_info_from_refresh_flow(
            refresh_token, auth_db, settings
        )
    else:
        raise NotImplementedError(f"Grant type not implemented {grant_type}")

    # Get a TokenResponse to return to the user
    return await exchange_token(
        auth_db,
        scope,
        oidc_token_info,
        config,
        settings,
        available_properties,
        legacy_exchange=legacy_exchange,
        refresh_token_expire_minutes=refresh_token_expire_minutes,
        include_refresh_token=include_refresh_token,
    )

get_oidc_token_info_from_device_flow(device_code, client_id, auth_db, settings) async

Get OIDC token information from the device flow DB and check few parameters before returning it.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def get_oidc_token_info_from_device_flow(
    device_code: str, client_id: str, auth_db: AuthDB, settings: AuthSettings
) -> tuple[dict, str]:
    """Get OIDC token information from the device flow DB and check few parameters before returning it."""
    info = await get_device_flow(
        auth_db, device_code, settings.device_flow_expiration_seconds
    )

    if info["ClientID"] != client_id:
        raise ValueError("Bad client_id")

    oidc_token_info = info["IDToken"]
    scope = info["Scope"]

    # TODO: use HTTPException while still respecting the standard format
    # required by the RFC
    if info["Status"] != FlowStatus.READY:
        # That should never ever happen
        raise NotImplementedError(f"Unexpected flow status {info['status']!r}")
    return (oidc_token_info, scope)

get_oidc_token_info_from_authorization_flow(code, client_id, redirect_uri, code_verifier, auth_db, settings) async

Get OIDC token information from the authorization flow DB and check few parameters before returning it.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def get_oidc_token_info_from_authorization_flow(
    code: str,
    client_id: str | None,
    redirect_uri: str | None,
    code_verifier: str,
    auth_db: AuthDB,
    settings: AuthSettings,
) -> tuple[dict, str]:
    """Get OIDC token information from the authorization flow DB and check few parameters before returning it."""
    info = await get_authorization_flow(
        auth_db, code, settings.authorization_flow_expiration_seconds
    )
    if redirect_uri != info["RedirectURI"]:
        raise ValueError("Invalid redirect_uri")
    if client_id != info["ClientID"]:
        raise ValueError("Bad client_id")

    # Check the code_verifier
    try:
        code_challenge = (
            base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
            .decode()
            .strip("=")
        )
    except Exception as e:
        raise ValueError("Malformed code_verifier") from e

    if code_challenge != info["CodeChallenge"]:
        raise ValueError("Invalid code_challenge")

    oidc_token_info = info["IDToken"]
    scope = info["Scope"]

    # TODO: use HTTPException while still respecting the standard format
    # required by the RFC
    if info["Status"] != FlowStatus.READY:
        # That should never ever happen
        raise NotImplementedError(f"Unexpected flow status {info['status']!r}")

    return (oidc_token_info, scope)

get_oidc_token_info_from_refresh_flow(refresh_token, auth_db, settings) async

Get OIDC token information from the refresh token DB and check few parameters before returning it.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def get_oidc_token_info_from_refresh_flow(
    refresh_token: str, auth_db: AuthDB, settings: AuthSettings
) -> tuple[dict, str, bool, float, bool]:
    """Get OIDC token information from the refresh token DB and check few parameters before returning it."""
    # Decode the refresh token to get the JWT ID
    jti, exp, legacy_exchange = await verify_dirac_refresh_token(
        refresh_token, settings
    )

    # Get some useful user information from the refresh token entry in the DB
    refresh_token_attributes = await auth_db.get_refresh_token(jti)

    sub = refresh_token_attributes["Sub"]

    # Get the remaining time in minutes before the token expires
    remaining_minutes = (
        datetime.fromtimestamp(exp, timezone.utc) - datetime.now(timezone.utc)
    ).total_seconds() / 60

    # Check if the refresh token was obtained from the legacy_exchange endpoint
    if not legacy_exchange:
        include_refresh_token = True

        # Refresh token rotation: https://datatracker.ietf.org/doc/html/rfc6749#section-10.4
        # Check that the refresh token has not been already revoked
        # This might indicate that a potential attacker try to impersonate someone
        # In such case, all the refresh tokens bound to a given user (subject) should be revoked
        # Forcing the user to reauthenticate interactively through an authorization/device flow (recommended practice)
        if refresh_token_attributes["Status"] == RefreshTokenStatus.REVOKED:
            # Revoke all the user tokens from the subject
            await auth_db.revoke_user_refresh_tokens(sub)

            # Commit here, otherwise the revokation operation will not be taken into account
            # as we return an error to the user
            await auth_db.conn.commit()

            raise InvalidCredentialsError(
                "Revoked refresh token reused: potential attack detected. You must authenticate again"
            )

        # Part of the refresh token rotation mechanism:
        # Revoke the refresh token provided, a new one needs to be generated
        await auth_db.revoke_refresh_token(jti)
    else:
        # We bypass the refresh token rotation mechanism
        # and we don't want to generate a new refresh token
        include_refresh_token = False

    # Build an ID token and get scope from the refresh token attributes received
    oidc_token_info = {
        # The sub attribute coming from the DB contains the VO name
        # We need to remove it as if it were coming from an ID token from an external IdP
        "sub": sub.split(":", 1)[1],
    }
    scope = refresh_token_attributes["Scope"]
    return (
        oidc_token_info,
        scope,
        legacy_exchange,
        remaining_minutes,
        include_refresh_token,
    )

perform_legacy_exchange(expected_api_key, preferred_username, scope, authorization, auth_db, available_properties, settings, config, expires_minutes=None) async

Endpoint used by legacy DIRAC to mint tokens for proxy -> token exchange.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def perform_legacy_exchange(
    expected_api_key: str,
    preferred_username: str,
    scope: str,
    authorization: str,
    auth_db: AuthDB,
    available_properties: set[SecurityProperty],
    settings: AuthSettings,
    config: Config,
    expires_minutes: float | None = None,
) -> tuple[AccessTokenPayload, RefreshTokenPayload | None]:
    """Endpoint used by legacy DIRAC to mint tokens for proxy -> token exchange."""
    if match := re.fullmatch(LEGACY_EXCHANGE_PATTERN, authorization):
        raw_token = base64.urlsafe_b64decode(match.group(1))
    else:
        raise ValueError("Invalid authorization header")

    if hashlib.sha256(raw_token).hexdigest() != expected_api_key:
        raise InvalidCredentialsError("Invalid credentials")

    try:
        parsed_scope = parse_and_validate_scope(scope, config, available_properties)
        vo_users = config.Registry[parsed_scope["vo"]]
        sub = vo_users.sub_from_preferred_username(preferred_username)
    except (KeyError, ValueError) as e:
        raise ValueError("Invalid scope or preferred_username") from e

    return await exchange_token(
        auth_db,
        scope,
        {"sub": sub, "preferred_username": preferred_username},
        config,
        settings,
        available_properties,
        refresh_token_expire_minutes=expires_minutes,
        legacy_exchange=True,
    )

exchange_token(auth_db, scope, oidc_token_info, config, settings, available_properties, *, refresh_token_expire_minutes=None, legacy_exchange=False, include_refresh_token=True) async

Method called to exchange the OIDC token for a DIRAC generated access token.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def exchange_token(
    auth_db: AuthDB,
    scope: str,
    oidc_token_info: dict,
    config: Config,
    settings: AuthSettings,
    available_properties: set[SecurityProperty],
    *,
    refresh_token_expire_minutes: float | None = None,
    legacy_exchange: bool = False,
    include_refresh_token: bool = True,
) -> tuple[AccessTokenPayload, RefreshTokenPayload | None]:
    """Method called to exchange the OIDC token for a DIRAC generated access token."""
    # Extract dirac attributes from the OIDC scope
    parsed_scope = parse_and_validate_scope(scope, config, available_properties)
    vo = parsed_scope["vo"]
    dirac_group = parsed_scope["group"]
    properties = parsed_scope["properties"]

    # Extract attributes from the OIDC token details
    sub = oidc_token_info["sub"]
    if user_info := config.Registry[vo].Users.get(sub):
        preferred_username = user_info.PreferedUsername
    else:
        preferred_username = oidc_token_info.get("preferred_username", sub)
        raise NotImplementedError(
            "Dynamic registration of users is not yet implemented"
        )

    # Check that the subject is part of the dirac users
    if sub not in config.Registry[vo].Groups[dirac_group].Users:
        raise PermissionError(
            f"User is not a member of the requested group ({preferred_username}, {dirac_group})"
        )

    # Check that the user properties are valid
    allowed_user_properties = get_allowed_user_properties(config, sub, vo)
    if not properties.issubset(allowed_user_properties):
        raise PermissionError(
            f"{' '.join(properties - allowed_user_properties)} are not valid properties "
            f"for user {preferred_username}, available values: {' '.join(allowed_user_properties)}"
        )

    # Merge the VO with the subject to get a unique DIRAC sub
    sub = f"{vo}:{sub}"

    refresh_payload: RefreshTokenPayload | None = None
    if include_refresh_token:
        # Insert the refresh token with user details into the RefreshTokens table
        # User details are needed to regenerate access tokens later
        refresh_jti = await insert_refresh_token(
            auth_db=auth_db,
            subject=sub,
            scope=scope,
        )

        # Generate refresh token payload
        if refresh_token_expire_minutes is None:
            refresh_token_expire_minutes = settings.refresh_token_expire_minutes
        refresh_exp = uuid7_to_datetime(refresh_jti) + timedelta(
            minutes=refresh_token_expire_minutes
        )
        refresh_payload = {
            "jti": str(refresh_jti),
            "exp": refresh_exp,
            # legacy_exchange is used to indicate that the original refresh token
            # was obtained from the legacy_exchange endpoint
            "legacy_exchange": legacy_exchange,
            "dirac_policies": {},
        }

    # Generate access token payload
    # For now, the access token is only used to access DIRAC services,
    # therefore, the audience is not set and checked
    access_jti = uuid7()
    access_exp = uuid7_to_datetime(access_jti) + timedelta(
        minutes=settings.access_token_expire_minutes
    )
    access_payload: AccessTokenPayload = {
        "sub": sub,
        "vo": vo,
        "iss": settings.token_issuer,
        "dirac_properties": list(properties),
        "jti": str(access_jti),
        "preferred_username": preferred_username,
        "dirac_group": dirac_group,
        "exp": access_exp,
        "dirac_policies": {},
    }

    return access_payload, refresh_payload

create_token(payload, settings)

Create a JWT token with the given payload and settings.

Source code in diracx-logic/src/diracx/logic/auth/token.py
def create_token(payload: TokenPayload, settings: AuthSettings) -> str:
    """Create a JWT token with the given payload and settings."""
    signing_key = None
    for key in settings.token_keystore.jwks.keys:
        key_ops = key.get("key_ops")
        if key_ops and not isinstance(key_ops, list):
            key_ops = [key_ops]
        if key_ops and "sign" in key_ops:
            signing_key = key
            break

    if not signing_key:
        raise ValueError("No signing key found in JWKS")

    return jwt.encode(
        header={"alg": signing_key.get("alg"), "kid": signing_key.get("kid")},
        claims=cast(Claims, payload),
        key=settings.token_keystore.jwks,
        algorithms=settings.token_allowed_algorithms,
    )

insert_refresh_token(auth_db, subject, scope) async

Insert a refresh token into the database and return the JWT ID.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def insert_refresh_token(
    auth_db: AuthDB,
    subject: str,
    scope: str,
) -> UUID:
    """Insert a refresh token into the database and return the JWT ID."""
    # Generate a JWT ID
    jti = uuid7()

    # Insert the refresh token into the DB
    await auth_db.insert_refresh_token(
        jti=jti,
        subject=subject,
        scope=scope,
    )
    return jti

get_device_flow(auth_db, device_code, max_validity) async

Get the device flow from the DB and check few parameters before returning it.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def get_device_flow(auth_db: AuthDB, device_code: str, max_validity: int):
    """Get the device flow from the DB and check few parameters before returning it."""
    res = await auth_db.get_device_flow(device_code)

    if res["CreationTime"].replace(tzinfo=timezone.utc) < substract_date(
        seconds=max_validity
    ):
        raise InvalidCredentialsError("Device code expired")

    if res["Status"] == FlowStatus.READY:
        await auth_db.update_device_flow_status(device_code, FlowStatus.DONE)
        return res

    if res["Status"] == FlowStatus.DONE:
        raise AuthorizationError("Code was already used")

    if res["Status"] == FlowStatus.PENDING:
        raise PendingAuthorizationError()

    raise AuthorizationError("Bad state in device flow")

get_authorization_flow(auth_db, code, max_validity) async

Get the authorization flow from the DB and check few parameters before returning it.

Source code in diracx-logic/src/diracx/logic/auth/token.py
async def get_authorization_flow(auth_db: AuthDB, code: str, max_validity: int):
    """Get the authorization flow from the DB and check few parameters before returning it."""
    res = await auth_db.get_authorization_flow(code, max_validity)

    if res["Status"] == FlowStatus.READY:
        await auth_db.update_authorization_flow_status(code, FlowStatus.DONE)
        return res

    if res["Status"] == FlowStatus.DONE:
        raise AuthorizationError("Code was already used")

    raise AuthorizationError("Bad state in authorization flow")

Authorization Code Flow

authorize_code_flow

Authorization code flow.

Classes

Functions

initiate_authorization_flow(request_url, code_challenge, code_challenge_method, client_id, redirect_uri, scope, state, auth_db, config, settings, available_properties) async

Initiate the authorization flow.

Source code in diracx-logic/src/diracx/logic/auth/authorize_code_flow.py
async def initiate_authorization_flow(
    request_url: str,
    code_challenge: str,
    code_challenge_method: Literal["S256"],
    client_id: str,
    redirect_uri: str,
    scope: str,
    state: str,
    auth_db: AuthDB,
    config: Config,
    settings: AuthSettings,
    available_properties: set[SecurityProperty],
) -> str:
    """Initiate the authorization flow."""
    if settings.dirac_client_id != client_id:
        raise ValueError("Unrecognised client_id")
    if redirect_uri not in settings.allowed_redirects:
        raise ValueError("Unrecognised redirect_uri")

    # Parse and validate the scope
    parsed_scope = parse_and_validate_scope(scope, config, available_properties)

    # Store the authorization flow details
    uuid = await auth_db.insert_authorization_flow(
        client_id,
        scope,
        code_challenge,
        code_challenge_method,
        redirect_uri,
    )

    # Initiate the authorization flow with the IAM
    state_for_iam = {
        "external_state": state,
        "uuid": uuid,
        "grant_type": GrantType.authorization_code.value,
    }

    authorization_flow_url = await initiate_authorization_flow_with_iam(
        config,
        parsed_scope["vo"],
        f"{request_url}/complete",
        state_for_iam,
        settings.state_key.fernet,
    )

    return authorization_flow_url

complete_authorization_flow(code, state, request_url, auth_db, config, settings) async

Complete the authorization flow.

Source code in diracx-logic/src/diracx/logic/auth/authorize_code_flow.py
async def complete_authorization_flow(
    code: str,
    state: str,
    request_url: str,
    auth_db: AuthDB,
    config: Config,
    settings: AuthSettings,
) -> str:
    """Complete the authorization flow."""
    # Decrypt the state to access user details
    decrypted_state = decrypt_state(state, settings.state_key.fernet)
    assert decrypted_state["grant_type"] == GrantType.authorization_code

    # Get the ID token from the IAM
    id_token = await get_token_from_iam(
        config,
        decrypted_state["vo"],
        code,
        decrypted_state,
        request_url,
    )

    # Store the ID token and redirect the user to the client's redirect URI
    code, redirect_uri = await auth_db.authorization_flow_insert_id_token(
        decrypted_state["uuid"],
        id_token,
        settings.authorization_flow_expiration_seconds,
    )

    return f"{redirect_uri}?code={code}&state={decrypted_state['external_state']}"

Device Flow

device_flow

Device flow.

Classes

Functions

initiate_device_flow(client_id, scope, verification_uri, auth_db, config, available_properties, settings) async

Initiate the device flow against DIRAC authorization Server.

Source code in diracx-logic/src/diracx/logic/auth/device_flow.py
async def initiate_device_flow(
    client_id: str,
    scope: str,
    verification_uri: str,
    auth_db: AuthDB,
    config: Config,
    available_properties: set[SecurityProperty],
    settings: AuthSettings,
) -> InitiateDeviceFlowResponse:
    """Initiate the device flow against DIRAC authorization Server."""
    if settings.dirac_client_id != client_id:
        raise ValueError("Unrecognised client ID")

    parse_and_validate_scope(scope, config, available_properties)

    user_code, device_code = await auth_db.insert_device_flow(client_id, scope)

    return {
        "user_code": user_code,
        "device_code": device_code,
        "verification_uri_complete": f"{verification_uri}?user_code={user_code}",
        "verification_uri": verification_uri,
        "expires_in": settings.device_flow_expiration_seconds,
    }

do_device_flow(request_url, auth_db, user_code, config, available_properties, settings) async

This is called as the verification URI for the device flow.

Source code in diracx-logic/src/diracx/logic/auth/device_flow.py
async def do_device_flow(
    request_url: str,
    auth_db: AuthDB,
    user_code: str,
    config: Config,
    available_properties: set[SecurityProperty],
    settings: AuthSettings,
) -> str:
    """This is called as the verification URI for the device flow."""
    # Here we make sure the user_code actually exists
    scope = await auth_db.device_flow_validate_user_code(
        user_code, settings.device_flow_expiration_seconds
    )
    parsed_scope = parse_and_validate_scope(scope, config, available_properties)

    redirect_uri = f"{request_url}/complete"

    state_for_iam = {
        "grant_type": GrantType.device_code.value,
        "user_code": user_code,
    }

    authorization_flow_url = await initiate_authorization_flow_with_iam(
        config,
        parsed_scope["vo"],
        redirect_uri,
        state_for_iam,
        settings.state_key.fernet,
    )
    return authorization_flow_url

finish_device_flow(request_url, code, state, auth_db, config, settings) async

This the url callbacked by IAM/CheckIn after the authorization flow was granted.

Source code in diracx-logic/src/diracx/logic/auth/device_flow.py
async def finish_device_flow(
    request_url: str,
    code: str,
    state: str,
    auth_db: AuthDB,
    config: Config,
    settings: AuthSettings,
):
    """This the url callbacked by IAM/CheckIn after the authorization
    flow was granted.
    """
    decrypted_state = decrypt_state(state, settings.state_key.fernet)
    assert decrypted_state["grant_type"] == GrantType.device_code

    id_token = await get_token_from_iam(
        config,
        decrypted_state["vo"],
        code,
        decrypted_state,
        request_url,
    )
    await auth_db.device_flow_insert_id_token(
        decrypted_state["user_code"], id_token, settings.device_flow_expiration_seconds
    )

User Management

management

This module contains the auth management functions.

Classes

Functions

get_refresh_tokens(auth_db, subject) async

Get all refresh tokens bound to a given subject. If there is no subject, then all the refresh tokens are retrieved.

Source code in diracx-logic/src/diracx/logic/auth/management.py
async def get_refresh_tokens(
    auth_db: AuthDB,
    subject: str | None,
) -> list:
    """Get all refresh tokens bound to a given subject. If there is no subject, then
    all the refresh tokens are retrieved.
    """
    return await auth_db.get_user_refresh_tokens(subject)

revoke_refresh_token_by_jti(auth_db, subject, jti) async

Revoke a refresh token. If a subject is provided, then the refresh token must be owned by that subject.

Source code in diracx-logic/src/diracx/logic/auth/management.py
async def revoke_refresh_token_by_jti(
    auth_db: AuthDB,
    subject: str | None,
    jti: UUID,
) -> str:
    """Revoke a refresh token. If a subject is provided, then the refresh token must be owned by that subject."""
    res = await auth_db.get_refresh_token(jti)

    if subject and subject != res["Sub"]:
        raise PermissionError("Cannot revoke a refresh token owned by someone else")

    await auth_db.revoke_refresh_token(jti)
    return f"Refresh token {jti} revoked"

revoke_refresh_token_by_refresh_token(auth_db, subject, token, token_type_hint, client_id, settings) async

Revoke a refresh token following RFC7009.

Source code in diracx-logic/src/diracx/logic/auth/management.py
async def revoke_refresh_token_by_refresh_token(
    auth_db: AuthDB,
    subject: str | None,
    token: str,
    token_type_hint: str | None,
    client_id: str,
    settings: AuthSettings,
) -> str:
    """Revoke a refresh token following RFC7009."""
    # Test the token type hint
    if token_type_hint and token_type_hint == TokenTypeHint.access_token:
        raise ValueError("unsupported_token_type")

    # Test the client_id
    if settings.dirac_client_id != client_id:
        raise InvalidCredentialsError("Unrecognised client_id")

    # Decode and verify the refresh token
    jti, _, _ = await verify_dirac_refresh_token(token, settings)
    return await revoke_refresh_token_by_jti(auth_db=auth_db, subject=subject, jti=jti)

Well Known Endpoints

well_known

Classes

Functions

get_openid_configuration(token_endpoint, userinfo_endpoint, authorization_endpoint, device_authorization_endpoint, revoke_refresh_token_endpoint, jwks_endpoint, config, settings) async

OpenID Connect discovery endpoint.

Source code in diracx-logic/src/diracx/logic/auth/well_known.py
async def get_openid_configuration(
    token_endpoint: str,
    userinfo_endpoint: str,
    authorization_endpoint: str,
    device_authorization_endpoint: str,
    revoke_refresh_token_endpoint: str,
    jwks_endpoint: str,
    config: Config,
    settings: AuthSettings,
) -> OpenIDConfiguration:
    """OpenID Connect discovery endpoint."""
    scopes_supported = []
    for vo in config.Registry:
        scopes_supported.append(f"vo:{vo}")
        scopes_supported += [f"group:{vo}" for vo in config.Registry[vo].Groups]
    scopes_supported += [f"property:{p}" for p in settings.available_properties]

    return {
        "issuer": settings.token_issuer,
        "token_endpoint": token_endpoint,
        "userinfo_endpoint": userinfo_endpoint,
        "authorization_endpoint": authorization_endpoint,
        "device_authorization_endpoint": device_authorization_endpoint,
        "revocation_endpoint": revoke_refresh_token_endpoint,
        "jwks_uri": jwks_endpoint,
        "grant_types_supported": [
            "authorization_code",
            "urn:ietf:params:oauth:grant-type:device_code",
        ],
        "scopes_supported": scopes_supported,
        "response_types_supported": ["code"],
        "token_endpoint_auth_signing_alg_values_supported": settings.token_allowed_algorithms,
        "token_endpoint_auth_methods_supported": ["none"],
        "code_challenge_methods_supported": ["S256"],
    }

get_jwks(settings) async

Get the JWKs (public keys).

Source code in diracx-logic/src/diracx/logic/auth/well_known.py
async def get_jwks(settings: AuthSettings) -> dict:
    """Get the JWKs (public keys)."""
    return settings.token_keystore.jwks.as_dict(  # type: ignore[return-value]
        private=False,
    )

get_installation_metadata(config) async

Get metadata about the dirac installation.

Source code in diracx-logic/src/diracx/logic/auth/well_known.py
async def get_installation_metadata(
    config: Config,
) -> Metadata:
    """Get metadata about the dirac installation."""
    metadata: Metadata = {
        "virtual_organizations": {},
    }
    for vo, vo_info in config.Registry.items():
        groups: dict[str, GroupInfo] = {
            group: {"properties": sorted(group_info.Properties)}
            for group, group_info in vo_info.Groups.items()
        }
        metadata["virtual_organizations"][vo] = {
            "groups": groups,
            "support": {
                "message": vo_info.Support.Message,
                "webpage": vo_info.Support.Webpage,
                "email": vo_info.Support.Email,
            },
            "default_group": vo_info.DefaultGroup,
        }

    return metadata

Auth Utilities

utils

Classes

ScopeInfoDict

Bases: TypedDict

Source code in diracx-logic/src/diracx/logic/auth/utils.py
class ScopeInfoDict(TypedDict):
    group: str
    properties: set[str]
    vo: str
Attributes
group instance-attribute
properties instance-attribute
vo instance-attribute

Functions

get_server_metadata(url) async

Get the server metadata from the IAM.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
async def get_server_metadata(url: str):
    """Get the server metadata from the IAM."""
    server_metadata = _server_metadata_cache.get(url)
    if server_metadata is None:
        async with httpx.AsyncClient() as c:
            res = await c.get(url)
            if res.status_code != 200:
                # TODO: Better error handling
                raise NotImplementedError(res)
            server_metadata = res.json()
            _server_metadata_cache[url] = server_metadata
    return server_metadata

encrypt_state(state_dict, cipher_suite)

Encrypt the state dict and return it as a string.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
def encrypt_state(state_dict: dict[str, str], cipher_suite: Fernet) -> str:
    """Encrypt the state dict and return it as a string."""
    return cipher_suite.encrypt(
        base64.urlsafe_b64encode(json.dumps(state_dict).encode())
    ).decode()

decrypt_state(state, cipher_suite)

Decrypt the state string and return it as a dict.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
def decrypt_state(state: str, cipher_suite: Fernet) -> dict[str, str]:
    """Decrypt the state string and return it as a dict."""
    try:
        return json.loads(
            base64.urlsafe_b64decode(cipher_suite.decrypt(state.encode())).decode()
        )
    except Exception as e:
        raise AuthorizationError("Invalid state") from e

fetch_jwk_set(url) async

Fetch the JWK set from the IAM.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
async def fetch_jwk_set(url: str):
    """Fetch the JWK set from the IAM."""
    server_metadata = await get_server_metadata(url)

    jwks_uri = server_metadata.get("jwks_uri")
    if not jwks_uri:
        raise RuntimeError('Missing "jwks_uri" in metadata')

    async with httpx.AsyncClient() as c:
        res = await c.get(jwks_uri)
        if res.status_code != 200:
            # TODO: Better error handling
            raise NotImplementedError(res)
        jwk_set = res.json()

    return KeySet.import_key_set(jwk_set)

parse_id_token(config, vo, raw_id_token) async

Parse and validate the ID token from IAM.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
async def parse_id_token(config, vo, raw_id_token: str):
    """Parse and validate the ID token from IAM."""
    server_metadata = await get_server_metadata(
        config.Registry[vo].IdP.server_metadata_url
    )
    alg_values = server_metadata.get("id_token_signing_alg_values_supported", ["RS256"])
    jwk_set = await fetch_jwk_set(config.Registry[vo].IdP.server_metadata_url)

    token = jwt.decode(
        raw_id_token,
        key=jwk_set,
        algorithms=alg_values,
    )
    JWTClaimsRegistry(
        iss={"essential": True, "value": server_metadata["issuer"]},
        # The audience is a required parameter and is the client ID of the application
        # https://openid.net/specs/openid-connect-core-1_0.html#IDToken
        aud={"essential": True, "values": [config.Registry[vo].IdP.ClientID]},
        exp={"essential": True},
        iat={"essential": True},
        sub={"essential": True},
    ).validate(token.claims)
    return token.claims

initiate_authorization_flow_with_iam(config, vo, redirect_uri, state, cipher_suite) async

Initiate the authorization flow with the IAM. Return the URL to redirect the user to.

The state dict is encrypted and passed to the IAM. It is then decrypted when the user is redirected back to the redirect_uri.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
async def initiate_authorization_flow_with_iam(
    config, vo: str, redirect_uri: str, state: dict[str, str], cipher_suite: Fernet
):
    """Initiate the authorization flow with the IAM. Return the URL to redirect the user to.

    The state dict is encrypted and passed to the IAM.
    It is then decrypted when the user is redirected back to the redirect_uri.
    """
    # code_verifier: https://www.rfc-editor.org/rfc/rfc7636#section-4.1
    code_verifier = secrets.token_hex()

    # code_challenge: https://www.rfc-editor.org/rfc/rfc7636#section-4.2
    code_challenge = (
        base64.urlsafe_b64encode(hashlib.sha256(code_verifier.encode()).digest())
        .decode()
        .replace("=", "")
    )

    server_metadata = await get_server_metadata(
        config.Registry[vo].IdP.server_metadata_url
    )

    # Take these two from CS/.well-known
    authorization_endpoint = server_metadata["authorization_endpoint"]

    # Encrypt the state and pass it to the IAM
    # Needed to retrieve the original flow details when the user is redirected back to the redirect_uri
    encrypted_state = encrypt_state(
        state | {"vo": vo, "code_verifier": code_verifier}, cipher_suite
    )

    url_params = [
        "response_type=code",
        f"code_challenge={code_challenge}",
        "code_challenge_method=S256",
        f"client_id={config.Registry[vo].IdP.ClientID}",
        f"redirect_uri={redirect_uri}",
        "scope=openid%20profile",
        f"state={encrypted_state}",
    ]
    authorization_flow_url = f"{authorization_endpoint}?{'&'.join(url_params)}"
    return authorization_flow_url

get_token_from_iam(config, vo, code, state, redirect_uri) async

Get the token from the IAM using the code and state. Return the ID token.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
async def get_token_from_iam(
    config, vo: str, code: str, state: dict[str, str], redirect_uri: str
) -> dict[str, str]:
    """Get the token from the IAM using the code and state. Return the ID token."""
    server_metadata = await get_server_metadata(
        config.Registry[vo].IdP.server_metadata_url
    )

    # Take these two from CS/.well-known
    token_endpoint = server_metadata["token_endpoint"]

    data = {
        "grant_type": GrantType.authorization_code.value,
        "client_id": config.Registry[vo].IdP.ClientID,
        "code": code,
        "code_verifier": state["code_verifier"],
        "redirect_uri": redirect_uri,
    }

    async with httpx.AsyncClient() as c:
        res = await c.post(
            token_endpoint,
            data=data,
        )
        if res.status_code >= 500:
            raise IAMServerError("Failed to contact IAM server")
        elif res.status_code >= 400:
            raise IAMClientError("Failed to contact IAM server")

    raw_id_token = res.json()["id_token"]
    # Extract the payload and verify it
    try:
        id_token = await parse_id_token(
            config=config,
            vo=vo,
            raw_id_token=raw_id_token,
        )
    except ValueError:
        raise

    return id_token

read_token(payload, jwks, allowed_algorithms, claims_requests=None)

Source code in diracx-logic/src/diracx/logic/auth/utils.py
def read_token(
    payload: str,
    jwks: KeySet,
    allowed_algorithms: list[str],
    claims_requests: JWTClaimsRegistry | None = None,
) -> Claims:
    if not claims_requests:
        claims_requests = JWTClaimsRegistry()

    token = jwt.decode(payload, key=jwks, algorithms=allowed_algorithms)
    claims_requests.validate(token.claims)
    return token.claims

verify_dirac_refresh_token(refresh_token, settings) async

Verify dirac user token and return a UserInfo class Used for each API endpoint.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
async def verify_dirac_refresh_token(
    refresh_token: str,
    settings: AuthSettings,
) -> tuple[UUID, float, bool]:
    """Verify dirac user token and return a UserInfo class
    Used for each API endpoint.
    """
    claims = read_token(
        refresh_token, settings.token_keystore.jwks, settings.token_allowed_algorithms
    )

    return (
        UUID(claims["jti"]),
        float(claims["exp"]),
        claims["legacy_exchange"],
    )

get_allowed_user_properties(config, sub, vo)

Retrieve all properties of groups a user is registered in.

Source code in diracx-logic/src/diracx/logic/auth/utils.py
def get_allowed_user_properties(config: Config, sub, vo: str) -> set[SecurityProperty]:
    """Retrieve all properties of groups a user is registered in."""
    allowed_user_properties = set()
    for group in config.Registry[vo].Groups:
        if sub in config.Registry[vo].Groups[group].Users:
            allowed_user_properties.update(config.Registry[vo].Groups[group].Properties)
    return allowed_user_properties

parse_and_validate_scope(scope, config, available_properties)

Check
  • At most one VO
  • At most one group
  • group belongs to VO
  • properties are known

return dict with group and properties.

:raises: * ValueError in case the scope isn't valid

Source code in diracx-logic/src/diracx/logic/auth/utils.py
def parse_and_validate_scope(
    scope: str, config: Config, available_properties: set[SecurityProperty]
) -> ScopeInfoDict:
    """Check:
        * At most one VO
        * At most one group
        * group belongs to VO
        * properties are known
    return dict with group and properties.

    :raises:
        * ValueError in case the scope isn't valid
    """
    scopes = set(scope.split(" "))

    groups = []
    properties = []
    vos = []
    unrecognised = []
    for scope in scopes:
        if scope.startswith("group:"):
            groups.append(scope.split(":", 1)[1])
        elif scope.startswith("property:"):
            properties.append(scope.split(":", 1)[1])
        elif scope.startswith("vo:"):
            vos.append(scope.split(":", 1)[1])
        else:
            unrecognised.append(scope)
    if unrecognised:
        raise ValueError(f"Unrecognised scopes: {unrecognised}")

    if not vos:
        available_vo_scopes = [repr(f"vo:{vo}") for vo in config.Registry]
        raise ValueError(
            f"No vo scope requested, available values: {' '.join(available_vo_scopes)}"
        )
    elif len(vos) > 1:
        raise ValueError(f"Only one vo is allowed but got {vos}")
    else:
        vo = vos[0]
        if vo not in config.Registry:
            raise ValueError(f"VO {vo} is not known to this installation")

    if not groups:
        # TODO: Handle multiple groups correctly
        group = config.Registry[vo].DefaultGroup
    elif len(groups) > 1:
        raise ValueError(f"Only one DIRAC group allowed but got {groups}")
    else:
        group = groups[0]
        if group not in config.Registry[vo].Groups:
            raise ValueError(f"{group} not in {vo} groups")

    allowed_properties = config.Registry[vo].Groups[group].Properties
    properties.extend([str(p) for p in allowed_properties])

    if not set(properties).issubset(available_properties):
        raise ValueError(
            f"{set(properties) - set(available_properties)} are not valid properties"
        )

    return {
        "group": group,
        "properties": set(sorted(properties)),
        "vo": vo,
    }