replace jax.tree.leaves with drop in

This commit is contained in:
2026-01-20 17:33:02 +08:00
parent 4290ce6756
commit ac43eb6f27
4 changed files with 20 additions and 19 deletions

View File

@@ -3,7 +3,6 @@ from typing import Type, TypedDict, Literal, Dict, List, Tuple, Any, AsyncIterat
import tyro
from pydantic import BaseModel, Field
from loguru import logger
import jax
import os.path as osp
import commentjson
import glob