|
|
from typing import List, Dict, Any, Optional
|
|
|
from dataclasses import dataclass
|
|
|
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
|
class TreeKeyConfig:
|
|
|
"""集中管理树结构相关的键名。
|
|
|
|
|
|
id_key: 节点ID字段名
|
|
|
parent_key: 父节点字段名
|
|
|
children_key: 子节点集合字段名
|
|
|
"""
|
|
|
|
|
|
id_key: str = "id"
|
|
|
parent_key: str = "parent_id"
|
|
|
children_key: str = "children"
|
|
|
|
|
|
|
|
|
# 兼容旧命名(保留但内部使用 TreeKeyConfig)
|
|
|
default_key_config = TreeKeyConfig()
|
|
|
|
|
|
|
|
|
# 统一数据访问函数
|
|
|
def _get_value(item, key):
|
|
|
return item.get(key) if isinstance(item, dict) else getattr(item, key, None)
|
|
|
|
|
|
|
|
|
def _set_value(item, key, value):
|
|
|
if isinstance(item, dict):
|
|
|
item[key] = value
|
|
|
else:
|
|
|
setattr(item, key, value)
|
|
|
|
|
|
|
|
|
def build_tree_from_list(
|
|
|
data_list: list,
|
|
|
key_config: TreeKeyConfig = default_key_config,
|
|
|
):
|
|
|
"""
|
|
|
从对象列表构建树结构(优化版 - 两遍遍历算法)
|
|
|
|
|
|
修复了原算法的问题:
|
|
|
- 原算法依赖数据顺序,当子节点在父节点之前时无法正确建立关系
|
|
|
- 优化后使用两遍遍历,确保所有节点都能正确建立父子关系
|
|
|
|
|
|
Args:
|
|
|
data_list: ORM对象列表或字典列表
|
|
|
key_config: 键配置,指定父级ID和子级字段名
|
|
|
|
|
|
Returns:
|
|
|
list: 树形结构数据
|
|
|
"""
|
|
|
|
|
|
if not data_list:
|
|
|
return []
|
|
|
|
|
|
# 第一遍:构建所有节点的映射并初始化children字段
|
|
|
data_map = {}
|
|
|
for item in data_list:
|
|
|
item_id = _get_value(item, key_config.id_key)
|
|
|
_set_value(item, key_config.children_key, [])
|
|
|
data_map[item_id] = item
|
|
|
|
|
|
# 第二遍:建立父子关系
|
|
|
tree = []
|
|
|
for item in data_list:
|
|
|
parent_id = _get_value(item, key_config.parent_key)
|
|
|
|
|
|
if parent_id is None:
|
|
|
# 根节点
|
|
|
tree.append(item)
|
|
|
else:
|
|
|
# 子节点,查找父节点
|
|
|
parent = data_map.get(parent_id)
|
|
|
if parent:
|
|
|
children = _get_value(parent, key_config.children_key)
|
|
|
children.append(item)
|
|
|
# 注意:如果找不到父节点,说明数据有问题,这里选择忽略
|
|
|
# 可以根据需要添加日志记录或异常处理
|
|
|
|
|
|
return tree
|
|
|
|
|
|
|
|
|
def flatten_tree(
|
|
|
tree_list: List[Dict[str, Any]],
|
|
|
*,
|
|
|
key_config: TreeKeyConfig = default_key_config,
|
|
|
include_children: bool = False,
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
将树形结构扁平化为列表
|
|
|
|
|
|
Args:
|
|
|
tree_list: 树形结构数据
|
|
|
children_key: 子节点字段名
|
|
|
include_children: 是否包含children字段
|
|
|
|
|
|
Returns:
|
|
|
list: 扁平化后的列表
|
|
|
"""
|
|
|
result = []
|
|
|
|
|
|
def flatten_node(node):
|
|
|
# 创建节点副本
|
|
|
node_copy = node.copy()
|
|
|
|
|
|
# 如果不包含children字段,则移除
|
|
|
if not include_children and key_config.children_key in node_copy:
|
|
|
del node_copy[key_config.children_key]
|
|
|
|
|
|
result.append(node_copy)
|
|
|
|
|
|
# 递归处理子节点
|
|
|
if key_config.children_key in node and node[key_config.children_key]:
|
|
|
for child in node[key_config.children_key]:
|
|
|
flatten_node(child)
|
|
|
|
|
|
for node in tree_list:
|
|
|
flatten_node(node)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
def find_node_by_id(
|
|
|
tree_list: List[Dict[str, Any]],
|
|
|
target_id: str,
|
|
|
*,
|
|
|
key_config: TreeKeyConfig = default_key_config,
|
|
|
) -> Optional[Dict[str, Any]]:
|
|
|
"""
|
|
|
在树形结构中根据ID查找节点
|
|
|
|
|
|
Args:
|
|
|
tree_list: 树形结构数据
|
|
|
target_id: 目标ID
|
|
|
id_key: ID字段名
|
|
|
children_key: 子节点字段名
|
|
|
|
|
|
Returns:
|
|
|
dict: 找到的节点,未找到返回None
|
|
|
"""
|
|
|
|
|
|
def search_node(node):
|
|
|
if node.get(key_config.id_key) == target_id:
|
|
|
return node
|
|
|
|
|
|
if key_config.children_key in node and node[key_config.children_key]:
|
|
|
for child in node[key_config.children_key]:
|
|
|
result = search_node(child)
|
|
|
if result:
|
|
|
return result
|
|
|
|
|
|
return None
|
|
|
|
|
|
for node in tree_list:
|
|
|
result = search_node(node)
|
|
|
if result:
|
|
|
return result
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
def get_node_path(
|
|
|
tree_list: List[Dict[str, Any]],
|
|
|
target_id: str,
|
|
|
*,
|
|
|
key_config: TreeKeyConfig = default_key_config,
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
获取从根节点到目标节点的路径
|
|
|
|
|
|
Args:
|
|
|
tree_list: 树形结构数据
|
|
|
target_id: 目标ID
|
|
|
id_key: ID字段名
|
|
|
children_key: 子节点字段名
|
|
|
|
|
|
Returns:
|
|
|
list: 路径节点列表(从根到目标)
|
|
|
"""
|
|
|
|
|
|
def find_path(node, path):
|
|
|
current_path = path + [node]
|
|
|
|
|
|
if node.get(key_config.id_key) == target_id:
|
|
|
return current_path
|
|
|
|
|
|
if key_config.children_key in node and node[key_config.children_key]:
|
|
|
for child in node[key_config.children_key]:
|
|
|
result = find_path(child, current_path)
|
|
|
if result:
|
|
|
return result
|
|
|
|
|
|
return None
|
|
|
|
|
|
for node in tree_list:
|
|
|
result = find_path(node, [])
|
|
|
if result:
|
|
|
return result
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
def filter_tree_by_condition(
|
|
|
tree_list: List[Dict[str, Any]],
|
|
|
condition_func: callable,
|
|
|
*,
|
|
|
key_config: TreeKeyConfig = default_key_config,
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
根据条件过滤树形结构
|
|
|
|
|
|
Args:
|
|
|
tree_list: 树形结构数据
|
|
|
condition_func: 过滤条件函数,接收节点参数,返回bool
|
|
|
children_key: 子节点字段名
|
|
|
|
|
|
Returns:
|
|
|
list: 过滤后的树形结构
|
|
|
"""
|
|
|
|
|
|
def filter_node(node):
|
|
|
# 检查当前节点是否满足条件
|
|
|
if condition_func(node):
|
|
|
# 创建节点副本
|
|
|
filtered_node = node.copy()
|
|
|
|
|
|
# 递归过滤子节点
|
|
|
if key_config.children_key in node and node[key_config.children_key]:
|
|
|
filtered_children = []
|
|
|
for child in node[key_config.children_key]:
|
|
|
filtered_child = filter_node(child)
|
|
|
if filtered_child:
|
|
|
filtered_children.append(filtered_child)
|
|
|
|
|
|
if filtered_children:
|
|
|
filtered_node[key_config.children_key] = filtered_children
|
|
|
else:
|
|
|
filtered_node[key_config.children_key] = []
|
|
|
|
|
|
return filtered_node
|
|
|
|
|
|
# 如果当前节点不满足条件,检查子节点
|
|
|
if key_config.children_key in node and node[key_config.children_key]:
|
|
|
filtered_children = []
|
|
|
for child in node[key_config.children_key]:
|
|
|
filtered_child = filter_node(child)
|
|
|
if filtered_child:
|
|
|
filtered_children.append(filtered_child)
|
|
|
|
|
|
if filtered_children:
|
|
|
# 如果子节点有满足条件的,创建父节点
|
|
|
parent_node = node.copy()
|
|
|
parent_node[key_config.children_key] = filtered_children
|
|
|
return parent_node
|
|
|
|
|
|
return None
|
|
|
|
|
|
result = []
|
|
|
for node in tree_list:
|
|
|
filtered_node = filter_node(node)
|
|
|
if filtered_node:
|
|
|
result.append(filtered_node)
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
def get_tree_depth(
|
|
|
tree_list: List[Dict[str, Any]],
|
|
|
*,
|
|
|
key_config: TreeKeyConfig = default_key_config,
|
|
|
) -> int:
|
|
|
"""
|
|
|
获取树的最大深度
|
|
|
|
|
|
Args:
|
|
|
tree_list: 树形结构数据
|
|
|
children_key: 子节点字段名
|
|
|
|
|
|
Returns:
|
|
|
int: 树的最大深度
|
|
|
"""
|
|
|
|
|
|
def get_node_depth(node, current_depth):
|
|
|
max_depth = current_depth
|
|
|
|
|
|
if key_config.children_key in node and node[key_config.children_key]:
|
|
|
for child in node[key_config.children_key]:
|
|
|
child_depth = get_node_depth(child, current_depth + 1)
|
|
|
max_depth = max(max_depth, child_depth)
|
|
|
|
|
|
return max_depth
|
|
|
|
|
|
max_depth = 0
|
|
|
for node in tree_list:
|
|
|
depth = get_node_depth(node, 1)
|
|
|
max_depth = max(max_depth, depth)
|
|
|
|
|
|
return max_depth
|