use LLMNodeConfig
This commit is contained in:
@@ -3,7 +3,7 @@ from typing import Type, Callable, List
|
|||||||
import tyro
|
import tyro
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
@@ -11,7 +11,7 @@ from langchain_core.messages import BaseMessage, ToolMessage
|
|||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ValidatorConfig(LLMKeyConfig):
|
class ValidatorConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda:Validator)
|
_target: Type = field(default_factory=lambda:Validator)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
@@ -48,7 +48,7 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜
|
|||||||
用户在描述梦境的时候用紫色。"""
|
用户在描述梦境的时候用紫色。"""
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DualConfig(LLMKeyConfig):
|
class DualConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda:Dual)
|
_target: Type = field(default_factory=lambda:Dual)
|
||||||
|
|
||||||
pipeline_id: Optional[str] = None
|
pipeline_id: Optional[str] = None
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import time
|
|||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.graphs.graph_states import State
|
from lang_agent.graphs.graph_states import State
|
||||||
@@ -45,7 +45,7 @@ TOOL_SYS_PROMPT = """You are a helpful helper and will use the self_led_control
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class XiaoAiConfig(LLMKeyConfig):
|
class XiaoAiConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda:XiaoAi)
|
_target: Type = field(default_factory=lambda:XiaoAi)
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import tyro
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
@@ -20,7 +20,7 @@ from langgraph.graph import StateGraph, START, END
|
|||||||
# NOTE: maybe make this into a base_graph_config?
|
# NOTE: maybe make this into a base_graph_config?
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReactGraphConfig(LLMKeyConfig):
|
class ReactGraphConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: ReactGraph)
|
_target: Type = field(default_factory=lambda: ReactGraph)
|
||||||
|
|
||||||
sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt")
|
sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt")
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import commentjson
|
|||||||
import glob
|
import glob
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.base import GraphBase, ToolNodeBase
|
from lang_agent.base import GraphBase, ToolNodeBase
|
||||||
@@ -27,7 +27,7 @@ from langgraph.checkpoint.memory import MemorySaver
|
|||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class RoutingConfig(LLMKeyConfig):
|
class RoutingConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: RoutingGraph)
|
_target: Type = field(default_factory=lambda: RoutingGraph)
|
||||||
|
|
||||||
llm_name: str = "qwen-plus"
|
llm_name: str = "qwen-plus"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig, LLMKeyConfig
|
from lang_agent.config import InstantiateConfig, LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager
|
from lang_agent.components.tool_manager import ToolManager
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.components.reit_llm import ReitLLM
|
from lang_agent.components.reit_llm import ReitLLM
|
||||||
@@ -85,7 +85,7 @@ class ToolNode(ToolNodeBase):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChattyToolNodeConfig(LLMKeyConfig, ToolNodeConfig):
|
class ChattyToolNodeConfig(LLMNodeConfig, ToolNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
||||||
|
|
||||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.base import GraphBase, ToolNodeBase
|
from lang_agent.base import GraphBase, ToolNodeBase
|
||||||
@@ -90,7 +90,7 @@ class VisionRoutingState(TypedDict):
|
|||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class VisionRoutingConfig(LLMKeyConfig):
|
class VisionRoutingConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: VisionRoutingGraph)
|
_target: Type = field(default_factory=lambda: VisionRoutingGraph)
|
||||||
|
|
||||||
tool_llm_name: str = "qwen-flash"
|
tool_llm_name: str = "qwen-flash"
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.components import conv_store
|
from lang_agent.components import conv_store
|
||||||
@@ -52,7 +52,7 @@ DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共
|
|||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineConfig(LLMKeyConfig):
|
class PipelineConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: Pipeline)
|
_target: Type = field(default_factory=lambda: Pipeline)
|
||||||
|
|
||||||
config_f: str = None
|
config_f: str = None
|
||||||
|
|||||||
Reference in New Issue
Block a user