bug fix
This commit is contained in:
@@ -32,7 +32,9 @@ def tree_leaves(tree):
|
|||||||
while stack:
|
while stack:
|
||||||
node = stack.pop()
|
node = stack.pop()
|
||||||
if isinstance(node, dict):
|
if isinstance(node, dict):
|
||||||
stack.extend(reversed(node.values()))
|
# JAX sorts dict keys alphabetically
|
||||||
|
sorted_values = [node[k] for k in sorted(node.keys())]
|
||||||
|
stack.extend(reversed(sorted_values))
|
||||||
elif isinstance(node, (list, tuple)):
|
elif isinstance(node, (list, tuple)):
|
||||||
stack.extend(reversed(node))
|
stack.extend(reversed(node))
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user