diff --git a/lang_agent/utils.py b/lang_agent/utils.py index 3f4cce0..87f65f8 100644 --- a/lang_agent/utils.py +++ b/lang_agent/utils.py @@ -32,7 +32,9 @@ def tree_leaves(tree): while stack: node = stack.pop() if isinstance(node, dict): - stack.extend(reversed(node.values())) + # JAX sorts dict keys alphabetically + sorted_values = [node[k] for k in sorted(node.keys())] + stack.extend(reversed(sorted_values)) elif isinstance(node, (list, tuple)): stack.extend(reversed(node)) else: