from typing import List, Dict, Any, Optional from dataclasses import dataclass @dataclass(frozen=True) class TreeKeyConfig: """集中管理树结构相关的键名。 id_key: 节点ID字段名 parent_key: 父节点字段名 children_key: 子节点集合字段名 """ id_key: str = "id" parent_key: str = "parent_id" children_key: str = "children" # 兼容旧命名(保留但内部使用 TreeKeyConfig) default_key_config = TreeKeyConfig() # 统一数据访问函数 def _get_value(item, key): return item.get(key) if isinstance(item, dict) else getattr(item, key, None) def _set_value(item, key, value): if isinstance(item, dict): item[key] = value else: setattr(item, key, value) def build_tree_from_list( data_list: list, key_config: TreeKeyConfig = default_key_config, ): """ 从对象列表构建树结构(优化版 - 两遍遍历算法) 修复了原算法的问题: - 原算法依赖数据顺序,当子节点在父节点之前时无法正确建立关系 - 优化后使用两遍遍历,确保所有节点都能正确建立父子关系 Args: data_list: ORM对象列表或字典列表 key_config: 键配置,指定父级ID和子级字段名 Returns: list: 树形结构数据 """ if not data_list: return [] # 第一遍:构建所有节点的映射并初始化children字段 data_map = {} for item in data_list: item_id = _get_value(item, key_config.id_key) _set_value(item, key_config.children_key, []) data_map[item_id] = item # 第二遍:建立父子关系 tree = [] for item in data_list: parent_id = _get_value(item, key_config.parent_key) if parent_id is None: # 根节点 tree.append(item) else: # 子节点,查找父节点 parent = data_map.get(parent_id) if parent: children = _get_value(parent, key_config.children_key) children.append(item) # 注意:如果找不到父节点,说明数据有问题,这里选择忽略 # 可以根据需要添加日志记录或异常处理 return tree def flatten_tree( tree_list: List[Dict[str, Any]], *, key_config: TreeKeyConfig = default_key_config, include_children: bool = False, ) -> List[Dict[str, Any]]: """ 将树形结构扁平化为列表 Args: tree_list: 树形结构数据 children_key: 子节点字段名 include_children: 是否包含children字段 Returns: list: 扁平化后的列表 """ result = [] def flatten_node(node): # 创建节点副本 node_copy = node.copy() # 如果不包含children字段,则移除 if not include_children and key_config.children_key in node_copy: del node_copy[key_config.children_key] result.append(node_copy) # 递归处理子节点 if key_config.children_key in node and node[key_config.children_key]: for child in node[key_config.children_key]: flatten_node(child) for node in tree_list: flatten_node(node) return result def find_node_by_id( tree_list: List[Dict[str, Any]], target_id: str, *, key_config: TreeKeyConfig = default_key_config, ) -> Optional[Dict[str, Any]]: """ 在树形结构中根据ID查找节点 Args: tree_list: 树形结构数据 target_id: 目标ID id_key: ID字段名 children_key: 子节点字段名 Returns: dict: 找到的节点,未找到返回None """ def search_node(node): if node.get(key_config.id_key) == target_id: return node if key_config.children_key in node and node[key_config.children_key]: for child in node[key_config.children_key]: result = search_node(child) if result: return result return None for node in tree_list: result = search_node(node) if result: return result return None def get_node_path( tree_list: List[Dict[str, Any]], target_id: str, *, key_config: TreeKeyConfig = default_key_config, ) -> List[Dict[str, Any]]: """ 获取从根节点到目标节点的路径 Args: tree_list: 树形结构数据 target_id: 目标ID id_key: ID字段名 children_key: 子节点字段名 Returns: list: 路径节点列表(从根到目标) """ def find_path(node, path): current_path = path + [node] if node.get(key_config.id_key) == target_id: return current_path if key_config.children_key in node and node[key_config.children_key]: for child in node[key_config.children_key]: result = find_path(child, current_path) if result: return result return None for node in tree_list: result = find_path(node, []) if result: return result return [] def filter_tree_by_condition( tree_list: List[Dict[str, Any]], condition_func: callable, *, key_config: TreeKeyConfig = default_key_config, ) -> List[Dict[str, Any]]: """ 根据条件过滤树形结构 Args: tree_list: 树形结构数据 condition_func: 过滤条件函数,接收节点参数,返回bool children_key: 子节点字段名 Returns: list: 过滤后的树形结构 """ def filter_node(node): # 检查当前节点是否满足条件 if condition_func(node): # 创建节点副本 filtered_node = node.copy() # 递归过滤子节点 if key_config.children_key in node and node[key_config.children_key]: filtered_children = [] for child in node[key_config.children_key]: filtered_child = filter_node(child) if filtered_child: filtered_children.append(filtered_child) if filtered_children: filtered_node[key_config.children_key] = filtered_children else: filtered_node[key_config.children_key] = [] return filtered_node # 如果当前节点不满足条件,检查子节点 if key_config.children_key in node and node[key_config.children_key]: filtered_children = [] for child in node[key_config.children_key]: filtered_child = filter_node(child) if filtered_child: filtered_children.append(filtered_child) if filtered_children: # 如果子节点有满足条件的,创建父节点 parent_node = node.copy() parent_node[key_config.children_key] = filtered_children return parent_node return None result = [] for node in tree_list: filtered_node = filter_node(node) if filtered_node: result.append(filtered_node) return result def get_tree_depth( tree_list: List[Dict[str, Any]], *, key_config: TreeKeyConfig = default_key_config, ) -> int: """ 获取树的最大深度 Args: tree_list: 树形结构数据 children_key: 子节点字段名 Returns: int: 树的最大深度 """ def get_node_depth(node, current_depth): max_depth = current_depth if key_config.children_key in node and node[key_config.children_key]: for child in node[key_config.children_key]: child_depth = get_node_depth(child, current_depth + 1) max_depth = max(max_depth, child_depth) return max_depth max_depth = 0 for node in tree_list: depth = get_node_depth(node, 1) max_depth = max(max_depth, depth) return max_depth