You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
2.9 KiB
Python
105 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
import threading
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from collections.abc import Callable
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class TaskDefinition:
|
|
name: str
|
|
handler: Callable[[], Any]
|
|
schedule: str | None = None
|
|
description: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class TaskRun:
|
|
id: str
|
|
task_name: str
|
|
status: str
|
|
started_at: float
|
|
finished_at: float | None = None
|
|
result: Any = None
|
|
error: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class TaskRegistry:
|
|
tasks: dict[str, TaskDefinition] = field(default_factory=dict)
|
|
runs: dict[str, TaskRun] = field(default_factory=dict)
|
|
_running: set[str] = field(default_factory=set)
|
|
_lock: threading.Lock = field(default_factory=threading.Lock)
|
|
|
|
def register(
|
|
self,
|
|
*,
|
|
name: str,
|
|
handler: Callable[[], Any],
|
|
schedule: str | None = None,
|
|
description: str | None = None,
|
|
) -> TaskDefinition:
|
|
if not name:
|
|
raise ValueError("task name is required")
|
|
if name in self.tasks:
|
|
raise ValueError(f"task already registered: {name}")
|
|
task = TaskDefinition(
|
|
name=name,
|
|
handler=handler,
|
|
schedule=schedule,
|
|
description=description,
|
|
)
|
|
self.tasks[name] = task
|
|
return task
|
|
|
|
def trigger(self, name: str) -> TaskRun:
|
|
task = self.tasks.get(name)
|
|
if task is None:
|
|
raise KeyError(f"task not registered: {name}")
|
|
|
|
with self._lock:
|
|
if name in self._running:
|
|
run = TaskRun(
|
|
id=uuid.uuid4().hex,
|
|
task_name=name,
|
|
status="skipped",
|
|
started_at=time.time(),
|
|
finished_at=time.time(),
|
|
error="task already running",
|
|
)
|
|
self.runs[run.id] = run
|
|
return run
|
|
self._running.add(name)
|
|
run = TaskRun(
|
|
id=uuid.uuid4().hex,
|
|
task_name=name,
|
|
status="running",
|
|
started_at=time.time(),
|
|
)
|
|
self.runs[run.id] = run
|
|
|
|
try:
|
|
run.result = task.handler()
|
|
run.status = "success"
|
|
except Exception:
|
|
run.error = traceback.format_exc()
|
|
run.status = "failed"
|
|
finally:
|
|
run.finished_at = time.time()
|
|
with self._lock:
|
|
self._running.discard(name)
|
|
return run
|
|
|
|
def list_runs(self, task_name: str | None = None) -> list[TaskRun]:
|
|
runs = list(self.runs.values())
|
|
if task_name is not None:
|
|
runs = [run for run in runs if run.task_name == task_name]
|
|
return sorted(runs, key=lambda run: run.started_at, reverse=True)
|
|
|
|
|
|
task_registry = TaskRegistry()
|