You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
3.6 KiB
Python
125 lines
3.6 KiB
Python
from flask_sqlalchemy import SQLAlchemy
|
|
from flask_sqlalchemy.query import Query as BaseQuery
|
|
from flask_marshmallow import Marshmallow
|
|
import datetime
|
|
import os
|
|
from marshmallow import fields
|
|
from marshmallow.validate import (
|
|
URL,
|
|
Email,
|
|
Range,
|
|
Length,
|
|
Equal,
|
|
Regexp,
|
|
Predicate,
|
|
NoneOf,
|
|
OneOf,
|
|
ContainsOnly,
|
|
)
|
|
from iti.applications.common.utils import fail
|
|
from sqlalchemy import MetaData
|
|
|
|
URL.default_message = "无效的链接"
|
|
Email.default_message = "无效的邮箱地址"
|
|
Range.message_min = "不能小于{min}"
|
|
Range.message_max = "不能小于{max}"
|
|
Range.message_all = "不能超过{min}和{max}这个范围"
|
|
Length.message_min = "长度不得小于{min}位"
|
|
Length.message_max = "长度不得大于{max}位"
|
|
Length.message_all = "长度不能超过{min}和{max}这个范围"
|
|
Length.message_equal = "长度必须等于{equal}位"
|
|
Equal.default_message = "必须等于{other}"
|
|
Regexp.default_message = "非法输入"
|
|
Predicate.default_message = "非法输入"
|
|
NoneOf.default_message = "非法输入"
|
|
OneOf.default_message = "无效的选择"
|
|
ContainsOnly.default_message = "一个或多个无效的选择"
|
|
|
|
fields.Field.default_error_messages = {
|
|
"required": "缺少必要数据",
|
|
"null": "数据不能为空",
|
|
"validator_failed": "非法数据",
|
|
}
|
|
|
|
fields.Str.default_error_messages = {"invalid": "不是合法文本"}
|
|
fields.Int.default_error_messages = {"invalid": "不是合法整数"}
|
|
fields.Number.default_error_messages = {"invalid": "不是合法数字"}
|
|
fields.Boolean.default_error_messages = {"invalid": "不是合法布尔值"}
|
|
|
|
|
|
class Query(BaseQuery):
|
|
def soft_delete(self):
|
|
"""
|
|
软删除查询
|
|
"""
|
|
return self.update({"deleted_at": datetime.datetime.now()})
|
|
|
|
def logic_all(self):
|
|
"""
|
|
逻辑未删除查询
|
|
"""
|
|
return self.filter_by(deleted_at=None).all()
|
|
|
|
def all_json(self, schema: Marshmallow().Schema):
|
|
"""
|
|
查询结果转换为 JSON
|
|
"""
|
|
return schema(many=True).dump(self.all())
|
|
|
|
|
|
naming_convention = {
|
|
"ix": "ix_%(column_0_label)s",
|
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
|
"ck": "ck_%(table_name)s_%(column_0_name)s",
|
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
|
"pk": "pk_%(table_name)s",
|
|
}
|
|
|
|
db = SQLAlchemy(
|
|
query_class=Query, metadata=MetaData(naming_convention=naming_convention),
|
|
)
|
|
ma = Marshmallow()
|
|
|
|
|
|
def init_db(app) -> None:
|
|
"""
|
|
初始化数据库
|
|
"""
|
|
db.init_app(app)
|
|
ma.init_app(app)
|
|
|
|
# db错误处理
|
|
_handle_db_error(app)
|
|
|
|
if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
|
|
with app.app_context():
|
|
try:
|
|
db.engine.connect()
|
|
except Exception as e:
|
|
exit(f"数据库连接失败: {e}")
|
|
|
|
|
|
def _handle_db_error(app):
|
|
from sqlalchemy.exc import SQLAlchemyError
|
|
|
|
show_error_details = app.config.get("SQLALCHEMY_SHOW_ERROR_DETAILS", False)
|
|
|
|
@app.errorhandler(SQLAlchemyError)
|
|
def handle_sqlalchemy_db_error(error):
|
|
"""
|
|
SQLAlchemy 数据库错误处理
|
|
"""
|
|
app.logger.error(f"数据库错误: {error}")
|
|
data = {
|
|
"code": error.code if hasattr(error, "code") else 500,
|
|
}
|
|
if show_error_details:
|
|
data["args"] = error.args if hasattr(error, "args") else None
|
|
data["statement"] = error.statement if hasattr(error, "statement") else None
|
|
data["params"] = error.params if hasattr(error, "params") else None
|
|
return fail(
|
|
"数据库错误",
|
|
code=500,
|
|
data=data,
|
|
)
|