Source code for stelar.client.tasks

"""Basic support for tasks in the client."""

from __future__ import annotations

import time
from datetime import datetime
from functools import cached_property
from typing import TYPE_CHECKING
from uuid import UUID

from .api_call import api_call
from .dataset import Dataset
from .generic import GenericCursor, GenericProxy, GenericProxyList
from .proxy import (
    DateField,
    Id,
    Property,
    ProxySynclist,
    Reference,
    StrField,
    derived_property,
)
from .proxy.property import DictProperty
from .utils import client_for

if TYPE_CHECKING:
    # Type definitions for
    from .workflows import Tool

    ToolSpec = Tool | UUID | str | None
    ImageSpec = str | None
    ExistingDataset = Dataset | UUID | str


[docs] class Task(GenericProxy): id = Id(entity_name="id") start_date = Property(validator=DateField) end_date = Property(validator=DateField(nullable=True)) exec_state = Property(validator=StrField) creator = Property(validator=StrField, entity_name="creator") process = Reference("Process", entity_name="process_id", trigger_sync=True) messages = Property(validator=StrField, updatable=False, optional=True) # TODO: parameters are dict[str, json] parameters = DictProperty(str, str, updatable=False, optional=True) metrics = DictProperty(str, str, updatable=False, optional=True) inputs = DictProperty(str, list[str], updatable=False, optional=True) outputs = DictProperty(str, dict, updatable=False, optional=True) tool = Property(validator=StrField, optional=True) image = Property(validator=StrField, optional=True) # TODO: return the tool proxy # via a method tags = DictProperty(str, str, updatable=False, optional=True) @derived_property def is_external(self, entity) -> bool: """Check if this task is local (i.e., not using a remote tool).""" return "image" not in entity @derived_property def done(self, entity) -> bool: """Check if this task is done (i.e., has an end date).""" return entity["end_date"] is not None
[docs] def sync_state(self): self.proxy_sync() """Check if this task is done (i.e., has an end date).""" return self.exec_state
[docs] def wait(self, timeout: float = 5.0, polling_interval: float = 1.0): """Wait for the task to finish. Args: timeout (float): The maximum time to wait in seconds. Defaults to 5.0. polling_interval (float): The interval between checks in seconds. Defaults to 1.0. """ while self.sync_state() not in ("succeeded", "failed") and timeout > 0: time.sleep(polling_interval) timeout -= polling_interval return self.exec_state in ("succeeded", "failed")
@derived_property def runtime(self, entity) -> float | None: """Calculate the runtime of the task in seconds. Returns: float: The runtime in seconds, or None if the task is not done. """ start_date = datetime.fromisoformat(entity["start_date"]) end_date = ( datetime.fromisoformat(entity["end_date"]) if entity["end_date"] is not None else datetime.now() ) return (end_date - start_date).total_seconds() @cached_property def signature(self) -> str: """Fetch the signature for this task. The signature is a unique identifier for the task, which allows calling methods `job_input()` and `exit_job()` to implement an external task, even from a connection that is not authenticated. """ return api_call(self).task_signature(str(self.id))["signature"] @cached_property def job_input(self, signature=None) -> dict: """Return the job input for this local task. The job input is a dictionary that contains the inputs and parameters of the task, formatted for use in a job execution context. """ if not self.is_external: raise ValueError("Cannot get job input for non-external tasks") if signature is None: signature = self.signature if signature is None: raise ValueError( "No signature, cannot get job input for a task without a signature" ) return api_call(self).task_job_input(task_id=str(self.id), signature=signature)
[docs] def exit_job( self, *, metrics={}, output={}, message="The task is finished", status="succeeded", signature=None, ): """Post the job output for this external task. The job output contains the outputs and metrics of the task, the final message and a task status. Args: metrics (dict): A dictionary of metrics to post. output (dict): A dictionary of outputs to post. message (str): A message to post with the job output. status (str): The status of the task, e.g., "succeeded" or "failed". signature (str, optional): The signature of the task. If not provided, the signature will be fetched. Raises: ValueError: If the task is not external or if the POST request fails. """ if not self.is_external: raise ValueError("Cannot post job output for non-external tasks") if signature is None: signature = self.signature if signature is None: raise ValueError( "No signature, cannot post job output for a task without a signature" ) output_spec = { "metrics": metrics, "output": output, "message": message, "status": status, } psl = ProxySynclist([self]) api_call(self).task_post_job_output(str(self.id), signature, output_spec) psl.on_update_all(self) psl.sync()
[docs] def fail( self, message: str, *, metrics={}, output={}, ): """Mark this external task as failed with a message. This method sets the task's execution state to 'failed' and updates the messages property with the provided message. """ return self.exit_job( metrics=metrics, output=output, message=message, status="failed" )
[docs] def abort(self, message: str = None): """Abort this task. This method sets the task's execution state to 'failed' and updates the messages property with the provided message. """ if message is None: c = client_for(self) message = f"Task was aborted by user {c.users.current_user.username} at {datetime.now().isoformat()}" output_spec = { "message": message, "status": "failed", } psl = ProxySynclist([self]) api_call(self).task_post_job_output(str(self.id), self.signature, output_spec) psl.sync()
[docs] def jobs(self): """Return a dictionary with info about Kubernetes jobs for this task.""" response = client_for(self).GET("v2/task", self.id, "jobs") response.raise_for_status() return response.json()["result"]
[docs] def logs(self): """Return the logs for this task.""" response = client_for(self).GET("v2/task", self.id, "logs") response.raise_for_status() return response.json()["result"]
[docs] def printlog(self, *, file=None, flush=False): for job, log in self.logs().items(): print(job, ":", file=file, flush=flush) print(log, file=file, flush=flush)
[docs] class TaskCursor(GenericCursor): """A cursor for a collection of STELAR tasks.""" def __init__(self, api): super().__init__(api, Task)
[docs] def fetch(self, **kwargs): raise NotImplementedError("TaskCursor does not support fetch operations.")
[docs] def fetch_list( self, *, state="created", limit: int = 100, offset: int = 0 ) -> list[Task]: """Fetch a list of tasks in a specific state. Args: state (str): The state of the tasks to fetch. Defaults to 'running'. **kwargs: Additional keyword arguments to pass to the API call. Returns: list[Task]: A list of Task objects in the specified state. """ ac = api_call(self) tasks = ac.task_list(state=state, limit=limit, offset=offset) flat_tasks = [] if tasks: # Note : tasks is a dict with keys for each state, or None!! for s in ("created", "running", "succeeded", "failed"): flat_tasks.extend(tasks.get(s, [])) return GenericProxyList(flat_tasks, client_for(self), self.proxy_type)
[docs] def created(self, *, limit: int = 100, offset: int = 0) -> list[Task]: """Fetch a list of tasks in the 'created' state. Args: limit (int): The maximum number of tasks to fetch. Defaults to 100. offset (int): The offset for pagination. Defaults to 0. Returns: list[Task]: A list of Task objects in the 'created' state. """ return self.fetch_list(state="created", limit=limit, offset=offset)
[docs] def running(self, *, limit: int = 100, offset: int = 0) -> list[Task]: """Fetch a list of tasks in the 'running' state. Args: limit (int): The maximum number of tasks to fetch. Defaults to 100. offset (int): The offset for pagination. Defaults to 0. Returns: list[Task]: A list of Task objects in the 'running' state. """ return self.fetch_list(state="running", limit=limit, offset=offset)
[docs] def succeeded(self, *, limit: int = 100, offset: int = 0) -> list[Task]: """Fetch a list of tasks in the 'succeeded' state. Args: limit (int): The maximum number of tasks to fetch. Defaults to 100. offset (int): The offset for pagination. Defaults to 0. Returns: list[Task]: A list of Task objects in the 'succeeded' state. """ return self.fetch_list(state="succeeded", limit=limit, offset=offset)
[docs] def failed(self, *, limit: int = 100, offset: int = 0) -> list[Task]: """Fetch a list of tasks in the 'failed' state. Args: limit (int): The maximum number of tasks to fetch. Defaults to 100. offset (int): The offset for pagination. Defaults to 0. Returns: list[Task]: A list of Task objects in the 'failed' state. """ return self.fetch_list(state="failed", limit=limit, offset=offset)
[docs] def create(self, process_id: UUID | str, task_spec: TaskSpec, secrets=None) -> Task: """Create a new Task using the provided TaskSpec. Args: process_id (UUID): The ID of the process to which the task belongs. task_spec (TaskSpec): The specification for the task to create. secrets (dict, optional): Secrets to pass to the task. Returns: Task: The created task. """ json_spec = task_spec.spec() json_spec["process_id"] = str(process_id) if secrets is not None: json_spec["secrets"] = secrets ac = api_call(self) resp = ac.task_create(**json_spec) task_id = resp["id"] task = self.fetch_proxy(UUID(task_id)) psl = ProxySynclist() psl.on_create_proxy(task) psl.sync() return task