use LLMKeyConfig

This commit is contained in:
2026-02-12 14:35:27 +08:00
parent 43dad177ab
commit c2cc2628dd
6 changed files with 14 additions and 41 deletions

View File

@@ -1 +1 @@
from lang_agent.config.core_config import InstantiateConfig, KeyConfig, ToolConfig, LLMKeyConfig from lang_agent.config.core_config import InstantiateConfig, ToolConfig, LLMKeyConfig

View File

@@ -3,7 +3,7 @@ from typing import Type, Callable, List
import tyro import tyro
import random import random
from lang_agent.config import KeyConfig from lang_agent.config import LLMKeyConfig
from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.pipeline import Pipeline, PipelineConfig
from langchain.chat_models import init_chat_model from langchain.chat_models import init_chat_model
@@ -11,7 +11,7 @@ from langchain_core.messages import BaseMessage, ToolMessage
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class ValidatorConfig(KeyConfig): class ValidatorConfig(LLMKeyConfig):
_target: Type = field(default_factory=lambda:Validator) _target: Type = field(default_factory=lambda:Validator)
@@ -34,9 +34,9 @@ class Validator:
def populate_modules(self): def populate_modules(self):
self.judge_llm = init_chat_model( self.judge_llm = init_chat_model(
model="qwen-plus", model=self.config.llm_name,
model_provider="openai", model_provider=self.config.llm_provider,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", base_url=self.config.base_url,
api_key=self.config.api_key api_key=self.config.api_key
) )

View File

@@ -4,7 +4,7 @@ import tyro
import os.path as osp import os.path as osp
from loguru import logger from loguru import logger
from lang_agent.config import KeyConfig from lang_agent.config import LLMKeyConfig
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
from lang_agent.components.prompt_store import build_prompt_store from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
@@ -20,21 +20,12 @@ from langgraph.graph import StateGraph, START, END
# NOTE: maybe make this into a base_graph_config? # NOTE: maybe make this into a base_graph_config?
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class ReactGraphConfig(KeyConfig): class ReactGraphConfig(LLMKeyConfig):
_target: Type = field(default_factory=lambda: ReactGraph) _target: Type = field(default_factory=lambda: ReactGraph)
llm_name: str = "qwen-plus"
"""name of llm"""
llm_provider:str = "openai"
"""provider of the llm"""
sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt") sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt")
"""path to system prompt""" """path to system prompt"""
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
pipeline_id: Optional[str] = None pipeline_id: Optional[str] = None
"""If set, load prompts from database (with file fallback)""" """If set, load prompts from database (with file fallback)"""

View File

@@ -6,7 +6,7 @@ import time
import asyncio import asyncio
from loguru import logger from loguru import logger
from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.config import InstantiateConfig, LLMKeyConfig
from lang_agent.components.tool_manager import ToolManager from lang_agent.components.tool_manager import ToolManager
from lang_agent.components.prompt_store import build_prompt_store from lang_agent.components.prompt_store import build_prompt_store
from lang_agent.components.reit_llm import ReitLLM from lang_agent.components.reit_llm import ReitLLM
@@ -85,18 +85,9 @@ class ToolNode(ToolNodeBase):
@dataclass @dataclass
class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig): class ChattyToolNodeConfig(LLMKeyConfig, ToolNodeConfig):
_target: Type = field(default_factory=lambda: ChattyToolNode) _target: Type = field(default_factory=lambda: ChattyToolNode)
llm_name: str = "qwen-plus"
"""name of llm"""
llm_provider:str = "openai"
"""provider of the llm"""
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt") chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
"""path to chatty system prompt""" """path to chatty system prompt"""

View File

@@ -13,7 +13,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import InstantiateConfig, KeyConfig from lang_agent.config import LLMKeyConfig
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.components import conv_store from lang_agent.components import conv_store
@@ -52,21 +52,12 @@ DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class PipelineConfig(KeyConfig): class PipelineConfig(LLMKeyConfig):
_target: Type = field(default_factory=lambda: Pipeline) _target: Type = field(default_factory=lambda: Pipeline)
config_f: str = None config_f: str = None
"""path to config file""" """path to config file"""
llm_name: str = "qwen-plus"
"""name of llm; use default for qwen-plus"""
llm_provider:str = "openai"
"""provider of the llm; use default for openai"""
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
"""base url; could be used to overwrite the baseurl in llm provider"""
host:str = "0.0.0.0" host:str = "0.0.0.0"
"""where am I hosted""" """where am I hosted"""

View File

@@ -9,13 +9,13 @@ from langchain_community.vectorstores import FAISS
from langchain_core.documents.base import Document from langchain_core.documents.base import Document
from lang_agent.rag.emb import QwenEmbeddings from lang_agent.rag.emb import QwenEmbeddings
from lang_agent.config import ToolConfig, KeyConfig from lang_agent.config import ToolConfig, LLMKeyConfig
from lang_agent.base import LangToolBase from lang_agent.base import LangToolBase
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class SimpleRagConfig(ToolConfig, KeyConfig): class SimpleRagConfig(ToolConfig, LLMKeyConfig):
_target: Type = field(default_factory=lambda: SimpleRag) _target: Type = field(default_factory=lambda: SimpleRag)
model_name:str = "text-embedding-v4" model_name:str = "text-embedding-v4"