diff --git a/lang_agent/config/__init__.py b/lang_agent/config/__init__.py index da957f1..eddbded 100644 --- a/lang_agent/config/__init__.py +++ b/lang_agent/config/__init__.py @@ -4,6 +4,7 @@ from lang_agent.config.core_config import ( LLMKeyConfig, LLMNodeConfig, load_tyro_conf, + resolve_llm_api_key, ) from lang_agent.config.constants import ( diff --git a/lang_agent/config/core_config.py b/lang_agent/config/core_config.py index 9b1afed..ca4a4b8 100644 --- a/lang_agent/config/core_config.py +++ b/lang_agent/config/core_config.py @@ -10,6 +10,20 @@ from dotenv import load_dotenv load_dotenv() + +def resolve_llm_api_key(api_key: Optional[str]) -> Optional[str]: + """Resolve the API key for OpenAI-compatible providers.""" + if api_key not in (None, "", "wrong-key"): + resolved_key = api_key + else: + resolved_key = os.environ.get("ALI_API_KEY") or os.environ.get("OPENAI_API_KEY") + + # Some OpenAI-compatible integrations still read OPENAI_API_KEY from env. + if resolved_key and not os.environ.get("OPENAI_API_KEY"): + os.environ["OPENAI_API_KEY"] = resolved_key + + return resolved_key + ## NOTE: base classes taken from nerfstudio class PrintableConfig: """ @@ -99,12 +113,12 @@ class LLMKeyConfig(InstantiateConfig): """api key for llm""" def __post_init__(self): - if self.api_key == "wrong-key" or self.api_key is None: - self.api_key = os.environ.get("ALI_API_KEY") - if self.api_key is None: - logger.error(f"no ALI_API_KEY provided for embedding") - else: - logger.info("ALI_API_KEY loaded from environ") + original_api_key = self.api_key + self.api_key = resolve_llm_api_key(self.api_key) + if self.api_key is None: + logger.error("no ALI_API_KEY or OPENAI_API_KEY provided for embedding") + elif original_api_key in (None, "", "wrong-key"): + logger.info("LLM API key loaded from environment") @dataclass diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index c79beec..a3b72e9 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -13,7 +13,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver -from lang_agent.config import LLMNodeConfig, load_tyro_conf +from lang_agent.config import LLMNodeConfig, load_tyro_conf, resolve_llm_api_key from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.base import GraphBase from lang_agent.components import conv_store @@ -104,7 +104,13 @@ class Pipeline: if self.config.base_url is not None else self.config.graph_config.base_url ) - self.config.graph_config.api_key = self.config.api_key + pipeline_api_key = resolve_llm_api_key(self.config.api_key) + graph_api_key = resolve_llm_api_key( + getattr(self.config.graph_config, "api_key", None) + ) + resolved_api_key = pipeline_api_key or graph_api_key + self.config.api_key = resolved_api_key + self.config.graph_config.api_key = resolved_api_key self.graph: GraphBase = self.config.graph_config.setup()