front_api update
This commit is contained in:
@@ -85,6 +85,8 @@ class PipelineCreateRequest(BaseModel):
|
|||||||
api_key: Optional[str] = Field(default=None)
|
api_key: Optional[str] = Field(default=None)
|
||||||
llm_name: str = Field(default="qwen-plus")
|
llm_name: str = Field(default="qwen-plus")
|
||||||
enabled: bool = Field(default=True)
|
enabled: bool = Field(default=True)
|
||||||
|
# Arbitrary per-graph options forwarded to the graph build function.
|
||||||
|
graph_params: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class PipelineSpec(BaseModel):
|
class PipelineSpec(BaseModel):
|
||||||
@@ -93,7 +95,6 @@ class PipelineSpec(BaseModel):
|
|||||||
enabled: bool
|
enabled: bool
|
||||||
config_file: str
|
config_file: str
|
||||||
llm_name: str
|
llm_name: str
|
||||||
overrides: Dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineCreateResponse(BaseModel):
|
class PipelineCreateResponse(BaseModel):
|
||||||
@@ -314,19 +315,13 @@ def _resolve_runtime_fast_api_key() -> RuntimeAuthInfoResponse:
|
|||||||
def _normalize_pipeline_spec(pipeline_id: str, spec: Dict[str, Any]) -> PipelineSpec:
|
def _normalize_pipeline_spec(pipeline_id: str, spec: Dict[str, Any]) -> PipelineSpec:
|
||||||
if not isinstance(spec, dict):
|
if not isinstance(spec, dict):
|
||||||
raise ValueError(f"pipeline spec for '{pipeline_id}' must be an object")
|
raise ValueError(f"pipeline spec for '{pipeline_id}' must be an object")
|
||||||
overrides = spec.get("overrides", {})
|
llm_name = str(spec.get("llm_name") or "unknown")
|
||||||
if overrides is None:
|
|
||||||
overrides = {}
|
|
||||||
if not isinstance(overrides, dict):
|
|
||||||
raise ValueError(f"`overrides` for pipeline '{pipeline_id}' must be an object")
|
|
||||||
llm_name = str(overrides.get("llm_name") or "unknown")
|
|
||||||
return PipelineSpec(
|
return PipelineSpec(
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
graph_id=str(spec.get("graph_id") or pipeline_id),
|
graph_id=str(spec.get("graph_id") or pipeline_id),
|
||||||
enabled=bool(spec.get("enabled", True)),
|
enabled=bool(spec.get("enabled", True)),
|
||||||
config_file=str(spec.get("config_file") or ""),
|
config_file=str(spec.get("config_file") or ""),
|
||||||
llm_name=llm_name,
|
llm_name=llm_name,
|
||||||
overrides=overrides,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -616,6 +611,7 @@ async def create_pipeline(body: PipelineCreateRequest):
|
|||||||
|
|
||||||
config_file = f"configs/pipelines/{pipeline_id}.yaml"
|
config_file = f"configs/pipelines/{pipeline_id}.yaml"
|
||||||
config_abs_dir = osp.join(_PROJECT_ROOT, "configs", "pipelines")
|
config_abs_dir = osp.join(_PROJECT_ROOT, "configs", "pipelines")
|
||||||
|
extra_params = dict(body.graph_params or {})
|
||||||
try:
|
try:
|
||||||
build_fn(
|
build_fn(
|
||||||
pipeline_id=pipeline_id,
|
pipeline_id=pipeline_id,
|
||||||
@@ -624,6 +620,7 @@ async def create_pipeline(body: PipelineCreateRequest):
|
|||||||
api_key=resolved_api_key,
|
api_key=resolved_api_key,
|
||||||
llm_name=body.llm_name,
|
llm_name=body.llm_name,
|
||||||
pipeline_config_dir=config_abs_dir,
|
pipeline_config_dir=config_abs_dir,
|
||||||
|
**extra_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
update_pipeline_registry(
|
update_pipeline_registry(
|
||||||
|
|||||||
Reference in New Issue
Block a user