from __future__ import annotations import hashlib from dataclasses import asdict from datetime import datetime from io import BytesIO from enum import Enum from typing import Any from sqlalchemy import select from sqlalchemy.orm import Session from iti.exceptions import BizError from .base import ( ExchangeField, ExchangePlaceholder, ExchangeTemplateBinding, ExchangeTemplateKind, ExchangePlan, ExchangeTemplateSnapshot, ExchangeTaskKind, ) from .models import ( ExchangeTaskModel, ExchangeTaskRowModel, ExchangeTemplateModel, ExchangeTemplateVersionModel, ) from .tasks import get_exchange_storage class ExchangeService: def __init__(self, app, db: Session) -> None: self.app = app self.db = db def create_template( self, *, code: str, name: str, template_kind: ExchangeTemplateKind | str, entity: str, description: str | None = None, meta: dict[str, Any] | None = None, ) -> ExchangeTemplateModel: template = ExchangeTemplateModel( code=code, name=name, template_kind=_enum_value(template_kind), entity=entity, description=description, meta=meta or {}, ) self.db.add(template) self.db.commit() self.db.refresh(template) return template def update_template( self, template_id: str, *, name: str | None = None, description: str | None = None, status: str | None = None, current_version: str | None = None, meta: dict[str, Any] | None = None, ) -> ExchangeTemplateModel: template = self.get_template_or_404(template_id) if name is not None: template.name = name if description is not None: template.description = description if status is not None: template.status = status if current_version is not None: template.current_version = current_version if meta is not None: template.meta = meta self.db.commit() self.db.refresh(template) return template def publish_version( self, *, template_id: str, version: str, bindings: list[ExchangeTemplateBinding] | None = None, fields: list[ExchangeField] | None = None, placeholders: list[ExchangePlaceholder] | None = None, file_content: bytes | None = None, file_name: str | None = None, meta: dict[str, Any] | None = None, make_current: bool = True, ) -> ExchangeTemplateVersionModel: template = self.get_template_or_404(template_id) file_key = None checksum = None if file_content is not None: file_key = self.save_template_file( template=template, version=version, content=file_content, file_name=file_name, ) checksum = hashlib.sha256(file_content).hexdigest() snapshot = ExchangeTemplateVersionModel( template_id=template.id, version=version, template_kind=template.template_kind, published_at=datetime.now(), file_key=file_key, checksum=checksum, bindings=[_jsonable(asdict(item)) for item in bindings or []], fields=[_jsonable(asdict(item)) for item in fields or []], placeholders=[_jsonable(asdict(item)) for item in placeholders or []], meta=meta or {}, ) self.db.add(snapshot) if make_current: template.current_version = version template.status = "published" self.db.commit() self.db.refresh(snapshot) return snapshot def build_template_file(self, version_id: str) -> bytes: version = self.get_version_or_404(version_id) if version.file_key: storage = get_exchange_storage(self.app) with storage.download(version.file_key) as file_stream: return file_stream.read() snapshot = self.snapshot_from_model(version) return _excel_template_codec().dump(snapshot) def build_plan_template_file(self, plan: ExchangePlan) -> bytes: if plan.version_id: version = self.get_snapshot_by_version_id(plan.version_id) if version is not None and version.file_key: storage = get_exchange_storage(self.app) with storage.download(version.file_key) as file_stream: return file_stream.read() return _excel_template_codec().dump(plan) def export_rows( self, rows: list[dict[str, Any]], *, plan: ExchangePlan | None = None, fields: list[ExchangeField] | None = None, sheet_name: str | None = None, ) -> bytes: workbook_codec = _excel_workbook_codec() if plan is not None: return workbook_codec.export_rows_with_plan( plan=plan, rows=rows, sheet_name=sheet_name, ) if fields is not None: return workbook_codec.export_rows_with_template( fields=fields, rows=rows, sheet_name=sheet_name or "Export", ) if not rows: return workbook_codec.export_rows([], [], sheet_name=sheet_name or "Export") headers = list(rows[0].keys()) return workbook_codec.export_rows(headers, rows, sheet_name=sheet_name or "Export") def import_rows( self, content: bytes, *, plan: ExchangePlan | None = None, fields: list[ExchangeField] | None = None, ) -> list[dict[str, Any]]: workbook_codec = _excel_workbook_codec() if plan is not None and plan.fields: return workbook_codec.import_rows_with_fields(content, fields=list(plan.fields)) if fields is not None: return workbook_codec.import_rows_with_fields(content, fields=fields) return workbook_codec.import_rows(content) def save_template_file( self, *, template: ExchangeTemplateModel, version: str, content: bytes, file_name: str | None = None, ) -> str: suffix = _safe_suffix(file_name or "template.xlsx") key = f"exchange/templates/{template.code}/{version}/{hashlib.sha256(content).hexdigest()}.{suffix}" storage = get_exchange_storage(self.app) storage.upload(BytesIO(content), key, _excel_mime_type()) return key def create_task( self, *, template_id: str | None = None, version_id: str | None = None, version: str | None = None, task_kind: ExchangeTaskKind | str, requested_by: str | None = None, storage_key: str | None = None, input_payload: dict[str, Any] | None = None, meta: dict[str, Any] | None = None, ) -> ExchangeTaskModel: template = self.get_template_or_404(template_id) if template_id else None version_model = self.get_version_or_404(version_id) if version_id else None if template is not None and version_model is not None and version_model.template_id != template.id: raise BizError("模板版本不属于该模板", code=400) if template is None and version_model is not None: template = self.get_template_or_404(version_model.template_id) if template is not None and version_model is None: if version: version_model = self.get_snapshot(template_id=template.id, version=version) elif template.current_version: version_model = self.get_snapshot( template_id=template.id, version=template.current_version, ) task = ExchangeTaskModel( template_id=template.id if template is not None else None, template_version_id=version_model.id if version_model is not None else None, task_kind=_enum_value(task_kind), status="pending", requested_by=requested_by, storage_key=storage_key, input_payload=input_payload or {}, meta=meta or {}, ) self.db.add(task) self.db.commit() self.db.refresh(task) return task def get_snapshot(self, *, template_id: str, version: str) -> ExchangeTemplateSnapshot | None: version_model = self.db.scalar( select(ExchangeTemplateVersionModel) .where(ExchangeTemplateVersionModel.template_id == template_id) .where(ExchangeTemplateVersionModel.version == version) ) if version_model is None: return None return self.snapshot_from_model(version_model) def get_snapshot_by_version_id(self, version_id: str) -> ExchangeTemplateSnapshot | None: version_model = self.db.get(ExchangeTemplateVersionModel, version_id) if version_model is None: return None return self.snapshot_from_model(version_model) def get_current_snapshot(self, template_id: str) -> ExchangeTemplateSnapshot | None: template = self.db.get(ExchangeTemplateModel, template_id) if template is None or not template.current_version: return None return self.get_snapshot(template_id=template_id, version=template.current_version) def resolve_plan( self, *, template_kind: ExchangeTemplateKind | str, template_id: str | None = None, version_id: str | None = None, version: str | None = None, bindings: list[ExchangeTemplateBinding] | None = None, fields: list[ExchangeField] | None = None, placeholders: list[ExchangePlaceholder] | None = None, title: str | None = None, description: str | None = None, sheet_name: str | None = None, meta: dict[str, Any] | None = None, source: Any | None = None, ) -> ExchangePlan: if source is not None: return source.resolve_plan( template_kind=template_kind, template_id=template_id, version_id=version_id, version=version, bindings=bindings, fields=fields, placeholders=placeholders, title=title, description=description, sheet_name=sheet_name, meta=meta, ) if version_id: snapshot = self.get_snapshot_by_version_id(version_id) if snapshot is not None: return snapshot.to_plan() if template_id and version: snapshot = self.get_snapshot(template_id=template_id, version=version) if snapshot is not None: return snapshot.to_plan() if template_id: current = self.get_current_snapshot(template_id) if current is not None: return current.to_plan() return ExchangePlan.from_mapping( template_kind=template_kind, template_id=template_id, version_id=version_id, version=version, bindings=bindings, fields=fields, placeholders=placeholders, title=title, description=description, sheet_name=sheet_name, meta=meta, ) def mark_task_running(self, task_id: str) -> ExchangeTaskModel: task = self.get_task_or_404(task_id) task.status = "running" task.started_at = datetime.now() self.db.commit() self.db.refresh(task) return task def mark_task_finished( self, task_id: str, *, status: str = "success", message: str | None = None, result_payload: dict[str, Any] | None = None, success_count: int | None = None, failed_count: int | None = None, ) -> ExchangeTaskModel: task = self.get_task_or_404(task_id) task.status = status task.message = message task.finished_at = datetime.now() if result_payload is not None: task.result_payload = result_payload if success_count is not None: task.success_count = success_count if failed_count is not None: task.failed_count = failed_count task.error_count = failed_count self.db.commit() self.db.refresh(task) return task def add_task_row( self, *, task_id: str, row_index: int, status: str, data: dict[str, Any] | None = None, message: str | None = None, result: dict[str, Any] | None = None, ) -> ExchangeTaskRowModel: row = ExchangeTaskRowModel( task_id=task_id, row_index=row_index, status=status, data=data or {}, message=message, result=result or {}, ) self.db.add(row) self.db.commit() self.db.refresh(row) return row def get_template_or_404(self, template_id: str) -> ExchangeTemplateModel: template = self.db.get(ExchangeTemplateModel, template_id) if template is None: raise BizError("模板不存在", code=404) return template def get_version_or_404(self, version_id: str) -> ExchangeTemplateVersionModel: version = self.db.get(ExchangeTemplateVersionModel, version_id) if version is None: raise BizError("模板版本不存在", code=404) return version def get_task_or_404(self, task_id: str) -> ExchangeTaskModel: task = self.db.get(ExchangeTaskModel, task_id) if task is None: raise BizError("导入导出任务不存在", code=404) return task def list_templates(self) -> list[ExchangeTemplateModel]: return list( self.db.scalars( select(ExchangeTemplateModel).order_by( ExchangeTemplateModel.entity, ExchangeTemplateModel.code, ) ) ) def list_versions(self, template_id: str) -> list[ExchangeTemplateVersionModel]: return list( self.db.scalars( select(ExchangeTemplateVersionModel) .where(ExchangeTemplateVersionModel.template_id == template_id) .order_by(ExchangeTemplateVersionModel.version) ) ) def list_tasks(self, template_id: str | None = None) -> list[ExchangeTaskModel]: statement = select(ExchangeTaskModel).order_by(ExchangeTaskModel.created_at.desc()) if template_id is not None: statement = statement.where(ExchangeTaskModel.template_id == template_id) return list(self.db.scalars(statement)) def list_task_rows(self, task_id: str) -> list[ExchangeTaskRowModel]: return list( self.db.scalars( select(ExchangeTaskRowModel) .where(ExchangeTaskRowModel.task_id == task_id) .order_by(ExchangeTaskRowModel.row_index) ) ) def snapshot_from_model( self, version: ExchangeTemplateVersionModel ) -> ExchangeTemplateSnapshot: return ExchangeTemplateSnapshot( id=version.id, version=version.version, template_id=version.template_id, template_kind=ExchangeTemplateKind(version.template_kind), bindings=tuple(_binding_from_dict(item) for item in version.bindings), published_at=version.published_at.isoformat() if version.published_at else None, file_key=version.file_key, checksum=version.checksum, fields=tuple(_field_from_dict(item) for item in version.fields), placeholders=tuple(_placeholder_from_dict(item) for item in version.placeholders), meta=version.meta, ) def _field_from_dict(value: dict[str, Any]) -> ExchangeField: options = value.get("options") or () return ExchangeField( key=value["key"], label=value["label"], placeholder=value.get("placeholder"), required=bool(value.get("required", False)), example=value.get("example"), width=value.get("width"), format=value.get("format"), source=value.get("source"), target=value.get("target"), options=tuple(tuple(item) for item in options), meta=value.get("meta") or {}, ) def _placeholder_from_dict(value: dict[str, Any]) -> ExchangePlaceholder: return ExchangePlaceholder( key=value["key"], label=value["label"], description=value.get("description"), required=bool(value.get("required", False)), example=value.get("example"), ) def _binding_from_dict(value: dict[str, Any]) -> ExchangeTemplateBinding: return ExchangeTemplateBinding( entity=value["entity"], template_kind=ExchangeTemplateKind(value["template_kind"]), handler=value.get("handler"), description=value.get("description"), default_sheet_name=value.get("default_sheet_name"), default_file_name=value.get("default_file_name"), title=value.get("title"), meta=value.get("meta") or {}, ) def _enum_value(value: Any) -> str: return value.value if hasattr(value, "value") else str(value) def _safe_suffix(file_name: str) -> str: if "." not in file_name: return "xlsx" suffix = file_name.rsplit(".", 1)[-1].lower() return "".join(ch for ch in suffix if ch.isalnum()) or "xlsx" def _excel_mime_type() -> str: return "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" def _excel_template_codec(): from .excel import ExcelTemplateCodec return ExcelTemplateCodec() def _excel_workbook_codec(): from .excel import ExcelWorkbookCodec return ExcelWorkbookCodec() def _jsonable(value: Any) -> Any: if isinstance(value, Enum): return value.value if isinstance(value, dict): return {key: _jsonable(item) for key, item in value.items()} if isinstance(value, list): return [_jsonable(item) for item in value] if isinstance(value, tuple): return [_jsonable(item) for item in value] return value