diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index a053cac..6f3f1cc 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -13,7 +13,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver -from lang_agent.config import LLMNodeConfig +from lang_agent.config import LLMNodeConfig, load_tyro_conf from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.base import GraphBase from lang_agent.components import conv_store @@ -67,6 +67,12 @@ class PipelineConfig(LLMNodeConfig): # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) graph_config: AnnotatedGraph = field(default_factory=RoutingConfig) + def __post_init__(self): + if self.config_f is not None: + logger.info(f"loading config from {self.config_f}") + self.config = load_tyro_conf(self.config_f) + + super().__post_init__()