jax drop in replacement
This commit is contained in:
@@ -21,6 +21,26 @@ def make_llm(model="qwen-plus",
|
|||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
def tree_leaves(tree):
|
||||||
|
"""
|
||||||
|
Extracts all leaf values from a nested structure (dict, list, tuple).
|
||||||
|
Drop-in replacement for jax.tree.leaves.
|
||||||
|
"""
|
||||||
|
leaves = []
|
||||||
|
stack = [tree]
|
||||||
|
|
||||||
|
while stack:
|
||||||
|
node = stack.pop()
|
||||||
|
if isinstance(node, dict):
|
||||||
|
stack.extend(reversed(node.values()))
|
||||||
|
elif isinstance(node, (list, tuple)):
|
||||||
|
stack.extend(reversed(node))
|
||||||
|
else:
|
||||||
|
leaves.append(node)
|
||||||
|
|
||||||
|
return leaves
|
||||||
|
|
||||||
|
|
||||||
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
|
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
|
||||||
def words_only(text):
|
def words_only(text):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user