diff --git a/lang_agent/eval/validator.py b/lang_agent/eval/validator.py index d120f54..92855c7 100644 --- a/lang_agent/eval/validator.py +++ b/lang_agent/eval/validator.py @@ -3,7 +3,7 @@ from typing import Type, Callable, List import tyro import random -from lang_agent.config import LLMKeyConfig +from lang_agent.config import LLMNodeConfig from lang_agent.pipeline import Pipeline, PipelineConfig from langchain.chat_models import init_chat_model @@ -11,7 +11,7 @@ from langchain_core.messages import BaseMessage, ToolMessage @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class ValidatorConfig(LLMKeyConfig): +class ValidatorConfig(LLMNodeConfig): _target: Type = field(default_factory=lambda:Validator) diff --git a/lang_agent/graphs/dual_path.py b/lang_agent/graphs/dual_path.py index 2c415ad..673f4b5 100644 --- a/lang_agent/graphs/dual_path.py +++ b/lang_agent/graphs/dual_path.py @@ -6,7 +6,7 @@ from loguru import logger from langchain.chat_models import init_chat_model -from lang_agent.config import LLMKeyConfig +from lang_agent.config import LLMNodeConfig from lang_agent.base import GraphBase from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.prompt_store import build_prompt_store @@ -48,7 +48,7 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜 用户在描述梦境的时候用紫色。""" @dataclass -class DualConfig(LLMKeyConfig): +class DualConfig(LLMNodeConfig): _target: Type = field(default_factory=lambda:Dual) pipeline_id: Optional[str] = None diff --git a/lang_agent/graphs/legacy_xiaoai_demo.py b/lang_agent/graphs/legacy_xiaoai_demo.py index 59d06ab..7962641 100644 --- a/lang_agent/graphs/legacy_xiaoai_demo.py +++ b/lang_agent/graphs/legacy_xiaoai_demo.py @@ -8,7 +8,7 @@ import time from langchain.chat_models import init_chat_model -from lang_agent.config import LLMKeyConfig +from lang_agent.config import LLMNodeConfig from lang_agent.base import GraphBase from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.graphs.graph_states import State @@ -45,7 +45,7 @@ TOOL_SYS_PROMPT = """You are a helpful helper and will use the self_led_control @dataclass -class XiaoAiConfig(LLMKeyConfig): +class XiaoAiConfig(LLMNodeConfig): _target: Type = field(default_factory=lambda:XiaoAi) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) diff --git a/lang_agent/graphs/react.py b/lang_agent/graphs/react.py index 0622eb1..9cc5ed8 100644 --- a/lang_agent/graphs/react.py +++ b/lang_agent/graphs/react.py @@ -4,7 +4,7 @@ import tyro import os.path as osp from loguru import logger -from lang_agent.config import LLMKeyConfig +from lang_agent.config import LLMNodeConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.prompt_store import build_prompt_store from lang_agent.base import GraphBase @@ -20,7 +20,7 @@ from langgraph.graph import StateGraph, START, END # NOTE: maybe make this into a base_graph_config? @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class ReactGraphConfig(LLMKeyConfig): +class ReactGraphConfig(LLMNodeConfig): _target: Type = field(default_factory=lambda: ReactGraph) sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt") diff --git a/lang_agent/graphs/routing.py b/lang_agent/graphs/routing.py index a264600..a2d8415 100644 --- a/lang_agent/graphs/routing.py +++ b/lang_agent/graphs/routing.py @@ -8,7 +8,7 @@ import commentjson import glob import time -from lang_agent.config import LLMKeyConfig +from lang_agent.config import LLMNodeConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.prompt_store import build_prompt_store from lang_agent.base import GraphBase, ToolNodeBase @@ -27,7 +27,7 @@ from langgraph.checkpoint.memory import MemorySaver @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class RoutingConfig(LLMKeyConfig): +class RoutingConfig(LLMNodeConfig): _target: Type = field(default_factory=lambda: RoutingGraph) llm_name: str = "qwen-plus" diff --git a/lang_agent/graphs/tool_nodes.py b/lang_agent/graphs/tool_nodes.py index 422563e..1905605 100644 --- a/lang_agent/graphs/tool_nodes.py +++ b/lang_agent/graphs/tool_nodes.py @@ -6,7 +6,7 @@ import time import asyncio from loguru import logger -from lang_agent.config import InstantiateConfig, LLMKeyConfig +from lang_agent.config import InstantiateConfig, LLMNodeConfig from lang_agent.components.tool_manager import ToolManager from lang_agent.components.prompt_store import build_prompt_store from lang_agent.components.reit_llm import ReitLLM @@ -85,7 +85,7 @@ class ToolNode(ToolNodeBase): @dataclass -class ChattyToolNodeConfig(LLMKeyConfig, ToolNodeConfig): +class ChattyToolNodeConfig(LLMNodeConfig, ToolNodeConfig): _target: Type = field(default_factory=lambda: ChattyToolNode) chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") diff --git a/lang_agent/graphs/vision_routing.py b/lang_agent/graphs/vision_routing.py index 9c0ff17..6f02a1e 100644 --- a/lang_agent/graphs/vision_routing.py +++ b/lang_agent/graphs/vision_routing.py @@ -12,7 +12,7 @@ import base64 import json from loguru import logger -from lang_agent.config import LLMKeyConfig +from lang_agent.config import LLMNodeConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.prompt_store import build_prompt_store from lang_agent.base import GraphBase, ToolNodeBase @@ -90,7 +90,7 @@ class VisionRoutingState(TypedDict): @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class VisionRoutingConfig(LLMKeyConfig): +class VisionRoutingConfig(LLMNodeConfig): _target: Type = field(default_factory=lambda: VisionRoutingGraph) tool_llm_name: str = "qwen-flash" diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 11618e0..9783cf3 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -13,7 +13,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver -from lang_agent.config import LLMKeyConfig +from lang_agent.config import LLMNodeConfig from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.base import GraphBase from lang_agent.components import conv_store @@ -52,7 +52,7 @@ DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共 @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass -class PipelineConfig(LLMKeyConfig): +class PipelineConfig(LLMNodeConfig): _target: Type = field(default_factory=lambda: Pipeline) config_f: str = None