Compare commits

...

14 Commits

Author SHA1 Message Date
728d5934d7 create config file 2026-03-03 15:51:34 +08:00
4974ca936c dashscope pipeline manages multiple pipelines 2026-03-03 15:44:11 +08:00
bc208209c7 serverpipelinemanager impl 2026-03-03 15:40:50 +08:00
afb493adf4 update load 2026-03-03 15:29:40 +08:00
cc2e9cf90c some default deepagent prompt 2026-03-03 15:29:33 +08:00
686c1d6a1f load config 2026-03-03 14:54:22 +08:00
1fcd5b4c61 import load conf 2026-03-03 14:54:09 +08:00
62a00b4a5b print the config first 2026-03-03 14:40:03 +08:00
7294e07df7 save key 2026-03-03 14:39:54 +08:00
6425275d4b moved path locations 2026-03-03 14:16:08 +08:00
5742a08e98 remove unused 2026-03-03 14:14:37 +08:00
af16b87b0e rename 2026-03-03 14:11:07 +08:00
6b0e50c532 moved files 2026-03-03 14:07:01 +08:00
65a1705280 pipeline_manager v1 2026-03-02 18:14:24 +08:00
18 changed files with 378 additions and 180 deletions

View File

@@ -1 +1,3 @@
you are a helpful bot enhanced with skills you are a helpful bot enhanced with skills.
To use a skill, read its SKILL.md file using the read_file tool. Skills are NOT tools — they are instructions for using existing tools.
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew. You can check if the environment the packages you need.

View File

@@ -1,8 +0,0 @@
{
"mcpServers": {
"remote-http-server": {
"type": "https",
"url": "https://xiaoliang.quant-speed.com/api/mcp/"
}
}
}

View File

@@ -3,28 +3,45 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse from fastapi.responses import StreamingResponse, JSONResponse
from fastapi.security import APIKeyHeader from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing import List, Optional from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path as FsPath
import os import os
import os.path as osp
import sys import sys
import time import time
import json import json
import copy
import uvicorn import uvicorn
from loguru import logger from loguru import logger
import tyro import tyro
# Ensure we can import from project root # Ensure we can import from project root
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.pipeline import Pipeline, PipelineConfig
from lang_agent.config.core_config import load_tyro_conf
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
# Initialize Pipeline once # Initialize default pipeline once (used when no explicit pipeline id is provided)
pipeline_config = tyro.cli(PipelineConfig) pipeline_config = tyro.cli(PipelineConfig)
logger.info(f"starting agent with pipeline: \n{pipeline_config}") logger.info(f"starting agent with default pipeline: \n{pipeline_config}")
pipeline:Pipeline = pipeline_config.setup() pipeline: Pipeline = pipeline_config.setup()
# API Key Authentication # API Key Authentication
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True) API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
REGISTRY_FILE = os.environ.get(
"FAST_PIPELINE_REGISTRY_FILE",
osp.join(osp.dirname(osp.dirname(osp.abspath(__file__))), "configs", "pipeline_registry.json"),
)
PIPELINE_MANAGER = ServerPipelineManager(
default_route_id=os.environ.get("FAST_DEFAULT_ROUTE_ID", os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default")),
default_config=pipeline_config,
default_pipeline=pipeline,
)
PIPELINE_MANAGER.load_registry(REGISTRY_FILE)
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
@@ -143,81 +160,107 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str
yield f"data: {json.dumps(final)}\n\n" yield f"data: {json.dumps(final)}\n\n"
def _normalize_messages(body: Dict[str, Any]) -> List[Dict[str, Any]]:
messages = body.get("messages")
body_input = body.get("input", {})
if messages is None and isinstance(body_input, dict):
messages = body_input.get("messages")
if messages is None and isinstance(body_input, dict):
prompt = body_input.get("prompt")
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
if not messages:
raise HTTPException(status_code=400, detail="messages is required")
return messages
def _extract_user_message(messages: List[Dict[str, Any]]) -> str:
user_msg = None
for m in reversed(messages):
role = m.get("role") if isinstance(m, dict) else None
content = m.get("content") if isinstance(m, dict) else None
if role == "user" and content:
user_msg = content
break
if user_msg is None:
last = messages[-1]
user_msg = last.get("content") if isinstance(last, dict) else str(last)
return user_msg
async def _process_dashscope_request(
body: Dict[str, Any],
app_id: Optional[str],
session_id: Optional[str],
api_key: str,
):
req_app_id = app_id or body.get("app_id")
body_input = body.get("input", {}) if isinstance(body.get("input"), dict) else {}
req_session_id = session_id or body_input.get("session_id")
messages = _normalize_messages(body)
stream = body.get("stream")
if stream is None:
stream = body.get("parameters", {}).get("stream", True)
thread_id = body_input.get("session_id") or req_session_id or "3"
user_msg = _extract_user_message(messages)
route_id = PIPELINE_MANAGER.resolve_route_id(body=body, app_id=req_app_id, api_key=api_key)
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(route_id)
# Namespace thread ids to prevent memory collisions across pipelines.
thread_id = f"{route_id}:{thread_id}"
response_id = f"appcmpl-{os.urandom(12).hex()}"
if stream:
chunk_generator = await selected_pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
return StreamingResponse(
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=selected_model),
media_type="text/event-stream",
)
result_text = await selected_pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
if not isinstance(result_text, str):
result_text = str(result_text)
data = {
"request_id": response_id,
"code": 200,
"message": "OK",
"app_id": req_app_id,
"session_id": req_session_id,
"output": {
"text": result_text,
"created": int(time.time()),
"model": selected_model,
},
"route_id": route_id,
# Backward compatibility: keep pipeline_id in response as the route id selector.
"pipeline_id": route_id,
"is_end": True,
}
return JSONResponse(content=data)
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses") @app.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses") @app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
async def application_responses( async def application_responses(
request: Request, request: Request,
app_id: str = Path(...), app_id: str = Path(...),
session_id: str = Path(...), session_id: str = Path(...),
_: str = Depends(verify_api_key), api_key: str = Depends(verify_api_key),
): ):
try: try:
body = await request.json() body = await request.json()
return await _process_dashscope_request(
# Prefer path params body=body,
req_app_id = app_id or body.get("app_id") app_id=app_id,
req_session_id = session_id or body['input'].get("session_id") session_id=session_id,
api_key=api_key,
# Normalize messages )
messages = body.get("messages")
if messages is None and isinstance(body.get("input"), dict):
messages = body.get("input", {}).get("messages")
if messages is None and isinstance(body.get("input"), dict):
prompt = body.get("input", {}).get("prompt")
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
if not messages:
raise HTTPException(status_code=400, detail="messages is required")
# Determine stream flag
stream = body.get("stream")
if stream is None:
stream = body.get("parameters", {}).get("stream", True)
thread_id = body['input'].get("session_id")
# Extract latest user message
user_msg = None
for m in reversed(messages):
role = m.get("role") if isinstance(m, dict) else None
content = m.get("content") if isinstance(m, dict) else None
if role == "user" and content:
user_msg = content
break
if user_msg is None:
last = messages[-1]
user_msg = last.get("content") if isinstance(last, dict) else str(last)
response_id = f"appcmpl-{os.urandom(12).hex()}"
if stream:
# Use async streaming from pipeline
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
return StreamingResponse(
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
media_type="text/event-stream",
)
# Non-streaming: get full result using async
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
if not isinstance(result_text, str):
result_text = str(result_text)
data = {
"request_id": response_id,
"code": 200,
"message": "OK",
"app_id": req_app_id,
"session_id": req_session_id,
"output": {
"text": result_text,
"created": int(time.time()),
"model": pipeline_config.llm_name,
},
"is_end": True,
}
return JSONResponse(content=data)
except HTTPException: except HTTPException:
raise raise
@@ -234,71 +277,16 @@ async def application_responses(
async def application_completion( async def application_completion(
request: Request, request: Request,
app_id: str = Path(...), app_id: str = Path(...),
_: str = Depends(verify_api_key), api_key: str = Depends(verify_api_key),
): ):
try: try:
body = await request.json() body = await request.json()
return await _process_dashscope_request(
req_session_id = body['input'].get("session_id") body=body,
app_id=app_id,
# Normalize messages session_id=None,
messages = body.get("messages") api_key=api_key,
if messages is None and isinstance(body.get("input"), dict): )
messages = body.get("input", {}).get("messages")
if messages is None and isinstance(body.get("input"), dict):
prompt = body.get("input", {}).get("prompt")
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
if not messages:
raise HTTPException(status_code=400, detail="messages is required")
stream = body.get("stream")
if stream is None:
stream = body.get("parameters", {}).get("stream", True)
thread_id = body['input'].get("session_id")
user_msg = None
for m in reversed(messages):
role = m.get("role") if isinstance(m, dict) else None
content = m.get("content") if isinstance(m, dict) else None
if role == "user" and content:
user_msg = content
break
if user_msg is None:
last = messages[-1]
user_msg = last.get("content") if isinstance(last, dict) else str(last)
response_id = f"appcmpl-{os.urandom(12).hex()}"
if stream:
# Use async streaming from pipeline
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
return StreamingResponse(
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
media_type="text/event-stream",
)
# Non-streaming: get full result using async
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
if not isinstance(result_text, str):
result_text = str(result_text)
data = {
"request_id": response_id,
"code": 200,
"message": "OK",
"app_id": app_id,
"session_id": req_session_id,
"output": {
"text": result_text,
"created": int(time.time()),
"model": pipeline_config.llm_name,
},
"is_end": True,
}
return JSONResponse(content=data)
except HTTPException: except HTTPException:
raise raise

View File

@@ -0,0 +1,179 @@
from fastapi import HTTPException
from typing import Any, Dict, Optional, Tuple
from pathlib import Path as FsPath
import os.path as osp
import json
import copy
from loguru import logger
from lang_agent.pipeline import Pipeline, PipelineConfig
from lang_agent.config.core_config import load_tyro_conf
class ServerPipelineManager:
"""Lazily load and cache multiple pipelines keyed by a client-facing route id."""
def __init__(self, default_route_id: str, default_config: PipelineConfig, default_pipeline: Pipeline):
self.default_route_id = default_route_id
self.default_config = default_config
self._route_specs: Dict[str, Dict[str, Any]] = {}
self._api_key_policy: Dict[str, Dict[str, Any]] = {}
self._pipelines: Dict[str, Pipeline] = {default_route_id: default_pipeline}
self._pipeline_llm: Dict[str, str] = {default_route_id: default_config.llm_name}
self._route_specs[default_route_id] = {
"enabled": True,
"config_file": None,
"overrides": {},
"prompt_pipeline_id": None,
}
def _resolve_registry_path(self, registry_path: str) -> str:
path = FsPath(registry_path)
if path.is_absolute():
return str(path)
# server_pipeline_manager.py is under <repo>/lang_agent/components/,
# so parents[2] is the repository root.
root = FsPath(__file__).resolve().parents[2]
return str((root / path).resolve())
def load_registry(self, registry_path: str) -> None:
abs_path = self._resolve_registry_path(registry_path)
if not osp.exists(abs_path):
logger.warning(f"pipeline registry file not found: {abs_path}. Using default pipeline only.")
return
with open(abs_path, "r", encoding="utf-8") as f:
registry:dict = json.load(f)
routes = registry.get("routes")
if routes is None:
# Backward compatibility with initial schema.
routes = registry.get("pipelines", {})
if not isinstance(routes, dict):
raise ValueError("`routes` in pipeline registry must be an object.")
for route_id, spec in routes.items():
if not isinstance(spec, dict):
raise ValueError(f"route spec for `{route_id}` must be an object.")
self._route_specs[route_id] = {
"enabled": bool(spec.get("enabled", True)),
"config_file": spec.get("config_file"),
"overrides": spec.get("overrides", {}),
# Explicitly separates routing id from prompt config pipeline_id.
"prompt_pipeline_id": spec.get("prompt_pipeline_id"),
}
api_key_policy = registry.get("api_keys", {})
if api_key_policy and not isinstance(api_key_policy, dict):
raise ValueError("`api_keys` in pipeline registry must be an object.")
self._api_key_policy = api_key_policy
logger.info(f"loaded pipeline registry: {abs_path}, routes={list(self._route_specs.keys())}")
def _resolve_config_path(self, config_file: str) -> str:
path = FsPath(config_file)
if path.is_absolute():
return str(path)
# Resolve relative config paths from repository root for consistency
# with docker-compose and tests.
root = FsPath(__file__).resolve().parents[2]
return str((root / path).resolve())
def _build_pipeline(self, route_id: str) -> Tuple[Pipeline, str]:
spec = self._route_specs.get(route_id)
if spec is None:
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}")
if not spec.get("enabled", True):
raise HTTPException(status_code=403, detail=f"Route disabled: {route_id}")
config_file = spec.get("config_file")
overrides = spec.get("overrides", {})
if not config_file and not overrides:
# default pipeline
p = self._pipelines[self.default_route_id]
llm_name = self._pipeline_llm[self.default_route_id]
return p, llm_name
if config_file:
loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file))
# Some legacy yaml configs deserialize to plain dicts instead of
# InstantiateConfig dataclasses. Fall back to default config in that case.
if hasattr(loaded_cfg, "setup"):
cfg = loaded_cfg
else:
logger.warning(
f"config_file for route `{route_id}` did not deserialize to config object; "
"falling back to default config and applying route-level overrides."
)
cfg = copy.deepcopy(self.default_config)
else:
# Build from default config + shallow overrides so new pipelines can be
# added via registry without additional yaml files.
cfg = copy.deepcopy(self.default_config)
if not isinstance(overrides, dict):
raise ValueError(f"route `overrides` for `{route_id}` must be an object.")
for key, value in overrides.items():
if not hasattr(cfg, key):
raise ValueError(f"unknown override field `{key}` for route `{route_id}`")
setattr(cfg, key, value)
prompt_pipeline_id = spec.get("prompt_pipeline_id")
if prompt_pipeline_id and (not isinstance(overrides, dict) or "pipeline_id" not in overrides):
if hasattr(cfg, "pipeline_id"):
cfg.pipeline_id = prompt_pipeline_id
p = cfg.setup()
llm_name = getattr(cfg, "llm_name", "unknown-model")
return p, llm_name
def _authorize(self, api_key: str, route_id: str) -> None:
if not self._api_key_policy:
return
policy = self._api_key_policy.get(api_key)
if policy is None:
return
allowed = policy.get("allowed_route_ids")
if allowed is None:
# Backward compatibility.
allowed = policy.get("allowed_pipeline_ids")
if allowed and route_id not in allowed:
raise HTTPException(status_code=403, detail=f"route_id `{route_id}` is not allowed for this API key")
def resolve_route_id(self, body: Dict[str, Any], app_id: Optional[str], api_key: str) -> str:
body_input = body.get("input", {})
route_id = (
body.get("route_id")
or (body_input.get("route_id") if isinstance(body_input, dict) else None)
or body.get("pipeline_key")
or (body_input.get("pipeline_key") if isinstance(body_input, dict) else None)
# Backward compatibility: pipeline_id still accepted as route selector.
or body.get("pipeline_id")
or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None)
or app_id
)
if not route_id:
key_policy = self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
route_id = key_policy.get("default_route_id")
if not route_id:
# Backward compatibility.
route_id = key_policy.get("default_pipeline_id", self.default_route_id)
if route_id not in self._route_specs:
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}")
self._authorize(api_key, route_id)
return route_id
def get_pipeline(self, route_id: str) -> Tuple[Pipeline, str]:
cached = self._pipelines.get(route_id)
if cached is not None:
return cached, self._pipeline_llm[route_id]
pipeline_obj, llm_name = self._build_pipeline(route_id)
self._pipelines[route_id] = pipeline_obj
self._pipeline_llm[route_id] = llm_name
logger.info(f"lazy-loaded route_id={route_id} model={llm_name}")
return pipeline_obj, llm_name

View File

@@ -1,4 +1,5 @@
from lang_agent.config.core_config import (InstantiateConfig, from lang_agent.config.core_config import (InstantiateConfig,
ToolConfig, ToolConfig,
LLMKeyConfig, LLMKeyConfig,
LLMNodeConfig) LLMNodeConfig,
load_tyro_conf)

View File

@@ -72,54 +72,56 @@ class InstantiateConfig(PrintableConfig):
将配置保存到 YAML 文件 将配置保存到 YAML 文件
""" """
def mask_value(key, value): def mask_value(key, value):
""" """
Apply masking if key is secret-like Apply masking if key is secret-like
如果键是敏感的,应用掩码 如果键是敏感的,应用掩码
检查键是否敏感(如包含 "secret""api_key"),如果是,则对值进行掩码处理 检查键是否敏感(如包含 "secret""api_key"),如果是,则对值进行掩码处理
""" """
if isinstance(value, str) and self.is_secrete(key): if isinstance(value, str) and self.is_secrete(str(key)):
sval = str(value) sval = str(value)
return sval[:3] + "*" * (len(sval) - 6) + sval[-3:] return sval[:3] + "*" * (len(sval) - 6) + sval[-3:]
return value return value
def to_masked_serializable(obj): def to_serializable(obj, apply_mask: bool):
""" """
Recursively convert dataclasses and containers to serializable with masked secrets Recursively convert dataclasses and containers to serializable format,
optionally masking secrets.
递归地将数据类和容器转换为可序列化的格式,同时对敏感信息进行掩码处理
递归地将数据类和容器转换为可序列化的格式,可选地对敏感信息进行掩码处理
""" """
if is_dataclass(obj): if is_dataclass(obj):
out = {} out = {}
for k, v in vars(obj).items(): for k, v in vars(obj).items():
if is_dataclass(v) or isinstance(v, (dict, list, tuple)): if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
out[k] = to_masked_serializable(v) out[k] = to_serializable(v, apply_mask)
else: else:
out[k] = mask_value(k, v) out[k] = mask_value(k, v) if apply_mask else v
return out return out
if isinstance(obj, dict): if isinstance(obj, dict):
out = {} out = {}
for k, v in obj.items(): for k, v in obj.items():
if is_dataclass(v) or isinstance(v, (dict, list, tuple)): if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
out[k] = to_masked_serializable(v) out[k] = to_serializable(v, apply_mask)
else: else:
# k might be a non-string; convert to str for is_secrete check consistency
key_str = str(k) key_str = str(k)
out[k] = mask_value(key_str, v) out[k] = mask_value(key_str, v) if apply_mask else v
return out return out
if isinstance(obj, list): if isinstance(obj, list):
return [to_masked_serializable(v) for v in obj] return [to_serializable(v, apply_mask) for v in obj]
if isinstance(obj, tuple): if isinstance(obj, tuple):
return tuple(to_masked_serializable(v) for v in obj) return tuple(to_serializable(v, apply_mask) for v in obj)
return obj return obj
masked = to_masked_serializable(self) # NOTE: we intentionally do NOT mask secrets when saving to disk so that
# configs can be reloaded with real values. Masking is handled in __str__
# for safe logging/printing. If you need a redacted copy, call
# to_serializable(self, apply_mask=True) manually and dump it yourself.
serializable = to_serializable(self, apply_mask=False)
with open(filename, 'w') as f: with open(filename, 'w') as f:
yaml.dump(masked, f) yaml.dump(serializable, f)
logger.info(f"[yellow]config saved to: {filename}[/yellow]") logger.info(f"[yellow]config saved to: {filename}[/yellow]")
def get_name(self): def get_name(self):

View File

@@ -13,7 +13,7 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import LLMNodeConfig from lang_agent.config import LLMNodeConfig, load_tyro_conf
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.components import conv_store from lang_agent.components import conv_store
@@ -67,6 +67,16 @@ class PipelineConfig(LLMNodeConfig):
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig) graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
def __post_init__(self):
if self.config_f is not None:
logger.info(f"loading config from {self.config_f}")
loaded_conf = load_tyro_conf(self.config_f)# NOTE: We are not merging with self , self)
if not hasattr(loaded_conf, "__dict__"):
raise TypeError(f"config_f {self.config_f} did not load into a config object")
# Apply loaded
self.__dict__.update(vars(loaded_conf))
super().__post_init__()

12
scripts/create_config.sh Normal file
View File

@@ -0,0 +1,12 @@
source ~/.bashrc
conda init
conda activate lang
echo create blueberry config
python scripts/py_scripts/misc_tasks.py --save-path config/pipelines/blueberry.yaml \
react \
--sys-prompt-f configs/prompts/blueberry.txt \
--tool-manager-config.client-tool-manager.tool-keys
# echo create xiaozhan config
python scripts/py_scripts/misc_tasks.py --save-path config/pipelines/xiaozhan.yaml

View File

@@ -1,17 +0,0 @@
from lang_agent.graphs import ReactGraphConfig, ReactGraph, RoutingConfig,RoutingGraph
from lang_agent.base import GraphBase
import os.path as osp
from tqdm import tqdm
def main():
save_dir = osp.join(osp.dirname(osp.dirname(__file__)), "frontend/assets/images/graph_arch")
confs:GraphBase = [ReactGraphConfig(), RoutingConfig()]
for conf in tqdm(confs):
graph:GraphBase = conf.setup()
img = graph.show_graph(ret_img=True)
img.save(osp.join(save_dir, f"arch_{conf.__class__.__name__}.png"))
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,29 @@
from lang_agent.graphs import ReactGraphConfig, ReactGraph, RoutingConfig,RoutingGraph
from lang_agent.pipeline import PipelineConfig
from lang_agent.base import GraphBase
import os.path as osp
import os
from tqdm import tqdm
import yaml
import tyro
from loguru import logger
def gen_arch_imgs(save_dir="frontend/assets/images/graph_arch"):
save_dir = osp.join(osp.dirname(osp.dirname(__file__)), save_dir)
confs:GraphBase = [ReactGraphConfig(), RoutingConfig()]
for conf in tqdm(confs):
graph:GraphBase = conf.setup()
img = graph.show_graph(ret_img=True)
img.save(osp.join(save_dir, f"arch_{conf.__class__.__name__}.png"))
def make_save_conf(pipeline:PipelineConfig, save_path:str):
os.makedirs(osp.dirname(save_path), exist_ok=True)
logger.info(pipeline)
pipeline.save_config(save_path)
if __name__ == "__main__":
# gen_arch_imgs()
tyro.cli(make_save_conf)