This commit is contained in:
2026-01-20 17:53:32 +08:00
parent 0839c9bcee
commit afe94fd9f6

View File

@@ -32,7 +32,9 @@ def tree_leaves(tree):
while stack:
node = stack.pop()
if isinstance(node, dict):
stack.extend(reversed(node.values()))
# JAX sorts dict keys alphabetically
sorted_values = [node[k] for k in sorted(node.keys())]
stack.extend(reversed(sorted_values))
elif isinstance(node, (list, tuple)):
stack.extend(reversed(node))
else: