You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/exchange/service.py

514 lines
18 KiB
Python

from __future__ import annotations
import hashlib
from dataclasses import asdict
from datetime import datetime
from io import BytesIO
from enum import Enum
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import Session
from iti.exceptions import BizError
from .base import (
ExchangeField,
ExchangePlaceholder,
ExchangeTemplateBinding,
ExchangeTemplateKind,
ExchangePlan,
ExchangeTemplateSnapshot,
ExchangeTaskKind,
)
from .excel import ExcelTemplateCodec, ExcelWorkbookCodec
from .models import (
ExchangeTaskModel,
ExchangeTaskRowModel,
ExchangeTemplateModel,
ExchangeTemplateVersionModel,
)
from .tasks import get_exchange_storage
class ExchangeService:
def __init__(self, app, db: Session) -> None:
self.app = app
self.db = db
def create_template(
self,
*,
code: str,
name: str,
template_kind: ExchangeTemplateKind | str,
entity: str,
description: str | None = None,
meta: dict[str, Any] | None = None,
) -> ExchangeTemplateModel:
template = ExchangeTemplateModel(
code=code,
name=name,
template_kind=_enum_value(template_kind),
entity=entity,
description=description,
meta=meta or {},
)
self.db.add(template)
self.db.commit()
self.db.refresh(template)
return template
def update_template(
self,
template_id: str,
*,
name: str | None = None,
description: str | None = None,
status: str | None = None,
current_version: str | None = None,
meta: dict[str, Any] | None = None,
) -> ExchangeTemplateModel:
template = self.get_template_or_404(template_id)
if name is not None:
template.name = name
if description is not None:
template.description = description
if status is not None:
template.status = status
if current_version is not None:
template.current_version = current_version
if meta is not None:
template.meta = meta
self.db.commit()
self.db.refresh(template)
return template
def publish_version(
self,
*,
template_id: str,
version: str,
bindings: list[ExchangeTemplateBinding] | None = None,
fields: list[ExchangeField] | None = None,
placeholders: list[ExchangePlaceholder] | None = None,
file_content: bytes | None = None,
file_name: str | None = None,
meta: dict[str, Any] | None = None,
make_current: bool = True,
) -> ExchangeTemplateVersionModel:
template = self.get_template_or_404(template_id)
file_key = None
checksum = None
if file_content is not None:
file_key = self.save_template_file(
template=template,
version=version,
content=file_content,
file_name=file_name,
)
checksum = hashlib.sha256(file_content).hexdigest()
snapshot = ExchangeTemplateVersionModel(
template_id=template.id,
version=version,
template_kind=template.template_kind,
published_at=datetime.now(),
file_key=file_key,
checksum=checksum,
bindings=[_jsonable(asdict(item)) for item in bindings or []],
fields=[_jsonable(asdict(item)) for item in fields or []],
placeholders=[_jsonable(asdict(item)) for item in placeholders or []],
meta=meta or {},
)
self.db.add(snapshot)
if make_current:
template.current_version = version
template.status = "published"
self.db.commit()
self.db.refresh(snapshot)
return snapshot
def build_template_file(self, version_id: str) -> bytes:
version = self.get_version_or_404(version_id)
if version.file_key:
storage = get_exchange_storage(self.app)
with storage.download(version.file_key) as file_stream:
return file_stream.read()
snapshot = self.snapshot_from_model(version)
return ExcelTemplateCodec().dump(snapshot)
def build_plan_template_file(self, plan: ExchangePlan) -> bytes:
if plan.version_id:
version = self.get_snapshot_by_version_id(plan.version_id)
if version is not None and version.file_key:
storage = get_exchange_storage(self.app)
with storage.download(version.file_key) as file_stream:
return file_stream.read()
return ExcelTemplateCodec().dump(plan)
def export_rows(
self,
rows: list[dict[str, Any]],
*,
plan: ExchangePlan | None = None,
fields: list[ExchangeField] | None = None,
sheet_name: str | None = None,
) -> bytes:
workbook_codec = ExcelWorkbookCodec()
if plan is not None:
return workbook_codec.export_rows_with_plan(
plan=plan,
rows=rows,
sheet_name=sheet_name,
)
if fields is not None:
return workbook_codec.export_rows_with_template(
fields=fields,
rows=rows,
sheet_name=sheet_name or "Export",
)
if not rows:
return workbook_codec.export_rows([], [], sheet_name=sheet_name or "Export")
headers = list(rows[0].keys())
return workbook_codec.export_rows(headers, rows, sheet_name=sheet_name or "Export")
def import_rows(
self,
content: bytes,
*,
plan: ExchangePlan | None = None,
fields: list[ExchangeField] | None = None,
) -> list[dict[str, Any]]:
workbook_codec = ExcelWorkbookCodec()
if plan is not None and plan.fields:
return workbook_codec.import_rows_with_fields(content, fields=list(plan.fields))
if fields is not None:
return workbook_codec.import_rows_with_fields(content, fields=fields)
return workbook_codec.import_rows(content)
def save_template_file(
self,
*,
template: ExchangeTemplateModel,
version: str,
content: bytes,
file_name: str | None = None,
) -> str:
suffix = _safe_suffix(file_name or "template.xlsx")
key = f"exchange/templates/{template.code}/{version}/{hashlib.sha256(content).hexdigest()}.{suffix}"
storage = get_exchange_storage(self.app)
storage.upload(BytesIO(content), key, _excel_mime_type())
return key
def create_task(
self,
*,
template_id: str | None = None,
version_id: str | None = None,
version: str | None = None,
task_kind: ExchangeTaskKind | str,
requested_by: str | None = None,
storage_key: str | None = None,
input_payload: dict[str, Any] | None = None,
meta: dict[str, Any] | None = None,
) -> ExchangeTaskModel:
template = self.get_template_or_404(template_id) if template_id else None
version_model = self.get_version_or_404(version_id) if version_id else None
if template is not None and version_model is not None and version_model.template_id != template.id:
raise BizError("模板版本不属于该模板", code=400)
if template is None and version_model is not None:
template = self.get_template_or_404(version_model.template_id)
if template is not None and version_model is None:
if version:
version_model = self.get_snapshot(template_id=template.id, version=version)
elif template.current_version:
version_model = self.get_snapshot(
template_id=template.id,
version=template.current_version,
)
task = ExchangeTaskModel(
template_id=template.id if template is not None else None,
template_version_id=version_model.id if version_model is not None else None,
task_kind=_enum_value(task_kind),
status="pending",
requested_by=requested_by,
storage_key=storage_key,
input_payload=input_payload or {},
meta=meta or {},
)
self.db.add(task)
self.db.commit()
self.db.refresh(task)
return task
def get_snapshot(self, *, template_id: str, version: str) -> ExchangeTemplateSnapshot | None:
version_model = self.db.scalar(
select(ExchangeTemplateVersionModel)
.where(ExchangeTemplateVersionModel.template_id == template_id)
.where(ExchangeTemplateVersionModel.version == version)
)
if version_model is None:
return None
return self.snapshot_from_model(version_model)
def get_snapshot_by_version_id(self, version_id: str) -> ExchangeTemplateSnapshot | None:
version_model = self.db.get(ExchangeTemplateVersionModel, version_id)
if version_model is None:
return None
return self.snapshot_from_model(version_model)
def get_current_snapshot(self, template_id: str) -> ExchangeTemplateSnapshot | None:
template = self.db.get(ExchangeTemplateModel, template_id)
if template is None or not template.current_version:
return None
return self.get_snapshot(template_id=template_id, version=template.current_version)
def resolve_plan(
self,
*,
template_kind: ExchangeTemplateKind | str,
template_id: str | None = None,
version_id: str | None = None,
version: str | None = None,
bindings: list[ExchangeTemplateBinding] | None = None,
fields: list[ExchangeField] | None = None,
placeholders: list[ExchangePlaceholder] | None = None,
title: str | None = None,
description: str | None = None,
sheet_name: str | None = None,
meta: dict[str, Any] | None = None,
source: Any | None = None,
) -> ExchangePlan:
if source is not None:
return source.resolve_plan(
template_kind=template_kind,
template_id=template_id,
version_id=version_id,
version=version,
bindings=bindings,
fields=fields,
placeholders=placeholders,
title=title,
description=description,
sheet_name=sheet_name,
meta=meta,
)
if version_id:
snapshot = self.get_snapshot_by_version_id(version_id)
if snapshot is not None:
return snapshot.to_plan()
if template_id and version:
snapshot = self.get_snapshot(template_id=template_id, version=version)
if snapshot is not None:
return snapshot.to_plan()
if template_id:
current = self.get_current_snapshot(template_id)
if current is not None:
return current.to_plan()
return ExchangePlan.from_mapping(
template_kind=template_kind,
template_id=template_id,
version_id=version_id,
version=version,
bindings=bindings,
fields=fields,
placeholders=placeholders,
title=title,
description=description,
sheet_name=sheet_name,
meta=meta,
)
def mark_task_running(self, task_id: str) -> ExchangeTaskModel:
task = self.get_task_or_404(task_id)
task.status = "running"
task.started_at = datetime.now()
self.db.commit()
self.db.refresh(task)
return task
def mark_task_finished(
self,
task_id: str,
*,
status: str = "success",
message: str | None = None,
result_payload: dict[str, Any] | None = None,
success_count: int | None = None,
failed_count: int | None = None,
) -> ExchangeTaskModel:
task = self.get_task_or_404(task_id)
task.status = status
task.message = message
task.finished_at = datetime.now()
if result_payload is not None:
task.result_payload = result_payload
if success_count is not None:
task.success_count = success_count
if failed_count is not None:
task.failed_count = failed_count
task.error_count = failed_count
self.db.commit()
self.db.refresh(task)
return task
def add_task_row(
self,
*,
task_id: str,
row_index: int,
status: str,
data: dict[str, Any] | None = None,
message: str | None = None,
result: dict[str, Any] | None = None,
) -> ExchangeTaskRowModel:
row = ExchangeTaskRowModel(
task_id=task_id,
row_index=row_index,
status=status,
data=data or {},
message=message,
result=result or {},
)
self.db.add(row)
self.db.commit()
self.db.refresh(row)
return row
def get_template_or_404(self, template_id: str) -> ExchangeTemplateModel:
template = self.db.get(ExchangeTemplateModel, template_id)
if template is None:
raise BizError("模板不存在", code=404)
return template
def get_version_or_404(self, version_id: str) -> ExchangeTemplateVersionModel:
version = self.db.get(ExchangeTemplateVersionModel, version_id)
if version is None:
raise BizError("模板版本不存在", code=404)
return version
def get_task_or_404(self, task_id: str) -> ExchangeTaskModel:
task = self.db.get(ExchangeTaskModel, task_id)
if task is None:
raise BizError("导入导出任务不存在", code=404)
return task
def list_templates(self) -> list[ExchangeTemplateModel]:
return list(
self.db.scalars(
select(ExchangeTemplateModel).order_by(
ExchangeTemplateModel.entity,
ExchangeTemplateModel.code,
)
)
)
def list_versions(self, template_id: str) -> list[ExchangeTemplateVersionModel]:
return list(
self.db.scalars(
select(ExchangeTemplateVersionModel)
.where(ExchangeTemplateVersionModel.template_id == template_id)
.order_by(ExchangeTemplateVersionModel.version)
)
)
def list_tasks(self, template_id: str | None = None) -> list[ExchangeTaskModel]:
statement = select(ExchangeTaskModel).order_by(ExchangeTaskModel.created_at.desc())
if template_id is not None:
statement = statement.where(ExchangeTaskModel.template_id == template_id)
return list(self.db.scalars(statement))
def list_task_rows(self, task_id: str) -> list[ExchangeTaskRowModel]:
return list(
self.db.scalars(
select(ExchangeTaskRowModel)
.where(ExchangeTaskRowModel.task_id == task_id)
.order_by(ExchangeTaskRowModel.row_index)
)
)
def snapshot_from_model(
self, version: ExchangeTemplateVersionModel
) -> ExchangeTemplateSnapshot:
return ExchangeTemplateSnapshot(
id=version.id,
version=version.version,
template_id=version.template_id,
template_kind=ExchangeTemplateKind(version.template_kind),
bindings=tuple(_binding_from_dict(item) for item in version.bindings),
published_at=version.published_at.isoformat() if version.published_at else None,
file_key=version.file_key,
checksum=version.checksum,
fields=tuple(_field_from_dict(item) for item in version.fields),
placeholders=tuple(_placeholder_from_dict(item) for item in version.placeholders),
meta=version.meta,
)
def _field_from_dict(value: dict[str, Any]) -> ExchangeField:
options = value.get("options") or ()
return ExchangeField(
key=value["key"],
label=value["label"],
placeholder=value.get("placeholder"),
required=bool(value.get("required", False)),
example=value.get("example"),
width=value.get("width"),
format=value.get("format"),
source=value.get("source"),
target=value.get("target"),
options=tuple(tuple(item) for item in options),
meta=value.get("meta") or {},
)
def _placeholder_from_dict(value: dict[str, Any]) -> ExchangePlaceholder:
return ExchangePlaceholder(
key=value["key"],
label=value["label"],
description=value.get("description"),
required=bool(value.get("required", False)),
example=value.get("example"),
)
def _binding_from_dict(value: dict[str, Any]) -> ExchangeTemplateBinding:
return ExchangeTemplateBinding(
entity=value["entity"],
template_kind=ExchangeTemplateKind(value["template_kind"]),
handler=value.get("handler"),
description=value.get("description"),
default_sheet_name=value.get("default_sheet_name"),
default_file_name=value.get("default_file_name"),
title=value.get("title"),
meta=value.get("meta") or {},
)
def _enum_value(value: Any) -> str:
return value.value if hasattr(value, "value") else str(value)
def _safe_suffix(file_name: str) -> str:
if "." not in file_name:
return "xlsx"
suffix = file_name.rsplit(".", 1)[-1].lower()
return "".join(ch for ch in suffix if ch.isalnum()) or "xlsx"
def _excel_mime_type() -> str:
return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
def _jsonable(value: Any) -> Any:
if isinstance(value, Enum):
return value.value
if isinstance(value, dict):
return {key: _jsonable(item) for key, item in value.items()}
if isinstance(value, list):
return [_jsonable(item) for item in value]
if isinstance(value, tuple):
return [_jsonable(item) for item in value]
return value