You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1138 lines
39 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from __future__ import annotations
import hashlib
import mimetypes
import os
import shutil
import tempfile
from datetime import datetime, timedelta
from io import BytesIO
from pathlib import Path
from typing import Dict, Optional, Union
from flask import current_app
from sqlalchemy import select, exists
from iti.applications.common.enums import StatusEnum
from iti.applications.common.enums.sys import StorageTypeEnum
from iti.applications.common.exceptions.biz_exp import BizException
from iti.applications.common.storage import StorageManager
from iti.applications.extensions import cache_simple, db
from iti.applications.models import SysFile, SysFileDirectory
class SysFileService:
CHUNK_UPLOAD_CACHE_PREFIX = "chunk_upload:"
CHUNK_TEMP_DIR = "chunk_uploads" # 分片临时目录
# ------------------------------------------------------------------
# 工具方法
# ------------------------------------------------------------------
@staticmethod
def _guess_mime_type(
filename: str, provided_mime: Optional[str] = None
) -> Optional[str]:
"""
推断文件 MIME 类型
Args:
filename: 文件名
provided_mime: 提供的 MIME 类型(优先使用)
Returns:
MIME 类型
"""
# 优先使用提供的 MIME 类型
if provided_mime:
return provided_mime
# 根据文件扩展名推断
mime_type, _ = mimetypes.guess_type(filename)
return mime_type
@staticmethod
def _get_backend_url() -> str:
"""
获取后端访问地址
Returns:
后端 URL不包含尾部斜杠
"""
from iti.applications.service.sys.sys_config import get_str
backend_url = get_str(
"BACKEND_URL", type="SYSTEM", default="http://localhost:5000"
)
return backend_url.rstrip("/")
@classmethod
def _get_chunk_temp_dir(cls) -> Path:
"""
获取分片临时目录
Returns:
临时目录路径
"""
# 优先使用配置的临时目录
temp_base = current_app.config.get("UPLOAD_TEMP_DIR")
if not temp_base:
# 使用系统临时目录
temp_base = tempfile.gettempdir()
chunk_dir = Path(temp_base) / cls.CHUNK_TEMP_DIR
chunk_dir.mkdir(parents=True, exist_ok=True)
return chunk_dir
@classmethod
def _get_upload_temp_dir(cls, upload_id: str) -> Path:
"""
获取指定上传任务的临时目录
Args:
upload_id: 上传任务ID
Returns:
上传任务临时目录路径
"""
upload_dir = cls._get_chunk_temp_dir() / upload_id
upload_dir.mkdir(parents=True, exist_ok=True)
return upload_dir
@classmethod
def _get_chunk_file_path(cls, upload_id: str, chunk_index: int) -> Path:
"""
获取分片文件路径
Args:
upload_id: 上传任务ID
chunk_index: 分片索引
Returns:
分片文件路径
"""
upload_dir = cls._get_upload_temp_dir(upload_id)
return upload_dir / f"chunk_{chunk_index}"
@classmethod
def _cleanup_upload_temp_dir(cls, upload_id: str) -> None:
"""
清理上传任务的临时目录
Args:
upload_id: 上传任务ID
"""
upload_dir = cls._get_chunk_temp_dir() / upload_id
if upload_dir.exists():
try:
shutil.rmtree(upload_dir)
except Exception as e:
# 记录日志但不抛出异常
current_app.logger.warning(f"清理临时目录失败: {upload_dir}, 错误: {e}")
@classmethod
def cleanup_expired_chunk_uploads(cls, days: int = 7) -> Dict[str, int]:
"""
清理过期的分片上传临时文件(定期任务调用)
Args:
days: 保留天数默认7天
Returns:
{"cleaned_dirs": int, "cleaned_size": int} 清理的目录数和释放的空间(字节)
"""
chunk_temp_dir = cls._get_chunk_temp_dir()
if not chunk_temp_dir.exists():
return {"cleaned_dirs": 0, "cleaned_size": 0}
threshold = datetime.now() - timedelta(days=days)
cleaned_dirs = 0
cleaned_size = 0
try:
for upload_dir in chunk_temp_dir.iterdir():
if not upload_dir.is_dir():
continue
# 检查目录修改时间
dir_mtime = datetime.fromtimestamp(upload_dir.stat().st_mtime)
if dir_mtime < threshold:
# 计算目录大小
dir_size = sum(
f.stat().st_size for f in upload_dir.rglob("*") if f.is_file()
)
# 删除目录
try:
shutil.rmtree(upload_dir)
cleaned_dirs += 1
cleaned_size += dir_size
current_app.logger.info(
f"清理过期分片上传目录: {upload_dir.name}, 大小: {dir_size} 字节"
)
except Exception as e:
current_app.logger.warning(
f"清理目录失败: {upload_dir}, 错误: {e}"
)
except Exception as e:
current_app.logger.error(f"清理过期分片上传失败: {e}")
return {"cleaned_dirs": cleaned_dirs, "cleaned_size": cleaned_size}
# ------------------------------------------------------------------
# 普通上传
# ------------------------------------------------------------------
@classmethod
def upload_file(
cls,
file,
directory_id: Optional[str] = None,
metadata: Optional[Dict] = None,
storage_type: Optional[str] = None,
) -> Dict:
"""
上传文件
Returns:
{"file": SysFile, "instantUpload": bool}
"""
metadata = metadata or {}
# 如果未指定目录,使用默认目录
if not directory_id:
from .sys_file_directory import SysFileDirectoryService
directory_id = SysFileDirectoryService.get_default_directory_id()
file.seek(0)
file_bytes = file.read()
file_hash = hashlib.md5(file_bytes).hexdigest()
file.seek(0)
storage_type_enum = cls._resolve_storage_type(storage_type)
existing = db.session.scalar(
select(SysFile).where(
SysFile.file_hash == file_hash,
SysFile.status == StatusEnum.ENABLED,
)
)
if existing and existing.storage_type == storage_type_enum:
# 秒传:更新已有记录
existing.filename = file.filename
existing.directory_id = directory_id
existing.metadata_ = metadata if metadata else None
db.session.commit()
return {"file": existing, "instantUpload": True}
storage = StorageManager.get_storage(storage_type_enum)
ext = os.path.splitext(file.filename or "")[1]
# 推断 MIME 类型
mime_type = cls._guess_mime_type(file.filename, getattr(file, "mimetype", None))
# 为支持多存储类型,前缀加上存储类型,使用冒号分隔,避免唯一索引冲突
file_key = f"{storage_type_enum.value}:{datetime.now():%Y%m%d}/{file_hash}{ext}"
upload_result = storage.upload(BytesIO(file_bytes), file_key, mime_type)
new_file = SysFile(
filename=file.filename,
file_key=file_key,
file_hash=file_hash,
mime_type=mime_type,
file_size=len(file_bytes),
extension=ext,
storage_type=storage_type_enum,
storage_info=upload_result
if storage_type_enum != StorageTypeEnum.LOCAL
else None,
directory_id=directory_id,
metadata_=metadata if metadata else None,
status=StatusEnum.ENABLED,
)
db.session.add(new_file)
db.session.commit()
return {"file": new_file, "instantUpload": False}
# ------------------------------------------------------------------
# 分片上传(自定义协议,参考 hotgo
# ------------------------------------------------------------------
@classmethod
def init_chunk_upload(
cls,
filename: str,
file_size: int,
file_hash: Optional[str] = None,
chunk_size: int = 2 * 1024 * 1024,
total_chunks: Optional[int] = None,
directory_id: Optional[str] = None,
metadata: Optional[Dict] = None,
storage_type: Optional[str] = None,
) -> Dict:
"""
初始化分片上传任务
Args:
filename: 文件名
file_size: 文件总大小
file_hash: 文件MD5哈希用于秒传和断点续传
chunk_size: 分片大小
total_chunks: 总分片数
directory_id: 目录ID
metadata: 扩展元数据
storage_type: 存储类型
Returns:
{"instantUpload": bool, "uploadId": str, "file": SysFile, "uploadedChunks": list}
"""
metadata = metadata or {}
# 如果未指定目录,使用默认目录
if not directory_id:
from .sys_file_directory import SysFileDirectoryService
directory_id = SysFileDirectoryService.get_default_directory_id()
resolved_storage_type = cls._resolve_storage_type(storage_type)
# 秒传检测
if file_hash:
existing = db.session.scalar(
select(SysFile)
.filter_by(file_hash=file_hash, status=StatusEnum.ENABLED)
.limit(1)
)
if existing and existing.storage_type == resolved_storage_type:
# 秒传:更新已有记录
existing.filename = filename
existing.directory_id = directory_id
existing.metadata_ = metadata if metadata else None
db.session.commit()
return {"instantUpload": True, "file": existing}
import uuid
upload_id = str(uuid.uuid4())
ext = os.path.splitext(filename or "")[1]
# 计算总分片数
if total_chunks is None:
total_chunks = (file_size + chunk_size - 1) // chunk_size
# 检查是否存在未完成的上传任务(断点续传)
existing_chunks = []
if file_hash:
# 尝试通过文件哈希查找已存在的上传任务
# 遍历所有上传任务,查找匹配的 file_hash
chunk_temp_dir = cls._get_chunk_temp_dir()
if chunk_temp_dir.exists():
for existing_upload_dir in chunk_temp_dir.iterdir():
if not existing_upload_dir.is_dir():
continue
# 检查缓存中的上传任务信息
existing_upload_id = existing_upload_dir.name
cache_key = cls._chunk_upload_cache_key(existing_upload_id)
cached_data = cache_simple.get(cache_key)
if cached_data and cached_data.get("file_hash") == file_hash:
# 找到匹配的上传任务,复用该 upload_id
upload_id = existing_upload_id
current_app.logger.info(
f"断点续传 - 复用上传任务: {upload_id}, file_hash: {file_hash}"
)
break
# 检查临时目录中已存在的分片文件
upload_dir = cls._get_chunk_temp_dir() / upload_id
if upload_dir.exists():
for chunk_file in upload_dir.glob("chunk_*"):
try:
# 从文件名中提取分片索引
chunk_index = int(chunk_file.name.split("_")[1])
if 0 <= chunk_index < total_chunks:
existing_chunks.append(chunk_index)
except (ValueError, IndexError):
# 忽略无效的文件名
current_app.logger.warning(f"无效的分片文件名: {chunk_file.name}")
continue
# 排序分片索引
existing_chunks.sort()
if existing_chunks:
current_app.logger.info(
f"断点续传 - upload_id: {upload_id}, "
f"已上传分片: {len(existing_chunks)}/{total_chunks}, "
f"分片列表: {existing_chunks}"
)
upload_data = {
"upload_id": upload_id,
"filename": filename,
"file_size": file_size,
"file_hash": file_hash,
"chunk_size": chunk_size,
"total_chunks": total_chunks,
"uploaded_chunks": existing_chunks, # 已上传的分片索引列表
"storage_type": resolved_storage_type.value, # 存储为字符串值
"directory_id": directory_id,
"metadata": metadata,
"extension": ext,
"created_at": datetime.now().isoformat(),
}
cache_simple.set(
cls._chunk_upload_cache_key(upload_id), upload_data, timeout=7 * 24 * 3600
)
return {
"instantUpload": False,
"uploadId": upload_id,
"uploadedChunks": existing_chunks,
}
@classmethod
def upload_chunk(
cls,
upload_id: str,
chunk_index: int,
chunk_data: bytes,
) -> Dict:
"""
上传单个分片
Args:
upload_id: 上传任务ID
chunk_index: 分片索引从0开始
chunk_data: 分片数据
Returns:
{"chunkIndex": int, "uploaded": bool}
"""
cache_key = cls._chunk_upload_cache_key(upload_id)
upload_data = cache_simple.get(cache_key)
if not upload_data:
raise BizException("上传任务不存在或已过期", code=404)
total_chunks = upload_data.get("total_chunks", 0)
if chunk_index < 0 or chunk_index >= total_chunks:
raise BizException(f"分片索引无效: {chunk_index}", code=400)
# 检查分片文件是否已存在(基于文件系统,避免缓存并发问题)
chunk_file_path = cls._get_chunk_file_path(upload_id, chunk_index)
if chunk_file_path.exists():
# 分片已上传,跳过
current_app.logger.debug(f"分片 {chunk_index} 文件已存在,跳过")
return {"chunkIndex": chunk_index, "uploaded": True}
# 将分片数据写入临时文件
temp_file_path = None
try:
# 使用临时文件 + 原子重命名来避免并发写入问题
temp_file_path = chunk_file_path.with_suffix(".tmp")
with open(temp_file_path, "wb") as f:
f.write(chunk_data)
# 原子重命名
temp_file_path.replace(chunk_file_path)
temp_file_path = None # 重命名成功标记为None
except Exception as e:
# 清理临时文件
if temp_file_path and temp_file_path.exists():
try:
temp_file_path.unlink()
except:
pass
raise BizException(f"保存分片失败: {str(e)}", code=500)
current_app.logger.debug(f"分片上传成功 - chunk_index: {chunk_index}")
return {"chunkIndex": chunk_index, "uploaded": True}
@classmethod
def merge_chunks(
cls,
upload_id: str,
file_hash: Optional[str] = None,
) -> SysFile:
"""
合并分片,生成最终文件
Args:
upload_id: 上传任务ID
file_hash: 文件MD5哈希用于最终校验
Returns:
SysFile 记录
"""
cache_key = cls._chunk_upload_cache_key(upload_id)
upload_data = cache_simple.get(cache_key)
if not upload_data:
raise BizException("上传任务不存在或已过期", code=404)
total_chunks = upload_data.get("total_chunks", 0)
# 直接检查文件系统中的分片文件,而不是依赖缓存
existing_chunks = []
for chunk_index in range(total_chunks):
chunk_file_path = cls._get_chunk_file_path(upload_id, chunk_index)
if chunk_file_path.exists():
existing_chunks.append(chunk_index)
# 记录详细信息用于调试
current_app.logger.info(
f"合并分片 - upload_id: {upload_id}, "
f"total_chunks: {total_chunks}, "
f"existing_chunks_count: {len(existing_chunks)}, "
f"existing_chunks: {existing_chunks}"
)
# 检查是否所有分片都已上传
if len(existing_chunks) != total_chunks:
# 找出缺失的分片
expected_chunks = set(range(total_chunks))
actual_chunks = set(existing_chunks)
missing_chunks = sorted(expected_chunks - actual_chunks)
current_app.logger.error(
f"分片未上传完整 - upload_id: {upload_id}, 缺失分片: {missing_chunks}"
)
raise BizException(
f"分片未上传完整: {len(existing_chunks)}/{total_chunks}, 缺失分片: {missing_chunks}",
code=400,
)
# 流式合并分片 - 避免内存溢出
# 使用临时文件而不是内存来合并大文件
temp_merged_file = None
temp_merged_file_path = None
md5_hash = hashlib.md5()
total_size = 0
file_record = None
try:
# 创建临时合并文件
temp_merged_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tmp")
temp_merged_file_path = temp_merged_file.name
current_app.logger.info(f"创建临时合并文件: {temp_merged_file_path}")
# 流式读取并合并分片(每次只读取一个分片到内存)
BUFFER_SIZE = 8 * 1024 * 1024 # 8MB 缓冲区
chunk_file = None
try:
for chunk_index in range(total_chunks):
chunk_file_path = cls._get_chunk_file_path(upload_id, chunk_index)
if not chunk_file_path.exists():
raise BizException(f"分片文件丢失: {chunk_index}", code=500)
# 流式复制分片内容
chunk_file = open(chunk_file_path, "rb")
try:
while True:
buffer = chunk_file.read(BUFFER_SIZE)
if not buffer:
break
temp_merged_file.write(buffer)
md5_hash.update(buffer)
total_size += len(buffer)
finally:
chunk_file.close()
chunk_file = None
# 定期记录进度
if (chunk_index + 1) % 50 == 0 or chunk_index == total_chunks - 1:
current_app.logger.info(
f"合并进度: {chunk_index + 1}/{total_chunks} "
f"({(chunk_index + 1) / total_chunks * 100:.1f}%)"
)
finally:
# 确保所有分片文件句柄都被关闭
if chunk_file:
try:
chunk_file.close()
except:
pass
# 刷新并关闭临时文件
temp_merged_file.flush()
os.fsync(temp_merged_file.fileno()) # 强制写入磁盘
temp_merged_file.close()
temp_merged_file = None
actual_hash = md5_hash.hexdigest()
current_app.logger.info(
f"分片合并完成 - 文件大小: {total_size} 字节, "
f"MD5: {actual_hash}"
)
# 校验文件哈希
if file_hash and actual_hash != file_hash:
raise BizException("文件哈希校验失败", code=400)
# 上传到存储(从缓存中获取的是字符串值,需要转换为 enum
storage_type_enum = cls._resolve_storage_type(
upload_data.get("storage_type")
)
storage = StorageManager.get_storage(storage_type_enum)
filename = upload_data.get("filename")
ext = upload_data.get("extension", "")
# 推断 MIME 类型
mime_type = cls._guess_mime_type(filename)
file_key = (
f"{storage_type_enum.value}:{datetime.now():%Y%m%d}/{actual_hash}{ext}"
)
# 重新打开临时文件用于上传(流式传输)
merged_stream = None
try:
merged_stream = open(temp_merged_file_path, "rb")
current_app.logger.info(f"开始上传到存储: {file_key}")
upload_result = storage.upload(merged_stream, file_key, mime_type)
current_app.logger.info(f"上传完成: {file_key}")
finally:
if merged_stream:
try:
merged_stream.close()
except:
pass
# 创建文件记录
metadata = upload_data.get("metadata", {})
file_record = SysFile(
filename=filename,
file_key=file_key,
file_hash=actual_hash,
file_size=total_size,
extension=ext,
mime_type=mime_type,
storage_type=storage_type_enum,
storage_info=upload_result
if storage_type_enum != StorageTypeEnum.LOCAL
else None,
directory_id=upload_data.get("directory_id"),
metadata_=metadata if metadata else None,
status=StatusEnum.ENABLED,
)
db.session.add(file_record)
db.session.commit()
return file_record
except Exception as e:
# 发生异常时回滚数据库事务
try:
db.session.rollback()
except:
pass
current_app.logger.error(
f"合并分片失败 - upload_id: {upload_id}, 错误: {str(e)}"
)
raise
finally:
# 清理临时合并文件(无论成功还是失败)
if temp_merged_file:
try:
temp_merged_file.close()
except:
pass
if temp_merged_file_path and os.path.exists(temp_merged_file_path):
try:
os.remove(temp_merged_file_path)
current_app.logger.info(f"清理临时合并文件: {temp_merged_file_path}")
except Exception as e:
current_app.logger.warning(f"清理临时合并文件失败: {e}")
# 只有成功时才清理缓存和分片文件
# 失败时保留分片文件,支持断点续传
if file_record:
try:
cache_simple.delete(cache_key)
cls._cleanup_upload_temp_dir(upload_id)
except Exception as e:
current_app.logger.warning(f"清理上传临时数据失败: {e}")
@classmethod
def get_chunk_upload_progress(cls, upload_id: str) -> Dict:
"""
获取分片上传进度
Args:
upload_id: 上传任务ID
Returns:
{"uploadId": str, "totalChunks": int, "uploadedChunks": list, "progress": float}
"""
cache_key = cls._chunk_upload_cache_key(upload_id)
upload_data = cache_simple.get(cache_key)
if not upload_data:
raise BizException("上传任务不存在或已过期", code=404)
total_chunks = upload_data.get("total_chunks", 0)
uploaded_chunks = upload_data.get("uploaded_chunks", [])
progress = (
(len(uploaded_chunks) / total_chunks * 100) if total_chunks > 0 else 0
)
return {
"uploadId": upload_id,
"totalChunks": total_chunks,
"uploadedChunks": uploaded_chunks,
"progress": round(progress, 2),
}
@classmethod
def abort_chunk_upload(cls, upload_id: str) -> None:
"""
取消分片上传,清理临时数据
Args:
upload_id: 上传任务ID
"""
cache_key = cls._chunk_upload_cache_key(upload_id)
upload_data = cache_simple.get(cache_key)
if not upload_data:
raise BizException("上传任务不存在或已过期", code=404)
# 清理上传任务缓存
cache_simple.delete(cache_key)
# 清理临时文件
cls._cleanup_upload_temp_dir(upload_id)
# ------------------------------------------------------------------
# 文件访问工具
# ------------------------------------------------------------------
@staticmethod
def get_file_by_id(file_id: str) -> SysFile:
file_obj = db.session.get(SysFile, file_id)
if not file_obj or file_obj.status != StatusEnum.ENABLED:
raise BizException("文件不存在", code=404)
return file_obj
@classmethod
def get_file_url(cls, file_id: str, expires: int = 3600) -> str:
"""
获取文件访问URL
Args:
file_id: 文件ID
expires: 过期时间0表示永久仅对OSS生效
Returns:
文件访问URL
"""
file_obj = cls.get_file_by_id(file_id)
storage = StorageManager.get_storage(file_obj.storage_type)
# 本地存储返回后端下载路由
if file_obj.storage_type == StorageTypeEnum.LOCAL:
backend_url = cls._get_backend_url()
return f"{backend_url}/file/{file_id}/download"
# OSS存储返回直接访问URL
return storage.get_url(file_obj.file_key, expires=expires)
@classmethod
def get_preview_url(cls, file_id: str) -> str:
"""
获取预览URL
- local: 返回后端预览路由
- 非local: 委托存储适配器生成签名直链
"""
file_obj = cls.get_file_by_id(file_id)
if file_obj.storage_type == StorageTypeEnum.LOCAL:
backend_url = cls._get_backend_url()
return f"{backend_url}/file/{file_id}/preview"
storage = StorageManager.get_storage(file_obj.storage_type)
return storage.get_preview_url(file_obj.file_key, expires=3600)
@classmethod
def get_thumbnail_url(
cls, file_id: str, width: int = 200, height: int = 200, mode: str = "fit", include_params: bool = False
) -> Optional[str]:
"""
获取缩略图URL
策略:
- local: 返回后端缩略图路由(实时生成)
- 阿里云OSS: 使用 OSS 图片处理能力(外链)
- 其他存储: 返回后端缩略图路由(下载后本地生成)
Args:
file_id: 文件ID
width: 宽度
height: 高度
mode: 模式fit/fill/pad
include_params: 是否在URL中包含参数默认True设为False时返回不带参数的基础URL
Returns:
缩略图URL非图片返回 None
"""
file_obj = cls.get_file_by_id(file_id)
if not file_obj.mime_type or not file_obj.mime_type.startswith("image/"):
return None
backend_url = cls._get_backend_url()
# 本地存储:返回后端缩略图路由
if file_obj.storage_type == StorageTypeEnum.LOCAL:
base_url = f"{backend_url}/file/{file_id}/thumbnail"
if include_params:
return f"{base_url}?w={width}&h={height}&mode={mode}"
return base_url
storage = StorageManager.get_storage(file_obj.storage_type)
# 阿里云 OSS支持图片处理返回外链
if file_obj.storage_type == StorageTypeEnum.ALIYUN_OSS:
if include_params:
return storage.get_thumbnail_url(
file_obj.file_key, width=width, height=height, mode=mode, expires=3600
)
# 阿里云OSS不带参数时返回原图URL
return storage.get_url(file_obj.file_key, expires=3600)
# 其他存储MinIO、腾讯云等不支持图片处理返回后端路由
# 后端会下载原图后生成缩略图
base_url = f"{backend_url}/file/{file_id}/thumbnail"
if include_params:
return f"{base_url}?w={width}&h={height}&mode={mode}"
return base_url
@classmethod
def get_thumbnail(
cls, file_id: str, width: int = 200, height: int = 200, mode: str = "fit"
) -> BytesIO:
"""
生成缩略图
策略:
- local: 直接从本地读取并生成
- 阿里云OSS: 使用 OSS 图片处理(不应该调用此方法)
- 其他存储: 下载原图后生成缩略图
Args:
file_id: 文件ID
width: 宽度
height: 高度
mode: 模式fit/fill/pad
Returns:
缩略图数据流
"""
file_obj = cls.get_file_by_id(file_id)
# 非图片类型不支持缩略图
if not file_obj.mime_type or not file_obj.mime_type.startswith("image/"):
raise BizException("该文件类型不支持缩略图", code=400)
# 检查 Pillow 是否安装
try:
from PIL import Image
except ImportError:
raise BizException("需要安装 Pillow: pip install Pillow", code=500)
# 下载原图(无论是本地还是 OSS
_, file_stream = cls.download_file(file_id)
img = Image.open(file_stream)
# 转换为 RGB处理 PNG 透明通道等)
if img.mode in ("RGBA", "LA", "P"):
background = Image.new("RGB", img.size, (255, 255, 255))
if img.mode == "P":
img = img.convert("RGBA")
background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
img = background
elif img.mode != "RGB":
img = img.convert("RGB")
# 根据模式调整大小
if mode == "fill":
# 填充模式:裁剪居中
img.thumbnail((width * 2, height * 2), Image.Resampling.LANCZOS)
left = (img.width - width) / 2
top = (img.height - height) / 2
img = img.crop((left, top, left + width, top + height))
elif mode == "pad":
# 填充模式:保持比例,添加白边
img.thumbnail((width, height), Image.Resampling.LANCZOS)
background = Image.new("RGB", (width, height), (255, 255, 255))
offset = ((width - img.width) // 2, (height - img.height) // 2)
background.paste(img, offset)
img = background
else: # fit (默认)
# 适应模式:保持比例
img.thumbnail((width, height), Image.Resampling.LANCZOS)
# 保存为 JPEG
output = BytesIO()
img.save(output, format="JPEG", quality=85, optimize=True)
output.seek(0)
return output
@classmethod
def download_file(cls, file_id: str) -> tuple[SysFile, BytesIO]:
"""
下载文件
Args:
file_id: 文件ID
Returns:
(文件对象, 文件流)
"""
file_obj = cls.get_file_by_id(file_id)
storage = StorageManager.get_storage(file_obj.storage_type)
file_stream = storage.download(file_obj.file_key)
return file_obj, file_stream
# ------------------------------------------------------------------
# 回收站功能
# ------------------------------------------------------------------
@classmethod
def move_to_recycle(cls, file_id: str, user_id: Optional[str] = None) -> None:
"""
移动文件到回收站
Args:
file_id: 文件ID
user_id: 操作用户ID
"""
from iti.applications.service.sys.sys_config import get_bool
# 检查回收站功能是否启用
if not get_bool("FILE_RECYCLE_ENABLED", type="SYSTEM", default=True):
# 回收站未启用,直接物理删除
cls.delete_file_permanently(file_id)
return
file_obj = cls.get_file_by_id(file_id)
file_obj.is_deleted = True
file_obj.deleted_at = datetime.now()
file_obj.deleted_by = user_id
file_obj.status = StatusEnum.DISABLED
db.session.commit()
@classmethod
def restore_from_recycle(cls, file_id: str) -> None:
"""
从回收站恢复文件
Args:
file_id: 文件ID
"""
file_obj = db.session.get(SysFile, file_id)
if not file_obj:
raise BizException("文件不存在", code=404)
file_obj.is_deleted = False
file_obj.deleted_at = None
file_obj.deleted_by = None
file_obj.status = StatusEnum.ENABLED
db.session.commit()
@classmethod
def delete_file_permanently(cls, file_id: str) -> None:
"""
永久删除文件(物理删除)
Args:
file_id: 文件ID
"""
file_obj = db.session.get(SysFile, file_id)
if not file_obj:
raise BizException("文件不存在", code=404)
# 删除存储中的文件
storage = StorageManager.get_storage(file_obj.storage_type)
try:
storage.delete(file_obj.file_key)
except FileNotFoundError:
pass
# 删除数据库记录
db.session.delete(file_obj)
db.session.commit()
@classmethod
def clear_recycle_bin(cls, days: int = 30) -> int:
"""
清空回收站(删除指定天数前的文件)
Args:
days: 保留天数默认30天
Returns:
删除的文件数量
"""
from datetime import timedelta
threshold = datetime.now() - timedelta(days=days)
files = db.session.scalars(
select(SysFile).filter(
SysFile.is_deleted == True, SysFile.deleted_at < threshold
)
).all()
count = 0
for file_obj in files:
try:
cls.delete_file_permanently(file_obj.id)
count += 1
except Exception:
continue
return count
# ------------------------------------------------------------------
# 分享功能
# ------------------------------------------------------------------
@classmethod
def create_share(
cls,
file_id: str,
password: Optional[str] = None,
expire_hours: Optional[int] = None,
) -> Dict:
"""
创建文件分享
Args:
file_id: 文件ID
password: 分享密码(可选)
expire_hours: 过期小时数可选None表示永久
Returns:
{"share_code": str, "share_url": str}
"""
from iti.applications.service.sys.sys_config import get_bool
# 检查分享功能是否启用
if not get_bool("FILE_SHARE_ENABLED", type="SYSTEM", default=True):
raise BizException("文件分享功能未启用", code=403)
file_obj = cls.get_file_by_id(file_id)
# 生成分享码
import secrets
share_code = secrets.token_urlsafe(8)
# 计算过期时间
expire_at = None
if expire_hours:
from datetime import timedelta
expire_at = datetime.now() + timedelta(hours=expire_hours)
file_obj.share_code = share_code
file_obj.share_password = password
file_obj.share_expire_at = expire_at
file_obj.share_count = 0
db.session.commit()
# 生成分享链接
backend_url = cls._get_backend_url()
share_url = f"{backend_url}/file/share/{share_code}"
return {
"share_code": share_code,
"share_url": share_url,
"password": password,
"expire_at": expire_at.isoformat() if expire_at else None,
}
@classmethod
def cancel_share(cls, file_id: str) -> None:
"""
取消文件分享
Args:
file_id: 文件ID
"""
file_obj = cls.get_file_by_id(file_id)
file_obj.share_code = None
file_obj.share_password = None
file_obj.share_expire_at = None
db.session.commit()
@classmethod
def get_file_by_share_code(
cls, share_code: str, password: Optional[str] = None
) -> SysFile:
"""
通过分享码获取文件
Args:
share_code: 分享码
password: 分享密码
Returns:
文件对象
"""
file_obj = db.session.scalar(
select(SysFile).filter_by(
share_code=share_code, status=StatusEnum.ENABLED
)
)
if not file_obj:
raise BizException("分享不存在或已失效", code=404)
# 检查是否过期
if file_obj.share_expire_at and file_obj.share_expire_at < datetime.now():
raise BizException("分享已过期", code=403)
# 检查密码
if file_obj.share_password and file_obj.share_password != password:
raise BizException("分享密码错误", code=403)
# 增加访问次数
file_obj.share_count += 1
db.session.commit()
return file_obj
@staticmethod
def _chunk_upload_cache_key(upload_id: str) -> str:
"""分片上传任务缓存键"""
return f"{SysFileService.CHUNK_UPLOAD_CACHE_PREFIX}{upload_id}"
@staticmethod
def _resolve_storage_type(
storage_type: Optional[str | StorageTypeEnum],
) -> StorageTypeEnum:
"""
解析存储类型
Args:
storage_type: 用户指定的存储类型(可以是字符串或 StorageTypeEnum
Returns:
StorageTypeEnum 枚举对象
"""
# 如果已经是枚举对象,直接返回
if isinstance(storage_type, StorageTypeEnum):
return storage_type
# 如果未指定,从配置读取默认值
if not storage_type:
file_storage_config = current_app.config.get("FILE_STORAGE", {})
storage_type = file_storage_config.get("DEFAULT_STORAGE_TYPE", "local")
# 将字符串转换为枚举
try:
return StorageTypeEnum(storage_type)
except ValueError:
current_app.logger.warning(
f"存储类型 '{storage_type}' 无效,使用 'local' 作为后备"
)
return StorageTypeEnum.LOCAL