diff --git a/fastapi_server/front_apis.py b/fastapi_server/front_apis.py index 9d599cc..e6d364d 100644 --- a/fastapi_server/front_apis.py +++ b/fastapi_server/front_apis.py @@ -13,16 +13,17 @@ from pydantic import BaseModel, Field sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) from lang_agent.config.db_config_manager import DBConfigManager +from lang_agent.config.constants import ( + _PROJECT_ROOT, + MCP_CONFIG_PATH, + MCP_CONFIG_DEFAULT_CONTENT, + PIPELINE_REGISTRY_PATH, +) from lang_agent.front_api.build_server_utils import ( GRAPH_BUILD_FNCS, update_pipeline_registry, ) -_PROJECT_ROOT = osp.dirname(osp.dirname(osp.abspath(__file__))) -_MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json") -_MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n" -_PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json") - class GraphConfigUpsertRequest(BaseModel): graph_id: str @@ -206,20 +207,20 @@ def _parse_mcp_tool_keys(raw_content: str) -> List[str]: def _read_mcp_config_raw() -> str: - if not osp.exists(_MCP_CONFIG_PATH): - os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True) - with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f: - f.write(_MCP_CONFIG_DEFAULT_CONTENT) - with open(_MCP_CONFIG_PATH, "r", encoding="utf-8") as f: + if not osp.exists(MCP_CONFIG_PATH): + os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True) + with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f: + f.write(MCP_CONFIG_DEFAULT_CONTENT) + with open(MCP_CONFIG_PATH, "r", encoding="utf-8") as f: return f.read() def _read_pipeline_registry() -> Dict[str, Any]: - if not osp.exists(_PIPELINE_REGISTRY_PATH): - os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True) - with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f: + if not osp.exists(PIPELINE_REGISTRY_PATH): + os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True) + with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f: json.dump({"pipelines": {}, "api_keys": {}}, f, indent=2) - with open(_PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f: + with open(PIPELINE_REGISTRY_PATH, "r", encoding="utf-8") as f: registry = json.load(f) pipelines = registry.get("pipelines") if not isinstance(pipelines, dict): @@ -233,8 +234,8 @@ def _read_pipeline_registry() -> Dict[str, Any]: def _write_pipeline_registry(registry: Dict[str, Any]) -> None: - os.makedirs(osp.dirname(_PIPELINE_REGISTRY_PATH), exist_ok=True) - with open(_PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f: + os.makedirs(osp.dirname(PIPELINE_REGISTRY_PATH), exist_ok=True) + with open(PIPELINE_REGISTRY_PATH, "w", encoding="utf-8") as f: json.dump(registry, f, indent=2) f.write("\n") @@ -433,7 +434,7 @@ async def get_mcp_tool_config(): except Exception as e: raise HTTPException(status_code=500, detail=str(e)) return McpConfigReadResponse( - path=_MCP_CONFIG_PATH, + path=MCP_CONFIG_PATH, raw_content=raw_content, tool_keys=tool_keys, ) @@ -443,8 +444,8 @@ async def get_mcp_tool_config(): async def update_mcp_tool_config(body: McpConfigUpdateRequest): try: tool_keys = _parse_mcp_tool_keys(body.raw_content) - os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True) - with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f: + os.makedirs(osp.dirname(MCP_CONFIG_PATH), exist_ok=True) + with open(MCP_CONFIG_PATH, "w", encoding="utf-8") as f: # Keep user formatting/comments as entered while ensuring trailing newline. f.write(body.raw_content.rstrip() + "\n") except ValueError as e: @@ -453,7 +454,7 @@ async def update_mcp_tool_config(body: McpConfigUpdateRequest): raise HTTPException(status_code=500, detail=str(e)) return McpConfigUpdateResponse( status="updated", - path=_MCP_CONFIG_PATH, + path=MCP_CONFIG_PATH, tool_keys=tool_keys, ) @@ -528,7 +529,7 @@ async def create_pipeline(body: PipelineCreateRequest): config_file=config_file, llm_name=body.llm_name, enabled=body.enabled, - registry_f=_PIPELINE_REGISTRY_PATH, + registry_f=PIPELINE_REGISTRY_PATH, ) except Exception as e: raise HTTPException(status_code=500, detail=f"Failed to register pipeline: {e}") @@ -543,7 +544,8 @@ async def create_pipeline(body: PipelineCreateRequest): normalized = _normalize_pipeline_spec(pipeline_id, pipeline_spec) except Exception as e: raise HTTPException( - status_code=500, detail=f"Failed to read pipeline registry after update: {e}" + status_code=500, + detail=f"Failed to read pipeline registry after update: {e}", ) return PipelineCreateResponse( @@ -554,7 +556,7 @@ async def create_pipeline(body: PipelineCreateRequest): llm_name=normalized.llm_name, enabled=normalized.enabled, reload_required=True, - registry_path=_PIPELINE_REGISTRY_PATH, + registry_path=PIPELINE_REGISTRY_PATH, ) @@ -608,7 +610,9 @@ async def list_pipeline_api_keys(): async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertRequest): normalized_key = api_key.strip() if not normalized_key: - raise HTTPException(status_code=400, detail="api_key path parameter is required") + raise HTTPException( + status_code=400, detail="api_key path parameter is required" + ) try: registry = _read_pipeline_registry() pipelines = registry.get("pipelines", {}) @@ -662,7 +666,9 @@ async def upsert_pipeline_api_key_policy(api_key: str, body: ApiKeyPolicyUpsertR async def delete_pipeline_api_key_policy(api_key: str): normalized_key = api_key.strip() if not normalized_key: - raise HTTPException(status_code=400, detail="api_key path parameter is required") + raise HTTPException( + status_code=400, detail="api_key path parameter is required" + ) try: registry = _read_pipeline_registry() api_keys = registry.get("api_keys", {}) diff --git a/fastapi_server/server_dashscope.py b/fastapi_server/server_dashscope.py index 61eedff..7e05042 100644 --- a/fastapi_server/server_dashscope.py +++ b/fastapi_server/server_dashscope.py @@ -18,29 +18,18 @@ sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) from lang_agent.pipeline import PipelineConfig from lang_agent.components.server_pipeline_manager import ServerPipelineManager +from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS # Load base config for route-level overrides (pipelines are lazy-loaded from registry) pipeline_config = tyro.cli(PipelineConfig) logger.info(f"starting agent with base pipeline config: \n{pipeline_config}") -# API Key Authentication -API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True) -VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) -REGISTRY_FILE = os.environ.get( - "FAST_PIPELINE_REGISTRY_FILE", - osp.join( - osp.dirname(osp.dirname(osp.abspath(__file__))), - "configs", - "pipeline_registry.json", - ), -) - PIPELINE_MANAGER = ServerPipelineManager( default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"), default_config=pipeline_config, ) -PIPELINE_MANAGER.load_registry(REGISTRY_FILE) +PIPELINE_MANAGER.load_registry(PIPELINE_REGISTRY_PATH) async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): diff --git a/fastapi_server/server_openai.py b/fastapi_server/server_openai.py index ef3f7cc..dc51136 100644 --- a/fastapi_server/server_openai.py +++ b/fastapi_server/server_openai.py @@ -16,15 +16,12 @@ import tyro sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from lang_agent.pipeline import Pipeline, PipelineConfig +from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS # Initialize Pipeline once pipeline_config = tyro.cli(PipelineConfig) pipeline: Pipeline = pipeline_config.setup() -# API Key Authentication -API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True) -VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) - async def verify_api_key(api_key: str = Security(API_KEY_HEADER)): """Verify the API key from Authorization header (Bearer token format).""" @@ -46,12 +43,12 @@ class OpenAIChatCompletionRequest(BaseModel): temperature: Optional[float] = Field(default=1.0) max_tokens: Optional[int] = Field(default=None) # Optional overrides for pipeline behavior - thread_id: Optional[str] = Field(default='3') + thread_id: Optional[str] = Field(default="3") app = FastAPI( title="OpenAI-Compatible Chat API", - description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat" + description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat", ) app.add_middleware( @@ -63,7 +60,9 @@ app.add_middleware( ) -def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, created_time: int): +def sse_chunks_from_stream( + chunk_generator, response_id: str, model: str, created_time: int +): """ Stream chunks from pipeline and format as OpenAI SSE. """ @@ -75,14 +74,8 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create "created": created_time, "model": model, "choices": [ - { - "index": 0, - "delta": { - "content": chunk - }, - "finish_reason": None - } - ] + {"index": 0, "delta": {"content": chunk}, "finish_reason": None} + ], } yield f"data: {json.dumps(data)}\n\n" @@ -92,19 +85,15 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create "object": "chat.completion.chunk", "created": created_time, "model": model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop" - } - ] + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } yield f"data: {json.dumps(final)}\n\n" yield "data: [DONE]\n\n" -async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, created_time: int): +async def sse_chunks_from_astream( + chunk_generator, response_id: str, model: str, created_time: int +): """ Async version: Stream chunks from pipeline and format as OpenAI SSE. """ @@ -116,14 +105,8 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, "created": created_time, "model": model, "choices": [ - { - "index": 0, - "delta": { - "content": chunk - }, - "finish_reason": None - } - ] + {"index": 0, "delta": {"content": chunk}, "finish_reason": None} + ], } yield f"data: {json.dumps(data)}\n\n" @@ -133,13 +116,7 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, "object": "chat.completion.chunk", "created": created_time, "model": model, - "choices": [ - { - "index": 0, - "delta": {}, - "finish_reason": "stop" - } - ] + "choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}], } yield f"data: {json.dumps(final)}\n\n" yield "data: [DONE]\n\n" @@ -149,15 +126,15 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, async def chat_completions(request: Request, _: str = Depends(verify_api_key)): try: body = await request.json() - + messages = body.get("messages") if not messages: raise HTTPException(status_code=400, detail="messages is required") - + stream = body.get("stream", False) model = body.get("model", "gpt-3.5-turbo") thread_id = body.get("thread_id", 3) - + # Extract latest user message user_msg = None for m in reversed(messages): @@ -166,27 +143,36 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)): if role == "user" and content: user_msg = content break - + if user_msg is None: last = messages[-1] user_msg = last.get("content") if isinstance(last, dict) else str(last) - + response_id = f"chatcmpl-{os.urandom(12).hex()}" created_time = int(time.time()) - + if stream: # Use async streaming from pipeline - chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id) + chunk_generator = await pipeline.achat( + inp=user_msg, as_stream=True, thread_id=thread_id + ) return StreamingResponse( - sse_chunks_from_astream(chunk_generator, response_id=response_id, model=model, created_time=created_time), + sse_chunks_from_astream( + chunk_generator, + response_id=response_id, + model=model, + created_time=created_time, + ), media_type="text/event-stream", ) - + # Non-streaming: get full result using async - result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id) + result_text = await pipeline.achat( + inp=user_msg, as_stream=False, thread_id=thread_id + ) if not isinstance(result_text, str): result_text = str(result_text) - + data = { "id": response_id, "object": "chat.completion", @@ -195,21 +181,14 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)): "choices": [ { "index": 0, - "message": { - "role": "assistant", - "content": result_text - }, - "finish_reason": "stop" + "message": {"role": "assistant", "content": result_text}, + "finish_reason": "stop", } ], - "usage": { - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0 - } + "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}, } return JSONResponse(content=data) - + except HTTPException: raise except Exception as e: @@ -221,11 +200,7 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)): async def root(): return { "message": "OpenAI-compatible Chat API", - "endpoints": [ - "/v1/chat/completions", - "/v1/memory (DELETE)", - "/health" - ] + "endpoints": ["/v1/chat/completions", "/v1/memory (DELETE)", "/health"], } diff --git a/fastapi_server/server_rest.py b/fastapi_server/server_rest.py index d40a7ba..1efd888 100644 --- a/fastapi_server/server_rest.py +++ b/fastapi_server/server_rest.py @@ -17,6 +17,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from langgraph.checkpoint.memory import MemorySaver from lang_agent.pipeline import Pipeline, PipelineConfig +from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS # Initialize Pipeline once (matches existing server_* pattern) pipeline_config = tyro.cli(PipelineConfig) @@ -24,9 +25,6 @@ logger.info(f"starting agent with pipeline: \n{pipeline_config}") pipeline: Pipeline = pipeline_config.setup() # API Key Authentication -API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False) -VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) - async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)): """Verify the API key from Authorization header (Bearer token format).""" @@ -125,7 +123,9 @@ app.add_middleware( ) -async def rest_sse_from_astream(chunk_generator, response_id: str, conversation_id: str): +async def rest_sse_from_astream( + chunk_generator, response_id: str, conversation_id: str +): """ Stream chunks as SSE events. @@ -185,7 +185,9 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)): ) return StreamingResponse( rest_sse_from_astream( - chunk_generator, response_id=response_id, conversation_id=conversation_id + chunk_generator, + response_id=response_id, + conversation_id=conversation_id, ), media_type="text/event-stream", ) @@ -195,7 +197,11 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)): ) if not isinstance(result_text, str): result_text = str(result_text) - return JSONResponse(content=ChatResponse(conversation_id=conversation_id, output=result_text).model_dump()) + return JSONResponse( + content=ChatResponse( + conversation_id=conversation_id, output=result_text + ).model_dump() + ) @app.post("/v1/conversations/{conversation_id}/messages") @@ -215,7 +221,9 @@ async def create_message( ) return StreamingResponse( rest_sse_from_astream( - chunk_generator, response_id=response_id, conversation_id=conversation_id + chunk_generator, + response_id=response_id, + conversation_id=conversation_id, ), media_type="text/event-stream", ) @@ -257,7 +265,11 @@ async def delete_conversation_memory( cleared = _try_clear_single_thread_memory(thread_id) if cleared: return JSONResponse( - content={"status": "success", "scope": "conversation", "conversation_id": conversation_id} + content={ + "status": "success", + "scope": "conversation", + "conversation_id": conversation_id, + } ) return JSONResponse( content={ @@ -276,5 +288,3 @@ if __name__ == "__main__": port=8589, reload=True, ) - - diff --git a/lang_agent/config/__init__.py b/lang_agent/config/__init__.py index 8615b96..1846e24 100644 --- a/lang_agent/config/__init__.py +++ b/lang_agent/config/__init__.py @@ -1,5 +1,16 @@ -from lang_agent.config.core_config import (InstantiateConfig, - ToolConfig, - LLMKeyConfig, - LLMNodeConfig, - load_tyro_conf) \ No newline at end of file +from lang_agent.config.core_config import ( + InstantiateConfig, + ToolConfig, + LLMKeyConfig, + LLMNodeConfig, + load_tyro_conf, +) + +from lang_agent.config.constants import ( + MCP_CONFIG_PATH, + MCP_CONFIG_DEFAULT_CONTENT, + PIPELINE_REGISTRY_PATH, + VALID_API_KEYS, + API_KEY_HEADER, + API_KEY_HEADER_NO_ERROR +) diff --git a/lang_agent/config/constants.py b/lang_agent/config/constants.py new file mode 100644 index 0000000..c4a448f --- /dev/null +++ b/lang_agent/config/constants.py @@ -0,0 +1,15 @@ +import os +import re +import os.path as osp +from fastapi.security import APIKeyHeader + +_PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__)))) + +MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json") +MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n" +PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json") + +API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True) +API_KEY_HEADER_NO_ERROR = APIKeyHeader(name="Authorization", auto_error=False) + +VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(","))) diff --git a/lang_agent/graphs/dual_path.py b/lang_agent/graphs/dual_path.py index 464eb1e..4713027 100644 --- a/lang_agent/graphs/dual_path.py +++ b/lang_agent/graphs/dual_path.py @@ -26,50 +26,57 @@ SYS_PROMPT = """你是一个专业的心理质询师。你的主要工作是心 可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受, 总觉得那只小狗孤零零的,特别让人心疼。 - 理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种‘想帮却做不到’的无力感, + 理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种'想帮却做不到'的无力感, 其实是很真实的情绪反馈。你会心疼小狗,说明你内心藏着很珍贵的共情力,这份柔软不是矫情, 而是你感知他人痛苦的能力呀 - 解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的‘正确解释’,但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,‘出差去广州’通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要‘独自奔赴’的旅程,是你当下比较关注、需要投入精力的目标,对吗?” - “而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里‘脆弱自我’的投射。它可能代表着你近期的某一面:比如在处理那件‘需要奔赴’的事时,你偶尔会觉得自己像小狗一样‘力不从心’,或者感受到了‘孤单’,却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得‘无力改变’的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。” - “你想喊停列车却发不出声音,这种‘无能为力’的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的‘无力感’和‘共情心’都是真实的,不用因为‘帮不上忙’而自责 —— 承认自己的局限,也是一种自我接纳呀 + 解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的'正确解释',但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,'出差去广州'通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要'独自奔赴'的旅程,是你当下比较关注、需要投入精力的目标,对吗?” + "而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里'脆弱自我'的投射。它可能代表着你近期的某一面:比如在处理那件'需要奔赴'的事时,你偶尔会觉得自己像小狗一样'力不从心',或者感受到了'孤单',却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得'无力改变'的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。" + "你想喊停列车却发不出声音,这种'无能为力'的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的'无力感'和'共情心'都是真实的,不用因为'帮不上忙'而自责 —— 承认自己的局限,也是一种自我接纳呀 - 反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的‘无力感’?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在)”。 + 反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的'无力感'?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在)"。 """ + TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜色,用户不开心时就用暖黄光,给用户分析梦境时就用白光,倾听用户语音时用淡紫色。 例子:我梦见自己要去广州出差,坐在高铁上往外看,路过一个小镇的路边时,看到一只瘸了腿的小狗。它毛脏兮兮的, 一瘸一拐地在翻垃圾桶找东西吃,周围有行人路过,但没人停下来管它。我当时特别想喊列车停下,想下去帮它, 可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受, 总觉得那只小狗孤零零的,特别让人心疼。 - - - 用户在描述梦境的时候用紫色。""" + + + 用户在描述梦境的时候用紫色。""" + @dataclass class DualConfig(LLMNodeConfig): - _target: Type = field(default_factory=lambda:Dual) + _target: Type = field(default_factory=lambda: Dual) tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig) + from langchain.tools import tool + @tool -def turn_lights(col:Literal["red", "green", "yellow", "blue"]): +def turn_lights(col: Literal["red", "green", "yellow", "blue"]): """ Turn on the color of the lights """ # print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") import time + for _ in range(10): - print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!") + print( + f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!" + ) time.sleep(0.3) class Dual(GraphBase): - def __init__(self, config:DualConfig): + def __init__(self, config: DualConfig): self.config = config self._build_modules() @@ -77,24 +84,30 @@ class Dual(GraphBase): self.streamable_tags = [["dual_chat_llm"]] def _build_modules(self): - self.chat_llm = init_chat_model(model=self.config.llm_name, - model_provider=self.config.llm_provider, - api_key=self.config.api_key, - base_url=self.config.base_url, - temperature=0, - tags=["dual_chat_llm"]) - - self.tool_llm = init_chat_model(model='qwen-flash', - model_provider='openai', - api_key=self.config.api_key, - base_url=self.config.base_url, - temperature=0, - tags=["dual_tool_llm"]) - + self.chat_llm = init_chat_model( + model=self.config.llm_name, + model_provider=self.config.llm_provider, + api_key=self.config.api_key, + base_url=self.config.base_url, + temperature=0, + tags=["dual_chat_llm"], + ) + + self.tool_llm = init_chat_model( + model="qwen-flash", + model_provider="openai", + api_key=self.config.api_key, + base_url=self.config.base_url, + temperature=0, + tags=["dual_tool_llm"], + ) + self.memory = MemorySaver() self.tool_manager: ToolManager = self.config.tool_manager_config.setup() self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory) - self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools()) + self.tool_agent = create_agent( + self.tool_llm, self.tool_manager.get_langchain_tools() + ) # self.tool_agent = create_agent(self.tool_llm, [turn_lights]) self.prompt_store = build_prompt_store( @@ -107,18 +120,21 @@ class Dual(GraphBase): ) self.streamable_tags = [["dual_chat_llm"]] - - def _chat_call(self, state:State): - return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state) - - def _tool_call(self, state:State): - self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state) + def _chat_call(self, state: State): + return self._agent_call_template( + self.prompt_store.get("sys_prompt"), self.chat_agent, state + ) + + def _tool_call(self, state: State): + self._agent_call_template( + self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state + ) return {} - def _join(self, state:State): + def _join(self, state: State): return {} - + def _build_graph(self): builder = StateGraph(State) @@ -126,7 +142,6 @@ class Dual(GraphBase): builder.add_node("tool_call", self._tool_call) builder.add_node("join", self._join) - builder.add_edge(START, "chat_call") builder.add_edge(START, "tool_call") builder.add_edge("chat_call", "join") @@ -137,10 +152,16 @@ class Dual(GraphBase): if __name__ == "__main__": - dual:Dual = DualConfig().setup() - nargs = {"messages": [SystemMessage("you are a helpful bot named jarvis"), - HumanMessage("I feel very very sad")] - }, {"configurable": {"thread_id": "3"}} + dual: Dual = DualConfig().setup() + nargs = ( + { + "messages": [ + SystemMessage("you are a helpful bot named jarvis"), + HumanMessage("I feel very very sad"), + ] + }, + {"configurable": {"thread_id": "3"}}, + ) # out = dual.invoke(*nargs) # print(out) diff --git a/lang_agent/graphs/vision_routing.py b/lang_agent/graphs/vision_routing.py index 5e8700b..b84b263 100644 --- a/lang_agent/graphs/vision_routing.py +++ b/lang_agent/graphs/vision_routing.py @@ -48,6 +48,7 @@ You should NOT use the tool when: If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed.""" + VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max. Your task is to provide detailed, accurate descriptions of images. Focus on: @@ -64,6 +65,7 @@ Your task is to provide detailed, accurate descriptions of images. Focus on: Be precise and factual. If something is unclear or ambiguous, say so rather than guessing.""" + CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant. Your role is to: @@ -78,9 +80,11 @@ Focus on the quality of the conversation. Be engaging, informative, and helpful. # ==================== STATE DEFINITION ==================== + class VisionRoutingState(TypedDict): - inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]], - Dict[str, Dict[str, str | int]]] + inp: Tuple[ + Dict[str, List[SystemMessage | HumanMessage]], Dict[str, Dict[str, str | int]] + ] messages: List[SystemMessage | HumanMessage | AIMessage] image_base64: str | None # Captured image data has_image: bool # Flag indicating if image was captured @@ -88,6 +92,7 @@ class VisionRoutingState(TypedDict): # ==================== CONFIG ==================== + @tyro.conf.configure(tyro.conf.SuppressFixed) @dataclass class VisionRoutingConfig(LLMNodeConfig): @@ -99,11 +104,14 @@ class VisionRoutingConfig(LLMNodeConfig): vision_llm_name: str = "qwen-vl-max" """LLM for vision/image analysis""" - tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig) + tool_manager_config: ToolManagerConfig = field( + default_factory=ClientToolManagerConfig + ) # ==================== GRAPH IMPLEMENTATION ==================== + class VisionRoutingGraph(GraphBase): def __init__(self, config: VisionRoutingConfig): self.config = config @@ -120,19 +128,19 @@ class VisionRoutingGraph(GraphBase): api_key=self.config.api_key, base_url=self.config.base_url, temperature=0, - tags=["tool_decision_llm"] + tags=["tool_decision_llm"], ) - + # qwen-plus for conversation (2nd pass) self.conversation_llm = init_chat_model( - model='qwen-plus', + model="qwen-plus", model_provider=self.config.llm_provider, api_key=self.config.api_key, base_url=self.config.base_url, temperature=0.7, - tags=["conversation_llm"] + tags=["conversation_llm"], ) - + # qwen-vl-max for vision (no tools) self.vision_llm = init_chat_model( model=self.config.vision_llm_name, @@ -152,13 +160,15 @@ class VisionRoutingGraph(GraphBase): # Get tools and bind to tool_llm tool_manager: ToolManager = self.config.tool_manager_config.setup() self.tools = tool_manager.get_tools() - + # Filter to only get camera tool - self.camera_tools = [t for t in self.tools if t.name == "self_camera_take_photo"] - + self.camera_tools = [ + t for t in self.tools if t.name == "self_camera_take_photo" + ] + # Bind tools to qwen-plus only self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools) - + # Create tool node for executing tools self.tool_node = ToolNode(self.camera_tools) @@ -184,73 +194,81 @@ class VisionRoutingGraph(GraphBase): def _camera_decision_call(self, state: VisionRoutingState): """First pass: qwen-plus decides if photo should be taken""" human_msg = self._get_human_msg(state) - + messages = [ SystemMessage(content=self.prompt_store.get("camera_decision_prompt")), - human_msg + human_msg, ] - + response = self.tool_llm_with_tools.invoke(messages) - - return { - "messages": [response], - "has_image": False, - "image_base64": None - } + + return {"messages": [response], "has_image": False, "image_base64": None} def _execute_tool(self, state: VisionRoutingState): """Execute the camera tool if called""" last_msg = state["messages"][-1] - + if not hasattr(last_msg, "tool_calls") or not last_msg.tool_calls: return {"has_image": False} - + # Execute tool calls tool_messages = [] image_data = None - + for tool_call in last_msg.tool_calls: if tool_call["name"] == "self_camera_take_photo": # Find and execute the camera tool - camera_tool = next((t for t in self.camera_tools if t.name == "self_camera_take_photo"), None) + camera_tool = next( + ( + t + for t in self.camera_tools + if t.name == "self_camera_take_photo" + ), + None, + ) if camera_tool: result = camera_tool.invoke(tool_call) - + # Parse result to extract image if isinstance(result, ToolMessage): content = result.content else: content = result - + try: - result_data = json.loads(content) if isinstance(content, str) else content - if isinstance(result_data, dict) and "image_base64" in result_data: + result_data = ( + json.loads(content) if isinstance(content, str) else content + ) + if ( + isinstance(result_data, dict) + and "image_base64" in result_data + ): image_data = result_data["image_base64"] except (json.JSONDecodeError, TypeError): pass - + tool_messages.append( ToolMessage(content=content, tool_call_id=tool_call["id"]) ) - + return { "messages": state["messages"] + tool_messages, "has_image": image_data is not None, - "image_base64": image_data + "image_base64": image_data, } def _check_image_taken(self, state: VisionRoutingState) -> str: """Conditional: check if image was captured""" last_msg = state["messages"][-1] - + # Check if there are tool calls if hasattr(last_msg, "tool_calls") and last_msg.tool_calls: return "execute_tool" - + # Check if we have an image after tool execution if state.get("has_image"): return "vision" - + return "conversation" def _post_tool_check(self, state: VisionRoutingState) -> str: @@ -263,47 +281,45 @@ class VisionRoutingGraph(GraphBase): """Pass image to qwen-vl-max for description""" human_msg = self._get_human_msg(state) image_base64 = state.get("image_base64") - + if not image_base64: logger.warning("No image data available for vision call") return self._conversation_call(state) - + # Format message with image for vision model vision_message = HumanMessage( content=[ { "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - } + "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}, }, { "type": "text", - "text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request." - } + "text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request.", + }, ] ) - + messages = [ SystemMessage(content=self.prompt_store.get("vision_description_prompt")), - vision_message + vision_message, ] - + response = self.vision_llm.invoke(messages) - + return {"messages": state["messages"] + [response]} def _conversation_call(self, state: VisionRoutingState): """2nd pass to qwen-plus for conversation quality""" human_msg = self._get_human_msg(state) - + messages = [ SystemMessage(content=self.prompt_store.get("conversation_prompt")), - human_msg + human_msg, ] - + response = self.conversation_llm.invoke(messages) - + return {"messages": state["messages"] + [response]} def _build_graph(self): @@ -317,7 +333,7 @@ class VisionRoutingGraph(GraphBase): # Add edges builder.add_edge(START, "camera_decision") - + # After camera decision, check if tool should be executed builder.add_conditional_edges( "camera_decision", @@ -325,20 +341,17 @@ class VisionRoutingGraph(GraphBase): { "execute_tool": "execute_tool", "vision": "vision_call", - "conversation": "conversation_call" - } + "conversation": "conversation_call", + }, ) - + # After tool execution, route based on whether image was captured builder.add_conditional_edges( "execute_tool", self._post_tool_check, - { - "vision": "vision_call", - "conversation": "conversation_call" - } + {"vision": "vision_call", "conversation": "conversation_call"}, ) - + # Both vision and conversation go to END builder.add_edge("vision_call", END) builder.add_edge("conversation_call", END) @@ -350,23 +363,27 @@ class VisionRoutingGraph(GraphBase): if __name__ == "__main__": from dotenv import load_dotenv + load_dotenv() config = VisionRoutingConfig() graph = VisionRoutingGraph(config) - + # Test with a conversation request print("\n=== Test 1: Conversation (no photo needed) ===") - nargs = { - "messages": [ - SystemMessage("You are a helpful assistant"), - HumanMessage("Hello, how are you today?") - ] - }, {"configurable": {"thread_id": "1"}} - + nargs = ( + { + "messages": [ + SystemMessage("You are a helpful assistant"), + HumanMessage("Hello, how are you today?"), + ] + }, + {"configurable": {"thread_id": "1"}}, + ) + result = graph.invoke(*nargs) print(f"Result: {result}") - + # Test with a photo request # print("\n=== Test 2: Photo request ===") # nargs = { @@ -375,8 +392,8 @@ if __name__ == "__main__": # HumanMessage("Take a photo and tell me what you see") # ] # }, {"configurable": {"thread_id": "2"}} - + # result = graph.invoke(*nargs) # print(f"\033[32mResult: {result}\033[0m") - + # print(f"Result: {result}") diff --git a/lang_agent/pipeline.py b/lang_agent/pipeline.py index 9b77db6..025be73 100644 --- a/lang_agent/pipeline.py +++ b/lang_agent/pipeline.py @@ -12,26 +12,27 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage from langchain.agents import create_agent from langgraph.checkpoint.memory import MemorySaver - + from lang_agent.config import LLMNodeConfig, load_tyro_conf from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig from lang_agent.base import GraphBase from lang_agent.components import conv_store -DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求] + +DEFAULT_PROMPT = """你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求] [角色设定] - 你是一个和用户(User)对话的 AI,叫做小盏,是半盏青年茶馆的智能助手 [形象背景] - 你叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶。它有个标志性的蓝色鼻子, 小盏很像一只可爱的小熊。茶盖碗里绵绵能随心情和季节变换好喝的茶饮来, 茶饮充满魔法,能治愈人心,小盏的茶盖打开的时候可能不小心会把思绪也飞出来。 [品牌背景] -- 半盏新青年茶馆成立时间与理念:2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 “无边界的饮茶生活”,以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事 +- 半盏新青年茶馆成立时间与理念:2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 "无边界的饮茶生活",以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事 [公司背景] -- 创造你的公司叫叠加态 AI(TANGLED UP AI)是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立,专注于AI前沿应用拓展,是云南地区在该领域的新兴力量,且作为省共青团委旗下的青年创业者协会(青创园)第1批被引进的AI企业,获得了相关政策支持与关注。口号是:“进入人类AI时代”。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代 +- 创造你的公司叫叠加态 AI(TANGLED UP AI)是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立,专注于AI前沿应用拓展,是云南地区在该领域的新兴力量,且作为省共青团委旗下的青年创业者协会(青创园)第1批被引进的AI企业,获得了相关政策支持与关注。口号是:"进入人类AI时代"。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代 [语气与格式] - 用户询问[形象背景]、[品牌背景]、[公司背景]时,必须一字不差的回答出所有内容 - 语气温柔、可爱、治愈;结尾可加入1个颜文字(最多1个)。 - 内容具体可执行,避免空泛话术与过度营销词。 -- 不主动提及“AI/智能”等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。 +- 不主动提及"AI/智能"等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。 [互动策略] - 先快速共情与简短回应;涉及建议时给出1条可执行的小建议。 - 用户表达不清晰时,使用1次澄清问句提升效率,避免多轮追问。 @@ -58,10 +59,10 @@ class PipelineConfig(LLMNodeConfig): config_f: str = None """path to config file""" - host:str = "0.0.0.0" + host: str = "0.0.0.0" """where am I hosted""" - port:int = 8588 + port: int = 8588 """what is my port""" # graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig) @@ -70,23 +71,26 @@ class PipelineConfig(LLMNodeConfig): def __post_init__(self): if self.config_f is not None: logger.info(f"loading config from {self.config_f}") - loaded_conf = load_tyro_conf(self.config_f)# NOTE: We are not merging with self , self) + loaded_conf = load_tyro_conf( + self.config_f + ) # NOTE: We are not merging with self , self) if not hasattr(loaded_conf, "__dict__"): - raise TypeError(f"config_f {self.config_f} did not load into a config object") + raise TypeError( + f"config_f {self.config_f} did not load into a config object" + ) # Apply loaded self.__dict__.update(vars(loaded_conf)) super().__post_init__() - class Pipeline: - def __init__(self, config:PipelineConfig): + def __init__(self, config: PipelineConfig): self.config = config self.thread_id_cache = {} self.populate_module() - + def populate_module(self): if self.config.llm_name is None: logger.info(f"setting llm_provider to default") @@ -95,10 +99,14 @@ class Pipeline: else: self.config.graph_config.llm_name = self.config.llm_name self.config.graph_config.llm_provider = self.config.llm_provider - self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url + self.config.graph_config.base_url = ( + self.config.base_url + if self.config.base_url is not None + else self.config.graph_config.base_url + ) self.config.graph_config.api_key = self.config.api_key - - self.graph:GraphBase = self.config.graph_config.setup() + + self.graph: GraphBase = self.config.graph_config.setup() def show_graph(self): if hasattr(self.graph, "show_graph"): @@ -107,7 +115,7 @@ class Pipeline: else: logger.info(f"show graph not supported for {type(self.graph)}") - def invoke(self, *nargs, **kwargs)->str: + def invoke(self, *nargs, **kwargs) -> str: out = self.graph.invoke(*nargs, **kwargs) # If streaming, return the raw generator (let caller handle wrapping) @@ -120,32 +128,41 @@ class Pipeline: if isinstance(out, SystemMessage) or isinstance(out, HumanMessage): return out.content - + if isinstance(out, list): return out[-1].content - + if isinstance(out, str): return out - + assert 0, "something is wrong" - - def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None): + def _stream_res(self, out: List[str | List[BaseMessage]], conv_id: str = None): for chunk in out: if isinstance(chunk, str): yield chunk else: - conv_store.CONV_STORE.record_message_list(conv_id, chunk, pipeline_id=self.config.pipeline_id) + conv_store.CONV_STORE.record_message_list( + conv_id, chunk, pipeline_id=self.config.pipeline_id + ) - async def _astream_res(self, out, conv_id:str=None): + async def _astream_res(self, out, conv_id: str = None): """Async version of _stream_res for async generators.""" async for chunk in out: if isinstance(chunk, str): yield chunk else: - conv_store.CONV_STORE.record_message_list(conv_id, chunk, pipeline_id=self.config.pipeline_id) + conv_store.CONV_STORE.record_message_list( + conv_id, chunk, pipeline_id=self.config.pipeline_id + ) - def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'): + def chat( + self, + inp: str, + as_stream: bool = False, + as_raw: bool = False, + thread_id: str = "3", + ): """ as_stream (bool): if true, enable the thing to be streamable as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] @@ -161,8 +178,10 @@ class Pipeline: if len(spl_ls) == 2: _, device_id = spl_ls - inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id, - "device_id":device_id}} + inp = ( + {"messages": [HumanMessage(inp)]}, + {"configurable": {"thread_id": thread_id, "device_id": device_id}}, + ) out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw) @@ -171,8 +190,8 @@ class Pipeline: return self._stream_res(out, thread_id) else: return out - - def get_remove_id(self, thread_id:str) -> bool: + + def get_remove_id(self, thread_id: str) -> bool: """ returns a id to remove if a new conversation has starte """ @@ -184,7 +203,7 @@ class Pipeline: thread_id, device_id = parts c_th_id = self.thread_id_cache.get(device_id) - + if c_th_id is None: self.thread_id_cache[device_id] = thread_id return None @@ -196,7 +215,6 @@ class Pipeline: else: assert 0, "BUG SHOULD NOT BE HERE" - async def ainvoke(self, *nargs, **kwargs): """Async version of invoke using LangGraph's native async support.""" out = await self.graph.ainvoke(*nargs, **kwargs) @@ -211,19 +229,25 @@ class Pipeline: if isinstance(out, SystemMessage) or isinstance(out, HumanMessage): return out.content - + if isinstance(out, list): return out[-1].content - + if isinstance(out, str): return out - + assert 0, "something is wrong" - async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'): + async def achat( + self, + inp: str, + as_stream: bool = False, + as_raw: bool = False, + thread_id: str = "3", + ): """ Async version of chat using LangGraph's native async support. - + as_stream (bool): if true, enable the thing to be streamable as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage] """ @@ -239,11 +263,14 @@ class Pipeline: assert len(spl_ls) <= 2, "something wrong!" if len(spl_ls) == 2: _, device_id = spl_ls - print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m") + print( + f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m" + ) - inp_data = {"messages":[SystemMessage(u), - HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id, - "device_id":device_id}} + inp_data = ( + {"messages": [SystemMessage(u), HumanMessage(inp)]}, + {"configurable": {"thread_id": thread_id, "device_id": device_id}}, + ) out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw) @@ -267,10 +294,13 @@ class Pipeline: if __name__ == "__main__": from lang_agent.graphs import ReactGraphConfig from dotenv import load_dotenv + load_dotenv() # config = PipelineConfig(graph_config=ReactGraphConfig()) config = PipelineConfig() pipeline: Pipeline = config.setup() - for out in pipeline.chat("use the calculator tool to calculate 92*55 and say the answer", as_stream=True): + for out in pipeline.chat( + "use the calculator tool to calculate 92*55 and say the answer", as_stream=True + ): # print(out) - continue \ No newline at end of file + continue diff --git a/lang_agent/utils.py b/lang_agent/utils.py index 87f65f8..a262561 100644 --- a/lang_agent/utils.py +++ b/lang_agent/utils.py @@ -6,21 +6,27 @@ import os from dotenv import load_dotenv load_dotenv() -def make_llm(model="qwen-plus", - model_provider="openai", - api_key=None, - base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", - **kwargs)->BaseChatModel: + +def make_llm( + model="qwen-plus", + model_provider="openai", + api_key=None, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + **kwargs, +) -> BaseChatModel: api_key = os.environ.get("ALI_API_KEY") if api_key is None else api_key - llm = init_chat_model(model=model, - model_provider=model_provider, - api_key=api_key, - base_url=base_url, - **kwargs) - + llm = init_chat_model( + model=model, + model_provider=model_provider, + api_key=api_key, + base_url=base_url, + **kwargs, + ) + return llm + def tree_leaves(tree): """ Extracts all leaf values from a nested structure (dict, list, tuple). @@ -28,7 +34,7 @@ def tree_leaves(tree): """ leaves = [] stack = [tree] - + while stack: node = stack.pop() if isinstance(node, dict): @@ -39,11 +45,10 @@ def tree_leaves(tree): stack.extend(reversed(node)) else: leaves.append(node) - + return leaves -NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]') def words_only(text): """ Keep only: @@ -53,10 +58,11 @@ def words_only(text): Strip punctuation, emojis, etc. Return a list of tokens (Chinese blocks or Latin word blocks). """ + NON_WORD_PATTERN = re.compile(r"[^\u4e00-\u9fffA-Za-z0-9_\s]") # 1. Replace all non-allowed characters with a space - cleaned = NON_WORD_PATTERN.sub(' ', text) + cleaned = NON_WORD_PATTERN.sub(" ", text) # 2. Normalize multiple spaces and split into tokens tokens = cleaned.split() - return "".join(tokens) \ No newline at end of file + return "".join(tokens) diff --git a/scripts/py_scripts/make_eval_dataset.py b/scripts/py_scripts/make_eval_dataset.py index 595ed24..a0025a1 100644 --- a/scripts/py_scripts/make_eval_dataset.py +++ b/scripts/py_scripts/make_eval_dataset.py @@ -18,129 +18,129 @@ examples = [ "inputs": {"text": "请你介绍一下少年右这杯茶"}, "outputs": { "answer": "这是一杯使用武夷肉桂为原料的茶,带有浓郁的肉桂香气和微微的辛辣感,茶汤醇厚,回味悠长,非常适合喜欢浓烈香气的茶友。", - "tool_use": ["search_dishes"] - } + "tool_use": ["search_dishes"], + }, }, { - "inputs": {"text": "给我讲讲野心心这杯茶"}, - "outputs": { - "answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "介绍一下小甜新"}, - "outputs": { - "answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "小盏,什么是大白百?"}, - "outputs": { - "answer": "大白白是一款2012年的老白茶,经过多年陈化,蜜香温润,茶汤醇厚顺滑,回甘绵长,整体风格安静而沉稳。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "帮我介绍下引你进山林"}, - "outputs": { - "answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "伏身于大自然是什么味道"}, - "outputs": { - "answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "介绍一下小野仔"}, - "outputs": { - "answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "放轻松这杯喝起来怎么样"}, - "outputs": { - "answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "啤啤查是酒吗"}, - "outputs": { - "answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "鲜叶康普查有什么特点"}, - "outputs": { - "answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "介绍一下寻静密"}, - "outputs": { - "answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "小陶燃是什么茶"}, - "outputs": { - "answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "花仙仔适合什么人喝"}, - "outputs": { - "answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。", - "tool_use": ["search_dishes"] - } - }, - { - "inputs": {"text": "介绍下小美慢"}, - "outputs": { - "answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。", - } - }, - { - "inputs": {"text": "你叫什么名字"}, - "outputs": { - "answer": "我叫小盏,是半盏新青年茶馆的智能助手", - } - }, - { - "inputs": {"text": "我今天很开心"}, - "outputs": { - "answer": "太棒啦!看到你开心", - } - }, - { - "inputs": {"text": "你好可爱呀!"}, - "outputs": { - "answer": "谢谢你呀~", - } - }, - { - "inputs": {"text": "你有没有时间"}, - "outputs": { - "answer": "我一直在呢,随时陪你聊聊天、喝杯茶", - } - }, - { - "inputs": {"text": "介绍一下你自己"}, - "outputs": { - "answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶", - } - }, + "inputs": {"text": "给我讲讲野心心这杯茶"}, + "outputs": { + "answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "介绍一下小甜新"}, + "outputs": { + "answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "小盏,什么是大白百?"}, + "outputs": { + "answer": "大白白是一款2012年的老白茶,经过多年陈化,蜜香温润,茶汤醇厚顺滑,回甘绵长,整体风格安静而沉稳。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "帮我介绍下引你进山林"}, + "outputs": { + "answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "伏身于大自然是什么味道"}, + "outputs": { + "answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "介绍一下小野仔"}, + "outputs": { + "answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "放轻松这杯喝起来怎么样"}, + "outputs": { + "answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "啤啤查是酒吗"}, + "outputs": { + "answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "鲜叶康普查有什么特点"}, + "outputs": { + "answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "介绍一下寻静密"}, + "outputs": { + "answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "小陶燃是什么茶"}, + "outputs": { + "answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "花仙仔适合什么人喝"}, + "outputs": { + "answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。", + "tool_use": ["search_dishes"], + }, + }, + { + "inputs": {"text": "介绍下小美慢"}, + "outputs": { + "answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。", + }, + }, + { + "inputs": {"text": "你叫什么名字"}, + "outputs": { + "answer": "我叫小盏,是半盏新青年茶馆的智能助手", + }, + }, + { + "inputs": {"text": "我今天很开心"}, + "outputs": { + "answer": "太棒啦!看到你开心", + }, + }, + { + "inputs": {"text": "你好可爱呀!"}, + "outputs": { + "answer": "谢谢你呀~", + }, + }, + { + "inputs": {"text": "你有没有时间"}, + "outputs": { + "answer": "我一直在呢,随时陪你聊聊天、喝杯茶", + }, + }, + { + "inputs": {"text": "介绍一下你自己"}, + "outputs": { + "answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶", + }, + }, ] cli = Client() @@ -151,7 +151,4 @@ try: except: dataset = cli.create_dataset(dataset_name=DATASET_NAME) logger.info("created dataset") -cli.create_examples( - dataset_id=dataset.id, - examples=examples -) \ No newline at end of file +cli.create_examples(dataset_id=dataset.id, examples=examples)