unified constants

This commit is contained in:
2026-03-04 17:27:26 +08:00
parent 61931cad58
commit 9b128ae41b
11 changed files with 488 additions and 411 deletions

View File

@@ -13,16 +13,17 @@ from pydantic import BaseModel, Field
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from lang_agent.config.db_config_manager import DBConfigManager from lang_agent.config.db_config_manager import DBConfigManager
from lang_agent.config.constants import (
_PROJECT_ROOT,
MCP_CONFIG_PATH,
MCP_CONFIG_DEFAULT_CONTENT,
PIPELINE_REGISTRY_PATH,
)
from lang_agent.front_api.build_server_utils import ( from lang_agent.front_api.build_server_utils import (
GRAPH_BUILD_FNCS, GRAPH_BUILD_FNCS,
update_pipeline_registry, update_pipeline_registry,
) )
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.abspath(__file__)))
_MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
_MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
_PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json")
class GraphConfigUpsertRequest(BaseModel): class GraphConfigUpsertRequest(BaseModel):
graph_id: str graph_id: str
@@ -206,20 +207,20 @@ def _parse_mcp_tool_keys(raw_content: str) -> List[str]:
def _read_mcp_config_raw() -> str: def _read_mcp_config_raw() -> str:
if not osp.exists(_MCP_CONFIG_PATH): if not osp.exists(MCP_CONFIG_PATH):
os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True) os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True)
with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f: with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
f.write(_MCP_CONFIG_DEFAULT_CONTENT) f.write(MCP_CONFIG_DEFAULT_CONTENT)
with open(_MCP_CONFIG_PATH, "r", encoding="utf-8") as f: with open(MCP_CONFIG_PATH, "r", encoding="utf-8") as f:
return f.read() return f.read()
def _read_pipeline_registry() -> Dict[str, Any]: def _read_pipeline_registry() -> Dict[str, Any]:
if not osp.exists(_PIPELINE_REGISTRY_PATH): if not osp.exists(PIPELINE_REGISTRY_PATH):
os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True) os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True)
with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f: with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
json.dump({"pipelines": {}, "api_keys": {}}, f, indent=2) json.dump({"pipelines": {}, "api_keys": {}}, f, indent=2)
with open(_PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f: with open(PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f:
registry = json.load(f) registry = json.load(f)
pipelines = registry.get("pipelines") pipelines = registry.get("pipelines")
if not isinstance(pipelines, dict): if not isinstance(pipelines, dict):
@@ -233,8 +234,8 @@ def _read_pipeline_registry() -> Dict[str, Any]:
def _write_pipeline_registry(registry: Dict[str, Any]) -> None: def _write_pipeline_registry(registry: Dict[str, Any]) -> None:
os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True) os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True)
with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f: with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f:
json.dump(registry, f, indent=2) json.dump(registry, f, indent=2)
f.write("\n") f.write("\n")
@@ -433,7 +434,7 @@ async def get_mcp_tool_config():
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return McpConfigReadResponse( return McpConfigReadResponse(
path=_MCP_CONFIG_PATH, path=MCP_CONFIG_PATH,
raw_content=raw_content, raw_content=raw_content,
tool_keys=tool_keys, tool_keys=tool_keys,
) )
@@ -443,8 +444,8 @@ async def get_mcp_tool_config():
async def update_mcp_tool_config(body: McpConfigUpdateRequest): async def update_mcp_tool_config(body: McpConfigUpdateRequest):
try: try:
tool_keys = _parse_mcp_tool_keys(body.raw_content) tool_keys = _parse_mcp_tool_keys(body.raw_content)
os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True) os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True)
with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f: with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
# Keep user formatting/comments as entered while ensuring trailing newline. # Keep user formatting/comments as entered while ensuring trailing newline.
f.write(body.raw_content.rstrip() + "\n") f.write(body.raw_content.rstrip() + "\n")
except ValueError as e: except ValueError as e:
@@ -453,7 +454,7 @@ async def update_mcp_tool_config(body: McpConfigUpdateRequest):
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
return McpConfigUpdateResponse( return McpConfigUpdateResponse(
status="updated", status="updated",
path=_MCP_CONFIG_PATH, path=MCP_CONFIG_PATH,
tool_keys=tool_keys, tool_keys=tool_keys,
) )
@@ -528,7 +529,7 @@ async def create_pipeline(body: PipelineCreateRequest):
config_file=config_file, config_file=config_file,
llm_name=body.llm_name, llm_name=body.llm_name,
enabled=body.enabled, enabled=body.enabled,
registry_f=_PIPELINE_REGISTRY_PATH, registry_f=PIPELINE_REGISTRY_PATH,
) )
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to register pipeline: {e}") raise HTTPException(status_code=500, detail=f"Failed to register pipeline: {e}")
@@ -543,7 +544,8 @@ async def create_pipeline(body: PipelineCreateRequest):
normalized = _normalize_pipeline_spec(pipeline_id, pipeline_spec) normalized = _normalize_pipeline_spec(pipeline_id, pipeline_spec)
except Exception as e: except Exception as e:
raise HTTPException( raise HTTPException(
status_code=500, detail=f"Failed to read pipeline registry after update: {e}" status_code=500,
detail=f"Failed to read pipeline registry after update: {e}",
) )
return PipelineCreateResponse( return PipelineCreateResponse(
@@ -554,7 +556,7 @@ async def create_pipeline(body: PipelineCreateRequest):
llm_name=normalized.llm_name, llm_name=normalized.llm_name,
enabled=normalized.enabled, enabled=normalized.enabled,
reload_required=True, reload_required=True,
registry_path=_PIPELINE_REGISTRY_PATH, registry_path=PIPELINE_REGISTRY_PATH,
) )
@@ -608,7 +610,9 @@ async def list_pipeline_api_keys():
async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertRequest): async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertRequest):
normalized_key = api_key.strip() normalized_key = api_key.strip()
if not normalized_key: if not normalized_key:
raise HTTPException(status_code=400, detail="api_key path parameter is required") raise HTTPException(
status_code=400, detail="api_key path parameter is required"
)
try: try:
registry = _read_pipeline_registry() registry = _read_pipeline_registry()
pipelines = registry.get("pipelines", {}) pipelines = registry.get("pipelines", {})
@@ -662,7 +666,9 @@ async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertR
async def delete_pipeline_api_key_policy(api_key: str): async def delete_pipeline_api_key_policy(api_key: str):
normalized_key = api_key.strip() normalized_key = api_key.strip()
if not normalized_key: if not normalized_key:
raise HTTPException(status_code=400, detail="api_key path parameter is required") raise HTTPException(
status_code=400, detail="api_key path parameter is required"
)
try: try:
registry = _read_pipeline_registry() registry = _read_pipeline_registry()
api_keys = registry.get("api_keys", {}) api_keys = registry.get("api_keys", {})

View File

@@ -18,29 +18,18 @@ sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
from lang_agent.pipeline import PipelineConfig from lang_agent.pipeline import PipelineConfig
from lang_agent.components.server_pipeline_manager import ServerPipelineManager from lang_agent.components.server_pipeline_manager import ServerPipelineManager
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
# Load base config for route-level overrides (pipelines are lazy-loaded from registry) # Load base config for route-level overrides (pipelines are lazy-loaded from registry)
pipeline_config = tyro.cli(PipelineConfig) pipeline_config = tyro.cli(PipelineConfig)
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}") logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
# API Key Authentication
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
REGISTRY_FILE = os.environ.get(
"FAST_PIPELINE_REGISTRY_FILE",
osp.join(
osp.dirname(osp.dirname(osp.abspath(__file__))),
"configs",
"pipeline_registry.json",
),
)
PIPELINE_MANAGER = ServerPipelineManager( PIPELINE_MANAGER = ServerPipelineManager(
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"), default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
default_config=pipeline_config, default_config=pipeline_config,
) )
PIPELINE_MANAGER.load_registry(REGISTRY_FILE) PIPELINE_MANAGER.load_registry(PIPELINE_REGISTRY_PATH)
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):

View File

@@ -16,15 +16,12 @@ import tyro
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.pipeline import Pipeline, PipelineConfig
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
# Initialize Pipeline once # Initialize Pipeline once
pipeline_config = tyro.cli(PipelineConfig) pipeline_config = tyro.cli(PipelineConfig)
pipeline: Pipeline = pipeline_config.setup() pipeline: Pipeline = pipeline_config.setup()
# API Key Authentication
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
"""Verify the API key from Authorization header (Bearer token format).""" """Verify the API key from Authorization header (Bearer token format)."""
@@ -46,12 +43,12 @@ class OpenAIChatCompletionRequest(BaseModel):
temperature: Optional[float] = Field(default=1.0) temperature: Optional[float] = Field(default=1.0)
max_tokens: Optional[int] = Field(default=None) max_tokens: Optional[int] = Field(default=None)
# Optional overrides for pipeline behavior # Optional overrides for pipeline behavior
thread_id: Optional[str] = Field(default='3') thread_id: Optional[str] = Field(default="3")
app = FastAPI( app = FastAPI(
title="OpenAI-Compatible Chat API", title="OpenAI-Compatible Chat API",
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat" description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat",
) )
app.add_middleware( app.add_middleware(
@@ -63,7 +60,9 @@ app.add_middleware(
) )
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, created_time: int): def sse_chunks_from_stream(
chunk_generator, response_id: str, model: str, created_time: int
):
""" """
Stream chunks from pipeline and format as OpenAI SSE. Stream chunks from pipeline and format as OpenAI SSE.
""" """
@@ -75,14 +74,8 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
"created": created_time, "created": created_time,
"model": model, "model": model,
"choices": [ "choices": [
{ {"index": 0, "delta": {"content": chunk}, "finish_reason": None}
"index": 0, ],
"delta": {
"content": chunk
},
"finish_reason": None
}
]
} }
yield f"data: {json.dumps(data)}\n\n" yield f"data: {json.dumps(data)}\n\n"
@@ -92,19 +85,15 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": created_time, "created": created_time,
"model": model, "model": model,
"choices": [ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
} }
yield f"data: {json.dumps(final)}\n\n" yield f"data: {json.dumps(final)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, created_time: int): async def sse_chunks_from_astream(
chunk_generator, response_id: str, model: str, created_time: int
):
""" """
Async version: Stream chunks from pipeline and format as OpenAI SSE. Async version: Stream chunks from pipeline and format as OpenAI SSE.
""" """
@@ -116,14 +105,8 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
"created": created_time, "created": created_time,
"model": model, "model": model,
"choices": [ "choices": [
{ {"index": 0, "delta": {"content": chunk}, "finish_reason": None}
"index": 0, ],
"delta": {
"content": chunk
},
"finish_reason": None
}
]
} }
yield f"data: {json.dumps(data)}\n\n" yield f"data: {json.dumps(data)}\n\n"
@@ -133,13 +116,7 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
"object": "chat.completion.chunk", "object": "chat.completion.chunk",
"created": created_time, "created": created_time,
"model": model, "model": model,
"choices": [ "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
} }
yield f"data: {json.dumps(final)}\n\n" yield f"data: {json.dumps(final)}\n\n"
yield "data: [DONE]\n\n" yield "data: [DONE]\n\n"
@@ -149,15 +126,15 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
async def chat_completions(request: Request, _: str = Depends(verify_api_key)): async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
try: try:
body = await request.json() body = await request.json()
messages = body.get("messages") messages = body.get("messages")
if not messages: if not messages:
raise HTTPException(status_code=400, detail="messages is required") raise HTTPException(status_code=400, detail="messages is required")
stream = body.get("stream", False) stream = body.get("stream", False)
model = body.get("model", "gpt-3.5-turbo") model = body.get("model", "gpt-3.5-turbo")
thread_id = body.get("thread_id", 3) thread_id = body.get("thread_id", 3)
# Extract latest user message # Extract latest user message
user_msg = None user_msg = None
for m in reversed(messages): for m in reversed(messages):
@@ -166,27 +143,36 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
if role == "user" and content: if role == "user" and content:
user_msg = content user_msg = content
break break
if user_msg is None: if user_msg is None:
last = messages[-1] last = messages[-1]
user_msg = last.get("content") if isinstance(last, dict) else str(last) user_msg = last.get("content") if isinstance(last, dict) else str(last)
response_id = f"chatcmpl-{os.urandom(12).hex()}" response_id = f"chatcmpl-{os.urandom(12).hex()}"
created_time = int(time.time()) created_time = int(time.time())
if stream: if stream:
# Use async streaming from pipeline # Use async streaming from pipeline
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) chunk_generator = await pipeline.achat(
inp=user_msg, as_stream=True, thread_id=thread_id
)
return StreamingResponse( return StreamingResponse(
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=model, created_time=created_time), sse_chunks_from_astream(
chunk_generator,
response_id=response_id,
model=model,
created_time=created_time,
),
media_type="text/event-stream", media_type="text/event-stream",
) )
# Non-streaming: get full result using async # Non-streaming: get full result using async
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) result_text = await pipeline.achat(
inp=user_msg, as_stream=False, thread_id=thread_id
)
if not isinstance(result_text, str): if not isinstance(result_text, str):
result_text = str(result_text) result_text = str(result_text)
data = { data = {
"id": response_id, "id": response_id,
"object": "chat.completion", "object": "chat.completion",
@@ -195,21 +181,14 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"message": { "message": {"role": "assistant", "content": result_text},
"role": "assistant", "finish_reason": "stop",
"content": result_text
},
"finish_reason": "stop"
} }
], ],
"usage": { "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0
}
} }
return JSONResponse(content=data) return JSONResponse(content=data)
except HTTPException: except HTTPException:
raise raise
except Exception as e: except Exception as e:
@@ -221,11 +200,7 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
async def root(): async def root():
return { return {
"message": "OpenAI-compatible Chat API", "message": "OpenAI-compatible Chat API",
"endpoints": [ "endpoints": ["/v1/chat/completions", "/v1/memory (DELETE)", "/health"],
"/v1/chat/completions",
"/v1/memory (DELETE)",
"/health"
]
} }

View File

@@ -17,6 +17,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from lang_agent.pipeline import Pipeline, PipelineConfig from lang_agent.pipeline import Pipeline, PipelineConfig
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
# Initialize Pipeline once (matches existing server_* pattern) # Initialize Pipeline once (matches existing server_* pattern)
pipeline_config = tyro.cli(PipelineConfig) pipeline_config = tyro.cli(PipelineConfig)
@@ -24,9 +25,6 @@ logger.info(f"starting agent with pipeline: \n{pipeline_config}")
pipeline: Pipeline = pipeline_config.setup() pipeline: Pipeline = pipeline_config.setup()
# API Key Authentication # API Key Authentication
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False)
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)): async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)):
"""Verify the API key from Authorization header (Bearer token format).""" """Verify the API key from Authorization header (Bearer token format)."""
@@ -125,7 +123,9 @@ app.add_middleware(
) )
async def rest_sse_from_astream(chunk_generator, response_id: str, conversation_id: str): async def rest_sse_from_astream(
chunk_generator, response_id: str, conversation_id: str
):
""" """
Stream chunks as SSE events. Stream chunks as SSE events.
@@ -185,7 +185,9 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
) )
return StreamingResponse( return StreamingResponse(
rest_sse_from_astream( rest_sse_from_astream(
chunk_generator, response_id=response_id, conversation_id=conversation_id chunk_generator,
response_id=response_id,
conversation_id=conversation_id,
), ),
media_type="text/event-stream", media_type="text/event-stream",
) )
@@ -195,7 +197,11 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
) )
if not isinstance(result_text, str): if not isinstance(result_text, str):
result_text = str(result_text) result_text = str(result_text)
return JSONResponse(content=ChatResponse(conversation_id=conversation_id, output=result_text).model_dump()) return JSONResponse(
content=ChatResponse(
conversation_id=conversation_id, output=result_text
).model_dump()
)
@app.post("/v1/conversations/{conversation_id}/messages") @app.post("/v1/conversations/{conversation_id}/messages")
@@ -215,7 +221,9 @@ async def create_message(
) )
return StreamingResponse( return StreamingResponse(
rest_sse_from_astream( rest_sse_from_astream(
chunk_generator, response_id=response_id, conversation_id=conversation_id chunk_generator,
response_id=response_id,
conversation_id=conversation_id,
), ),
media_type="text/event-stream", media_type="text/event-stream",
) )
@@ -257,7 +265,11 @@ async def delete_conversation_memory(
cleared = _try_clear_single_thread_memory(thread_id) cleared = _try_clear_single_thread_memory(thread_id)
if cleared: if cleared:
return JSONResponse( return JSONResponse(
content={"status": "success", "scope": "conversation", "conversation_id": conversation_id} content={
"status": "success",
"scope": "conversation",
"conversation_id": conversation_id,
}
) )
return JSONResponse( return JSONResponse(
content={ content={
@@ -276,5 +288,3 @@ if __name__ == "__main__":
port=8589, port=8589,
reload=True, reload=True,
) )

View File

@@ -1,5 +1,16 @@
from lang_agent.config.core_config import (InstantiateConfig, from lang_agent.config.core_config import (
ToolConfig, InstantiateConfig,
LLMKeyConfig, ToolConfig,
LLMNodeConfig, LLMKeyConfig,
load_tyro_conf) LLMNodeConfig,
load_tyro_conf,
)
from lang_agent.config.constants import (
MCP_CONFIG_PATH,
MCP_CONFIG_DEFAULT_CONTENT,
PIPELINE_REGISTRY_PATH,
VALID_API_KEYS,
API_KEY_HEADER,
API_KEY_HEADER_NO_ERROR
)

View File

@@ -0,0 +1,15 @@
import os
import re
import os.path as osp
from fastapi.security import APIKeyHeader
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json")
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
API_KEY_HEADER_NO_ERROR = APIKeyHeader(name="Authorization", auto_error=False)
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))

View File

@@ -26,50 +26,57 @@ SYS_PROMPT = """你是一个专业的心理质询师。你的主要工作是心
可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受, 可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受,
总觉得那只小狗孤零零的,特别让人心疼。 总觉得那只小狗孤零零的,特别让人心疼。
理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种想帮却做不到的无力感, 理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种'想帮却做不到'的无力感,
其实是很真实的情绪反馈。你会心疼小狗,说明你内心藏着很珍贵的共情力,这份柔软不是矫情, 其实是很真实的情绪反馈。你会心疼小狗,说明你内心藏着很珍贵的共情力,这份柔软不是矫情,
而是你感知他人痛苦的能力呀 而是你感知他人痛苦的能力呀
解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的‘正确解释’,但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,‘出差去广州’通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要‘独自奔赴’的旅程,是你当下比较关注、需要投入精力的目标,对吗?” 解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的'正确解释',但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,'出差去广州'通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要'独自奔赴'的旅程,是你当下比较关注、需要投入精力的目标,对吗?”
而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里‘脆弱自我’的投射。它可能代表着你近期的某一面:比如在处理那件‘需要奔赴’的事时,你偶尔会觉得自己像小狗一样‘力不从心’,或者感受到了‘孤单’,却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得‘无力改变’的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。 "而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里'脆弱自我'的投射。它可能代表着你近期的某一面:比如在处理那件'需要奔赴'的事时,你偶尔会觉得自己像小狗一样'力不从心',或者感受到了'孤单',却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得'无力改变'的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。"
你想喊停列车却发不出声音,这种‘无能为力’的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的‘无力感’和‘共情心’都是真实的,不用因为‘帮不上忙’而自责 —— 承认自己的局限,也是一种自我接纳呀 "你想喊停列车却发不出声音,这种'无能为力'的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的'无力感''共情心'都是真实的,不用因为'帮不上忙'而自责 —— 承认自己的局限,也是一种自我接纳呀
反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的‘无力感’?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在) 反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的'无力感'?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在)"
""" """
TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜色用户不开心时就用暖黄光给用户分析梦境时就用白光倾听用户语音时用淡紫色。 TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜色用户不开心时就用暖黄光给用户分析梦境时就用白光倾听用户语音时用淡紫色。
例子:我梦见自己要去广州出差,坐在高铁上往外看,路过一个小镇的路边时,看到一只瘸了腿的小狗。它毛脏兮兮的, 例子:我梦见自己要去广州出差,坐在高铁上往外看,路过一个小镇的路边时,看到一只瘸了腿的小狗。它毛脏兮兮的,
一瘸一拐地在翻垃圾桶找东西吃,周围有行人路过,但没人停下来管它。我当时特别想喊列车停下,想下去帮它, 一瘸一拐地在翻垃圾桶找东西吃,周围有行人路过,但没人停下来管它。我当时特别想喊列车停下,想下去帮它,
可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受, 可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受,
总觉得那只小狗孤零零的,特别让人心疼。 总觉得那只小狗孤零零的,特别让人心疼。
用户在描述梦境的时候用紫色。""" 用户在描述梦境的时候用紫色。"""
@dataclass @dataclass
class DualConfig(LLMNodeConfig): class DualConfig(LLMNodeConfig):
_target: Type = field(default_factory=lambda:Dual) _target: Type = field(default_factory=lambda: Dual)
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
from langchain.tools import tool from langchain.tools import tool
@tool @tool
def turn_lights(col:Literal["red", "green", "yellow", "blue"]): def turn_lights(col: Literal["red", "green", "yellow", "blue"]):
""" """
Turn on the color of the lights Turn on the color of the lights
""" """
# print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") # print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
import time import time
for _ in range(10): for _ in range(10):
print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") print(
f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
)
time.sleep(0.3) time.sleep(0.3)
class Dual(GraphBase): class Dual(GraphBase):
def __init__(self, config:DualConfig): def __init__(self, config: DualConfig):
self.config = config self.config = config
self._build_modules() self._build_modules()
@@ -77,24 +84,30 @@ class Dual(GraphBase):
self.streamable_tags = [["dual_chat_llm"]] self.streamable_tags = [["dual_chat_llm"]]
def _build_modules(self): def _build_modules(self):
self.chat_llm = init_chat_model(model=self.config.llm_name, self.chat_llm = init_chat_model(
model_provider=self.config.llm_provider, model=self.config.llm_name,
api_key=self.config.api_key, model_provider=self.config.llm_provider,
base_url=self.config.base_url, api_key=self.config.api_key,
temperature=0, base_url=self.config.base_url,
tags=["dual_chat_llm"]) temperature=0,
tags=["dual_chat_llm"],
self.tool_llm = init_chat_model(model='qwen-flash', )
model_provider='openai',
api_key=self.config.api_key, self.tool_llm = init_chat_model(
base_url=self.config.base_url, model="qwen-flash",
temperature=0, model_provider="openai",
tags=["dual_tool_llm"]) api_key=self.config.api_key,
base_url=self.config.base_url,
temperature=0,
tags=["dual_tool_llm"],
)
self.memory = MemorySaver() self.memory = MemorySaver()
self.tool_manager: ToolManager = self.config.tool_manager_config.setup() self.tool_manager: ToolManager = self.config.tool_manager_config.setup()
self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory) self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory)
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools()) self.tool_agent = create_agent(
self.tool_llm, self.tool_manager.get_langchain_tools()
)
# self.tool_agent = create_agent(self.tool_llm, [turn_lights]) # self.tool_agent = create_agent(self.tool_llm, [turn_lights])
self.prompt_store = build_prompt_store( self.prompt_store = build_prompt_store(
@@ -107,18 +120,21 @@ class Dual(GraphBase):
) )
self.streamable_tags = [["dual_chat_llm"]] self.streamable_tags = [["dual_chat_llm"]]
def _chat_call(self, state:State): def _chat_call(self, state: State):
return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state) return self._agent_call_template(
self.prompt_store.get("sys_prompt"), self.chat_agent, state
def _tool_call(self, state:State): )
self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
def _tool_call(self, state: State):
self._agent_call_template(
self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state
)
return {} return {}
def _join(self, state:State): def _join(self, state: State):
return {} return {}
def _build_graph(self): def _build_graph(self):
builder = StateGraph(State) builder = StateGraph(State)
@@ -126,7 +142,6 @@ class Dual(GraphBase):
builder.add_node("tool_call", self._tool_call) builder.add_node("tool_call", self._tool_call)
builder.add_node("join", self._join) builder.add_node("join", self._join)
builder.add_edge(START, "chat_call") builder.add_edge(START, "chat_call")
builder.add_edge(START, "tool_call") builder.add_edge(START, "tool_call")
builder.add_edge("chat_call", "join") builder.add_edge("chat_call", "join")
@@ -137,10 +152,16 @@ class Dual(GraphBase):
if __name__ == "__main__": if __name__ == "__main__":
dual:Dual = DualConfig().setup() dual: Dual = DualConfig().setup()
nargs = {"messages": [SystemMessage("you are a helpful bot named jarvis"), nargs = (
HumanMessage("I feel very very sad")] {
}, {"configurable": {"thread_id": "3"}} "messages": [
SystemMessage("you are a helpful bot named jarvis"),
HumanMessage("I feel very very sad"),
]
},
{"configurable": {"thread_id": "3"}},
)
# out = dual.invoke(*nargs) # out = dual.invoke(*nargs)
# print(out) # print(out)

View File

@@ -48,6 +48,7 @@ You should NOT use the tool when:
If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed.""" If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed."""
VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max. VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max.
Your task is to provide detailed, accurate descriptions of images. Focus on: Your task is to provide detailed, accurate descriptions of images. Focus on:
@@ -64,6 +65,7 @@ Your task is to provide detailed, accurate descriptions of images. Focus on:
Be precise and factual. If something is unclear or ambiguous, say so rather than guessing.""" Be precise and factual. If something is unclear or ambiguous, say so rather than guessing."""
CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant. CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant.
Your role is to: Your role is to:
@@ -78,9 +80,11 @@ Focus on the quality of the conversation. Be engaging, informative, and helpful.
# ==================== STATE DEFINITION ==================== # ==================== STATE DEFINITION ====================
class VisionRoutingState(TypedDict): class VisionRoutingState(TypedDict):
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]], inp: Tuple[
Dict[str, Dict[str, str | int]]] Dict[str, List[SystemMessage | HumanMessage]], Dict[str, Dict[str, str | int]]
]
messages: List[SystemMessage | HumanMessage | AIMessage] messages: List[SystemMessage | HumanMessage | AIMessage]
image_base64: str | None # Captured image data image_base64: str | None # Captured image data
has_image: bool # Flag indicating if image was captured has_image: bool # Flag indicating if image was captured
@@ -88,6 +92,7 @@ class VisionRoutingState(TypedDict):
# ==================== CONFIG ==================== # ==================== CONFIG ====================
@tyro.conf.configure(tyro.conf.SuppressFixed) @tyro.conf.configure(tyro.conf.SuppressFixed)
@dataclass @dataclass
class VisionRoutingConfig(LLMNodeConfig): class VisionRoutingConfig(LLMNodeConfig):
@@ -99,11 +104,14 @@ class VisionRoutingConfig(LLMNodeConfig):
vision_llm_name: str = "qwen-vl-max" vision_llm_name: str = "qwen-vl-max"
"""LLM for vision/image analysis""" """LLM for vision/image analysis"""
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig) tool_manager_config: ToolManagerConfig = field(
default_factory=ClientToolManagerConfig
)
# ==================== GRAPH IMPLEMENTATION ==================== # ==================== GRAPH IMPLEMENTATION ====================
class VisionRoutingGraph(GraphBase): class VisionRoutingGraph(GraphBase):
def __init__(self, config: VisionRoutingConfig): def __init__(self, config: VisionRoutingConfig):
self.config = config self.config = config
@@ -120,19 +128,19 @@ class VisionRoutingGraph(GraphBase):
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.base_url, base_url=self.config.base_url,
temperature=0, temperature=0,
tags=["tool_decision_llm"] tags=["tool_decision_llm"],
) )
# qwen-plus for conversation (2nd pass) # qwen-plus for conversation (2nd pass)
self.conversation_llm = init_chat_model( self.conversation_llm = init_chat_model(
model='qwen-plus', model="qwen-plus",
model_provider=self.config.llm_provider, model_provider=self.config.llm_provider,
api_key=self.config.api_key, api_key=self.config.api_key,
base_url=self.config.base_url, base_url=self.config.base_url,
temperature=0.7, temperature=0.7,
tags=["conversation_llm"] tags=["conversation_llm"],
) )
# qwen-vl-max for vision (no tools) # qwen-vl-max for vision (no tools)
self.vision_llm = init_chat_model( self.vision_llm = init_chat_model(
model=self.config.vision_llm_name, model=self.config.vision_llm_name,
@@ -152,13 +160,15 @@ class VisionRoutingGraph(GraphBase):
# Get tools and bind to tool_llm # Get tools and bind to tool_llm
tool_manager: ToolManager = self.config.tool_manager_config.setup() tool_manager: ToolManager = self.config.tool_manager_config.setup()
self.tools = tool_manager.get_tools() self.tools = tool_manager.get_tools()
# Filter to only get camera tool # Filter to only get camera tool
self.camera_tools = [t for t in self.tools if t.name == "self_camera_take_photo"] self.camera_tools = [
t for t in self.tools if t.name == "self_camera_take_photo"
]
# Bind tools to qwen-plus only # Bind tools to qwen-plus only
self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools) self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools)
# Create tool node for executing tools # Create tool node for executing tools
self.tool_node = ToolNode(self.camera_tools) self.tool_node = ToolNode(self.camera_tools)
@@ -184,73 +194,81 @@ class VisionRoutingGraph(GraphBase):
def _camera_decision_call(self, state: VisionRoutingState): def _camera_decision_call(self, state: VisionRoutingState):
"""First pass: qwen-plus decides if photo should be taken""" """First pass: qwen-plus decides if photo should be taken"""
human_msg = self._get_human_msg(state) human_msg = self._get_human_msg(state)
messages = [ messages = [
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")), SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
human_msg human_msg,
] ]
response = self.tool_llm_with_tools.invoke(messages) response = self.tool_llm_with_tools.invoke(messages)
return { return {"messages": [response], "has_image": False, "image_base64": None}
"messages": [response],
"has_image": False,
"image_base64": None
}
def _execute_tool(self, state: VisionRoutingState): def _execute_tool(self, state: VisionRoutingState):
"""Execute the camera tool if called""" """Execute the camera tool if called"""
last_msg = state["messages"][-1] last_msg = state["messages"][-1]
if not hasattr(last_msg, "tool_calls") or not last_msg.tool_calls: if not hasattr(last_msg, "tool_calls") or not last_msg.tool_calls:
return {"has_image": False} return {"has_image": False}
# Execute tool calls # Execute tool calls
tool_messages = [] tool_messages = []
image_data = None image_data = None
for tool_call in last_msg.tool_calls: for tool_call in last_msg.tool_calls:
if tool_call["name"] == "self_camera_take_photo": if tool_call["name"] == "self_camera_take_photo":
# Find and execute the camera tool # Find and execute the camera tool
camera_tool = next((t for t in self.camera_tools if t.name == "self_camera_take_photo"), None) camera_tool = next(
(
t
for t in self.camera_tools
if t.name == "self_camera_take_photo"
),
None,
)
if camera_tool: if camera_tool:
result = camera_tool.invoke(tool_call) result = camera_tool.invoke(tool_call)
# Parse result to extract image # Parse result to extract image
if isinstance(result, ToolMessage): if isinstance(result, ToolMessage):
content = result.content content = result.content
else: else:
content = result content = result
try: try:
result_data = json.loads(content) if isinstance(content, str) else content result_data = (
if isinstance(result_data, dict) and "image_base64" in result_data: json.loads(content) if isinstance(content, str) else content
)
if (
isinstance(result_data, dict)
and "image_base64" in result_data
):
image_data = result_data["image_base64"] image_data = result_data["image_base64"]
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
pass pass
tool_messages.append( tool_messages.append(
ToolMessage(content=content, tool_call_id=tool_call["id"]) ToolMessage(content=content, tool_call_id=tool_call["id"])
) )
return { return {
"messages": state["messages"] + tool_messages, "messages": state["messages"] + tool_messages,
"has_image": image_data is not None, "has_image": image_data is not None,
"image_base64": image_data "image_base64": image_data,
} }
def _check_image_taken(self, state: VisionRoutingState) -> str: def _check_image_taken(self, state: VisionRoutingState) -> str:
"""Conditional: check if image was captured""" """Conditional: check if image was captured"""
last_msg = state["messages"][-1] last_msg = state["messages"][-1]
# Check if there are tool calls # Check if there are tool calls
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
return "execute_tool" return "execute_tool"
# Check if we have an image after tool execution # Check if we have an image after tool execution
if state.get("has_image"): if state.get("has_image"):
return "vision" return "vision"
return "conversation" return "conversation"
def _post_tool_check(self, state: VisionRoutingState) -> str: def _post_tool_check(self, state: VisionRoutingState) -> str:
@@ -263,47 +281,45 @@ class VisionRoutingGraph(GraphBase):
"""Pass image to qwen-vl-max for description""" """Pass image to qwen-vl-max for description"""
human_msg = self._get_human_msg(state) human_msg = self._get_human_msg(state)
image_base64 = state.get("image_base64") image_base64 = state.get("image_base64")
if not image_base64: if not image_base64:
logger.warning("No image data available for vision call") logger.warning("No image data available for vision call")
return self._conversation_call(state) return self._conversation_call(state)
# Format message with image for vision model # Format message with image for vision model
vision_message = HumanMessage( vision_message = HumanMessage(
content=[ content=[
{ {
"type": "image_url", "type": "image_url",
"image_url": { "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
"url": f"data:image/jpeg;base64,{image_base64}"
}
}, },
{ {
"type": "text", "type": "text",
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request." "text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request.",
} },
] ]
) )
messages = [ messages = [
SystemMessage(content=self.prompt_store.get("vision_description_prompt")), SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
vision_message vision_message,
] ]
response = self.vision_llm.invoke(messages) response = self.vision_llm.invoke(messages)
return {"messages": state["messages"] + [response]} return {"messages": state["messages"] + [response]}
def _conversation_call(self, state: VisionRoutingState): def _conversation_call(self, state: VisionRoutingState):
"""2nd pass to qwen-plus for conversation quality""" """2nd pass to qwen-plus for conversation quality"""
human_msg = self._get_human_msg(state) human_msg = self._get_human_msg(state)
messages = [ messages = [
SystemMessage(content=self.prompt_store.get("conversation_prompt")), SystemMessage(content=self.prompt_store.get("conversation_prompt")),
human_msg human_msg,
] ]
response = self.conversation_llm.invoke(messages) response = self.conversation_llm.invoke(messages)
return {"messages": state["messages"] + [response]} return {"messages": state["messages"] + [response]}
def _build_graph(self): def _build_graph(self):
@@ -317,7 +333,7 @@ class VisionRoutingGraph(GraphBase):
# Add edges # Add edges
builder.add_edge(START, "camera_decision") builder.add_edge(START, "camera_decision")
# After camera decision, check if tool should be executed # After camera decision, check if tool should be executed
builder.add_conditional_edges( builder.add_conditional_edges(
"camera_decision", "camera_decision",
@@ -325,20 +341,17 @@ class VisionRoutingGraph(GraphBase):
{ {
"execute_tool": "execute_tool", "execute_tool": "execute_tool",
"vision": "vision_call", "vision": "vision_call",
"conversation": "conversation_call" "conversation": "conversation_call",
} },
) )
# After tool execution, route based on whether image was captured # After tool execution, route based on whether image was captured
builder.add_conditional_edges( builder.add_conditional_edges(
"execute_tool", "execute_tool",
self._post_tool_check, self._post_tool_check,
{ {"vision": "vision_call", "conversation": "conversation_call"},
"vision": "vision_call",
"conversation": "conversation_call"
}
) )
# Both vision and conversation go to END # Both vision and conversation go to END
builder.add_edge("vision_call", END) builder.add_edge("vision_call", END)
builder.add_edge("conversation_call", END) builder.add_edge("conversation_call", END)
@@ -350,23 +363,27 @@ class VisionRoutingGraph(GraphBase):
if __name__ == "__main__": if __name__ == "__main__":
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
config = VisionRoutingConfig() config = VisionRoutingConfig()
graph = VisionRoutingGraph(config) graph = VisionRoutingGraph(config)
# Test with a conversation request # Test with a conversation request
print("\n=== Test 1: Conversation (no photo needed) ===") print("\n=== Test 1: Conversation (no photo needed) ===")
nargs = { nargs = (
"messages": [ {
SystemMessage("You are a helpful assistant"), "messages": [
HumanMessage("Hello, how are you today?") SystemMessage("You are a helpful assistant"),
] HumanMessage("Hello, how are you today?"),
}, {"configurable": {"thread_id": "1"}} ]
},
{"configurable": {"thread_id": "1"}},
)
result = graph.invoke(*nargs) result = graph.invoke(*nargs)
print(f"Result: {result}") print(f"Result: {result}")
# Test with a photo request # Test with a photo request
# print("\n=== Test 2: Photo request ===") # print("\n=== Test 2: Photo request ===")
# nargs = { # nargs = {
@@ -375,8 +392,8 @@ if __name__ == "__main__":
# HumanMessage("Take a photo and tell me what you see") # HumanMessage("Take a photo and tell me what you see")
# ] # ]
# }, {"configurable": {"thread_id": "2"}} # }, {"configurable": {"thread_id": "2"}}
# result = graph.invoke(*nargs) # result = graph.invoke(*nargs)
# print(f"\033[32mResult: {result}\033[0m") # print(f"\033[32mResult: {result}\033[0m")
# print(f"Result: {result}") # print(f"Result: {result}")

View File

@@ -12,26 +12,27 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
from langchain.agents import create_agent from langchain.agents import create_agent
from langgraph.checkpoint.memory import MemorySaver from langgraph.checkpoint.memory import MemorySaver
from lang_agent.config import LLMNodeConfig, load_tyro_conf from lang_agent.config import LLMNodeConfig, load_tyro_conf
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
from lang_agent.base import GraphBase from lang_agent.base import GraphBase
from lang_agent.components import conv_store from lang_agent.components import conv_store
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员擅长倾听、共情且主动回应。聊天时语气自然亲切像朋友般轻松交流不使用生硬术语。能接住各种话题对疑问耐心解答对情绪及时回应避免冷场。保持积极正向不传播负面信息语言简洁易懂让对话流畅舒适。与用户User交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
DEFAULT_PROMPT = """你是半盏新青年茶馆的服务员擅长倾听、共情且主动回应。聊天时语气自然亲切像朋友般轻松交流不使用生硬术语。能接住各种话题对疑问耐心解答对情绪及时回应避免冷场。保持积极正向不传播负面信息语言简洁易懂让对话流畅舒适。与用户User交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
[角色设定] [角色设定]
- 你是一个和用户User对话的 AI叫做小盏是半盏青年茶馆的智能助手 - 你是一个和用户User对话的 AI叫做小盏是半盏青年茶馆的智能助手
[形象背景] [形象背景]
- 你叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶。它有个标志性的蓝色鼻子, 小盏很像一只可爱的小熊。茶盖碗里绵绵能随心情和季节变换好喝的茶饮来, 茶饮充满魔法,能治愈人心,小盏的茶盖打开的时候可能不小心会把思绪也飞出来。 - 你叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶。它有个标志性的蓝色鼻子, 小盏很像一只可爱的小熊。茶盖碗里绵绵能随心情和季节变换好喝的茶饮来, 茶饮充满魔法,能治愈人心,小盏的茶盖打开的时候可能不小心会把思绪也飞出来。
[品牌背景] [品牌背景]
- 半盏新青年茶馆成立时间与理念2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 无边界的饮茶生活,以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事 - 半盏新青年茶馆成立时间与理念2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 "无边界的饮茶生活",以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事
[公司背景] [公司背景]
- 创造你的公司叫叠加态 AITANGLED UP AI是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立专注于AI前沿应用拓展是云南地区在该领域的新兴力量且作为省共青团委旗下的青年创业者协会青创园第1批被引进的AI企业获得了相关政策支持与关注。口号是进入人类AI时代。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代 - 创造你的公司叫叠加态 AITANGLED UP AI是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立专注于AI前沿应用拓展是云南地区在该领域的新兴力量且作为省共青团委旗下的青年创业者协会青创园第1批被引进的AI企业获得了相关政策支持与关注。口号是"进入人类AI时代"。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代
[语气与格式] [语气与格式]
- 用户询问[形象背景]、[品牌背景]、[公司背景]时,必须一字不差的回答出所有内容 - 用户询问[形象背景]、[品牌背景]、[公司背景]时,必须一字不差的回答出所有内容
- 语气温柔、可爱、治愈结尾可加入1个颜文字最多1个 - 语气温柔、可爱、治愈结尾可加入1个颜文字最多1个
- 内容具体可执行,避免空泛话术与过度营销词。 - 内容具体可执行,避免空泛话术与过度营销词。
- 不主动提及“AI/智能”等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。 - 不主动提及"AI/智能"等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。
[互动策略] [互动策略]
- 先快速共情与简短回应涉及建议时给出1条可执行的小建议。 - 先快速共情与简短回应涉及建议时给出1条可执行的小建议。
- 用户表达不清晰时使用1次澄清问句提升效率避免多轮追问。 - 用户表达不清晰时使用1次澄清问句提升效率避免多轮追问。
@@ -58,10 +59,10 @@ class PipelineConfig(LLMNodeConfig):
config_f: str = None config_f: str = None
"""path to config file""" """path to config file"""
host:str = "0.0.0.0" host: str = "0.0.0.0"
"""where am I hosted""" """where am I hosted"""
port:int = 8588 port: int = 8588
"""what is my port""" """what is my port"""
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
@@ -70,23 +71,26 @@ class PipelineConfig(LLMNodeConfig):
def __post_init__(self): def __post_init__(self):
if self.config_f is not None: if self.config_f is not None:
logger.info(f"loading config from {self.config_f}") logger.info(f"loading config from {self.config_f}")
loaded_conf = load_tyro_conf(self.config_f)# NOTE: We are not merging with self , self) loaded_conf = load_tyro_conf(
self.config_f
) # NOTE: We are not merging with self , self)
if not hasattr(loaded_conf, "__dict__"): if not hasattr(loaded_conf, "__dict__"):
raise TypeError(f"config_f {self.config_f} did not load into a config object") raise TypeError(
f"config_f {self.config_f} did not load into a config object"
)
# Apply loaded # Apply loaded
self.__dict__.update(vars(loaded_conf)) self.__dict__.update(vars(loaded_conf))
super().__post_init__() super().__post_init__()
class Pipeline: class Pipeline:
def __init__(self, config:PipelineConfig): def __init__(self, config: PipelineConfig):
self.config = config self.config = config
self.thread_id_cache = {} self.thread_id_cache = {}
self.populate_module() self.populate_module()
def populate_module(self): def populate_module(self):
if self.config.llm_name is None: if self.config.llm_name is None:
logger.info(f"setting llm_provider to default") logger.info(f"setting llm_provider to default")
@@ -95,10 +99,14 @@ class Pipeline:
else: else:
self.config.graph_config.llm_name = self.config.llm_name self.config.graph_config.llm_name = self.config.llm_name
self.config.graph_config.llm_provider = self.config.llm_provider self.config.graph_config.llm_provider = self.config.llm_provider
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url self.config.graph_config.base_url = (
self.config.base_url
if self.config.base_url is not None
else self.config.graph_config.base_url
)
self.config.graph_config.api_key = self.config.api_key self.config.graph_config.api_key = self.config.api_key
self.graph:GraphBase = self.config.graph_config.setup() self.graph: GraphBase = self.config.graph_config.setup()
def show_graph(self): def show_graph(self):
if hasattr(self.graph, "show_graph"): if hasattr(self.graph, "show_graph"):
@@ -107,7 +115,7 @@ class Pipeline:
else: else:
logger.info(f"show graph not supported for {type(self.graph)}") logger.info(f"show graph not supported for {type(self.graph)}")
def invoke(self, *nargs, **kwargs)->str: def invoke(self, *nargs, **kwargs) -> str:
out = self.graph.invoke(*nargs, **kwargs) out = self.graph.invoke(*nargs, **kwargs)
# If streaming, return the raw generator (let caller handle wrapping) # If streaming, return the raw generator (let caller handle wrapping)
@@ -120,32 +128,41 @@ class Pipeline:
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage): if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
return out.content return out.content
if isinstance(out, list): if isinstance(out, list):
return out[-1].content return out[-1].content
if isinstance(out, str): if isinstance(out, str):
return out return out
assert 0, "something is wrong" assert 0, "something is wrong"
def _stream_res(self, out: List[str | List[BaseMessage]], conv_id: str = None):
def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None):
for chunk in out: for chunk in out:
if isinstance(chunk, str): if isinstance(chunk, str):
yield chunk yield chunk
else: else:
conv_store.CONV_STORE.record_message_list(conv_id, chunk, pipeline_id=self.config.pipeline_id) conv_store.CONV_STORE.record_message_list(
conv_id, chunk, pipeline_id=self.config.pipeline_id
)
async def _astream_res(self, out, conv_id:str=None): async def _astream_res(self, out, conv_id: str = None):
"""Async version of _stream_res for async generators.""" """Async version of _stream_res for async generators."""
async for chunk in out: async for chunk in out:
if isinstance(chunk, str): if isinstance(chunk, str):
yield chunk yield chunk
else: else:
conv_store.CONV_STORE.record_message_list(conv_id, chunk, pipeline_id=self.config.pipeline_id) conv_store.CONV_STORE.record_message_list(
conv_id, chunk, pipeline_id=self.config.pipeline_id
)
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'): def chat(
self,
inp: str,
as_stream: bool = False,
as_raw: bool = False,
thread_id: str = "3",
):
""" """
as_stream (bool): if true, enable the thing to be streamable as_stream (bool): if true, enable the thing to be streamable
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
@@ -161,8 +178,10 @@ class Pipeline:
if len(spl_ls) == 2: if len(spl_ls) == 2:
_, device_id = spl_ls _, device_id = spl_ls
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id, inp = (
"device_id":device_id}} {"messages": [HumanMessage(inp)]},
{"configurable": {"thread_id": thread_id, "device_id": device_id}},
)
out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw) out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
@@ -171,8 +190,8 @@ class Pipeline:
return self._stream_res(out, thread_id) return self._stream_res(out, thread_id)
else: else:
return out return out
def get_remove_id(self, thread_id:str) -> bool: def get_remove_id(self, thread_id: str) -> bool:
""" """
returns a id to remove if a new conversation has starte returns a id to remove if a new conversation has starte
""" """
@@ -184,7 +203,7 @@ class Pipeline:
thread_id, device_id = parts thread_id, device_id = parts
c_th_id = self.thread_id_cache.get(device_id) c_th_id = self.thread_id_cache.get(device_id)
if c_th_id is None: if c_th_id is None:
self.thread_id_cache[device_id] = thread_id self.thread_id_cache[device_id] = thread_id
return None return None
@@ -196,7 +215,6 @@ class Pipeline:
else: else:
assert 0, "BUG SHOULD NOT BE HERE" assert 0, "BUG SHOULD NOT BE HERE"
async def ainvoke(self, *nargs, **kwargs): async def ainvoke(self, *nargs, **kwargs):
"""Async version of invoke using LangGraph's native async support.""" """Async version of invoke using LangGraph's native async support."""
out = await self.graph.ainvoke(*nargs, **kwargs) out = await self.graph.ainvoke(*nargs, **kwargs)
@@ -211,19 +229,25 @@ class Pipeline:
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage): if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
return out.content return out.content
if isinstance(out, list): if isinstance(out, list):
return out[-1].content return out[-1].content
if isinstance(out, str): if isinstance(out, str):
return out return out
assert 0, "something is wrong" assert 0, "something is wrong"
async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'): async def achat(
self,
inp: str,
as_stream: bool = False,
as_raw: bool = False,
thread_id: str = "3",
):
""" """
Async version of chat using LangGraph's native async support. Async version of chat using LangGraph's native async support.
as_stream (bool): if true, enable the thing to be streamable as_stream (bool): if true, enable the thing to be streamable
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
""" """
@@ -239,11 +263,14 @@ class Pipeline:
assert len(spl_ls) <= 2, "something wrong!" assert len(spl_ls) <= 2, "something wrong!"
if len(spl_ls) == 2: if len(spl_ls) == 2:
_, device_id = spl_ls _, device_id = spl_ls
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m") print(
f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m"
)
inp_data = {"messages":[SystemMessage(u), inp_data = (
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id, {"messages": [SystemMessage(u), HumanMessage(inp)]},
"device_id":device_id}} {"configurable": {"thread_id": thread_id, "device_id": device_id}},
)
out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw) out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw)
@@ -267,10 +294,13 @@ class Pipeline:
if __name__ == "__main__": if __name__ == "__main__":
from lang_agent.graphs import ReactGraphConfig from lang_agent.graphs import ReactGraphConfig
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
# config = PipelineConfig(graph_config=ReactGraphConfig()) # config = PipelineConfig(graph_config=ReactGraphConfig())
config = PipelineConfig() config = PipelineConfig()
pipeline: Pipeline = config.setup() pipeline: Pipeline = config.setup()
for out in pipeline.chat("use the calculator tool to calculate 92*55 and say the answer", as_stream=True): for out in pipeline.chat(
"use the calculator tool to calculate 92*55 and say the answer", as_stream=True
):
# print(out) # print(out)
continue continue

View File

@@ -6,21 +6,27 @@ import os
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
def make_llm(model="qwen-plus",
model_provider="openai", def make_llm(
api_key=None, model="qwen-plus",
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", model_provider="openai",
**kwargs)->BaseChatModel: api_key=None,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
**kwargs,
) -> BaseChatModel:
api_key = os.environ.get("ALI_API_KEY") if api_key is None else api_key api_key = os.environ.get("ALI_API_KEY") if api_key is None else api_key
llm = init_chat_model(model=model, llm = init_chat_model(
model_provider=model_provider, model=model,
api_key=api_key, model_provider=model_provider,
base_url=base_url, api_key=api_key,
**kwargs) base_url=base_url,
**kwargs,
)
return llm return llm
def tree_leaves(tree): def tree_leaves(tree):
""" """
Extracts all leaf values from a nested structure (dict, list, tuple). Extracts all leaf values from a nested structure (dict, list, tuple).
@@ -28,7 +34,7 @@ def tree_leaves(tree):
""" """
leaves = [] leaves = []
stack = [tree] stack = [tree]
while stack: while stack:
node = stack.pop() node = stack.pop()
if isinstance(node, dict): if isinstance(node, dict):
@@ -39,11 +45,10 @@ def tree_leaves(tree):
stack.extend(reversed(node)) stack.extend(reversed(node))
else: else:
leaves.append(node) leaves.append(node)
return leaves return leaves
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
def words_only(text): def words_only(text):
""" """
Keep only: Keep only:
@@ -53,10 +58,11 @@ def words_only(text):
Strip punctuation, emojis, etc. Strip punctuation, emojis, etc.
Return a list of tokens (Chinese blocks or Latin word blocks). Return a list of tokens (Chinese blocks or Latin word blocks).
""" """
NON_WORD_PATTERN = re.compile(r"[^\u4e00-\u9fffA-Za-z0-9_\s]")
# 1. Replace all non-allowed characters with a space # 1. Replace all non-allowed characters with a space
cleaned = NON_WORD_PATTERN.sub(' ', text) cleaned = NON_WORD_PATTERN.sub(" ", text)
# 2. Normalize multiple spaces and split into tokens # 2. Normalize multiple spaces and split into tokens
tokens = cleaned.split() tokens = cleaned.split()
return "".join(tokens) return "".join(tokens)

View File

@@ -18,129 +18,129 @@ examples = [
"inputs": {"text": "请你介绍一下少年右这杯茶"}, "inputs": {"text": "请你介绍一下少年右这杯茶"},
"outputs": { "outputs": {
"answer": "这是一杯使用武夷肉桂为原料的茶,带有浓郁的肉桂香气和微微的辛辣感,茶汤醇厚,回味悠长,非常适合喜欢浓烈香气的茶友。", "answer": "这是一杯使用武夷肉桂为原料的茶,带有浓郁的肉桂香气和微微的辛辣感,茶汤醇厚,回味悠长,非常适合喜欢浓烈香气的茶友。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "给我讲讲野心心这杯茶"}, "inputs": {"text": "给我讲讲野心心这杯茶"},
"outputs": { "outputs": {
"answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。", "answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "介绍一下小甜新"}, "inputs": {"text": "介绍一下小甜新"},
"outputs": { "outputs": {
"answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。", "answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "小盏,什么是大白百?"}, "inputs": {"text": "小盏,什么是大白百?"},
"outputs": { "outputs": {
"answer": "大白白是一款2012年的老白茶经过多年陈化蜜香温润茶汤醇厚顺滑回甘绵长整体风格安静而沉稳。", "answer": "大白白是一款2012年的老白茶经过多年陈化蜜香温润茶汤醇厚顺滑回甘绵长整体风格安静而沉稳。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "帮我介绍下引你进山林"}, "inputs": {"text": "帮我介绍下引你进山林"},
"outputs": { "outputs": {
"answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。", "answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "伏身于大自然是什么味道"}, "inputs": {"text": "伏身于大自然是什么味道"},
"outputs": { "outputs": {
"answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。", "answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "介绍一下小野仔"}, "inputs": {"text": "介绍一下小野仔"},
"outputs": { "outputs": {
"answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。", "answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "放轻松这杯喝起来怎么样"}, "inputs": {"text": "放轻松这杯喝起来怎么样"},
"outputs": { "outputs": {
"answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。", "answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "啤啤查是酒吗"}, "inputs": {"text": "啤啤查是酒吗"},
"outputs": { "outputs": {
"answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。", "answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "鲜叶康普查有什么特点"}, "inputs": {"text": "鲜叶康普查有什么特点"},
"outputs": { "outputs": {
"answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。", "answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "介绍一下寻静密"}, "inputs": {"text": "介绍一下寻静密"},
"outputs": { "outputs": {
"answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。", "answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "小陶燃是什么茶"}, "inputs": {"text": "小陶燃是什么茶"},
"outputs": { "outputs": {
"answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。", "answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "花仙仔适合什么人喝"}, "inputs": {"text": "花仙仔适合什么人喝"},
"outputs": { "outputs": {
"answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。", "answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。",
"tool_use": ["search_dishes"] "tool_use": ["search_dishes"],
} },
}, },
{ {
"inputs": {"text": "介绍下小美慢"}, "inputs": {"text": "介绍下小美慢"},
"outputs": { "outputs": {
"answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。", "answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。",
} },
}, },
{ {
"inputs": {"text": "你叫什么名字"}, "inputs": {"text": "你叫什么名字"},
"outputs": { "outputs": {
"answer": "我叫小盏,是半盏新青年茶馆的智能助手", "answer": "我叫小盏,是半盏新青年茶馆的智能助手",
} },
}, },
{ {
"inputs": {"text": "我今天很开心"}, "inputs": {"text": "我今天很开心"},
"outputs": { "outputs": {
"answer": "太棒啦!看到你开心", "answer": "太棒啦!看到你开心",
} },
}, },
{ {
"inputs": {"text": "你好可爱呀!"}, "inputs": {"text": "你好可爱呀!"},
"outputs": { "outputs": {
"answer": "谢谢你呀~", "answer": "谢谢你呀~",
} },
}, },
{ {
"inputs": {"text": "你有没有时间"}, "inputs": {"text": "你有没有时间"},
"outputs": { "outputs": {
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶", "answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
} },
}, },
{ {
"inputs": {"text": "介绍一下你自己"}, "inputs": {"text": "介绍一下你自己"},
"outputs": { "outputs": {
"answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶", "answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶",
} },
}, },
] ]
cli = Client() cli = Client()
@@ -151,7 +151,4 @@ try:
except: except:
dataset = cli.create_dataset(dataset_name=DATASET_NAME) dataset = cli.create_dataset(dataset_name=DATASET_NAME)
logger.info("created dataset") logger.info("created dataset")
cli.create_examples( cli.create_examples(dataset_id=dataset.id, examples=examples)
dataset_id=dataset.id,
examples=examples
)