jax drop in replacement

This commit is contained in:
2026-01-20 17:33:13 +08:00
parent ac43eb6f27
commit b7fce5d973

View File

@@ -21,6 +21,26 @@ def make_llm(model="qwen-plus",
return llm
def tree_leaves(tree):
"""
Extracts all leaf values from a nested structure (dict, list, tuple).
Drop-in replacement for jax.tree.leaves.
"""
leaves = []
stack = [tree]
while stack:
node = stack.pop()
if isinstance(node, dict):
stack.extend(reversed(node.values()))
elif isinstance(node, (list, tuple)):
stack.extend(reversed(node))
else:
leaves.append(node)
return leaves
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
def words_only(text):
"""