This commit is contained in:
2026-01-20 17:53:32 +08:00
parent 0839c9bcee
commit afe94fd9f6

View File

@@ -32,7 +32,9 @@ def tree_leaves(tree):
while stack: while stack:
node = stack.pop() node = stack.pop()
if isinstance(node, dict): if isinstance(node, dict):
stack.extend(reversed(node.values())) # JAX sorts dict keys alphabetically
sorted_values = [node[k] for k in sorted(node.keys())]
stack.extend(reversed(sorted_values))
elif isinstance(node, (list, tuple)): elif isinstance(node, (list, tuple)):
stack.extend(reversed(node)) stack.extend(reversed(node))
else: else: