diff --git a/lang_agent/utils.py b/lang_agent/utils.py index d5ce27e..3f4cce0 100644 --- a/lang_agent/utils.py +++ b/lang_agent/utils.py @@ -21,6 +21,26 @@ def make_llm(model="qwen-plus", return llm +def tree_leaves(tree): + """ + Extracts all leaf values from a nested structure (dict, list, tuple). + Drop-in replacement for jax.tree.leaves. + """ + leaves = [] + stack = [tree] + + while stack: + node = stack.pop() + if isinstance(node, dict): + stack.extend(reversed(node.values())) + elif isinstance(node, (list, tuple)): + stack.extend(reversed(node)) + else: + leaves.append(node) + + return leaves + + NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]') def words_only(text): """