Skip to content

Jobs Logic

Job-related business logic including submission, querying, status management, and sandboxes.

Job Submission

submission

Attributes

logger = logging.getLogger(__name__) module-attribute

MAX_PARAMETRIC_JOBS = 20 module-attribute

Classes

JobSubmissionSpec pydantic-model

Bases: BaseModel

Fields:

Source code in diracx-logic/src/diracx/logic/jobs/submission.py
class JobSubmissionSpec(BaseModel):
    jdl: str
    owner: str
    owner_group: str
    initial_status: str
    initial_minor_status: str
    vo: str
Attributes
jdl pydantic-field
owner pydantic-field
owner_group pydantic-field
initial_status pydantic-field
initial_minor_status pydantic-field
vo pydantic-field

Functions

submit_jdl_jobs(job_definitions, job_db, job_logging_db, user_info, config) async

Submit a list of JDLs to the JobDB.

Source code in diracx-logic/src/diracx/logic/jobs/submission.py
async def submit_jdl_jobs(
    job_definitions: list[str],
    job_db: JobDB,
    job_logging_db: JobLoggingDB,
    user_info: UserInfo,
    config: Config,
) -> list[InsertedJob]:
    """Submit a list of JDLs to the JobDB."""
    # TODO: that needs to go in the legacy adapter (Does it ? Because bulk submission is not supported there)
    for i in range(len(job_definitions)):
        job_definition = job_definitions[i].strip()
        if not (job_definition.startswith("[") and job_definition.endswith("]")):
            job_definition = f"[{job_definition}]"
        job_definitions[i] = job_definition

    if len(job_definitions) == 1:
        # Check if the job is a parametric one
        job_class_ad = ClassAd(job_definitions[0])
        result = getParameterVectorLength(job_class_ad)
        if not result["OK"]:
            # FIXME dont do this
            print("Issue with getParameterVectorLength", result["Message"])
            return result
        n_jobs = result["Value"]
        parametric_job = False
        if n_jobs is not None and n_jobs > 0:
            # if we are here, then jobDesc was the description of a parametric job. So we start unpacking
            parametric_job = True
            result = generateParametricJobs(job_class_ad)
            if not result["OK"]:
                # FIXME why?
                return result
            job_desc_list = result["Value"]
        else:
            # if we are here, then jobDesc was the description of a single job.
            job_desc_list = job_definitions
    else:
        # if we are here, then jobDesc is a list of JDLs
        # we need to check that none of them is a parametric
        for job_definition in job_definitions:
            res = getParameterVectorLength(ClassAd(job_definition))
            if not res["OK"]:
                raise ValueError(res["Message"])

            if res["Value"]:
                raise ValueError("You cannot submit parametric jobs in a bulk fashion")

        job_desc_list = job_definitions
        # parametric_job = True
        parametric_job = False

    # TODO: make the max number of jobs configurable in the CS
    if len(job_desc_list) > MAX_PARAMETRIC_JOBS:
        raise ValueError(
            f"Normal user cannot submit more than {MAX_PARAMETRIC_JOBS} jobs at once"
        )

    result = []

    if parametric_job:
        initial_status = JobStatus.SUBMITTING
        initial_minor_status = "Bulk transaction confirmation"
    else:
        initial_status = JobStatus.RECEIVED
        initial_minor_status = "Job accepted"

    try:
        submitted_job_ids = await create_jdl_jobs(
            [
                JobSubmissionSpec(
                    jdl=jdl,
                    owner=user_info.preferred_username,
                    owner_group=user_info.dirac_group,
                    initial_status=initial_status,
                    initial_minor_status=initial_minor_status,
                    vo=user_info.vo,
                )
                for jdl in job_desc_list
            ],
            job_db=job_db,
            config=config,
        )
    except ExceptionGroup as e:
        logging.exception("JDL syntax error occurred during job submission")
        raise ValueError("JDL syntax error") from e

    logging.debug(
        f'Jobs added to the JobDB", "{submitted_job_ids} for {user_info.preferred_username}/{user_info.dirac_group}'
    )

    job_created_time = datetime.now(timezone.utc)
    await job_logging_db.insert_records(
        [
            JobLoggingRecord(
                job_id=int(job_id),
                status=initial_status,
                minor_status=initial_minor_status,
                application_status="Unknown",
                date=job_created_time,
                source="JobManager",
            )
            for job_id in submitted_job_ids
        ]
    )

    # if not parametric_job:
    #     self.__sendJobsToOptimizationMind(submitted_job_ids)

    return [
        InsertedJob(
            JobID=job_id,
            Status=initial_status,
            MinorStatus=initial_minor_status,
            TimeStamp=job_created_time,
        )
        for job_id in submitted_job_ids
    ]

create_jdl_jobs(jobs, job_db, config) async

Create jobs from JDLs and insert them into the DB.

Source code in diracx-logic/src/diracx/logic/jobs/submission.py
async def create_jdl_jobs(jobs: list[JobSubmissionSpec], job_db: JobDB, config: Config):
    """Create jobs from JDLs and insert them into the DB."""
    jobs_to_insert = {}
    jdls_to_update = {}
    inputdata_to_insert = {}
    original_jdls = []

    # generate the jobIDs first
    # TODO: should ForgivingTaskGroup be used?
    async with asyncio.TaskGroup() as tg:
        for job in jobs:
            original_jdl = deepcopy(job.jdl)
            job_manifest = returnValueOrRaise(
                checkAndAddOwner(
                    original_jdl,
                    job.owner,
                    job.owner_group,
                    job_manifest_config=make_job_manifest_config(config, job.vo),
                )
            )

            # Fix possible lack of brackets
            if original_jdl.strip()[0] != "[":
                original_jdl = f"[{original_jdl}]"

            original_jdls.append(
                (
                    original_jdl,
                    job_manifest,
                    tg.create_task(job_db.create_job(compressJDL(original_jdl))),
                )
            )

    async with asyncio.TaskGroup() as tg:
        for job, (original_jdl, job_manifest_, job_id_task) in zip(jobs, original_jdls):
            job_id = job_id_task.result()
            job_attrs = {
                "JobID": job_id,
                "LastUpdateTime": datetime.now(tz=timezone.utc),
                "SubmissionTime": datetime.now(tz=timezone.utc),
                "Owner": job.owner,
                "OwnerGroup": job.owner_group,
                "VO": job.vo,
            }

            job_manifest_.setOption("JobID", job_id)

            # 2.- Check JDL and Prepare DIRAC JDL
            job_jdl = job_manifest_.dumpAsJDL()

            # Replace the JobID placeholder if any
            if job_jdl.find("%j") != -1:
                job_jdl = job_jdl.replace("%j", str(job_id))

            class_ad_job = ClassAd(job_jdl)

            class_ad_req = ClassAd("[]")
            if not class_ad_job.isOK():
                # Rollback the entire transaction
                logging.exception(f"Error in JDL syntax for job JDL: {original_jdl}")
                raise ValueError(f"Error in JDL syntax for job JDL: {original_jdl}")
            # TODO: check if that is actually true
            if class_ad_job.lookupAttribute("Parameters"):
                raise NotImplementedError("Parameters in the JDL are not supported")

            # TODO is this even needed?
            class_ad_job.insertAttributeInt("JobID", job_id)

            await check_and_prepare_job(
                job_id,
                class_ad_job,
                class_ad_req,
                job.owner,
                job.owner_group,
                job_attrs,
                job.vo,
                job_db,
                config,
            )
            job_jdl = createJDLWithInitialStatus(
                class_ad_job,
                class_ad_req,
                job_db.jdl_2_db_parameters,
                job_attrs,
                job.initial_status,
                job.initial_minor_status,
                modern=True,
            )

            jobs_to_insert[job_id] = job_attrs
            jdls_to_update[job_id] = compressJDL(job_jdl)

            if class_ad_job.lookupAttribute("InputData"):
                input_data = class_ad_job.getListFromExpression("InputData")
                inputdata_to_insert[job_id] = [lfn for lfn in input_data if lfn]

        tg.create_task(job_db.update_job_jdls(jdls_to_update))
        tg.create_task(job_db.insert_job_attributes(jobs_to_insert))

        if inputdata_to_insert:
            tg.create_task(job_db.insert_input_data(inputdata_to_insert))

    return list(jobs_to_insert.keys())

Job Query

query

Attributes

logger = logging.getLogger(__name__) module-attribute

MAX_PER_PAGE = 10000 module-attribute

Classes

Functions

search(config, job_db, job_parameters_db, job_logging_db, preferred_username, vo, page=1, per_page=100, body=None) async

Retrieve information about jobs.

Source code in diracx-logic/src/diracx/logic/jobs/query.py
async def search(
    config: Config,
    job_db: JobDB,
    job_parameters_db: JobParametersDB,
    job_logging_db: JobLoggingDB,
    preferred_username: str | None,
    vo: str,
    page: int = 1,
    per_page: int = 100,
    body: SearchParams | None = None,
) -> tuple[int, list[dict[str, Any]]]:
    """Retrieve information about jobs."""
    # Apply a limit to per_page to prevent abuse of the API
    if per_page > MAX_PER_PAGE:
        per_page = MAX_PER_PAGE

    if body is None:
        body = SearchParams()

    if query_logging_info := ("LoggingInfo" in (body.parameters or [])):
        if body.parameters:
            body.parameters.remove("LoggingInfo")
            if not body.parameters:
                body.parameters = None
            else:
                body.parameters = ["JobID"] + (body.parameters or [])

    # TODO: Apply all the job policy stuff properly using user_info
    global_jobs_info = config.Operations[vo].Services.JobMonitoring.GlobalJobsInfo
    if not global_jobs_info and preferred_username:
        body.search.append(
            {
                "parameter": "Owner",
                "operator": ScalarSearchOperator.EQUAL,
                # TODO-385: https://github.com/DIRACGrid/diracx/issues/385
                # The value should be user_info.sub,
                # but since we historically rely on the preferred_username
                # we will keep using the preferred_username for now.
                "value": preferred_username,
            }
        )

    total, jobs = await job_db.search(
        body.parameters,
        body.search,
        body.sort,
        distinct=body.distinct,
        page=page,
        per_page=per_page,
    )

    if query_logging_info:
        job_logging_info = await job_logging_db.get_records(
            [job["JobID"] for job in jobs]
        )
        for job in jobs:
            job.update({"LoggingInfo": job_logging_info[job["JobID"]]})

    return total, jobs

summary(config, job_db, preferred_username, vo, body) async

Show information suitable for plotting.

Source code in diracx-logic/src/diracx/logic/jobs/query.py
async def summary(
    config: Config,
    job_db: JobDB,
    preferred_username: str | None,
    vo: str,
    body: SummaryParams,
):
    """Show information suitable for plotting."""
    global_jobs_info = config.Operations[vo].Services.JobMonitoring.GlobalJobsInfo
    if not global_jobs_info and preferred_username:
        body.search.append(
            {
                "parameter": "Owner",
                "operator": ScalarSearchOperator.EQUAL,
                # TODO-385: https://github.com/DIRACGrid/diracx/issues/385
                # The value should be user_info.sub,
                # but since we historically rely on the preferred_username
                # we will keep using the preferred_username for now.
                "value": preferred_username,
            }
        )
    return await job_db.summary(body.grouping, body.search)

Job Status

status

Attributes

logger = logging.getLogger(__name__) module-attribute

JOB_ATTRIBUTES_ALIASES = {(field.alias): field_namefor (field_name, field) in (JobAttributes.model_fields.items()) if field.alias} module-attribute

JOB_PARAMETERS_ALIASES = {(field.alias): field_namefor (field_name, field) in (JobParameters.model_fields.items()) if field.alias} module-attribute

Classes

Functions

remove_jobs(job_ids, config, job_db, job_logging_db, sandbox_metadata_db, task_queue_db) async

Fully remove a list of jobs from the WMS databases.

Source code in diracx-logic/src/diracx/logic/jobs/status.py
async def remove_jobs(
    job_ids: list[int],
    config: Config,
    job_db: JobDB,
    job_logging_db: JobLoggingDB,
    sandbox_metadata_db: SandboxMetadataDB,
    task_queue_db: TaskQueueDB,
):
    """Fully remove a list of jobs from the WMS databases."""
    # Remove the staging task from the StorageManager
    # TODO: this was not done in the JobManagerHandler, but it was done in the kill method
    # I think it should be done here too
    # TODO: implement StorageManagerClient
    # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id]))

    # TODO: this was also  not done in the JobManagerHandler, but it was done in the JobCleaningAgent
    # I think it should be done here as well
    await sandbox_metadata_db.unassign_sandboxes_to_jobs(job_ids)

    # Remove the job from TaskQueueDB
    await remove_jobs_from_task_queue(job_ids, config, task_queue_db)

    # Remove the job from JobLoggingDB
    await job_logging_db.delete_records(job_ids)

    # Remove the job from JobDB
    await job_db.delete_jobs(job_ids)

set_job_statuses(status_changes, config, job_db, job_logging_db, task_queue_db, job_parameters_db, force=False, additional_attributes={}) async

Set various status fields for job specified by its jobId. Set only the last status in the JobDB, updating all the status logging information in the JobLoggingDB. The status dict has datetime as a key and status information dictionary as values.

:raises: JobNotFound if the job is not found in one of the DBs

Source code in diracx-logic/src/diracx/logic/jobs/status.py
async def set_job_statuses(
    status_changes: dict[int, dict[datetime, JobStatusUpdate]],
    config: Config,
    job_db: JobDB,
    job_logging_db: JobLoggingDB,
    task_queue_db: TaskQueueDB,
    job_parameters_db: JobParametersDB,
    force: bool = False,
    additional_attributes: dict[int, dict[str, str]] = {},
) -> SetJobStatusReturn:
    """Set various status fields for job specified by its jobId.
    Set only the last status in the JobDB, updating all the status
    logging information in the JobLoggingDB. The status dict has datetime
    as a key and status information dictionary as values.

    :raises: JobNotFound if the job is not found in one of the DBs
    """
    # check that the datetime contains timezone info
    for job_id, status in status_changes.items():
        for dt in status:
            if dt.tzinfo is None:
                raise ValueError(
                    f"Timestamp {dt} is not timezone aware for job {job_id}"
                )

    failed: dict[int, Any] = {}
    deletable_killable_jobs = set()
    job_attribute_updates: dict[int, dict[str, str]] = {}
    skipped_job_attribute_updates: set[int] = set()
    job_logging_updates: list[JobLoggingRecord] = []
    status_dicts: dict[int, dict[datetime, dict[str, str]]] = defaultdict(dict)

    # transform JobStateUpdate objects into dicts
    status_dicts = {
        job_id: {
            key: {k: v for k, v in value.model_dump().items() if v is not None}
            for key, value in status.items()
        }
        for job_id, status in status_changes.items()
    }

    # search all jobs at once
    _, results = await job_db.search(
        parameters=["Status", "StartExecTime", "EndExecTime", "JobID", "VO"],
        search=[
            {
                "parameter": "JobID",
                "operator": VectorSearchOperator.IN,
                "values": list(set(status_changes.keys())),
            }
        ],
        sorts=[],
    )
    if not results:
        return SetJobStatusReturn(
            success={},
            failed={
                int(job_id): {"detail": "Not found"} for job_id in status_changes.keys()
            },
        )

    found_jobs = set(int(res["JobID"]) for res in results)
    failed.update(
        {
            int(nf_job_id): {"detail": "Not found"}
            for nf_job_id in set(status_changes.keys()) - found_jobs
        }
    )
    # Get the latest time stamps of major status updates
    wms_time_stamps = await job_logging_db.get_wms_time_stamps(found_jobs)

    for res in results:
        job_id = int(res["JobID"])
        current_status = res["Status"]
        start_time = res["StartExecTime"]
        end_time = res["EndExecTime"]

        # If the current status is Stalled and we get an update, it should probably be "Running"
        if current_status == JobStatus.STALLED:
            current_status = JobStatus.RUNNING

        #####################################################################################################
        status_dict = status_dicts[job_id]
        # This is more precise than "LastTime". time_stamps is a sorted list of tuples...
        # time_stamps = sorted((float(t), s) for s, t in wms_time_stamps[job_id].items())
        first_status = min(
            wms_time_stamps[job_id].items(), key=lambda x: x[1], default=("", 0)
        )[0]
        last_time = max(wms_time_stamps[job_id].values())

        # Get chronological order of new updates
        update_times = sorted(status_dict)

        new_start_time, new_end_time = getStartAndEndTime(
            start_time,
            end_time,
            update_times,
            # Use a type ignore hint here as it exists solely to use the DIRAC API
            defaultdict(lambda x=first_status: x),  # type: ignore[misc]
            status_dict,
        )

        job_data: dict[str, str] = {}
        new_status: str | None = None
        if update_times[-1] >= last_time:
            new_status, new_minor, new_application = (
                returnValueOrRaise(  # TODO: Catch this
                    getNewStatus(
                        job_id,
                        update_times,
                        last_time,
                        status_dict,
                        current_status,
                        force,
                        MagicMock(),  # FIXME
                    )
                )
            )

            if new_status:
                job_data.update(additional_attributes.get(job_id, {}))
                job_data["Status"] = new_status
                job_data["LastUpdateTime"] = str(datetime.now(timezone.utc))
            if new_minor:
                job_data["MinorStatus"] = new_minor
            if new_application:
                job_data["ApplicationStatus"] = new_application

            await job_parameters_db.upsert(res["VO"], job_id, {"Status": new_status})

        for upd_time in update_times:
            source = status_dict[upd_time]["Source"]
            if source.startswith("Job") or source == "Heartbeat":
                job_data["HeartBeatTime"] = str(upd_time)

        if not start_time and new_start_time:
            job_data["StartExecTime"] = new_start_time

        if not end_time and new_end_time:
            job_data["EndExecTime"] = new_end_time

        #####################################################################################################
        # delete or kill job, if we transition to DELETED or KILLED state
        if new_status in [JobStatus.DELETED, JobStatus.KILLED]:
            deletable_killable_jobs.add(job_id)

        # Update database tables
        if job_data:
            job_attribute_updates[job_id] = job_data
        else:
            skipped_job_attribute_updates.add(job_id)

        for upd_time in update_times:
            s_dict = status_dict[upd_time]
            job_logging_updates.append(
                JobLoggingRecord(
                    job_id=job_id,
                    status=s_dict.get("Status", "idem"),
                    minor_status=s_dict.get("MinorStatus", "idem"),
                    application_status=s_dict.get("ApplicationStatus", "idem"),
                    date=upd_time,
                    source=s_dict.get("Source", "Unknown"),
                )
            )

    if job_attribute_updates:
        await job_db.set_job_attributes(job_attribute_updates)

    await remove_jobs_from_task_queue(
        list(deletable_killable_jobs),
        config,
        task_queue_db,
    )

    # TODO: implement StorageManagerClient
    # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids))

    if deletable_killable_jobs:
        await job_db.set_job_commands(
            [(job_id, "Kill", "") for job_id in deletable_killable_jobs]
        )

    await job_logging_db.insert_records(job_logging_updates)

    return SetJobStatusReturn(
        success=job_attribute_updates | {j: {} for j in skipped_job_attribute_updates},
        failed=failed,
    )

reschedule_jobs(job_ids, config, job_db, job_logging_db, task_queue_db, job_parameters_db, reset_jobs=False) async

Reschedule given job.

Source code in diracx-logic/src/diracx/logic/jobs/status.py
async def reschedule_jobs(
    job_ids: list[int],
    config: Config,
    job_db: JobDB,
    job_logging_db: JobLoggingDB,
    task_queue_db: TaskQueueDB,
    job_parameters_db: JobParametersDB,
    reset_jobs: bool = False,
):
    """Reschedule given job."""
    failed = {}
    status_changes = {}
    attribute_changes: defaultdict[int, dict[str, str]] = defaultdict(dict)
    jdl_changes = {}

    _, results = await job_db.search(
        parameters=[
            "Status",
            "MinorStatus",
            "VerifiedFlag",
            "RescheduleCounter",
            "Owner",
            "OwnerGroup",
            "JobID",
            "VO",
        ],
        search=[
            VectorSearchSpec(
                parameter="JobID", operator=VectorSearchOperator.IN, values=job_ids
            )
        ],
        sorts=[],
    )
    if not results:
        for job_id in job_ids:
            failed[job_id] = {"detail": "Not found"}

    jobs_to_resched = {}

    for job_attrs in results or []:
        job_id = int(job_attrs["JobID"])

        if "VerifiedFlag" not in job_attrs:
            failed[job_id] = {"detail": "Not found: No verified flag"}
            # Noop
            continue

        if not job_attrs["VerifiedFlag"]:
            failed[job_id] = {
                "detail": (
                    f"VerifiedFlag is False: Status {job_attrs['Status']}, "
                    f"Minor Status: {job_attrs['MinorStatus']}"
                )
            }
            # Noop
            continue

        if reset_jobs:
            job_attrs["RescheduleCounter"] = 0
        else:
            job_attrs["RescheduleCounter"] = int(job_attrs["RescheduleCounter"]) + 1

        reschedule_max = config.Operations[
            job_attrs["VO"]
        ].Services.JobScheduling.MaxRescheduling

        if job_attrs["RescheduleCounter"] > reschedule_max:
            status_changes[job_id] = {
                datetime.now(tz=timezone.utc): JobStatusUpdate(
                    Status=JobStatus.FAILED,
                    MinorStatus=JobMinorStatus.MAX_RESCHEDULING,
                    ApplicationStatus="Unknown",
                )
            }
            failed[job_id] = {
                "detail": f"Maximum number of reschedules exceeded ({reschedule_max})"
            }
            continue
        jobs_to_resched[job_id] = job_attrs

    surviving_job_ids = set(jobs_to_resched.keys())

    # TODO: get the job parameters from JobMonitoringClient
    # result = JobMonitoringClient().getJobParameters(jobID)
    # if result["OK"]:
    #     parDict = result["Value"]
    #     for key, value in parDict.get(jobID, {}).items():
    #         result = self.setAtticJobParameter(jobID, key, value, rescheduleCounter - 1)
    #         if not result["OK"]:
    #             break

    # TODO: IF we keep JobParameters and OptimizerParameters: Delete job in those tables.
    # await self.delete_job_parameters(job_id)
    # await self.delete_job_optimizer_parameters(job_id)

    def parse_jdl(job_id: int, job_jdl: str):
        if not job_jdl.strip().startswith("["):
            job_jdl = f"[{job_jdl}]"
        class_ad_job = ClassAd(job_jdl)
        class_ad_job.insertAttributeInt("JobID", job_id)
        return class_ad_job

    job_jdls = {
        jobid: parse_jdl(jobid, extractJDL(jdl))
        for jobid, jdl in (
            (await job_db.get_job_jdls(surviving_job_ids, original=True)).items()
        )
    }

    for job_id, job_attrs in jobs_to_resched.items():
        class_ad_job = job_jdls[job_id]
        class_ad_req = ClassAd("[]")
        try:
            await check_and_prepare_job(
                job_id,
                class_ad_job,
                class_ad_req,
                job_attrs["Owner"],
                job_attrs["OwnerGroup"],
                {"RescheduleCounter": job_attrs["RescheduleCounter"]},
                job_attrs["VO"],
                job_db,
                config,
            )
        except SErrorException as e:
            failed[job_id] = {"detail": str(e)}
            # surviving_job_ids.remove(job_id)
            continue

        priority = class_ad_job.getAttributeInt("Priority")
        if priority is None:
            priority = 0

        site_list = class_ad_job.getListFromExpression("Site")
        if not site_list:
            site = "ANY"
        elif len(site_list) > 1:
            site = "Multiple"
        else:
            site = site_list[0]

        req_jdl = class_ad_req.asJDL()
        class_ad_job.insertAttributeInt("JobRequirements", req_jdl)
        job_jdl = class_ad_job.asJDL()
        # Replace the JobID placeholder if any
        job_jdl = job_jdl.replace("%j", str(job_id))

        additional_attrs = {
            "Site": site,
            "UserPriority": priority,
            "RescheduleTime": datetime.now(tz=timezone.utc),
            "RescheduleCounter": job_attrs["RescheduleCounter"],
        }

        # set new JDL
        jdl_changes[job_id] = compressJDL(job_jdl)

        # set new status
        status_changes[job_id] = {
            datetime.now(tz=timezone.utc): JobStatusUpdate(
                Status=JobStatus.RECEIVED,
                MinorStatus=JobMinorStatus.RESCHEDULED,
                ApplicationStatus="Unknown",
            )
        }
        # set new attributes
        attribute_changes[job_id].update(additional_attrs)

    success = {}
    if surviving_job_ids:
        set_job_status_result = await set_job_statuses(
            status_changes=status_changes,
            config=config,
            job_db=job_db,
            job_logging_db=job_logging_db,
            task_queue_db=task_queue_db,
            job_parameters_db=job_parameters_db,
            additional_attributes=attribute_changes,
        )

        await job_db.update_job_jdls(jdl_changes)

        for job_id, set_status_result in set_job_status_result.success.items():
            if job_id in failed:
                continue

            jdl = job_jdls.get(job_id, None)
            if jdl:
                jdl = jdl.asJDL()

            success[job_id] = {
                "InputData": jdl,
                **attribute_changes[job_id],
                **set_status_result.model_dump(),
            }

    return {"failed": failed, "success": success}

remove_jobs_from_task_queue(job_ids, config, task_queue_db) async

Remove the job from TaskQueueDB.

Source code in diracx-logic/src/diracx/logic/jobs/status.py
async def remove_jobs_from_task_queue(
    job_ids: list[int],
    config: Config,
    task_queue_db: TaskQueueDB,
):
    """Remove the job from TaskQueueDB."""
    await task_queue_db.remove_jobs(job_ids)

    tq_infos = await task_queue_db.get_tq_infos_for_jobs(job_ids)
    for tq_id, owner, owner_group, vo in tq_infos:
        # TODO: move to Celery

        # If the task queue is not empty, do not remove it
        if not await task_queue_db.is_task_queue_empty(tq_id):
            continue

        await task_queue_db.delete_task_queue(tq_id)

        # Recalculate shares for the owner group
        await recalculate_tq_shares_for_entity(
            owner, owner_group, vo, config, task_queue_db
        )

set_job_parameters_or_attributes(updates, job_db, job_parameters_db) async

Set job parameters or attributes for a list of jobs.

Source code in diracx-logic/src/diracx/logic/jobs/status.py
async def set_job_parameters_or_attributes(
    updates: dict[int, JobMetaData],
    job_db: JobDB,
    job_parameters_db: JobParametersDB,
):
    """Set job parameters or attributes for a list of jobs."""
    # Those dicts create a mapping of job_id -> {attribute_name: value}
    attr_updates: dict[int, dict[str, Any]] = {}
    param_updates: dict[int, dict[str, Any]] = {}

    for job_id, metadata in updates.items():
        attr_updates[job_id] = {}
        param_updates[job_id] = {}
        for pname, pvalue in metadata.model_dump(
            by_alias=True, exclude_none=True
        ).items():
            # An argument can be a job attribute and/or a job parameter

            # Check if the argument is a valid job attribute (using alias)
            if pname in JOB_ATTRIBUTES_ALIASES:
                attr_updates[job_id][pname] = pvalue

            # Check if the argument is a valid job parameter (using alias)
            if pname in JOB_PARAMETERS_ALIASES:
                param_updates[job_id][pname] = pvalue

            # If the field is not in either known aliases, default to treating it as a parameter
            # This allows for more flexible metadata handling
            elif pname not in JOB_ATTRIBUTES_ALIASES:
                param_updates[job_id][pname] = pvalue

    # Bulk set job attributes if required
    attr_updates = {k: v for k, v in attr_updates.items() if v}
    if attr_updates:
        await job_db.set_job_attributes(attr_updates)

    # Bulk set job parameters if required
    await _insert_parameters(param_updates, job_parameters_db, job_db)

add_heartbeat(data, config, job_db, job_logging_db, task_queue_db, job_parameters_db) async

Send a heart beat sign of life for a job jobID.

Source code in diracx-logic/src/diracx/logic/jobs/status.py
async def add_heartbeat(
    data: dict[int, HeartbeatData],
    config: Config,
    job_db: JobDB,
    job_logging_db: JobLoggingDB,
    task_queue_db: TaskQueueDB,
    job_parameters_db: JobParametersDB,
) -> None:
    """Send a heart beat sign of life for a job jobID."""
    # Find the current status of the jobs
    search_query: VectorSearchSpec = {
        "parameter": "JobID",
        "operator": VectorSearchOperator.IN,
        "values": list(data),
    }
    _, results = await job_db.search(
        parameters=["Status", "JobID"], search=[search_query], sorts=[]
    )
    if len(results) != len(data):
        raise ValueError(f"Failed to lookup job IDs: {data.keys()=} {results=}")
    status_changes = {
        int(result["JobID"]): {
            datetime.now(timezone.utc): JobStatusUpdate(
                Status=JobStatus.RUNNING,
                Source="Heartbeat",
            )
        }
        for result in results
        if result["Status"] in [JobStatus.MATCHED, JobStatus.STALLED]
    }

    async with TaskGroup() as tg:
        if status_changes:
            tg.create_task(
                set_job_statuses(
                    status_changes=status_changes,
                    config=config,
                    job_db=job_db,
                    job_logging_db=job_logging_db,
                    task_queue_db=task_queue_db,
                    job_parameters_db=job_parameters_db,
                )
            )

        if other_ids := set(data) - set(status_changes):
            # If there are no status changes, we still need to update the heartbeat time
            heartbeat_updates = {
                job_id: {"HeartBeatTime": utcnow()} for job_id in other_ids
            }
            tg.create_task(job_db.set_job_attributes(heartbeat_updates))

        os_data_by_job_id: defaultdict[int, dict[str, Any]] = defaultdict(dict)
        for job_id, job_data in data.items():
            sql_data = {}
            for key, value in job_data.model_dump(exclude_defaults=True).items():
                if key in job_db.heartbeat_fields:
                    sql_data[key] = value
                else:
                    os_data_by_job_id[job_id][key] = value

            if sql_data:
                tg.create_task(job_db.add_heartbeat_data(job_id, sql_data))

        await _insert_parameters(os_data_by_job_id, job_parameters_db, job_db)

get_job_commands(job_ids, job_db) async

Get the pending job commands for a list of job IDs.

This function automatically marks the commands as "Sent" in the database.

Source code in diracx-logic/src/diracx/logic/jobs/status.py
async def get_job_commands(job_ids: Iterable[int], job_db: JobDB) -> list[JobCommand]:
    """Get the pending job commands for a list of job IDs.

    This function automatically marks the commands as "Sent" in the database.
    """
    return await job_db.get_job_commands(job_ids)

Job Sandboxes

sandboxes

Attributes

MAX_SANDBOX_SIZE_BYTES = 100 * 1024 * 1024 module-attribute

SANDBOX_PFN_REGEX = '^(:?SB:[A-Za-z]+\\|)?/S3/[a-z0-9\\.\\-]{3,63}(?:/[^/]+){3}/[a-z0-9]{3,10}:[0-9a-f]{64}\\.[a-z0-9\\.]+$' module-attribute

logger = logging.getLogger(__name__) module-attribute

Classes

Functions

initiate_sandbox_upload(user_info, sandbox_info, sandbox_metadata_db, settings) async

Get the PFN for the given sandbox, initiate an upload as required.

If the sandbox already exists in the database then the PFN is returned and there is no "url" field in the response.

If the sandbox does not exist in the database then the "url" and "fields" should be used to upload the sandbox to the storage backend.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def initiate_sandbox_upload(
    user_info: UserInfo,
    sandbox_info: SandboxInfo,
    sandbox_metadata_db: SandboxMetadataDB,
    settings: SandboxStoreSettings,
) -> SandboxUploadResponse:
    """Get the PFN for the given sandbox, initiate an upload as required.

    If the sandbox already exists in the database then the PFN is returned
    and there is no "url" field in the response.

    If the sandbox does not exist in the database then the "url" and "fields"
    should be used to upload the sandbox to the storage backend.
    """
    pfn = sandbox_metadata_db.get_pfn(settings.bucket_name, user_info, sandbox_info)

    # TODO: This test should come first, but if we do
    # the access policy will crash for not having been called
    # so we need to find a way to acknowledge that

    if sandbox_info.size > MAX_SANDBOX_SIZE_BYTES:
        raise ValueError(
            f"Sandbox too large, maximum allowed is {MAX_SANDBOX_SIZE_BYTES} bytes"
        )
    full_pfn = f"SB:{settings.se_name}|{pfn}"

    try:
        exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned(
            pfn, settings.se_name
        )
    except SandboxNotFoundError:
        # The sandbox doesn't exist in the database
        pass
    else:
        # As sandboxes are registered in the DB before uploading to the storage
        # backend we can't rely on their existence in the database to determine if
        # they have been uploaded. Instead we check if the sandbox has been
        # assigned to a job. If it has then we know it has been uploaded and we
        # can avoid communicating with the storage backend.
        if exists_and_assigned or await s3_object_exists(
            settings.s3_client, settings.bucket_name, pfn_to_key(pfn)
        ):
            await sandbox_metadata_db.update_sandbox_last_access_time(
                settings.se_name, pfn
            )
            return SandboxUploadResponse(pfn=full_pfn)

    upload_info = await generate_presigned_upload(
        settings.s3_client,
        settings.bucket_name,
        pfn_to_key(pfn),
        sandbox_info.checksum_algorithm,
        sandbox_info.checksum,
        sandbox_info.size,
        settings.url_validity_seconds,
    )
    await insert_sandbox(
        sandbox_metadata_db, settings.se_name, user_info, pfn, sandbox_info.size
    )

    return SandboxUploadResponse(**upload_info, pfn=full_pfn)

get_sandbox_file(pfn, sandbox_metadata_db, settings) async

Get a presigned URL to download a sandbox file.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def get_sandbox_file(
    pfn: str,
    sandbox_metadata_db: SandboxMetadataDB,
    settings: SandboxStoreSettings,
) -> SandboxDownloadResponse:
    """Get a presigned URL to download a sandbox file."""
    short_pfn = pfn.split("|", 1)[-1]

    await sandbox_metadata_db.update_sandbox_last_access_time(
        settings.se_name, short_pfn
    )

    # TODO: Support by name and by job id?
    presigned_url = await settings.s3_client.generate_presigned_url(
        ClientMethod="get_object",
        Params={"Bucket": settings.bucket_name, "Key": pfn_to_key(short_pfn)},
        ExpiresIn=settings.url_validity_seconds,
    )
    return SandboxDownloadResponse(
        url=presigned_url, expires_in=settings.url_validity_seconds
    )

get_job_sandboxes(job_id, sandbox_metadata_db) async

Get input and output sandboxes of given job.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def get_job_sandboxes(
    job_id: int,
    sandbox_metadata_db: SandboxMetadataDB,
) -> dict[str, list[Any]]:
    """Get input and output sandboxes of given job."""
    input_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job(
        job_id, SandboxType.Input
    )
    output_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job(
        job_id, SandboxType.Output
    )
    return {SandboxType.Input: input_sb, SandboxType.Output: output_sb}

get_job_sandbox(job_id, sandbox_metadata_db, sandbox_type) async

Get input or output sandbox of given job.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def get_job_sandbox(
    job_id: int,
    sandbox_metadata_db: SandboxMetadataDB,
    sandbox_type: Literal["input", "output"],
) -> list[Any]:
    """Get input or output sandbox of given job."""
    return await sandbox_metadata_db.get_sandbox_assigned_to_job(
        job_id, SandboxType(sandbox_type.capitalize())
    )

assign_sandbox_to_job(job_id, pfn, sandbox_metadata_db, settings) async

Map the pfn as output sandbox to job.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def assign_sandbox_to_job(
    job_id: int,
    pfn: str,
    sandbox_metadata_db: SandboxMetadataDB,
    settings: SandboxStoreSettings,
):
    """Map the pfn as output sandbox to job."""
    short_pfn = pfn.split("|", 1)[-1]
    await sandbox_metadata_db.assign_sandbox_to_jobs(
        jobs_ids=[job_id],
        pfn=short_pfn,
        sb_type=SandboxType.Output,
        se_name=settings.se_name,
    )

unassign_jobs_sandboxes(jobs_ids, sandbox_metadata_db) async

Delete bulk jobs sandbox mapping.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def unassign_jobs_sandboxes(
    jobs_ids: list[int],
    sandbox_metadata_db: SandboxMetadataDB,
):
    """Delete bulk jobs sandbox mapping."""
    await sandbox_metadata_db.unassign_sandboxes_to_jobs(jobs_ids)

pfn_to_key(pfn)

Convert a PFN to a key for S3.

This removes the leading "/S3/" from the PFN.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
def pfn_to_key(pfn: str) -> str:
    """Convert a PFN to a key for S3.

    This removes the leading "/S3/<bucket_name>" from the PFN.
    """
    return "/".join(pfn.split("/")[3:])

insert_sandbox(sandbox_metadata_db, se_name, user, pfn, size) async

Add a new sandbox in SandboxMetadataDB.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def insert_sandbox(
    sandbox_metadata_db: SandboxMetadataDB,
    se_name: str,
    user: UserInfo,
    pfn: str,
    size: int,
) -> None:
    """Add a new sandbox in SandboxMetadataDB."""
    # TODO: Follow https://github.com/DIRACGrid/diracx/issues/49
    owner_id = await sandbox_metadata_db.get_owner_id(user)
    if owner_id is None:
        owner_id = await sandbox_metadata_db.insert_owner(user)

    try:
        await sandbox_metadata_db.insert_sandbox(owner_id, se_name, pfn, size)
    except SandboxAlreadyInsertedError:
        await sandbox_metadata_db.update_sandbox_last_access_time(se_name, pfn)

clean_sandboxes(sandbox_metadata_db, settings, *, limit=10000, max_concurrent_batches=10) async

Delete sandboxes that are not assigned to any job.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def clean_sandboxes(
    sandbox_metadata_db: SandboxMetadataDB,
    settings: SandboxStoreSettings,
    *,
    limit: int = 10_000,
    max_concurrent_batches: int = 10,
) -> int:
    """Delete sandboxes that are not assigned to any job."""
    semaphore = asyncio.Semaphore(max_concurrent_batches)
    n_deleted = 0
    async with (
        sandbox_metadata_db.delete_unused_sandboxes(limit=limit) as generator,
        asyncio.TaskGroup() as tg,
    ):
        async for batch in batched_async(generator, 500):
            objects: list[S3Object] = [{"Key": pfn_to_key(pfn)} for pfn in batch]
            if logger.isEnabledFor(logging.DEBUG):
                for pfn in batch:
                    logger.debug("Deleting sandbox %s from S3", pfn)
            tg.create_task(delete_batch_and_log(settings, objects, semaphore))
            n_deleted += len(objects)
    return n_deleted

delete_batch_and_log(settings, objects, semaphore) async

Helper function to delete a batch of objects and log the result.

Source code in diracx-logic/src/diracx/logic/jobs/sandboxes.py
async def delete_batch_and_log(
    settings: SandboxStoreSettings,
    objects: list[S3Object],
    semaphore: asyncio.Semaphore,
) -> None:
    """Helper function to delete a batch of objects and log the result."""
    async with semaphore:
        await s3_bulk_delete_with_retry(
            settings.s3_client, settings.bucket_name, objects
        )
        logger.info("Deleted %d sandboxes from %s", len(objects), settings.bucket_name)

Job Utilities

utils

Classes

Functions

make_job_manifest_config(config, vo)

Create job manifest configuration for DIRACCommon functions from diracx config.

Source code in diracx-logic/src/diracx/logic/jobs/utils.py
def make_job_manifest_config(config: Config, vo: str):
    """Create job manifest configuration for DIRACCommon functions from diracx config."""
    job_desc = config.Operations[vo].JobDescription

    return {
        "defaultForGroup": {
            "CPUTime": job_desc.DefaultCPUTime,
            "Priority": job_desc.DefaultPriority,
        },
        "minForGroup": {
            "CPUTime": job_desc.MinCPUTime,
            "Priority": job_desc.MinPriority,
        },
        "maxForGroup": {
            "CPUTime": job_desc.MaxCPUTime,
            "Priority": job_desc.MaxPriority,
        },
        "allowedJobTypesForGroup": job_desc.AllowedJobTypes,
        "maxInputData": job_desc.MaxInputData,
    }

make_check_and_prepare_job_config(config, vo)

Create checkAndPrepareJob configuration for DIRACCommon functions from diracx config.

Source code in diracx-logic/src/diracx/logic/jobs/utils.py
def make_check_and_prepare_job_config(config: Config, vo: str):
    """Create checkAndPrepareJob configuration for DIRACCommon functions from diracx config."""
    ops = config.Operations[vo]
    return {
        "inputDataPolicyForVO": ops.InputDataPolicy.InputDataModule,
        "softwareDistModuleForVO": ops.SoftwareDistModule,
        "defaultCPUTimeForOwnerGroup": ops.JobDescription.DefaultCPUTime,
        "getDIRACPlatform": partial(find_compatible_platforms, config=config),
    }

check_and_prepare_job(job_id, class_ad_job, class_ad_req, owner, owner_group, job_attrs, vo, job_db, config) async

Check Consistency of Submitted JDL and set some defaults Prepare subJDL with Job Requirements.

Source code in diracx-logic/src/diracx/logic/jobs/utils.py
async def check_and_prepare_job(
    job_id: int,
    class_ad_job: ClassAd,
    class_ad_req: ClassAd,
    owner: str,
    owner_group: str,
    job_attrs: dict,
    vo: str,
    job_db: JobDB,
    config: Config,
):
    """Check Consistency of Submitted JDL and set some defaults
    Prepare subJDL with Job Requirements.
    """
    # Create configuration dict for DIRACCommon function from diracx config
    dirac_config = make_check_and_prepare_job_config(config, vo)

    ret_val = checkAndPrepareJob(
        job_id,
        class_ad_job,
        class_ad_req,
        owner,
        owner_group,
        job_attrs,
        vo,
        config=dirac_config,
    )

    if not ret_val["OK"]:
        if cmpError(ret_val, EWMSSUBM):
            await job_db.set_job_attributes({job_id: job_attrs})

        returnValueOrRaise(ret_val)