from __future__ import annotations
"""Classes used to access the STELAR API.
"""
from typing import TYPE_CHECKING, Optional
from .proxy import EntityNotFound, Proxy, ProxyCursor, ProxyList, ProxyOperationError
from .utils import client_for
if TYPE_CHECKING:
from .client import Client
APIContext = Proxy | Client | ProxyCursor | ProxyList
[docs]
class api_context:
def __init__(self, arg: APIContext):
from .client import Client
if isinstance(arg, Proxy):
self.proxy = arg
self.client = client_for(self.proxy)
self.proxy_id = self.proxy.proxy_id
self.proxy_type = type(self.proxy)
elif isinstance(arg, (ProxyCursor, ProxyList)):
self.proxy = None
self.client = arg.client
self.proxy_id = None
self.proxy_type = arg.proxy_type
elif isinstance(arg, Client):
self.proxy = None
self.client = arg
self.proxy_id = None
self.proxy_type = None
[docs]
class api_model:
[docs]
@staticmethod
def from_value(value):
if isinstance(value, dict):
return api_model(**{a: api_model.from_value(b) for a, b in value.items()})
elif isinstance(value, list):
return [api_model.from_value(v) for v in value]
else:
return value
def __init__(self, **fields):
self.members = []
self.search = None
for name, value in fields.items():
setattr(self, name, value)
[docs]
def get_method(self, op, mm: api_model = None):
if mm is None:
return f"{self.name}_{op}"
else:
return f"{self.name}_{op}_{mm.name}"
api_models = {
"Dataset": {
"name": "dataset",
"collection_name": "datasets",
"search": "solr_search",
},
"Group": {
"name": "group",
"collection_name": "groups",
"members": ["Dataset", "Workflow", "Tool", "Group", "User"],
},
"ImageRegistryToken": {
"name": "image_registry_token",
"collection_name": "image_registry_token",
},
"License": {
"name": "license",
"collection_name": "licenses",
},
"Organization": {
"name": "organization",
"collection_name": "organizations",
"members": ["Dataset", "Workflow", "Tool", "Group", "User"],
},
"Policy": {
"name": "policy",
"collection_name": "policy",
},
"Process": {
"name": "process",
"collection_name": "processes",
"members": ["Task"],
"search": "solr_search",
},
"Resource": {
"name": "resource",
"collection_name": "resources",
"search": "resource_search",
},
"Tag": {
"name": "tag",
"collection_name": "tags",
},
"Task": {
"name": "task",
"collection_name": "tasks",
},
"Tool": {
"name": "tool",
"collection_name": "tools",
"search": "solr_search",
},
"User": {
"name": "user",
"collection_name": "users",
},
"Vocabulary": {
"name": "vocabulary",
"collection_name": "vocabularies",
},
"Workflow": {
"name": "workflow",
"collection_name": "workflows",
"search": "solr_search",
},
}
for m in api_models:
api_models[m] = api_model.from_value(api_models[m])
for m in api_models:
api_models[m].members = [api_models[mm] for mm in api_models[m].members]
OPERATIONS = [
"create",
"show",
"update",
"patch",
"delete",
"list",
"fetch",
"purge",
"search",
]
MEMBER_OPERATIONS = ["add", "remove", "list_members"]
SEARCH_OPERATIONS = ["solr_search", "resource_search"]
[docs]
class api_call_base(api_context):
"""Access the STELAR API using a client or a proxy.
This is the base class for api_call, defining the generic methods
for all types of entities.
"""
def __init__(self, arg: APIContext):
super().__init__(arg)
[docs]
def request(
self, method: str, endpoint: str, params: dict = None, *, json=None, **kwargs
):
if json is None:
if kwargs:
json = dict(kwargs)
else:
json = json | kwargs
# This may raise requests exceptions
resp = self.client.api_request(method, endpoint, params=params, json=json)
jsout = resp.json()
match jsout:
case {"success": True, "result": result}:
return result
case {"success": False, "error": error}:
match resp.status_code:
case 404:
try:
api_name = api_models[self.proxy_type.__name__].name
entity_purged = error["detail"]["entity"] == api_name
except Exception:
entity_purged = False
raise EntityNotFound(
self.proxy_type,
self.proxy_id,
f"{method} {endpoint}",
purged=entity_purged,
)
case _:
raise ProxyOperationError(
self.proxy_type,
self.proxy_id,
f"{method} {endpoint}",
error,
)
case _:
raise RuntimeError(
"Unexpected response from the server",
method,
endpoint,
params,
json,
jsout,
resp.status_code,
resp,
)
[docs]
def get_call(self, proxy_type, op, member_type=None):
m = api_models[proxy_type.__name__]
if member_type is None:
call_name = m.get_method(op)
else:
mm = api_models[member_type.__name__]
call_name = m.get_method(op, mm)
return getattr(self, call_name)
# Populate the api_call class with the STELAR API endpoints
[docs]
def generate_list(model: api_model):
def gen_list(self, limit=None, offset=None):
verb = "GET"
endpoint = f"v2/{model.collection_name}"
params = {"limit": limit, "offset": offset}
return self.request(verb, endpoint, params)
return gen_list
[docs]
def generate_fetch(model: api_model):
def gen_fetch(self, limit=None, offset=None):
verb = "GET"
endpoint = f"v2/{model.collection_name}.fetch"
params = {"limit": limit, "offset": offset}
return self.request(verb, endpoint, params)
return gen_fetch
[docs]
def generate_show(model: api_model):
def gen_show(self, id):
verb = "GET"
endpoint = f"v2/{model.name}/{id}"
return self.request(verb, endpoint)
return gen_show
[docs]
def generate_create(model: api_model):
def gen_create(self, **kwargs):
verb = "POST"
endpoint = f"v2/{model.name}"
return self.request(verb, endpoint, json=kwargs)
return gen_create
[docs]
def generate_update(model: api_model):
def gen_update(self, id, **kwargs):
verb = "PUT"
endpoint = f"v2/{model.name}/{id}"
return self.request(verb, endpoint, json=kwargs)
return gen_update
[docs]
def generate_patch(model: api_model):
def gen_patch(self, id, **kwargs):
verb = "PATCH"
endpoint = f"v2/{model.name}/{id}"
return self.request(verb, endpoint, json=kwargs)
return gen_patch
[docs]
def generate_delete(model: api_model):
def gen_delete(self, id):
verb = "DELETE"
endpoint = f"v2/{model.name}/{id}"
return self.request(verb, endpoint)
return gen_delete
[docs]
def generate_purge(model: api_model):
def gen_purge(self, id):
verb = "DELETE"
endpoint = f"v2/{model.name}/{id}?purge=true"
return self.request(verb, endpoint)
return gen_purge
[docs]
def generate_add(model: api_model, mm: api_model):
def gen_add(self, id, member_id, capacity=None):
verb = "POST"
endpoint = f"v2/{model.name}/{id}/{mm.name}/{member_id}"
return self.request(verb, endpoint, json={"capacity": capacity})
return gen_add
[docs]
def generate_remove(model: api_model, mm: api_model):
def gen_remove(self, id, member_id):
verb = "DELETE"
endpoint = f"v2/{model.name}/{id}/{mm.name}/{member_id}"
return self.request(verb, endpoint)
return gen_remove
[docs]
def generate_list_members(model: api_model, mm: api_model):
def gen_list_members(self, id, capacity=None):
verb = "GET"
endpoint = f"v2/{model.name}/{id}/{mm.collection_name}"
if capacity is not None:
endpoint += f"?capacity={capacity}"
return self.request(verb, endpoint)
return gen_list_members
[docs]
def generate_solr_search(model: api_model):
def gen_search(self, query_spec):
verb = "POST"
endpoint = f"v2/search/{model.collection_name}"
return self.request(verb, endpoint, json=query_spec)
return gen_search
[docs]
def generate_resource_search(model: api_model):
def gen_search(self, query_spec):
verb = "POST"
endpoint = f"v2/search/{model.collection_name}"
return self.request(verb, endpoint, json=query_spec)
return gen_search
[docs]
def generate_unimplemented(model: api_model, op, mm=None):
def gen_unimplemented(self, *args, **kwargs):
raise NotImplementedError(api_model.name, op, args, kwargs)
return gen_unimplemented
# Instrumenting api_call_base with the generated methods.
# Where there is no specialized method defined,
# add the generated generic method to the api_call class.
for ptname in api_models:
model = api_models[ptname]
for op in OPERATIONS:
call_name = model.get_method(op)
match op:
case "create":
gcall = generate_create(model)
case "show":
gcall = generate_show(model)
case "update":
gcall = generate_update(model)
case "patch":
gcall = generate_patch(model)
case "delete":
gcall = generate_delete(model)
case "list":
gcall = generate_list(model)
case "fetch":
gcall = generate_fetch(model)
case "purge":
gcall = generate_purge(model)
case "search":
if model.search == "solr_search":
gcall = generate_solr_search(model)
elif model.search == "resource_search":
gcall = generate_resource_search(model)
else:
gcall = generate_unimplemented(model, op)
gcall.__qualname__ = f"api_call_base.{call_name}"
gcall.__name__ = call_name
setattr(api_call_base, call_name, gcall)
# Add the generated member methods to the api_call class
for mm in model.members:
for op in MEMBER_OPERATIONS:
call_name = model.get_method(op, mm)
match op:
case "add":
gcall = generate_add(model, mm)
case "remove":
gcall = generate_remove(model, mm)
case "list_members":
gcall = generate_list_members(model, mm)
gcall.__qualname__ = f"api_call_base.{call_name}"
gcall.__name__ = call_name
setattr(api_call_base, call_name, gcall)
[docs]
class api_call(api_call_base):
"""Class that exposes the STELAR API for a given entity.
`api_call(proxy).foo(...)`
returns the 'result' of the STELAR API response on success,
and raises a ProxyOperationError on failure.
`api_call(client).foo(...)`
does the same.
"""
def __init__(self, arg: APIContext):
super().__init__(arg)
# def tag_list(self, vocabulary_id: str = None):
# raise NotImplementedError("tag_list")
[docs]
def user_fetch(self, limit: int = None, offset: int = None):
users = self.request(
"GET", "v1/users/", params={"limit": limit, "offset": offset}
)
return users
[docs]
def user_list(self, limit: int = None, offset: int = None):
return [u["username"] for u in self.user_fetch()]
[docs]
def user_show(self, id: str):
return self.request("GET", f"v1/users/{id}")
[docs]
def user_delete(self, id):
return self.request("DELETE", f"v1/users/{id}")
[docs]
def user_create(self, **kwargs):
return self.request("POST", "v1/users", json=kwargs)
[docs]
def user_update(self, id, **kwargs):
raise NotImplementedError
[docs]
def user_patch(self, id, **kwargs):
return self.request("PATCH", f"v1/users/{id}", json=kwargs)
[docs]
def user_purge(self, id):
raise NotImplementedError
[docs]
def roles_fetch(self, limit: int = None, offset: int = None):
"""
Fetch roles.
Parameters
----------
limit : int, optional
The maximum number of roles to return.
offset : int, optional
The offset for pagination.
Returns
-------
list
A list of dictionaries containing role information.
"""
roles = self.request("GET", "v1/users/roles")
match (limit, offset):
case (None, None):
return roles
case (None, _):
return roles[offset:]
case (_, None):
return roles[:limit]
case (_, _):
return roles[offset : offset + limit]
[docs]
def user_add_role(self, user_id: str, role: str):
"""
Add a role to a user.
Parameters
----------
user_id : str
The ID of the user to whom the role will be added.
role : str
The role to be added to the user.
Returns
-------
dict
A dictionary containing the current user state.
"""
return self.request("POST", f"v1/users/{user_id}/roles/{role}")
[docs]
def user_remove_role(self, user_id: str, role: str):
"""
Remove a role from a user.
Parameters
----------
user_id : str
The ID of the user from whom the role will be removed.
role : str
The role to be removed from the user.
Returns
-------
dict
A dictionary containing the current user state.
"""
return self.request("DELETE", f"v1/users/{user_id}/roles/{role}")
[docs]
def user_add_roles(self, user_id: str, roles: list[str]):
"""
Add multiple roles to a user.
Parameters
----------
user_id : str
The ID of the user to whom the roles will be added.
roles : list[str]
A list of roles to be added to the user.
Returns
-------
dict
A dictionary containing the current user state.
"""
return self.request("POST", f"v1/users/{user_id}/roles", json={"roles": roles})
[docs]
def user_set_roles(self, user_id: str, roles: list[str]):
"""
Set roles for a user, replacing any existing roles.
Parameters
----------
user_id : str
The ID of the user whose roles will be set.
roles : list[str]
A list of roles to be set for the user.
Returns
-------
dict
A dictionary containing the current user state.
"""
return self.request("PATCH", f"v1/users/{user_id}/roles", json={"roles": roles})
[docs]
def entity_search(self, proxy_type: str, query_spec: dict):
entity_type = api_models.collection_name
return self.request("POST", f"v2/search/{entity_type}", json=query_spec)
[docs]
def dataset_export_zenodo(self, dataset_id: str) -> dict:
"""
Export a dataset to Zenodo.
Parameters
----------
dataset_id : str
The ID of the dataset to export.
Returns
-------
dict
A dictionary containing the export message, ready to be sent to zenodo.
"""
return self.request("GET", f"v2/export/zenodo/{dataset_id}")
#
#
# Handling tasks
#
#
[docs]
def task_post_job_output(
self, task_id: str, signature: str, output_spec: dict
) -> dict:
"""Get the output for a job in a task."""
return self.request(
"POST", f"v2/task/{task_id}/{signature}/output", json=output_spec
)
[docs]
def task_show_jobs(self, task_id: str):
"""Show the jobs associated with a task."""
return self.request("GET", f"v2/task/{task_id}/jobs")
[docs]
def task_show_logs(self, task_id: str):
"""Show the logs associated with a task."""
return self.request("GET", f"v2/task/{task_id}/logs")
[docs]
def task_signature(self, task_id: str) -> dict:
"""Get the signature of a task."""
return self.request("GET", f"v2/task/{task_id}/signature")
[docs]
def task_list(self, limit: int = None, offset: int = None, state: str = None):
"""List tasks."""
p = {"limit": limit, "offset": offset, "state": state}
return self.request(
"GET", "v2/tasks", params={k: v for k, v in p.items() if v is not None}
)
#
# Policies
#
[docs]
def policy_fetch(self, limit: int = None, offset: int = None):
"""
Fetch policies.
Parameters
----------
limit : int, optional
The maximum number of policies to return.
offset : int, optional
The offset for pagination.
Returns
-------
list
A list of dictionaries containing policy information.
"""
policies = self.request("GET", "v1/auth/policy")["policies"]
match (limit, offset):
case (None, None):
return policies
case (None, _):
return policies[offset:]
case (_, None):
return policies[:limit]
case (_, _):
return policies[offset : offset + limit]
[docs]
def policy_list(self, limit: int = None, offset: int = None):
"""
List all policies.
Parameters
----------
limit : int, optional
The maximum number of policies to return.
offset : int, optional
The offset for pagination.
Returns
-------
list
A list of dictionaries containing policy information (policy_uuid, policy_familiar_name).
"""
return [e["policy_uuid"] for e in self.policy_fetch(limit, offset)]
[docs]
def policy_show(self, eid: str):
"""
Show a specific policy.
Parameters
----------
policy_uuid : str
The UUID of the policy to show.
Returns
-------
dict
A dictionary containing the policy information.
"""
return self.request("GET", f"v1/auth/policy/{eid}")
[docs]
def policy_create(self, policy_yaml: str | bytes):
"""
Create a new policy.
Parameters
----------
prolicy_data: str | bytes
Returns
-------
dict
A dictionary containing the created policy information.
"""
# We need to use the client.request method in order to send
# the policy data as a string or bytes.
headers = {
"Content-Type": "application/x-yaml",
}
response = self.client.request(
"POST", "api/v1/auth/policy", headers=headers, data=policy_yaml
)
if response.status_code in range(200, 300):
return response.json()["result"]
else:
raise RuntimeError(
"Unexpected response trying to create policy",
policy_yaml,
response.status_code,
response.json(),
)
[docs]
def policy_spec(self, policy_uuid: str) -> bytes:
"""
Get the specification of a policy.
Parameters
----------
policy_uuid : str
The UUID of the policy to get the specification for.
Returns
-------
dict
A dictionary containing the policy specification.
"""
response = self.client.GET("v1/auth/policy/representation", policy_uuid)
if response.status_code in range(200, 300):
return response.content
else:
raise RuntimeError(
"Unexpected response trying to get policy spec",
policy_uuid,
response.status_code,
response,
)
##
#
# Image Registry Tokens
#
[docs]
def image_registry_token_list(self, limit: int = None, offset: int = None):
"""
List image registry tokens.
Parameters
----------
limit : int, optional
The maximum number of tokens to return.
offset : int, optional
The offset for pagination.
Returns
"""
result = self.request("GET", "v2/registry/credentials")
tokens = list(result.keys())
match (limit, offset):
case (None, None):
return tokens
case (None, _):
return tokens[offset:]
case (_, None):
return tokens[:limit]
case (_, _):
return tokens[offset : offset + limit]
[docs]
def image_registry_token_create(self, title: str, expiration: str | None = None):
"""
Create a new image registry token.
Parameters
----------
title : str
The title of the token.
expiration : str | None, optional
The expiration date of the token in ISO format. Defaults to None.
Returns
-------
dict
A dictionary containing the created token information.
"""
json = {"title": title}
# TODO: Fix the stelarapi to accept expiration as a string
return self.request("POST", "v2/registry/credentials", json=json)
[docs]
def image_registry_token_show(self, id: str):
"""
Show an image registry token.
Parameters
----------
uuid : str
The UUID of the token to show.
Returns
-------
dict
A dictionary containing the token information.
"""
return self.request("GET", f"v2/registry/credentials/{id}")
[docs]
def image_registry_token_delete(self, uuid: str):
"""
Delete an image registry token.
Parameters
----------
uuid : str
The UUID of the token to delete.
Returns
-------
dict
A dictionary containing the deletion result.
"""
return self.request("DELETE", f"v2/registry/credentials/{uuid}")
#
#
# Relationships
#
#
[docs]
def relationships_fetch(
self,
subject_id: str,
rel: Optional[str] = None,
object_id: Optional[str] = None,
/,
):
"""
Show relationships for a subject.
"""
p = [str(subject_id)]
if rel is not None:
p.append(str(rel))
if object_id is not None:
p.append(str(object_id))
endpoint = f"v2/relationships/{'/'.join(p)}"
return self.request("GET", endpoint)
[docs]
def relationship_show(self, subject_id: str, rel: str, object_id: str):
"""
Show a specific relationship.
Parameters
----------
subject_id : str
The ID of the subject of the relationship.
rel : str
The type of the relationship.
object_id : str
The ID of the object of the relationship.
Returns
-------
dict
A dictionary containing the relationship information.
"""
r = self.relationships_fetch(subject_id, rel, object_id)
if not r:
raise EntityNotFound(
"Relationship",
(subject_id, rel, object_id),
"fetch",
)
return r[0]
[docs]
def relationship_create(
self,
subject_id: str,
rel: str,
object_id: str,
comment: Optional[str] = None,
):
"""
Create a new relationship.
Parameters
----------
subject_id : str
The ID of the subject of the relationship.
rel : str
The type of the relationship.
object_id : str
The ID of the object of the relationship.
comment : str, optional
A comment for the relationship.
Returns
-------
dict
A dictionary containing the created relationship information.
"""
return self.request(
"POST",
f"v2/relationship/{subject_id}/{rel}/{object_id}",
json=dict(comment=comment),
)
[docs]
def relationship_update(
self,
subject_id: str,
rel: str,
object_id: str,
comment: Optional[str] = None,
):
"""
Update the comment of a relationship.
Parameters
----------
subject_id : str
The ID of the subject of the relationship.
rel : str
The type of the relationship.
object_id : str
The ID of the object of the relationship.
comment : str, optional
A comment for the relationship.
Returns
-------
dict
A dictionary containing the created relationship information.
"""
return self.request(
"PUT",
f"v2/relationship/{subject_id}/{rel}/{object_id}",
json=dict(comment=comment),
)
[docs]
def relationship_delete(self, subject_id: str, rel: str, object_id: str):
"""
Delete a relationship.
Parameters
----------
subject_id : str
The ID of the subject of the relationship.
rel : str
The type of the relationship.
object_id : str
The ID of the object of the relationship.
"""
self.request("DELETE", f"v2/relationship/{subject_id}/{rel}/{object_id}")
[docs]
def resource_lineage(self, resource_id: str, forward: bool) -> dict:
"""
Get the backward lineage of a resource.
Parameters
----------
resource_id : str
The ID of the resource to get the lineage for.
forward : bool
If True, get the forward lineage; if False, get the backward lineage.
Returns
-------
dict
A dictionary containing the lineage information.
"""
if forward:
# The API has a different endpoint for forward lineage
# This is a temporary workaround until the API is fixed
# to use the same endpoint for both forward and backward lineage.
direction = "cardinality"
else:
direction = "lineage"
return self.request("GET", f"v2/resource/{resource_id}/{direction}")