from concurrent.futures import ThreadPoolExecutor import threading import asyncio import inspect from functools import wraps from .event_middleware import EventMiddleware from iti.applications.common import setup_logger logger = setup_logger(__name__) class EventBus: """ 事件总线 """ def __init__(self, max_workers: int = 10): self._handlers: dict[str, list[dict]] = {} self._middlewares: list[EventMiddleware] = [] self._executor = ThreadPoolExecutor( thread_name_prefix="EventBusExecutor", max_workers=max_workers ) self._lock = threading.Lock() self._stats = { "errors": 0, "events_emitted": 0, "events_processed": 0, } def init_app(self, app): """ 初始化事件总线 """ self._app = app def on(self, event_name: str, order: int = 0, async_mode: bool = False): """ 注册事件处理器(装饰器形式) """ def decorator(func): self._register_handler(event_name, func, order, async_mode) return func return decorator def register_handler( self, event_name: str, handler_func, order: int = 0, async_mode: bool = False ): """ 手动注册事件处理器 Args: event_name: 事件名称 handler_func: 处理器函数 order: 执行顺序(数字越小越先执行) async_mode: 是否为异步模式 """ self._register_handler(event_name, handler_func, order, async_mode) def _register_handler( self, event_name: str, handler_func, order: int = 0, async_mode: bool = False ): """ 内部方法:注册事件处理器 """ with self._lock: if event_name not in self._handlers: self._handlers[event_name] = [] # 分析函数签名并缓存 sig_info = self._analyze_signature(handler_func) # 包装上下文 wrapped_func = self._wrap_context(handler_func, async_mode) # 添加处理器 self._handlers[event_name].append( { "func": wrapped_func, "orginal_func": handler_func, "order": order, "async_mode": async_mode, "name": handler_func.__name__, "sig_info": sig_info, # 缓存签名信息 } ) # 排序(按order升序) self._handlers[event_name].sort(key=lambda x: x["order"]) def _analyze_signature(self, func): """ 分析函数签名,返回参数接受信息 Returns: dict: 包含以下信息: - max_positional: 最大位置参数数量(不含self) - accepts_var_positional: 是否接受*args - accepts_var_keyword: 是否接受**kwargs """ try: sig = inspect.signature(func) params = list(sig.parameters.values()) # 检查是否接受可变参数 accepts_var_positional = any( p.kind == inspect.Parameter.VAR_POSITIONAL for p in params ) accepts_var_keyword = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in params ) # 计算最大位置参数数量(排除 self/cls) positional_params = [ p for p in params if p.kind in ( inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, ) and p.name not in ("self", "cls") ] max_positional = len(positional_params) return { "max_positional": max_positional, "accepts_var_positional": accepts_var_positional, "accepts_var_keyword": accepts_var_keyword, } except Exception as e: logger.warning(f"无法分析函数签名 {func.__name__}: {e},将使用默认行为") # 如果无法分析签名,返回保守的默认值(接受所有参数) return { "max_positional": float("inf"), "accepts_var_positional": True, "accepts_var_keyword": True, } def _adapt_args(self, sig_info, args, kwargs): """ 根据函数签名信息适配参数 Args: sig_info: 函数签名信息 args: 原始位置参数 kwargs: 原始关键字参数 Returns: tuple: (adapted_args, adapted_kwargs) """ # 如果接受可变位置参数,直接返回所有参数 if sig_info["accepts_var_positional"]: adapted_args = args else: # 只传递函数能接受的参数数量 max_pos = sig_info["max_positional"] adapted_args = args[:max_pos] if max_pos != float("inf") else args # 如果接受可变关键字参数,直接返回所有 kwargs if sig_info["accepts_var_keyword"]: adapted_kwargs = kwargs else: # 这里可以进一步过滤 kwargs,但通常不需要 # 因为多余的 kwargs 会在调用时报错 adapted_kwargs = kwargs return adapted_args, adapted_kwargs def _auto_merge_orm_objects(self, items): """ 自动将 ORM 对象 merge 到当前线程的 session Args: items: 参数列表或字典值 Returns: 处理后的参数(ORM 对象已 merge) """ from sqlalchemy import inspect as sa_inspect from sqlalchemy.orm.exc import UnmappedInstanceError result = [] for item in items: # 检测是否是 SQLAlchemy ORM 对象 if hasattr(item, "__table__"): try: from iti.applications.extensions import db # 检查对象状态 state = sa_inspect(item) # 如果对象是 persistent 或 detached,merge 到新 session if state.persistent or state.detached: merged = db.session.merge(item, load=False) result.append(merged) else: # transient 或其他状态,直接传递 result.append(item) except (UnmappedInstanceError, Exception) as e: logger.warning(f"merge ORM 对象失败: {e},使用原对象") result.append(item) elif isinstance(item, (list, tuple)): # 递归处理集合 merged_items = self._auto_merge_orm_objects(item) result.append(type(item)(merged_items)) elif isinstance(item, dict): # 递归处理字典 result.append( {k: self._auto_merge_orm_objects([v])[0] for k, v in item.items()} ) else: # 基础类型,直接传递 result.append(item) return result def _wrap_context(self, func, async_mode: bool = False): """ 包装上下文 注意: - sync_mode: 在当前上下文中同步执行,需要 app_context 包装 - async_mode: 在线程池中异步执行,app_context 在 _run_async_handler 中统一处理 """ if async_mode: # async 模式:不包装 app_context,只转换为 async 函数(用于 asyncio.run) if asyncio.iscoroutinefunction(func): # 已经是 async 函数 async def async_wrapper(*args, **kwargs): return await func(*args, **kwargs) else: # 普通函数,包装为 async(不 await) async def async_wrapper(*args, **kwargs): return func(*args, **kwargs) return async_wrapper else: # sync 模式:包装 app_context @self._with_context def sync_wrapper(*args, **kwargs): return func(*args, **kwargs) return sync_wrapper def _with_context(self, func): """ 包装同步上下文 """ @wraps(func) def wrapper(*args, **kwargs): with self._app.app_context(): return func(*args, **kwargs) return wrapper def emit(self, event_name: str, *args, **kwargs): """ 发布事件 """ try: with self._lock: self._stats["events_emitted"] += 1 if event_name not in self._handlers: return # 执行中间件 processed_args = args processed_kwargs = kwargs for middleware in self._middlewares: try: processed_args, processed_kwargs = middleware( event_name, processed_args, processed_kwargs ) except Exception as e: logger.error( f"事件中间件错误: {e}, 事件: {event_name}, args: {processed_args}, kwargs: {processed_kwargs}", exc_info=True, ) # 分离同步、异步事件,分别发布 sync_handlers = [ h for h in self._handlers[event_name] if not h["async_mode"] ] async_handlers = [h for h in self._handlers[event_name] if h["async_mode"]] # 先发布异步事件 if async_handlers: for handler in async_handlers: try: # 根据函数签名适配参数 adapted_args, adapted_kwargs = self._adapt_args( handler["sig_info"], processed_args, processed_kwargs ) self._executor.submit( self._run_async_handler, handler["func"], adapted_args, adapted_kwargs, ) logger.info( f"异步事件处理器: {handler['name']} 已执行, args: {adapted_args}, kwargs: {adapted_kwargs}" ) except Exception as e: logger.error( f"异步事件处理器错误: {e}, 事件: {event_name}, args: {processed_args}, kwargs: {processed_kwargs}", exc_info=True, ) # 再发布同步事件 if sync_handlers: for handler in sync_handlers: try: # 根据函数签名适配参数 adapted_args, adapted_kwargs = self._adapt_args( handler["sig_info"], processed_args, processed_kwargs ) handler["func"](*adapted_args, **adapted_kwargs) except Exception as e: logger.error( f"同步事件处理器错误: {e}, 事件: {event_name}, args: {processed_args}, kwargs: {processed_kwargs}", exc_info=True, ) except Exception as e: with self._lock: self._stats["errors"] += 1 logger.error( f"事件发布失败: {e}, 事件: {event_name}, args: {args}, kwargs: {kwargs}", exc_info=True, ) def _run_async_handler(self, func, args, kwargs): """ 在线程池中运行异步处理器 自动将 ORM 对象 merge 到当前线程的 session """ try: with self._app.app_context(): # 自动 merge ORM 对象到当前线程的 session merged_args = self._auto_merge_orm_objects(args) merged_kwargs = { k: self._auto_merge_orm_objects([v])[0] for k, v in kwargs.items() } asyncio.run(func(*merged_args, **merged_kwargs)) except Exception as e: logger.error( f"异步事件处理器错误: {e}, args: {args}, kwargs: {kwargs}", exc_info=True, ) def get_stats(self): """ 获取统计信息 """ with self._lock: return self._stats def get_handlers(self, event_name: str): """ 获取事件处理器 """ return self._handlers[event_name] def clear_handlers(self, event_name: str): """ 清除指定事件的所有处理器 """ with self._lock: if event_name in self._handlers: del self._handlers[event_name] def remove_handler(self, event_name: str, handler_name: str): """ 移除指定名称的处理器 Args: event_name: 事件名称 handler_name: 处理器函数名称 """ with self._lock: if event_name in self._handlers: self._handlers[event_name] = [ handler for handler in self._handlers[event_name] if handler["name"] != handler_name ] # 如果该事件没有处理器了,删除事件 if not self._handlers[event_name]: del self._handlers[event_name] def get_handler_count(self, event_name: str = None): """ 获取处理器数量 Args: event_name: 事件名称,如果为None则返回所有事件的处理器总数 Returns: int: 处理器数量 """ with self._lock: if event_name: return len(self._handlers.get(event_name, [])) else: return sum(len(handlers) for handlers in self._handlers.values()) def list_handlers(self, event_name: str = None): """ 列出所有处理器信息 Args: event_name: 事件名称,如果为None则列出所有事件 Returns: dict: 处理器信息字典 """ with self._lock: if event_name: return { event_name: [ { "name": handler["name"], "order": handler["order"], "async_mode": handler["async_mode"], "max_positional": handler["sig_info"]["max_positional"], "accepts_var_args": handler["sig_info"][ "accepts_var_positional" ], } for handler in self._handlers.get(event_name, []) ] } else: return { event: [ { "name": handler["name"], "order": handler["order"], "async_mode": handler["async_mode"], "max_positional": handler["sig_info"]["max_positional"], "accepts_var_args": handler["sig_info"][ "accepts_var_positional" ], } for handler in handlers ] for event, handlers in self._handlers.items() } def middleware(self, middleware: EventMiddleware): """ 注册中间件 """ self._middlewares.append(middleware) return middleware def shutdown(self): """ 关闭事件总线 """ self._executor.shutdown(wait=False)