load config
This commit is contained in:
@@ -13,7 +13,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
from lang_agent.config import LLMNodeConfig
|
from lang_agent.config import LLMNodeConfig, load_tyro_conf
|
||||||
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.components import conv_store
|
from lang_agent.components import conv_store
|
||||||
@@ -67,6 +67,12 @@ class PipelineConfig(LLMNodeConfig):
|
|||||||
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
||||||
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
|
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.config_f is not None:
|
||||||
|
logger.info(f"loading config from {self.config_f}")
|
||||||
|
self.config = load_tyro_conf(self.config_f)
|
||||||
|
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user