You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/applications/extensions/eventbus/event_bus.py

478 lines
16 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from concurrent.futures import ThreadPoolExecutor
import threading
import asyncio
import inspect
from functools import wraps
from .event_middleware import EventMiddleware
from iti.applications.common import setup_logger
logger = setup_logger(__name__)
class EventBus:
"""
事件总线
"""
def __init__(self, max_workers: int = 10):
self._handlers: dict[str, list[dict]] = {}
self._middlewares: list[EventMiddleware] = []
self._executor = ThreadPoolExecutor(
thread_name_prefix="EventBusExecutor", max_workers=max_workers
)
self._lock = threading.Lock()
self._stats = {
"errors": 0,
"events_emitted": 0,
"events_processed": 0,
}
def init_app(self, app):
"""
初始化事件总线
"""
self._app = app
def on(self, event_name: str, order: int = 0, async_mode: bool = False):
"""
注册事件处理器(装饰器形式)
"""
def decorator(func):
self._register_handler(event_name, func, order, async_mode)
return func
return decorator
def register_handler(
self, event_name: str, handler_func, order: int = 0, async_mode: bool = False
):
"""
手动注册事件处理器
Args:
event_name: 事件名称
handler_func: 处理器函数
order: 执行顺序(数字越小越先执行)
async_mode: 是否为异步模式
"""
self._register_handler(event_name, handler_func, order, async_mode)
def _register_handler(
self, event_name: str, handler_func, order: int = 0, async_mode: bool = False
):
"""
内部方法:注册事件处理器
"""
with self._lock:
if event_name not in self._handlers:
self._handlers[event_name] = []
# 分析函数签名并缓存
sig_info = self._analyze_signature(handler_func)
# 包装上下文
wrapped_func = self._wrap_context(handler_func, async_mode)
# 添加处理器
self._handlers[event_name].append(
{
"func": wrapped_func,
"orginal_func": handler_func,
"order": order,
"async_mode": async_mode,
"name": handler_func.__name__,
"sig_info": sig_info, # 缓存签名信息
}
)
# 排序(按order升序)
self._handlers[event_name].sort(key=lambda x: x["order"])
def _analyze_signature(self, func):
"""
分析函数签名,返回参数接受信息
Returns:
dict: 包含以下信息:
- max_positional: 最大位置参数数量不含self
- accepts_var_positional: 是否接受*args
- accepts_var_keyword: 是否接受**kwargs
"""
try:
sig = inspect.signature(func)
params = list(sig.parameters.values())
# 检查是否接受可变参数
accepts_var_positional = any(
p.kind == inspect.Parameter.VAR_POSITIONAL for p in params
)
accepts_var_keyword = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in params
)
# 计算最大位置参数数量(排除 self/cls
positional_params = [
p
for p in params
if p.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
)
and p.name not in ("self", "cls")
]
max_positional = len(positional_params)
return {
"max_positional": max_positional,
"accepts_var_positional": accepts_var_positional,
"accepts_var_keyword": accepts_var_keyword,
}
except Exception as e:
logger.warning(f"无法分析函数签名 {func.__name__}: {e},将使用默认行为")
# 如果无法分析签名,返回保守的默认值(接受所有参数)
return {
"max_positional": float("inf"),
"accepts_var_positional": True,
"accepts_var_keyword": True,
}
def _adapt_args(self, sig_info, args, kwargs):
"""
根据函数签名信息适配参数
Args:
sig_info: 函数签名信息
args: 原始位置参数
kwargs: 原始关键字参数
Returns:
tuple: (adapted_args, adapted_kwargs)
"""
# 如果接受可变位置参数,直接返回所有参数
if sig_info["accepts_var_positional"]:
adapted_args = args
else:
# 只传递函数能接受的参数数量
max_pos = sig_info["max_positional"]
adapted_args = args[:max_pos] if max_pos != float("inf") else args
# 如果接受可变关键字参数,直接返回所有 kwargs
if sig_info["accepts_var_keyword"]:
adapted_kwargs = kwargs
else:
# 这里可以进一步过滤 kwargs但通常不需要
# 因为多余的 kwargs 会在调用时报错
adapted_kwargs = kwargs
return adapted_args, adapted_kwargs
def _auto_merge_orm_objects(self, items):
"""
自动将 ORM 对象 merge 到当前线程的 session
Args:
items: 参数列表或字典值
Returns:
处理后的参数ORM 对象已 merge
"""
from sqlalchemy import inspect as sa_inspect
from sqlalchemy.orm.exc import UnmappedInstanceError
from werkzeug.local import LocalProxy
result = []
for item in items:
# 跳过 LocalProxy 对象(如 current_user避免触发 JWT 上下文检查
if isinstance(item, LocalProxy):
result.append(item)
continue
# 检测是否是 SQLAlchemy ORM 对象
if hasattr(item, "__table__"):
try:
from iti.applications.extensions import db
# 检查对象状态
state = sa_inspect(item)
# 如果对象是 persistent 或 detachedmerge 到新 session
if state.persistent or state.detached:
merged = db.session.merge(item, load=False)
result.append(merged)
else:
# transient 或其他状态,直接传递
result.append(item)
except (UnmappedInstanceError, Exception) as e:
logger.warning(f"merge ORM 对象失败: {e},使用原对象")
result.append(item)
elif isinstance(item, (list, tuple)):
# 递归处理集合
merged_items = self._auto_merge_orm_objects(item)
result.append(type(item)(merged_items))
elif isinstance(item, dict):
# 递归处理字典
result.append(
{k: self._auto_merge_orm_objects([v])[0] for k, v in item.items()}
)
else:
# 基础类型,直接传递
result.append(item)
return result
def _wrap_context(self, func, async_mode: bool = False):
"""
包装上下文
注意:
- sync_mode: 在当前上下文中同步执行,需要 app_context 包装
- async_mode: 在线程池中异步执行app_context 在 _run_async_handler 中统一处理
"""
if async_mode:
# async 模式:不包装 app_context只转换为 async 函数(用于 asyncio.run
if asyncio.iscoroutinefunction(func):
# 已经是 async 函数
async def async_wrapper(*args, **kwargs):
return await func(*args, **kwargs)
else:
# 普通函数,包装为 async不 await
async def async_wrapper(*args, **kwargs):
return func(*args, **kwargs)
return async_wrapper
else:
# sync 模式:包装 app_context
@self._with_context
def sync_wrapper(*args, **kwargs):
return func(*args, **kwargs)
return sync_wrapper
def _with_context(self, func):
"""
包装同步上下文
"""
@wraps(func)
def wrapper(*args, **kwargs):
with self._app.app_context():
return func(*args, **kwargs)
return wrapper
def emit(self, event_name: str, *args, **kwargs):
"""
发布事件
"""
try:
with self._lock:
self._stats["events_emitted"] += 1
if event_name not in self._handlers:
return
# 执行中间件
processed_args = args
processed_kwargs = kwargs
for middleware in self._middlewares:
try:
processed_args, processed_kwargs = middleware(
event_name, processed_args, processed_kwargs
)
except Exception as e:
logger.error(
f"事件中间件错误: {e}, 事件: {event_name}, args: {processed_args}, kwargs: {processed_kwargs}",
exc_info=True,
)
# 分离同步、异步事件,分别发布
sync_handlers = [
h for h in self._handlers[event_name] if not h["async_mode"]
]
async_handlers = [h for h in self._handlers[event_name] if h["async_mode"]]
# 先发布异步事件
if async_handlers:
for handler in async_handlers:
try:
# 根据函数签名适配参数
adapted_args, adapted_kwargs = self._adapt_args(
handler["sig_info"], processed_args, processed_kwargs
)
self._executor.submit(
self._run_async_handler,
handler["func"],
adapted_args,
adapted_kwargs,
)
logger.info(
f"异步事件处理器: {handler['name']} 已执行, args: {adapted_args}, kwargs: {adapted_kwargs}"
)
except Exception as e:
logger.error(
f"异步事件处理器错误: {e}, 事件: {event_name}, args: {processed_args}, kwargs: {processed_kwargs}",
exc_info=True,
)
# 再发布同步事件
if sync_handlers:
for handler in sync_handlers:
try:
# 根据函数签名适配参数
adapted_args, adapted_kwargs = self._adapt_args(
handler["sig_info"], processed_args, processed_kwargs
)
handler["func"](*adapted_args, **adapted_kwargs)
except Exception as e:
logger.error(
f"同步事件处理器错误: {e}, 事件: {event_name}, args: {processed_args}, kwargs: {processed_kwargs}",
exc_info=True,
)
except Exception as e:
with self._lock:
self._stats["errors"] += 1
logger.error(
f"事件发布失败: {e}, 事件: {event_name}, args: {args}, kwargs: {kwargs}",
exc_info=True,
)
def _run_async_handler(self, func, args, kwargs):
"""
在线程池中运行异步处理器
自动将 ORM 对象 merge 到当前线程的 session
"""
try:
with self._app.app_context():
# 自动 merge ORM 对象到当前线程的 session
merged_args = self._auto_merge_orm_objects(args)
merged_kwargs = {
k: self._auto_merge_orm_objects([v])[0] for k, v in kwargs.items()
}
asyncio.run(func(*merged_args, **merged_kwargs))
except Exception as e:
logger.error(
f"异步事件处理器错误: {e}, args: {args}, kwargs: {kwargs}",
exc_info=True,
)
def get_stats(self):
"""
获取统计信息
"""
with self._lock:
return self._stats
def get_handlers(self, event_name: str):
"""
获取事件处理器
"""
return self._handlers[event_name]
def clear_handlers(self, event_name: str):
"""
清除指定事件的所有处理器
"""
with self._lock:
if event_name in self._handlers:
del self._handlers[event_name]
def remove_handler(self, event_name: str, handler_name: str):
"""
移除指定名称的处理器
Args:
event_name: 事件名称
handler_name: 处理器函数名称
"""
with self._lock:
if event_name in self._handlers:
self._handlers[event_name] = [
handler
for handler in self._handlers[event_name]
if handler["name"] != handler_name
]
# 如果该事件没有处理器了,删除事件
if not self._handlers[event_name]:
del self._handlers[event_name]
def get_handler_count(self, event_name: str = None):
"""
获取处理器数量
Args:
event_name: 事件名称如果为None则返回所有事件的处理器总数
Returns:
int: 处理器数量
"""
with self._lock:
if event_name:
return len(self._handlers.get(event_name, []))
else:
return sum(len(handlers) for handlers in self._handlers.values())
def list_handlers(self, event_name: str = None):
"""
列出所有处理器信息
Args:
event_name: 事件名称如果为None则列出所有事件
Returns:
dict: 处理器信息字典
"""
with self._lock:
if event_name:
return {
event_name: [
{
"name": handler["name"],
"order": handler["order"],
"async_mode": handler["async_mode"],
"max_positional": handler["sig_info"]["max_positional"],
"accepts_var_args": handler["sig_info"][
"accepts_var_positional"
],
}
for handler in self._handlers.get(event_name, [])
]
}
else:
return {
event: [
{
"name": handler["name"],
"order": handler["order"],
"async_mode": handler["async_mode"],
"max_positional": handler["sig_info"]["max_positional"],
"accepts_var_args": handler["sig_info"][
"accepts_var_positional"
],
}
for handler in handlers
]
for event, handlers in self._handlers.items()
}
def middleware(self, middleware: EventMiddleware):
"""
注册中间件
"""
self._middlewares.append(middleware)
return middleware
def shutdown(self):
"""
关闭事件总线
"""
self._executor.shutdown(wait=False)