You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/exchange/service.py

741 lines
26 KiB
Python

from __future__ import annotations
import hashlib
from dataclasses import asdict
from datetime import datetime
from enum import Enum
from io import BytesIO
from typing import Any
from sqlalchemy import select, update
from sqlalchemy.orm import Session
from iti.exceptions import BizError
from .base import (
ExchangeOperation,
ExchangeScope,
ExchangeTaskContext,
ExchangeTaskResult,
ExchangeTemplateLayout,
ExchangeTemplatePlan,
ExchangeTemplateSnapshot,
ExchangeVariable,
)
from .models import (
ExchangeTaskModel,
ExchangeTaskRowModel,
ExchangeTemplateModel,
ExchangeTemplateVersionModel,
)
from .registry import get_exchange_registry
from .tasks import get_exchange_storage
class ExchangeService:
def __init__(self, app, db: Session) -> None:
self.app = app
self.db = db
def generate_template_code(
self,
*,
biz_domain: str,
biz_obj: str,
operation: ExchangeOperation | str,
) -> str:
return ExchangeScope.from_mapping(
biz_domain=biz_domain,
biz_obj=biz_obj,
operation=operation,
).code()
def create_template(
self,
*,
name: str,
biz_domain: str,
biz_obj: str,
operation: ExchangeOperation | str,
code: str | None = None,
description: str | None = None,
layout: ExchangeTemplateLayout | dict[str, Any] | None = None,
meta: dict[str, Any] | None = None,
) -> ExchangeTemplateModel:
scope = ExchangeScope.from_mapping(
biz_domain=biz_domain,
biz_obj=biz_obj,
operation=operation,
)
template = ExchangeTemplateModel(
code=code or scope.code(),
name=name,
biz_domain=scope.biz_domain,
biz_obj=scope.biz_obj,
operation=_operation_value(scope.operation),
description=description,
layout=asdict(_coerce_layout(layout)),
meta=meta or {},
)
self.db.add(template)
self.db.commit()
self.db.refresh(template)
return template
def update_template(
self,
template_id: str,
*,
name: str | None = None,
description: str | None = None,
status: str | None = None,
current_version: str | None = None,
layout: ExchangeTemplateLayout | dict[str, Any] | None = None,
meta: dict[str, Any] | None = None,
) -> ExchangeTemplateModel:
template = self.get_template_or_404(template_id)
if name is not None:
template.name = name
if description is not None:
template.description = description
if status is not None:
template.status = status
if current_version is not None:
template.current_version = current_version
if layout is not None:
template.layout = asdict(_coerce_layout(layout))
if meta is not None:
template.meta = meta
self.db.commit()
self.db.refresh(template)
return template
def delete_template(self, template_id: str) -> None:
template = self.get_template_or_404(template_id)
versions = list(template.versions)
version_ids = [version.id for version in versions]
storage = get_exchange_storage(self.app)
for version in versions:
if version.file_key:
storage.delete(version.file_key)
if version_ids:
self.db.execute(
update(ExchangeTaskModel)
.where(ExchangeTaskModel.template_version_id.in_(version_ids))
.values(template_version_id=None)
)
self.db.execute(
update(ExchangeTaskModel)
.where(ExchangeTaskModel.template_id == template.id)
.values(template_id=None, template_version_id=None)
)
self.db.delete(template)
self.db.commit()
def publish_version(
self,
*,
template_id: str,
version: str,
variables: list[ExchangeVariable] | None = None,
layout: ExchangeTemplateLayout | dict[str, Any] | None = None,
file_content: bytes | None = None,
file_name: str | None = None,
meta: dict[str, Any] | None = None,
make_current: bool = True,
) -> ExchangeTemplateVersionModel:
template = self.get_template_or_404(template_id)
resolved_layout = _coerce_layout(layout if layout is not None else template.layout)
resolved_variables = variables
if resolved_variables is None:
spec = get_exchange_registry(self.app).get_spec(
biz_domain=template.biz_domain,
biz_obj=template.biz_obj,
operation=template.operation,
)
resolved_variables = list(spec.variables) if spec is not None else []
file_key = None
checksum = None
if file_content is not None:
file_key = self.save_template_file(
template=template,
version=version,
content=file_content,
file_name=file_name,
)
checksum = hashlib.sha256(file_content).hexdigest()
snapshot = ExchangeTemplateVersionModel(
template_id=template.id,
version=version,
biz_domain=template.biz_domain,
biz_obj=template.biz_obj,
operation=template.operation,
published_at=datetime.now(),
file_key=file_key,
checksum=checksum,
layout=asdict(resolved_layout),
variables=[_jsonable(asdict(item)) for item in resolved_variables],
meta=meta or {},
)
self.db.add(snapshot)
if make_current:
template.current_version = version
template.status = "published"
self.db.commit()
self.db.refresh(snapshot)
return snapshot
def delete_version(self, version_id: str) -> None:
version = self.get_version_or_404(version_id)
template = self.get_template_or_404(version.template_id)
next_current = None
if template.current_version == version.version:
with self.db.no_autoflush:
next_current = self.db.scalar(
select(ExchangeTemplateVersionModel)
.where(ExchangeTemplateVersionModel.template_id == template.id)
.where(ExchangeTemplateVersionModel.id != version.id)
.order_by(
ExchangeTemplateVersionModel.created_at.desc(),
ExchangeTemplateVersionModel.updated_at.desc(),
)
)
if version.file_key:
storage = get_exchange_storage(self.app)
storage.delete(version.file_key)
self.db.execute(
update(ExchangeTaskModel)
.where(ExchangeTaskModel.template_version_id == version.id)
.values(template_version_id=None)
)
self.db.delete(version)
if template.current_version == version.version:
template.current_version = next_current.version if next_current is not None else None
if next_current is None:
template.status = "draft"
self.db.commit()
def build_template_file(self, version_id: str) -> bytes:
version = self.get_version_or_404(version_id)
if version.file_key:
storage = get_exchange_storage(self.app)
with storage.download(version.file_key) as file_stream:
return file_stream.read()
return _excel_template_codec().dump(self.snapshot_from_model(version))
def build_plan_template_file(self, plan: ExchangeTemplatePlan) -> bytes:
if plan.version_id:
version = self.db.get(ExchangeTemplateVersionModel, plan.version_id)
if version is not None and version.file_key:
storage = get_exchange_storage(self.app)
with storage.download(version.file_key) as file_stream:
return file_stream.read()
return _excel_template_codec().dump(plan)
def save_template_file(
self,
*,
template: ExchangeTemplateModel,
version: str,
content: bytes,
file_name: str | None = None,
) -> str:
suffix = _safe_suffix(file_name or "template.xlsx")
digest = hashlib.sha256(content).hexdigest()
key = f"exchange/templates/{template.code}/{version}/{digest}.{suffix}"
storage = get_exchange_storage(self.app)
storage.upload(BytesIO(content), key, _excel_mime_type())
return key
def export_rows(
self,
rows: list[dict[str, Any]],
*,
plan: ExchangeTemplatePlan | None = None,
variables: list[ExchangeVariable] | None = None,
sheet_name: str | None = None,
) -> bytes:
workbook_codec = _excel_workbook_codec()
if plan is not None:
return workbook_codec.export_rows_with_plan(
plan=plan,
rows=rows,
sheet_name=sheet_name,
)
if variables is not None:
return workbook_codec.export_rows_with_variables(
variables=variables,
rows=rows,
sheet_name=sheet_name or "Export",
)
if not rows:
return workbook_codec.export_rows([], [], sheet_name=sheet_name or "Export")
headers = list(rows[0].keys())
return workbook_codec.export_rows(headers, rows, sheet_name=sheet_name or "Export")
def import_rows(
self,
content: bytes,
*,
plan: ExchangeTemplatePlan | None = None,
variables: list[ExchangeVariable] | None = None,
) -> list[dict[str, Any]]:
workbook_codec = _excel_workbook_codec()
if plan is not None and plan.variables:
return workbook_codec.import_rows_with_variables(
content,
variables=list(plan.variables),
header_row=plan.layout.header_row,
data_start_row=plan.layout.data_start_row,
)
if variables is not None:
return workbook_codec.import_rows_with_variables(content, variables=variables)
return workbook_codec.import_rows(content)
def create_task(
self,
*,
biz_domain: str,
biz_obj: str,
operation: ExchangeOperation | str,
template_id: str | None = None,
version_id: str | None = None,
version: str | None = None,
requested_by: str | None = None,
storage_key: str | None = None,
input_payload: dict[str, Any] | None = None,
meta: dict[str, Any] | None = None,
) -> ExchangeTaskModel:
plan = self.resolve_plan(
biz_domain=biz_domain,
biz_obj=biz_obj,
operation=operation,
template_id=template_id,
version_id=version_id,
version=version,
)
task = ExchangeTaskModel(
template_id=plan.template_id,
template_version_id=plan.version_id,
biz_domain=plan.scope.biz_domain,
biz_obj=plan.scope.biz_obj,
operation=_operation_value(plan.scope.operation),
status="pending",
requested_by=requested_by,
storage_key=storage_key,
input_payload=input_payload or {},
meta=meta or {},
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def run_task(self, task_id: str) -> ExchangeTaskModel:
task = self.mark_task_running(task_id)
try:
snapshot = (
self.get_snapshot_by_version_id(task.template_version_id)
if task.template_version_id
else None
)
plan = snapshot or self.resolve_plan(
biz_domain=task.biz_domain,
biz_obj=task.biz_obj,
operation=task.operation,
template_id=task.template_id,
)
handler = get_exchange_registry(self.app).get_scope_handler(
biz_domain=task.biz_domain,
biz_obj=task.biz_obj,
operation=task.operation,
)
if handler is None:
raise BizError("导入导出处理器未注册", code=404)
result = handler(
ExchangeTaskContext(
task_id=task.id,
plan=plan,
snapshot=snapshot,
storage_key=task.storage_key,
payload=task.input_payload,
requested_by=task.requested_by,
)
)
if isinstance(result, dict):
result = ExchangeTaskResult(**result)
return self.mark_task_finished(
task_id,
status="success",
message=result.message,
result_payload=result.result_payload,
success_count=result.success_count,
failed_count=result.failed_count,
)
except Exception as exc:
return self.mark_task_finished(
task_id,
status="failed",
message=str(exc),
)
def get_snapshot(self, *, template_id: str, version: str) -> ExchangeTemplateSnapshot | None:
version_model = self.db.scalar(
select(ExchangeTemplateVersionModel)
.where(ExchangeTemplateVersionModel.template_id == template_id)
.where(ExchangeTemplateVersionModel.version == version)
)
if version_model is None:
return None
return self.snapshot_from_model(version_model)
def get_snapshot_by_version_id(self, version_id: str | None) -> ExchangeTemplateSnapshot | None:
if version_id is None:
return None
version_model = self.db.get(ExchangeTemplateVersionModel, version_id)
if version_model is None:
return None
return self.snapshot_from_model(version_model)
def get_current_snapshot(self, template_id: str) -> ExchangeTemplateSnapshot | None:
template = self.db.get(ExchangeTemplateModel, template_id)
if template is None or not template.current_version:
return None
return self.get_snapshot(template_id=template_id, version=template.current_version)
def resolve_plan(
self,
*,
biz_domain: str,
biz_obj: str,
operation: ExchangeOperation | str,
template_id: str | None = None,
version_id: str | None = None,
version: str | None = None,
code: str | None = None,
name: str | None = None,
description: str | None = None,
layout: ExchangeTemplateLayout | dict[str, Any] | None = None,
variables: list[ExchangeVariable] | None = None,
source: Any | None = None,
) -> ExchangeTemplatePlan:
if source is not None:
return source.resolve_plan(
biz_domain=biz_domain,
biz_obj=biz_obj,
operation=operation,
template_id=template_id,
version_id=version_id,
version=version,
code=code,
name=name,
description=description,
layout=layout,
variables=variables,
)
if version_id:
snapshot = self.get_snapshot_by_version_id(version_id)
if snapshot is not None:
return snapshot
if template_id and version:
snapshot = self.get_snapshot(template_id=template_id, version=version)
if snapshot is not None:
return snapshot
if template_id:
current = self.get_current_snapshot(template_id)
if current is not None:
return current
template = self.get_template_or_404(template_id)
return self.plan_from_template_model(template)
template = self.get_template_by_scope(
biz_domain=biz_domain,
biz_obj=biz_obj,
operation=operation,
code=code,
)
if template is not None:
current = self.get_current_snapshot(template.id)
if current is not None:
return current
return self.plan_from_template_model(template)
spec = get_exchange_registry(self.app).get_spec(
biz_domain=biz_domain,
biz_obj=biz_obj,
operation=_operation_value(operation),
)
if spec is not None:
return spec.to_plan()
return ExchangeTemplatePlan.from_mapping(
biz_domain=biz_domain,
biz_obj=biz_obj,
operation=operation,
code=code,
name=name,
description=description,
layout=layout,
variables=variables,
)
def mark_task_running(self, task_id: str) -> ExchangeTaskModel:
task = self.get_task_or_404(task_id)
task.status = "running"
task.started_at = datetime.now()
self.db.commit()
self.db.refresh(task)
return task
def mark_task_finished(
self,
task_id: str,
*,
status: str = "success",
message: str | None = None,
result_payload: dict[str, Any] | None = None,
success_count: int | None = None,
failed_count: int | None = None,
) -> ExchangeTaskModel:
task = self.get_task_or_404(task_id)
task.status = status
task.message = message
task.finished_at = datetime.now()
if result_payload is not None:
task.result_payload = result_payload
if success_count is not None:
task.success_count = success_count
if failed_count is not None:
task.failed_count = failed_count
task.error_count = failed_count
self.db.commit()
self.db.refresh(task)
return task
def add_task_row(
self,
*,
task_id: str,
row_index: int,
status: str,
data: dict[str, Any] | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
) -> ExchangeTaskRowModel:
row = ExchangeTaskRowModel(
task_id=task_id,
row_index=row_index,
status=status,
data=data or {},
message=message,
result=result or {},
)
self.db.add(row)
self.db.commit()
self.db.refresh(row)
return row
def get_template_or_404(self, template_id: str) -> ExchangeTemplateModel:
template = self.db.get(ExchangeTemplateModel, template_id)
if template is None:
raise BizError("模板不存在", code=404)
return template
def get_template_by_scope(
self,
*,
biz_domain: str,
biz_obj: str,
operation: ExchangeOperation | str,
code: str | None = None,
) -> ExchangeTemplateModel | None:
statement = (
select(ExchangeTemplateModel)
.where(ExchangeTemplateModel.biz_domain == biz_domain)
.where(ExchangeTemplateModel.biz_obj == biz_obj)
.where(ExchangeTemplateModel.operation == _operation_value(operation))
)
if code is not None:
statement = statement.where(ExchangeTemplateModel.code == code)
return self.db.scalar(
statement.order_by(
ExchangeTemplateModel.created_at.desc(),
ExchangeTemplateModel.updated_at.desc(),
)
)
def get_version_or_404(self, version_id: str) -> ExchangeTemplateVersionModel:
version = self.db.get(ExchangeTemplateVersionModel, version_id)
if version is None:
raise BizError("模板版本不存在", code=404)
return version
def get_task_or_404(self, task_id: str) -> ExchangeTaskModel:
task = self.db.get(ExchangeTaskModel, task_id)
if task is None:
raise BizError("导入导出任务不存在", code=404)
return task
def list_templates(
self,
*,
biz_domain: str | None = None,
biz_obj: str | None = None,
operation: str | None = None,
) -> list[ExchangeTemplateModel]:
statement = select(ExchangeTemplateModel)
if biz_domain is not None:
statement = statement.where(ExchangeTemplateModel.biz_domain == biz_domain)
if biz_obj is not None:
statement = statement.where(ExchangeTemplateModel.biz_obj == biz_obj)
if operation is not None:
statement = statement.where(ExchangeTemplateModel.operation == operation)
return list(
self.db.scalars(
statement.order_by(
ExchangeTemplateModel.biz_domain,
ExchangeTemplateModel.biz_obj,
ExchangeTemplateModel.operation,
ExchangeTemplateModel.code,
)
)
)
def list_versions(self, template_id: str) -> list[ExchangeTemplateVersionModel]:
return list(
self.db.scalars(
select(ExchangeTemplateVersionModel)
.where(ExchangeTemplateVersionModel.template_id == template_id)
.order_by(
ExchangeTemplateVersionModel.created_at.desc(),
ExchangeTemplateVersionModel.updated_at.desc(),
)
)
)
def list_tasks(self, template_id: str | None = None) -> list[ExchangeTaskModel]:
statement = select(ExchangeTaskModel).order_by(ExchangeTaskModel.created_at.desc())
if template_id is not None:
statement = statement.where(ExchangeTaskModel.template_id == template_id)
return list(self.db.scalars(statement))
def list_task_rows(self, task_id: str) -> list[ExchangeTaskRowModel]:
return list(
self.db.scalars(
select(ExchangeTaskRowModel)
.where(ExchangeTaskRowModel.task_id == task_id)
.order_by(ExchangeTaskRowModel.row_index)
)
)
def plan_from_template_model(self, template: ExchangeTemplateModel) -> ExchangeTemplatePlan:
spec = get_exchange_registry(self.app).get_spec(
biz_domain=template.biz_domain,
biz_obj=template.biz_obj,
operation=template.operation,
)
variables = tuple(spec.variables) if spec is not None else ()
return ExchangeTemplatePlan(
scope=ExchangeScope.from_mapping(
biz_domain=template.biz_domain,
biz_obj=template.biz_obj,
operation=template.operation,
),
code=template.code,
name=template.name,
description=template.description,
template_id=template.id,
version=template.current_version,
layout=_coerce_layout(template.layout),
variables=variables,
)
def snapshot_from_model(
self, version: ExchangeTemplateVersionModel
) -> ExchangeTemplateSnapshot:
template = self.get_template_or_404(version.template_id)
return ExchangeTemplateSnapshot(
scope=ExchangeScope.from_mapping(
biz_domain=version.biz_domain,
biz_obj=version.biz_obj,
operation=version.operation,
),
code=template.code,
name=template.name,
description=template.description,
template_id=version.template_id,
version_id=version.id,
version=version.version,
layout=_coerce_layout(version.layout),
variables=tuple(_variable_from_dict(item) for item in version.variables),
published_at=version.published_at.isoformat() if version.published_at else None,
file_key=version.file_key,
checksum=version.checksum,
)
def _variable_from_dict(value: dict[str, Any]) -> ExchangeVariable:
return ExchangeVariable(
key=value["key"],
label=value["label"],
header=value.get("header"),
description=value.get("description"),
required=bool(value.get("required", False)),
example=value.get("example"),
)
def _coerce_layout(value: ExchangeTemplateLayout | dict[str, Any] | None) -> ExchangeTemplateLayout:
if value is None:
return ExchangeTemplateLayout()
if isinstance(value, ExchangeTemplateLayout):
return value
return ExchangeTemplateLayout(
title=value.get("title"),
sheet_name=value.get("sheet_name"),
title_row=value.get("title_row", 1),
header_row=int(value.get("header_row") or 2),
data_start_row=value.get("data_start_row"),
)
def _operation_value(value: ExchangeOperation | str) -> str:
return value.value if hasattr(value, "value") else str(value)
def _safe_suffix(file_name: str) -> str:
if "." not in file_name:
return "xlsx"
suffix = file_name.rsplit(".", 1)[-1].lower()
return "".join(ch for ch in suffix if ch.isalnum()) or "xlsx"
def _excel_mime_type() -> str:
return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
def _excel_template_codec():
from .excel import ExcelTemplateCodec
return ExcelTemplateCodec()
def _excel_workbook_codec():
from .excel import ExcelWorkbookCodec
return ExcelWorkbookCodec()
def _jsonable(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
if isinstance(value, dict):
return {key: _jsonable(item) for key, item in value.items()}
if isinstance(value, list):
return [_jsonable(item) for item in value]
if isinstance(value, tuple):
return [_jsonable(item) for item in value]
return value