This commit is contained in:
jijiahao
2025-11-06 18:46:36 +08:00

View File

@@ -28,7 +28,7 @@ from langgraph.checkpoint.memory import MemorySaver
class RoutingConfig(KeyConfig): class RoutingConfig(KeyConfig):
_target: Type = field(default_factory=lambda: RoutingGraph) _target: Type = field(default_factory=lambda: RoutingGraph)
llm_name: str = "qwen-flash" llm_name: str = "qwen-plus"
"""name of llm""" """name of llm"""
llm_provider:str = "openai" llm_provider:str = "openai"
@@ -97,12 +97,24 @@ class RoutingGraph(GraphBase):
state = self.workflow.invoke({"inp": nargs}) state = self.workflow.invoke({"inp": nargs})
msg_list = jax.tree.leaves(state) msg_list = jax.tree.leaves(state)
for e in msg_list:
if isinstance(e, BaseMessage):
e.pretty_print()
if as_raw: if as_raw:
return msg_list return msg_list
return msg_list[-1].content return msg_list[-1].content
def _validate_input(self, *nargs, **kwargs): def _validate_input(self, *nargs, **kwargs):
print("\033[93m====================INPUT MESSAGES=============================\033[0m")
for e in nargs[0]["messages"]:
if isinstance(e, BaseMessage):
e.pretty_print()
print("\033[93m====================END INPUT MESSAGES=============================\033[0m")
print(f"\033[93 model used: {self.config.llm_name}\033[0m")
assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message" assert len(nargs[0]["messages"]) >= 2, "need at least 1 system and 1 human message"
assert len(kwargs) == 0, "due to inp assumptions" assert len(kwargs) == 0, "due to inp assumptions"