You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
iTi-Flask/iti/common/tree.py

300 lines
7.9 KiB
Python

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

from typing import List, Dict, Any, Optional
from dataclasses import dataclass
@dataclass(frozen=True)
class TreeKeyConfig:
"""集中管理树结构相关的键名。
id_key: 节点ID字段名
parent_key: 父节点字段名
children_key: 子节点集合字段名
"""
id_key: str = "id"
parent_key: str = "parent_id"
children_key: str = "children"
# 兼容旧命名(保留但内部使用 TreeKeyConfig
default_key_config = TreeKeyConfig()
# 统一数据访问函数
def _get_value(item, key):
return item.get(key) if isinstance(item, dict) else getattr(item, key, None)
def _set_value(item, key, value):
if isinstance(item, dict):
item[key] = value
else:
setattr(item, key, value)
def build_tree_from_list(
data_list: list,
key_config: TreeKeyConfig = default_key_config,
):
"""
从对象列表构建树结构(优化版 - 两遍遍历算法)
修复了原算法的问题:
- 原算法依赖数据顺序,当子节点在父节点之前时无法正确建立关系
- 优化后使用两遍遍历,确保所有节点都能正确建立父子关系
Args:
data_list: ORM对象列表或字典列表
key_config: 键配置指定父级ID和子级字段名
Returns:
list: 树形结构数据
"""
if not data_list:
return []
# 第一遍构建所有节点的映射并初始化children字段
data_map = {}
for item in data_list:
item_id = _get_value(item, key_config.id_key)
_set_value(item, key_config.children_key, [])
data_map[item_id] = item
# 第二遍:建立父子关系
tree = []
for item in data_list:
parent_id = _get_value(item, key_config.parent_key)
if parent_id is None:
# 根节点
tree.append(item)
else:
# 子节点,查找父节点
parent = data_map.get(parent_id)
if parent:
children = _get_value(parent, key_config.children_key)
children.append(item)
# 注意:如果找不到父节点,说明数据有问题,这里选择忽略
# 可以根据需要添加日志记录或异常处理
return tree
def flatten_tree(
tree_list: List[Dict[str, Any]],
*,
key_config: TreeKeyConfig = default_key_config,
include_children: bool = False,
) -> List[Dict[str, Any]]:
"""
将树形结构扁平化为列表
Args:
tree_list: 树形结构数据
children_key: 子节点字段名
include_children: 是否包含children字段
Returns:
list: 扁平化后的列表
"""
result = []
def flatten_node(node):
# 创建节点副本
node_copy = node.copy()
# 如果不包含children字段则移除
if not include_children and key_config.children_key in node_copy:
del node_copy[key_config.children_key]
result.append(node_copy)
# 递归处理子节点
if key_config.children_key in node and node[key_config.children_key]:
for child in node[key_config.children_key]:
flatten_node(child)
for node in tree_list:
flatten_node(node)
return result
def find_node_by_id(
tree_list: List[Dict[str, Any]],
target_id: str,
*,
key_config: TreeKeyConfig = default_key_config,
) -> Optional[Dict[str, Any]]:
"""
在树形结构中根据ID查找节点
Args:
tree_list: 树形结构数据
target_id: 目标ID
id_key: ID字段名
children_key: 子节点字段名
Returns:
dict: 找到的节点未找到返回None
"""
def search_node(node):
if node.get(key_config.id_key) == target_id:
return node
if key_config.children_key in node and node[key_config.children_key]:
for child in node[key_config.children_key]:
result = search_node(child)
if result:
return result
return None
for node in tree_list:
result = search_node(node)
if result:
return result
return None
def get_node_path(
tree_list: List[Dict[str, Any]],
target_id: str,
*,
key_config: TreeKeyConfig = default_key_config,
) -> List[Dict[str, Any]]:
"""
获取从根节点到目标节点的路径
Args:
tree_list: 树形结构数据
target_id: 目标ID
id_key: ID字段名
children_key: 子节点字段名
Returns:
list: 路径节点列表(从根到目标)
"""
def find_path(node, path):
current_path = path + [node]
if node.get(key_config.id_key) == target_id:
return current_path
if key_config.children_key in node and node[key_config.children_key]:
for child in node[key_config.children_key]:
result = find_path(child, current_path)
if result:
return result
return None
for node in tree_list:
result = find_path(node, [])
if result:
return result
return []
def filter_tree_by_condition(
tree_list: List[Dict[str, Any]],
condition_func: callable,
*,
key_config: TreeKeyConfig = default_key_config,
) -> List[Dict[str, Any]]:
"""
根据条件过滤树形结构
Args:
tree_list: 树形结构数据
condition_func: 过滤条件函数接收节点参数返回bool
children_key: 子节点字段名
Returns:
list: 过滤后的树形结构
"""
def filter_node(node):
# 检查当前节点是否满足条件
if condition_func(node):
# 创建节点副本
filtered_node = node.copy()
# 递归过滤子节点
if key_config.children_key in node and node[key_config.children_key]:
filtered_children = []
for child in node[key_config.children_key]:
filtered_child = filter_node(child)
if filtered_child:
filtered_children.append(filtered_child)
if filtered_children:
filtered_node[key_config.children_key] = filtered_children
else:
filtered_node[key_config.children_key] = []
return filtered_node
# 如果当前节点不满足条件,检查子节点
if key_config.children_key in node and node[key_config.children_key]:
filtered_children = []
for child in node[key_config.children_key]:
filtered_child = filter_node(child)
if filtered_child:
filtered_children.append(filtered_child)
if filtered_children:
# 如果子节点有满足条件的,创建父节点
parent_node = node.copy()
parent_node[key_config.children_key] = filtered_children
return parent_node
return None
result = []
for node in tree_list:
filtered_node = filter_node(node)
if filtered_node:
result.append(filtered_node)
return result
def get_tree_depth(
tree_list: List[Dict[str, Any]],
*,
key_config: TreeKeyConfig = default_key_config,
) -> int:
"""
获取树的最大深度
Args:
tree_list: 树形结构数据
children_key: 子节点字段名
Returns:
int: 树的最大深度
"""
def get_node_depth(node, current_depth):
max_depth = current_depth
if key_config.children_key in node and node[key_config.children_key]:
for child in node[key_config.children_key]:
child_depth = get_node_depth(child, current_depth + 1)
max_depth = max(max_depth, child_depth)
return max_depth
max_depth = 0
for node in tree_list:
depth = get_node_depth(node, 1)
max_depth = max(max_depth, depth)
return max_depth