Compare commits
183 Commits
d4b4ef3690
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 737a80aa39 | |||
| bdd4ddec9e | |||
| 2ee55d25cc | |||
| bf9ce709e2 | |||
| 59331d2435 | |||
| a9488a655a | |||
| eeadd4c825 | |||
| 42d8b8e8e1 | |||
| da6cc1e18b | |||
| 691ef1c77d | |||
| 9f9813f01d | |||
| f9f3651c7e | |||
| 87407e1656 | |||
| 2c226e2df2 | |||
| 8b2a506177 | |||
| a26cda2f04 | |||
| 5e2a86e3be | |||
| 2c7d5ea589 | |||
| 36e3b40936 | |||
| 86f6e4f81c | |||
| 60f3029e54 | |||
| fe7ff9a516 | |||
| 33faedc1b1 | |||
| c9b1c5cb32 | |||
| 8f99d47af9 | |||
| 1d36f196ca | |||
| 9f40ef93b7 | |||
| fda3e86a71 | |||
| d5303ad201 | |||
| 245be63f07 | |||
| 068df7715e | |||
| 95498180a2 | |||
| 7b6ba79417 | |||
| a228325d74 | |||
| e59ae66b85 | |||
| e524be15e5 | |||
| 2d90a4caac | |||
| 4e8c2b4f5d | |||
| 0ddeb91755 | |||
| cb3b98db10 | |||
| 487562042f | |||
| ac46518cf5 | |||
| c03fb17a5c | |||
| e049e4e104 | |||
| 8d0cd7861b | |||
| 7920a1bab3 | |||
| 30e93ed0a7 | |||
| 0087ac41e6 | |||
| 5358b46113 | |||
| 3306da038d | |||
| ac99dfd56b | |||
| 3932d695bf | |||
| 9ed141ba42 | |||
| 9b3db40b94 | |||
| da17f2b319 | |||
| 4b6e97d8fb | |||
| 112af37151 | |||
| 3cd46030ad | |||
| dd842fca42 | |||
| fc9f0f929d | |||
| 28d99f4b8d | |||
| dac067b6fe | |||
| e90f0afabe | |||
| 0676a68c9e | |||
| f185b70d3f | |||
| 07149e426e | |||
| 3fc3d7288c | |||
| eb7e85e4e6 | |||
| ddfda10700 | |||
| f8364bea68 | |||
| 01b0975abd | |||
| 7e23d5c056 | |||
| 3b730798f8 | |||
| 2781172724 | |||
| 26fba706f2 | |||
| ae93ef37b6 | |||
| c1b782c6b4 | |||
| ab3285a4cf | |||
| 0484343021 | |||
| b87fded473 | |||
| 8db22abf3b | |||
| f6d86f24bb | |||
| c1afebd7ba | |||
| 080631af31 | |||
| 38b0d5df15 | |||
| f7937c3744 | |||
| 867acaf717 | |||
| a2890148f9 | |||
| 55b37cc611 | |||
| c85598418d | |||
| ea605e19aa | |||
| 866edc319f | |||
| 8c6dd3344f | |||
| f58d80ac1b | |||
| 9e3fecd4bd | |||
| 9b128ae41b | |||
| 61931cad58 | |||
| b070d08130 | |||
| fe87bb55a1 | |||
| d455933e9c | |||
| f6547ca6fe | |||
| 9c0744e1bc | |||
| 2f40f1c526 | |||
| cf1cae51f7 | |||
| 91685d5bf7 | |||
| 6ed33f3185 | |||
| 9284e8f928 | |||
| bb6d98c9f4 | |||
| 9d1eeaeec5 | |||
| 6c26f92ff1 | |||
| e3703d962d | |||
| 9470b7e342 | |||
| 94b24682e2 | |||
| 2ce0075dcd | |||
| 431aea410b | |||
| 46978895d1 | |||
| 501f9954ce | |||
| 021b4d6ffb | |||
| 16916e183b | |||
| 6097e9c24a | |||
| 11a7809fc0 | |||
| e85d0e7771 | |||
| 26fcdacbe8 | |||
| 6b0976f263 | |||
| 5f7cd64c11 | |||
| 4a713190b7 | |||
| bda75d2558 | |||
| 728d5934d7 | |||
| 4974ca936c | |||
| bc208209c7 | |||
| afb493adf4 | |||
| cc2e9cf90c | |||
| 686c1d6a1f | |||
| 1fcd5b4c61 | |||
| 62a00b4a5b | |||
| 7294e07df7 | |||
| 6425275d4b | |||
| 5742a08e98 | |||
| af16b87b0e | |||
| 6b0e50c532 | |||
| 65a1705280 | |||
| c4fdfd23c4 | |||
| c8847b0dbb | |||
| c4ad6433cb | |||
| bd4dfaad2a | |||
| c3d748c08f | |||
| 8558e60ee6 | |||
| abb78ad70e | |||
| 262d7dd51b | |||
| c64df2f48a | |||
| ec6b6e4545 | |||
| d7085676bc | |||
| 7be4aa1283 | |||
| a832178810 | |||
| b6f26c1e1b | |||
| 10562a98af | |||
| 851284b0a3 | |||
| d9a654170c | |||
| 49405782af | |||
| 18d75e754b | |||
| b1e6b0f7fc | |||
| fbd214bea3 | |||
| 9e156943bb | |||
| cec02b4e98 | |||
| cf2aea2d26 | |||
| 2523703df0 | |||
| cab0a0a42c | |||
| c7db276df5 | |||
| 9d7d81c0ac | |||
| 3d072cab07 | |||
| 5ed3f80971 | |||
| 56124069e1 | |||
| c87c883313 | |||
| 05bcf884c5 | |||
| 7cb40fca0e | |||
| 9ab9f0c36e | |||
| 26ca06d50d | |||
| 9363bd3442 | |||
| 1972c182d8 | |||
| 156186bfae | |||
| c2cc2628dd | |||
| 43dad177ab | |||
| 1f690914fb |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
.vscode/
|
.vscode/
|
||||||
logs/
|
logs/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
|
workspace/
|
||||||
|
|
||||||
*.pyc
|
*.pyc
|
||||||
*.zip
|
*.zip
|
||||||
@@ -11,3 +12,4 @@ django.log
|
|||||||
|
|
||||||
frontend/node_modules/
|
frontend/node_modules/
|
||||||
frontend/dist/
|
frontend/dist/
|
||||||
|
frontend/.vite
|
||||||
37
README.md
37
README.md
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
这是一个基于FastAPI的聊天API服务,使用OpenAI格式的请求来调用pipeline.invoke方法进行聊天。
|
这是一个基于FastAPI的聊天API服务,使用OpenAI格式的请求来调用pipeline.invoke方法进行聊天。
|
||||||
|
|
||||||
|
## Docker Installation
|
||||||
|
|
||||||
|
For production deployment using Docker, see the [Installation Guide](README_INSTALL.md).
|
||||||
|
|
||||||
## 安装依赖
|
## 安装依赖
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -119,13 +123,6 @@ everything in scripts:
|
|||||||
## Registering MCP service
|
## Registering MCP service
|
||||||
put the links in `configs/mcp_config.json`
|
put the links in `configs/mcp_config.json`
|
||||||
|
|
||||||
|
|
||||||
## Graph structure
|
|
||||||
Graph structure:
|
|
||||||
|
|
||||||

|
|
||||||
We choose this structure to overcome a limitation in xiaozhi. Specifically, both normal chatting and tool use prompts are deligated to one model. That leads to degregation in quality of generated conversation and tool use. By splitting into two model, we effectively increase the prompt limit size while preserving model quality.
|
|
||||||
|
|
||||||
## Modifying LLM prompts
|
## Modifying LLM prompts
|
||||||
Refer to model above when modifying the prompts.
|
Refer to model above when modifying the prompts.
|
||||||
they are in `configs/route_sys_prompts`
|
they are in `configs/route_sys_prompts`
|
||||||
@@ -147,15 +144,35 @@ npm install
|
|||||||
|
|
||||||
### Start the `front_apis` server
|
### Start the `front_apis` server
|
||||||
|
|
||||||
The frontend talks to the `front_apis` FastAPI service, which by default listens on `http://127.0.0.1:8001`.
|
The frontend talks to the `front_apis` FastAPI service, which by default listens on `http://127.0.0.1:8500`.
|
||||||
|
|
||||||
From the project root:
|
From the project root:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
uvicorn fastapi_server.front_apis:app --reload --host 0.0.0.0 --port 8001
|
uvicorn fastapi_server.front_apis:app --reload --host 0.0.0.0 --port 8500
|
||||||
```
|
```
|
||||||
|
|
||||||
You can change the URL by setting `VITE_FRONT_API_BASE_URL` in `frontend/.env` (defaults to `http://127.0.0.1:8001`).
|
Or run directly:
|
||||||
|
```bash
|
||||||
|
python fastapi_server/front_apis.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### Backend run modes
|
||||||
|
|
||||||
|
Run whichever backend mode you need from the project root:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# admin/control plane only (/v1/... frontend APIs)
|
||||||
|
uvicorn fastapi_server.front_apis:app --reload --host 0.0.0.0 --port 8500
|
||||||
|
|
||||||
|
# DashScope chat runtime only (/apps/... and /v1/apps/... APIs)
|
||||||
|
uvicorn fastapi_server.server_dashscope:app --reload --host 0.0.0.0 --port 8588
|
||||||
|
|
||||||
|
# combined mode: one process serves both front_apis + DashScope endpoints
|
||||||
|
uvicorn fastapi_server.combined:app --reload --host 0.0.0.0 --port 8500
|
||||||
|
```
|
||||||
|
|
||||||
|
You can change the URL by setting `VITE_FRONT_API_BASE_URL` in `frontend/.env` (defaults to `/`, i.e. same-origin).
|
||||||
|
|
||||||
### Start the development server
|
### Start the development server
|
||||||
|
|
||||||
|
|||||||
267
README_INSTALL.md
Normal file
267
README_INSTALL.md
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
# Installation Guide
|
||||||
|
|
||||||
|
This guide explains how to install and run the LangChain Agent application using Docker.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- Docker (version 20.10 or later)
|
||||||
|
- Docker Compose (version 2.0 or later, or use `docker compose` command)
|
||||||
|
|
||||||
|
## Quick Start
|
||||||
|
|
||||||
|
1. **Run the installation script:**
|
||||||
|
```bash
|
||||||
|
./scripts/shell_scripts/install.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
This script will:
|
||||||
|
- Check for required tools (Docker, docker-compose)
|
||||||
|
- Create a `.env` file with default configuration
|
||||||
|
- Build Docker images (or use pre-loaded images)
|
||||||
|
- Start all services (PostgreSQL, Backend API, Nginx)
|
||||||
|
|
||||||
|
2. **Access the application:**
|
||||||
|
- Frontend: http://localhost (or http://localhost:80)
|
||||||
|
- Backend API: http://localhost:8500
|
||||||
|
- Database: localhost:5432
|
||||||
|
|
||||||
|
## Installation for China / Offline Use
|
||||||
|
|
||||||
|
If Docker Hub is slow or inaccessible in your region:
|
||||||
|
|
||||||
|
### Option 1: Use Chinese Docker Mirrors
|
||||||
|
|
||||||
|
Configure Docker to use Chinese registry mirrors:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo tee /etc/docker/daemon.json <<EOF
|
||||||
|
{
|
||||||
|
"registry-mirrors": [
|
||||||
|
"https://registry.docker-cn.com",
|
||||||
|
"https://mirror.ccsogou.com",
|
||||||
|
"https://docker.1ms.run"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
EOF
|
||||||
|
sudo systemctl daemon-reload
|
||||||
|
sudo systemctl restart docker
|
||||||
|
```
|
||||||
|
|
||||||
|
Then run `./scripts/shell_scripts/install.sh`
|
||||||
|
|
||||||
|
### configuring '.env' and 'frontend/.env'
|
||||||
|
```bash
|
||||||
|
ALI_API_KEY="API_KEY_FOR_ALI_QWEN"
|
||||||
|
ALI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
|
POSTGRES_ROOT_PASSWORD="ROOT_PASSOWRD_FOR_DB_IN_CONN_STR - required for installation"
|
||||||
|
POSTGRES_PASSWORD="USER_PASSWORD_FOR_DB_CONN_STR - required for installation" - need to be same in POSTGRES_PASSWORD
|
||||||
|
CONN_STR="CONNECTION_STRING_TO_DATABASE" # DOCKER PASSWORD
|
||||||
|
FAST_AUTH_KEYS="API_KEY_FOR_OTHER_APPLICATIONS_TO_USE_BUILT_PIPELINE"
|
||||||
|
DAYTONA_API_KEY="DAYTONA_CONFIG - NOT REQUIRED"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 2: Pre-load Docker Images Offline
|
||||||
|
|
||||||
|
1. On a machine with good Docker Hub access, run:
|
||||||
|
```bash
|
||||||
|
./scripts/shell_scripts/download_images.sh
|
||||||
|
```
|
||||||
|
This creates `images.tar` with all required images.
|
||||||
|
|
||||||
|
2. Transfer `images.tar` to your target machine.
|
||||||
|
|
||||||
|
3. Load the images:
|
||||||
|
```bash
|
||||||
|
docker load < images.tar
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Run the install script:
|
||||||
|
```bash
|
||||||
|
./scripts/shell_scripts/install.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
## Manual Installation
|
||||||
|
|
||||||
|
If you prefer to set up manually:
|
||||||
|
|
||||||
|
1. **Create environment file:**
|
||||||
|
```bash
|
||||||
|
cp .env.example .env # Edit as needed
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Build and start services:**
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml up -d --build
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **Check service status:**
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml ps
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Edit the `.env` file to customize:
|
||||||
|
|
||||||
|
- `POSTGRES_DB`: Database name (default: `ai_conversations`)
|
||||||
|
- `POSTGRES_USER`: Database user (default: `myapp_user`)
|
||||||
|
- `POSTGRES_PASSWORD`: Database password (default: `secure_password_123`)
|
||||||
|
- `POSTGRES_PORT`: PostgreSQL port (default: `5432`)
|
||||||
|
- `BACKEND_PORT`: Backend API port (default: `8500`)
|
||||||
|
- `FRONTEND_PORT`: Frontend web server port (default: `80`)
|
||||||
|
|
||||||
|
## Database Initialization
|
||||||
|
|
||||||
|
The database is automatically initialized when the PostgreSQL container starts for the first time. The following SQL scripts are executed in order:
|
||||||
|
|
||||||
|
1. `scripts/init_database/00_init_user.sh` - Creates database user and database
|
||||||
|
2. `scripts/init_database/create_conv_store.sql` - Creates conversation storage tables
|
||||||
|
3. `scripts/init_database/create_prompt_config.sql` - Creates prompt configuration tables
|
||||||
|
|
||||||
|
## Service Management
|
||||||
|
|
||||||
|
All commands run from the `docker/` directory:
|
||||||
|
|
||||||
|
### View logs:
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml logs -f
|
||||||
|
|
||||||
|
# Specific service
|
||||||
|
docker compose -f docker-compose.prod.yml logs -f backend
|
||||||
|
docker compose -f docker-compose.prod.yml logs -f postgres
|
||||||
|
docker compose -f docker-compose.prod.yml logs -f nginx
|
||||||
|
```
|
||||||
|
|
||||||
|
### Stop services:
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml down
|
||||||
|
```
|
||||||
|
|
||||||
|
### Restart services:
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml restart
|
||||||
|
```
|
||||||
|
|
||||||
|
### Rebuild after code changes:
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml up -d --build
|
||||||
|
```
|
||||||
|
|
||||||
|
### Reset database (delete all data):
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml down -v
|
||||||
|
docker compose -f docker-compose.prod.yml up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
The application consists of three main services:
|
||||||
|
|
||||||
|
1. **PostgreSQL** (`postgres`): Database server
|
||||||
|
- Stores conversations and prompt configurations
|
||||||
|
- Automatically initializes schema on first run
|
||||||
|
|
||||||
|
2. **Backend** (`backend`): FastAPI application
|
||||||
|
- Serves API endpoints at port 8500
|
||||||
|
- Handles agent management and chat endpoints
|
||||||
|
- Connects to PostgreSQL database
|
||||||
|
|
||||||
|
3. **Nginx** (`nginx`): Web server
|
||||||
|
- Serves the React frontend (port 80)
|
||||||
|
- Proxies API requests to the backend
|
||||||
|
- Handles static file serving
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
langchain-agent/
|
||||||
|
├── docker/
|
||||||
|
│ ├── docker-compose.prod.yml # Production compose file
|
||||||
|
│ └── Dockerfile.prod # Backend Docker image
|
||||||
|
├── scripts/
|
||||||
|
│ ├── shell_scripts/
|
||||||
|
│ │ ├── install.sh # Main installation script
|
||||||
|
│ │ └── download_images.sh # For offline image download
|
||||||
|
│ └── init_database/ # Database initialization scripts
|
||||||
|
├── frontend/ # React frontend
|
||||||
|
├── configs/ # Pipeline configurations
|
||||||
|
├── nginx.conf # Nginx configuration
|
||||||
|
└── .env # Environment variables
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Database connection issues
|
||||||
|
|
||||||
|
If the backend can't connect to the database:
|
||||||
|
|
||||||
|
1. Check that PostgreSQL is running:
|
||||||
|
```bash
|
||||||
|
docker compose -f docker-compose.prod.yml ps postgres
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Verify the connection string in `.env` matches the database configuration
|
||||||
|
|
||||||
|
3. Check backend logs:
|
||||||
|
```bash
|
||||||
|
docker compose -f docker-compose.prod.yml logs backend
|
||||||
|
```
|
||||||
|
|
||||||
|
### Frontend not loading / NetworkError
|
||||||
|
|
||||||
|
1. Check nginx logs:
|
||||||
|
```bash
|
||||||
|
docker compose -f docker-compose.prod.yml logs nginx
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Ensure frontend is built with correct API base URL. The `frontend/.env` file should contain:
|
||||||
|
```
|
||||||
|
VITE_FRONT_API_BASE_URL=/
|
||||||
|
```
|
||||||
|
Then rebuild: `docker compose -f docker-compose.prod.yml build backend`
|
||||||
|
|
||||||
|
### Port conflicts
|
||||||
|
|
||||||
|
If ports are already in use, update the port mappings in `.env`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Example: use port 5433 for PostgreSQL
|
||||||
|
POSTGRES_PORT=5433
|
||||||
|
```
|
||||||
|
|
||||||
|
## Development
|
||||||
|
|
||||||
|
For development, you may want to run services separately:
|
||||||
|
|
||||||
|
1. Start only PostgreSQL:
|
||||||
|
```bash
|
||||||
|
cd docker
|
||||||
|
docker compose -f docker-compose.prod.yml up -d postgres
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Run backend locally:
|
||||||
|
```bash
|
||||||
|
export CONN_STR="postgresql://myapp_user:secure_password_123@localhost:5432/ai_conversations"
|
||||||
|
python -m uvicorn lang_agent.fastapi_server.combined:app --reload --host 0.0.0.0 --port 8500
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Run frontend locally:
|
||||||
|
```bash
|
||||||
|
cd frontend
|
||||||
|
npm install
|
||||||
|
npm run dev
|
||||||
|
```
|
||||||
|
|
||||||
|
Note: For local frontend development, create a `.env` file in `frontend/` with:
|
||||||
|
```
|
||||||
|
VITE_FRONT_API_BASE_URL=http://localhost:8500
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
31
configs/pipeline_registry.json
Normal file
31
configs/pipeline_registry.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"pipelines": {
|
||||||
|
"xiaozhan": {
|
||||||
|
"enabled": true,
|
||||||
|
"config_file": "configs/pipelines/xiaozhan.yaml",
|
||||||
|
"graph_id": "routing",
|
||||||
|
"overrides": {
|
||||||
|
"llm_name": "qwen-plus"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"blueberry": {
|
||||||
|
"enabled": true,
|
||||||
|
"config_file": "configs/pipelines/blueberry.yaml",
|
||||||
|
"graph_id": "react",
|
||||||
|
"overrides": {
|
||||||
|
"llm_name": "qwen-plus"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"api_keys": {
|
||||||
|
"sk-6c7091e6a95f404efb2ec30e8f51b897626d670375cdf822d78262f24ab12367": {
|
||||||
|
"example-key-1": {
|
||||||
|
"default_route_id": "default",
|
||||||
|
"allowed_route_ids": [
|
||||||
|
"xiaozhan",
|
||||||
|
"blueberry"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
7
configs/prompts/deepagent.txt
Normal file
7
configs/prompts/deepagent.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
You are a helpful bot enhanced with skills.
|
||||||
|
|
||||||
|
To use a skill, read its SKILL.md file using the read_file tool. Skills are NOT tools — they are instructions for using existing tools.
|
||||||
|
|
||||||
|
When using a skill, assume required tools (e.g., npx, curl) are available and execute the commands directly. If a command fails because a tool is missing, install the missing dependency using apt/brew and retry.
|
||||||
|
|
||||||
|
For shell commands (e.g., npx, curl), use the execute tool to run them.
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
{
|
|
||||||
"mcpServers": {
|
|
||||||
"remote-http-server": {
|
|
||||||
"type": "https",
|
|
||||||
"url": "https://xiaoliang.quant-speed.com/api/mcp/"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
19
docker/Dockerfile.frontend
Normal file
19
docker/Dockerfile.frontend
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
FROM node:20-alpine
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
RUN npm config set registry https://registry.npmmirror.com
|
||||||
|
|
||||||
|
# Build-time API base for Vite (must be set before npm run build).
|
||||||
|
ARG VITE_FRONT_API_BASE_URL=/
|
||||||
|
ENV VITE_FRONT_API_BASE_URL=${VITE_FRONT_API_BASE_URL}
|
||||||
|
|
||||||
|
COPY package*.json ./
|
||||||
|
RUN npm install
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN npm run build && \
|
||||||
|
mkdir -p /opt/frontend_dist && \
|
||||||
|
cp -r dist/. /opt/frontend_dist/
|
||||||
|
|
||||||
|
CMD ["sh", "-c", "rm -rf /app/dist/* && cp -r /opt/frontend_dist/. /app/dist && ls /app/dist"]
|
||||||
93
docker/Dockerfile.prod
Normal file
93
docker/Dockerfile.prod
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
# Multi-stage Dockerfile for production deployment
|
||||||
|
# Stage 1: Build frontend
|
||||||
|
FROM node:20-alpine AS frontend-builder
|
||||||
|
|
||||||
|
WORKDIR /app/frontend
|
||||||
|
|
||||||
|
# Copy frontend files
|
||||||
|
COPY frontend/package*.json ./
|
||||||
|
RUN npm ci
|
||||||
|
|
||||||
|
COPY frontend/ ./
|
||||||
|
RUN npm run build
|
||||||
|
|
||||||
|
# Stage 2: Python backend
|
||||||
|
FROM python:3.12-slim
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Install system dependencies
|
||||||
|
RUN set -eux; \
|
||||||
|
for source_file in /etc/apt/sources.list /etc/apt/sources.list.d/*.list /etc/apt/sources.list.d/*.sources; do \
|
||||||
|
if [ -f "$source_file" ]; then \
|
||||||
|
sed -i 's|deb.debian.org|mirrors.aliyun.com|g' "$source_file"; \
|
||||||
|
sed -i 's|security.debian.org|mirrors.aliyun.com|g' "$source_file"; \
|
||||||
|
fi; \
|
||||||
|
done; \
|
||||||
|
apt-get update; \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
postgresql-client \
|
||||||
|
curl; \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# Copy Python dependencies
|
||||||
|
COPY pyproject.toml ./
|
||||||
|
RUN pip install --no-cache-dir --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||||
|
pip install --no-cache-dir -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||||
|
|
||||||
|
# Copy application code
|
||||||
|
COPY lang_agent/ ./lang_agent/
|
||||||
|
COPY configs/ ./configs/
|
||||||
|
COPY scripts/ ./scripts/
|
||||||
|
COPY assets/ ./assets/
|
||||||
|
COPY static/ ./static/
|
||||||
|
|
||||||
|
# Copy built frontend from stage 1
|
||||||
|
COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
|
||||||
|
|
||||||
|
# Set environment variables
|
||||||
|
ENV PYTHONPATH=/app
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
# Expose port
|
||||||
|
EXPOSE 8500
|
||||||
|
|
||||||
|
# Health check
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:8500/health || exit 1
|
||||||
|
|
||||||
|
# Create entrypoint script to wait for DB
|
||||||
|
# Uses Python to check database connection (more reliable than psql)
|
||||||
|
RUN echo '#!/bin/bash\n\
|
||||||
|
set -e\n\
|
||||||
|
echo "Waiting for database to be ready..."\n\
|
||||||
|
python3 << EOF\n\
|
||||||
|
import sys\n\
|
||||||
|
import time\n\
|
||||||
|
import psycopg\n\
|
||||||
|
\n\
|
||||||
|
max_attempts = 30\n\
|
||||||
|
conn_str = "${CONN_STR}"\n\
|
||||||
|
\n\
|
||||||
|
for i in range(max_attempts):\n\
|
||||||
|
try:\n\
|
||||||
|
with psycopg.connect(conn_str, connect_timeout=2) as conn:\n\
|
||||||
|
with conn.cursor() as cur:\n\
|
||||||
|
cur.execute("SELECT 1")\n\
|
||||||
|
print("Database is ready!")\n\
|
||||||
|
sys.exit(0)\n\
|
||||||
|
except Exception as e:\n\
|
||||||
|
if i == max_attempts - 1:\n\
|
||||||
|
print(f"Warning: Database not ready after {max_attempts * 2} seconds, continuing anyway...")\n\
|
||||||
|
print(f"Error: {e}")\n\
|
||||||
|
sys.exit(0)\n\
|
||||||
|
print(f"Database is unavailable - sleeping (attempt {i+1}/{max_attempts})")\n\
|
||||||
|
time.sleep(2)\n\
|
||||||
|
EOF\n\
|
||||||
|
exec "$@"' > /entrypoint.sh && chmod +x /entrypoint.sh
|
||||||
|
|
||||||
|
ENTRYPOINT ["/entrypoint.sh"]
|
||||||
|
|
||||||
|
# Run the combined server
|
||||||
|
CMD ["python", "-m", "uvicorn", "lang_agent.fastapi_server.combined:app", "--host", "0.0.0.0", "--port", "8500"]
|
||||||
|
|
||||||
100
docker/docker-compose.prod.yml
Normal file
100
docker/docker-compose.prod.yml
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
# PostgreSQL database
|
||||||
|
postgres:
|
||||||
|
image: postgres:16-alpine
|
||||||
|
container_name: langchain-agent-db
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
environment:
|
||||||
|
POSTGRES_DB: postgres
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: ${POSTGRES_ROOT_PASSWORD:-postgres_root_password}
|
||||||
|
# These are used by init scripts to create the app database
|
||||||
|
APP_DB_NAME: ${POSTGRES_DB:-ai_conversations}
|
||||||
|
APP_DB_USER: ${POSTGRES_USER:-myapp_user}
|
||||||
|
APP_DB_PASSWORD: ${POSTGRES_PASSWORD:-secure_password_123}
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
- ../scripts/init_database:/docker-entrypoint-initdb.d
|
||||||
|
ports:
|
||||||
|
- "${POSTGRES_PORT:-5434}:5432"
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
interval: 10s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: no #unless-stopped
|
||||||
|
|
||||||
|
# Backend API server
|
||||||
|
backend:
|
||||||
|
build:
|
||||||
|
context: ..
|
||||||
|
dockerfile: docker/Dockerfile.prod
|
||||||
|
container_name: langchain-agent-backend
|
||||||
|
environment:
|
||||||
|
- PYTHONPATH=/app
|
||||||
|
- PYTHONUNBUFFERED=1
|
||||||
|
- CONN_STR=postgresql://${POSTGRES_USER:-myapp_user}:${POSTGRES_PASSWORD:-secure_password_123}@postgres:5432/${POSTGRES_DB:-ai_conversations}
|
||||||
|
- POSTGRES_USER=${POSTGRES_USER:-myapp_user}
|
||||||
|
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-secure_password_123}
|
||||||
|
- POSTGRES_DB=${POSTGRES_DB:-ai_conversations}
|
||||||
|
ports:
|
||||||
|
- "${BACKEND_PORT:-8500}:8500"
|
||||||
|
volumes:
|
||||||
|
- ../configs:/app/configs
|
||||||
|
- ../scripts:/app/scripts
|
||||||
|
- ../assets:/app/assets
|
||||||
|
- ../static:/app/static
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
depends_on:
|
||||||
|
postgres:
|
||||||
|
condition: service_healthy
|
||||||
|
restart: no #unless-stopped
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8500/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 40s
|
||||||
|
|
||||||
|
# Frontend build service
|
||||||
|
frontend:
|
||||||
|
build:
|
||||||
|
context: ../frontend
|
||||||
|
dockerfile: ../docker/Dockerfile.frontend
|
||||||
|
args:
|
||||||
|
VITE_FRONT_API_BASE_URL: ${VITE_FRONT_API_BASE_URL:-/}
|
||||||
|
volumes:
|
||||||
|
- frontend_dist:/app/dist
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
|
||||||
|
# Nginx for serving frontend (optional - can also serve via FastAPI)
|
||||||
|
nginx:
|
||||||
|
image: nginx:alpine
|
||||||
|
container_name: langchain-agent-nginx
|
||||||
|
networks:
|
||||||
|
- app-network
|
||||||
|
ports:
|
||||||
|
- "${FRONTEND_PORT:-8080}:80"
|
||||||
|
volumes:
|
||||||
|
- ../nginx.conf:/etc/nginx/nginx.conf:ro
|
||||||
|
- frontend_dist:/usr/share/nginx/html:ro
|
||||||
|
depends_on:
|
||||||
|
frontend:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
backend:
|
||||||
|
condition: service_started
|
||||||
|
restart: no #unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
frontend_dist:
|
||||||
|
|
||||||
|
networks:
|
||||||
|
app-network:
|
||||||
|
driver: bridge
|
||||||
|
|
||||||
@@ -1,333 +0,0 @@
|
|||||||
from typing import Dict, List, Optional
|
|
||||||
import os
|
|
||||||
import os.path as osp
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
# Ensure we can import from project root.
|
|
||||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
|
||||||
|
|
||||||
from lang_agent.config.db_config_manager import DBConfigManager
|
|
||||||
from lang_agent.front_api.build_server import GRAPH_BUILD_FNCS
|
|
||||||
|
|
||||||
class GraphConfigUpsertRequest(BaseModel):
|
|
||||||
graph_id: str
|
|
||||||
pipeline_id: str
|
|
||||||
prompt_set_id: Optional[str] = Field(default=None)
|
|
||||||
tool_keys: List[str] = Field(default_factory=list)
|
|
||||||
prompt_dict: Dict[str, str] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
class GraphConfigUpsertResponse(BaseModel):
|
|
||||||
graph_id: str
|
|
||||||
pipeline_id: str
|
|
||||||
prompt_set_id: str
|
|
||||||
tool_keys: List[str]
|
|
||||||
prompt_keys: List[str]
|
|
||||||
|
|
||||||
class GraphConfigReadResponse(BaseModel):
|
|
||||||
graph_id: Optional[str] = Field(default=None)
|
|
||||||
pipeline_id: str
|
|
||||||
prompt_set_id: str
|
|
||||||
tool_keys: List[str]
|
|
||||||
prompt_dict: Dict[str, str]
|
|
||||||
|
|
||||||
class GraphConfigListItem(BaseModel):
|
|
||||||
graph_id: Optional[str] = Field(default=None)
|
|
||||||
pipeline_id: str
|
|
||||||
prompt_set_id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
is_active: bool
|
|
||||||
tool_keys: List[str]
|
|
||||||
created_at: Optional[str] = Field(default=None)
|
|
||||||
updated_at: Optional[str] = Field(default=None)
|
|
||||||
|
|
||||||
class GraphConfigListResponse(BaseModel):
|
|
||||||
items: List[GraphConfigListItem]
|
|
||||||
count: int
|
|
||||||
|
|
||||||
class PipelineCreateRequest(BaseModel):
|
|
||||||
graph_id: str = Field(
|
|
||||||
description="Graph key from GRAPH_BUILD_FNCS, e.g. routing or react"
|
|
||||||
)
|
|
||||||
pipeline_id: str
|
|
||||||
prompt_set_id: str
|
|
||||||
tool_keys: List[str] = Field(default_factory=list)
|
|
||||||
port: int
|
|
||||||
entry_point: str = Field(default="fastapi_server/server_dashscope.py")
|
|
||||||
llm_name: str = Field(default="qwen-plus")
|
|
||||||
|
|
||||||
class PipelineCreateResponse(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
pid: int
|
|
||||||
graph_id: str
|
|
||||||
pipeline_id: str
|
|
||||||
prompt_set_id: str
|
|
||||||
url: str
|
|
||||||
port: int
|
|
||||||
|
|
||||||
class PipelineRunInfo(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
pid: int
|
|
||||||
graph_id: str
|
|
||||||
pipeline_id: str
|
|
||||||
prompt_set_id: str
|
|
||||||
url: str
|
|
||||||
port: int
|
|
||||||
|
|
||||||
class PipelineListResponse(BaseModel):
|
|
||||||
items: List[PipelineRunInfo]
|
|
||||||
count: int
|
|
||||||
|
|
||||||
class PipelineStopResponse(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
status: str
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
|
||||||
title="Front APIs",
|
|
||||||
description="Manage graph configs and launch graph pipelines.",
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
_db = DBConfigManager()
|
|
||||||
_running_pipelines: Dict[str, Dict[str, object]] = {}
|
|
||||||
|
|
||||||
def _prune_stopped_pipelines() -> None:
|
|
||||||
stale_ids: List[str] = []
|
|
||||||
for run_id, info in _running_pipelines.items():
|
|
||||||
proc = info["proc"]
|
|
||||||
if proc.poll() is not None:
|
|
||||||
stale_ids.append(run_id)
|
|
||||||
for run_id in stale_ids:
|
|
||||||
_running_pipelines.pop(run_id, None)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health():
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
return {
|
|
||||||
"message": "Front APIs",
|
|
||||||
"endpoints": [
|
|
||||||
"/v1/graph-configs (POST)",
|
|
||||||
"/v1/graph-configs (GET)",
|
|
||||||
"/v1/graph-configs/default/{pipeline_id} (GET)",
|
|
||||||
"/v1/graphs/{graph_id}/default-config (GET)",
|
|
||||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id} (GET)",
|
|
||||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id} (DELETE)",
|
|
||||||
"/v1/pipelines/graphs (GET)",
|
|
||||||
"/v1/pipelines (POST)",
|
|
||||||
"/v1/pipelines (GET)",
|
|
||||||
"/v1/pipelines/{run_id} (DELETE)",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/graph-configs", response_model=GraphConfigUpsertResponse)
|
|
||||||
async def upsert_graph_config(body: GraphConfigUpsertRequest):
|
|
||||||
try:
|
|
||||||
resolved_prompt_set_id = _db.set_config(
|
|
||||||
graph_id=body.graph_id,
|
|
||||||
pipeline_id=body.pipeline_id,
|
|
||||||
prompt_set_id=body.prompt_set_id,
|
|
||||||
tool_list=body.tool_keys,
|
|
||||||
prompt_dict=body.prompt_dict,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
return GraphConfigUpsertResponse(
|
|
||||||
graph_id=body.graph_id,
|
|
||||||
pipeline_id=body.pipeline_id,
|
|
||||||
prompt_set_id=resolved_prompt_set_id,
|
|
||||||
tool_keys=body.tool_keys,
|
|
||||||
prompt_keys=list(body.prompt_dict.keys()),
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/v1/graph-configs", response_model=GraphConfigListResponse)
|
|
||||||
async def list_graph_configs(pipeline_id: Optional[str] = None, graph_id: Optional[str] = None):
|
|
||||||
try:
|
|
||||||
rows = _db.list_prompt_sets(pipeline_id=pipeline_id, graph_id=graph_id)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
items = [GraphConfigListItem(**row) for row in rows]
|
|
||||||
return GraphConfigListResponse(items=items, count=len(items))
|
|
||||||
|
|
||||||
@app.get("/v1/graph-configs/default/{pipeline_id}", response_model=GraphConfigReadResponse)
|
|
||||||
async def get_default_graph_config(pipeline_id: str):
|
|
||||||
try:
|
|
||||||
prompt_dict, tool_keys = _db.get_config(pipeline_id=pipeline_id, prompt_set_id=None)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
if not prompt_dict and not tool_keys:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"No active prompt set found for pipeline '{pipeline_id}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = _db.list_prompt_sets(pipeline_id=pipeline_id)
|
|
||||||
active = next((row for row in rows if row["is_active"]), None)
|
|
||||||
if active is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"No active prompt set found for pipeline '{pipeline_id}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
return GraphConfigReadResponse(
|
|
||||||
graph_id=active.get("graph_id"),
|
|
||||||
pipeline_id=pipeline_id,
|
|
||||||
prompt_set_id=active["prompt_set_id"],
|
|
||||||
tool_keys=tool_keys,
|
|
||||||
prompt_dict=prompt_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.get("/v1/graphs/{graph_id}/default-config", response_model=GraphConfigReadResponse)
|
|
||||||
async def get_graph_default_config_by_graph(graph_id: str):
|
|
||||||
return await get_default_graph_config(pipeline_id=graph_id)
|
|
||||||
|
|
||||||
@app.get("/v1/graph-configs/{pipeline_id}/{prompt_set_id}", response_model=GraphConfigReadResponse)
|
|
||||||
async def get_graph_config(pipeline_id: str, prompt_set_id: str):
|
|
||||||
try:
|
|
||||||
meta = _db.get_prompt_set(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
|
||||||
if meta is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"prompt_set_id '{prompt_set_id}' not found for pipeline '{pipeline_id}'",
|
|
||||||
)
|
|
||||||
prompt_dict, tool_keys = _db.get_config(
|
|
||||||
pipeline_id=pipeline_id,
|
|
||||||
prompt_set_id=prompt_set_id,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
return GraphConfigReadResponse(
|
|
||||||
graph_id=meta.get("graph_id"),
|
|
||||||
pipeline_id=pipeline_id,
|
|
||||||
prompt_set_id=prompt_set_id,
|
|
||||||
tool_keys=tool_keys,
|
|
||||||
prompt_dict=prompt_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.delete("/v1/graph-configs/{pipeline_id}/{prompt_set_id}")
|
|
||||||
async def delete_graph_config(pipeline_id: str, prompt_set_id: str):
|
|
||||||
try:
|
|
||||||
_db.remove_config(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"status": "deleted",
|
|
||||||
"pipeline_id": pipeline_id,
|
|
||||||
"prompt_set_id": prompt_set_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/pipelines/graphs")
|
|
||||||
async def available_graphs():
|
|
||||||
return {"available_graphs": sorted(GRAPH_BUILD_FNCS.keys())}
|
|
||||||
|
|
||||||
@app.get("/v1/pipelines", response_model=PipelineListResponse)
|
|
||||||
async def list_running_pipelines():
|
|
||||||
_prune_stopped_pipelines()
|
|
||||||
items = [
|
|
||||||
PipelineRunInfo(
|
|
||||||
run_id=run_id,
|
|
||||||
pid=info["proc"].pid,
|
|
||||||
graph_id=info["graph_id"],
|
|
||||||
pipeline_id=info["pipeline_id"],
|
|
||||||
prompt_set_id=info["prompt_set_id"],
|
|
||||||
url=info["url"],
|
|
||||||
port=info["port"],
|
|
||||||
)
|
|
||||||
for run_id, info in _running_pipelines.items()
|
|
||||||
]
|
|
||||||
return PipelineListResponse(items=items, count=len(items))
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/pipelines", response_model=PipelineCreateResponse)
|
|
||||||
async def create_pipeline(body: PipelineCreateRequest):
|
|
||||||
build_fn = GRAPH_BUILD_FNCS.get(body.graph_id)
|
|
||||||
if build_fn is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail=f"Unknown graph_id '{body.graph_id}'. Valid options: {sorted(GRAPH_BUILD_FNCS.keys())}",
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
proc, url = build_fn(
|
|
||||||
pipeline_id=body.pipeline_id,
|
|
||||||
prompt_set=body.prompt_set_id,
|
|
||||||
tool_keys=body.tool_keys,
|
|
||||||
port=str(body.port),
|
|
||||||
entry_pnt=body.entry_point,
|
|
||||||
llm_name=body.llm_name,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=f"Failed to start pipeline: {e}")
|
|
||||||
|
|
||||||
run_id = str(uuid.uuid4())
|
|
||||||
_running_pipelines[run_id] = {
|
|
||||||
"proc": proc,
|
|
||||||
"graph_id": body.graph_id,
|
|
||||||
"pipeline_id": body.pipeline_id,
|
|
||||||
"prompt_set_id": body.prompt_set_id,
|
|
||||||
"url": url,
|
|
||||||
"port": body.port,
|
|
||||||
}
|
|
||||||
|
|
||||||
return PipelineCreateResponse(
|
|
||||||
run_id=run_id,
|
|
||||||
pid=proc.pid,
|
|
||||||
graph_id=body.graph_id,
|
|
||||||
pipeline_id=body.pipeline_id,
|
|
||||||
prompt_set_id=body.prompt_set_id,
|
|
||||||
url=url,
|
|
||||||
port=body.port,
|
|
||||||
)
|
|
||||||
|
|
||||||
@app.delete("/v1/pipelines/{run_id}", response_model=PipelineStopResponse)
|
|
||||||
async def stop_pipeline(run_id: str):
|
|
||||||
info = _running_pipelines.pop(run_id, None)
|
|
||||||
if info is None:
|
|
||||||
raise HTTPException(status_code=404, detail=f"run_id '{run_id}' not found")
|
|
||||||
|
|
||||||
proc = info["proc"]
|
|
||||||
if proc.poll() is None:
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=5)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
proc.kill()
|
|
||||||
proc.wait(timeout=5)
|
|
||||||
|
|
||||||
return PipelineStopResponse(run_id=run_id, status="stopped")
|
|
||||||
@@ -1,331 +0,0 @@
|
|||||||
from fastapi import FastAPI, HTTPException, Path, Request, Depends, Security
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.responses import StreamingResponse, JSONResponse
|
|
||||||
from fastapi.security import APIKeyHeader
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Optional
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import json
|
|
||||||
import uvicorn
|
|
||||||
from loguru import logger
|
|
||||||
import tyro
|
|
||||||
|
|
||||||
# Ensure we can import from project root
|
|
||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
||||||
|
|
||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
|
||||||
|
|
||||||
# Initialize Pipeline once
|
|
||||||
pipeline_config = tyro.cli(PipelineConfig)
|
|
||||||
logger.info(f"starting agent with pipeline: \n{pipeline_config}")
|
|
||||||
pipeline:Pipeline = pipeline_config.setup()
|
|
||||||
|
|
||||||
# API Key Authentication
|
|
||||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
|
||||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
|
||||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
|
||||||
key = api_key[7:] if api_key.startswith("Bearer ") else api_key
|
|
||||||
if VALID_API_KEYS and key not in VALID_API_KEYS:
|
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
class DSMessage(BaseModel):
|
|
||||||
role: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class DSApplicationCallRequest(BaseModel):
|
|
||||||
api_key: Optional[str] = Field(default=None)
|
|
||||||
app_id: Optional[str] = Field(default=None)
|
|
||||||
session_id: Optional[str] = Field(default=None)
|
|
||||||
messages: List[DSMessage]
|
|
||||||
stream: bool = Field(default=True)
|
|
||||||
# Optional overrides for pipeline behavior
|
|
||||||
thread_id: Optional[str] = Field(default="3")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="DashScope-Compatible Application API",
|
|
||||||
description="DashScope Application.call compatible endpoint backed by pipeline.chat")
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen-flash"):
|
|
||||||
"""
|
|
||||||
Stream chunks from pipeline and format as SSE.
|
|
||||||
Accumulates text and sends incremental updates.
|
|
||||||
DashScope SDK expects accumulated text in each chunk (not deltas).
|
|
||||||
"""
|
|
||||||
created_time = int(time.time())
|
|
||||||
accumulated_text = ""
|
|
||||||
|
|
||||||
for chunk in chunk_generator:
|
|
||||||
if chunk:
|
|
||||||
accumulated_text += chunk
|
|
||||||
data = {
|
|
||||||
"request_id": response_id,
|
|
||||||
"code": 200,
|
|
||||||
"message": "OK",
|
|
||||||
"output": {
|
|
||||||
# DashScope SDK expects accumulated text, not empty or delta
|
|
||||||
"text": accumulated_text,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model,
|
|
||||||
},
|
|
||||||
"is_end": False,
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
# Final message with complete text
|
|
||||||
final = {
|
|
||||||
"request_id": response_id,
|
|
||||||
"code": 200,
|
|
||||||
"message": "OK",
|
|
||||||
"output": {
|
|
||||||
"text": accumulated_text,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model,
|
|
||||||
},
|
|
||||||
"is_end": True,
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(final)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str = "qwen-flash"):
|
|
||||||
"""
|
|
||||||
Async version: Stream chunks from pipeline and format as SSE.
|
|
||||||
Accumulates text and sends incremental updates.
|
|
||||||
DashScope SDK expects accumulated text in each chunk (not deltas).
|
|
||||||
"""
|
|
||||||
created_time = int(time.time())
|
|
||||||
accumulated_text = ""
|
|
||||||
|
|
||||||
async for chunk in chunk_generator:
|
|
||||||
if chunk:
|
|
||||||
accumulated_text += chunk
|
|
||||||
data = {
|
|
||||||
"request_id": response_id,
|
|
||||||
"code": 200,
|
|
||||||
"message": "OK",
|
|
||||||
"output": {
|
|
||||||
"text": accumulated_text,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model,
|
|
||||||
},
|
|
||||||
"is_end": False,
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(data)}\n\n"
|
|
||||||
|
|
||||||
# Final message with complete text
|
|
||||||
final = {
|
|
||||||
"request_id": response_id,
|
|
||||||
"code": 200,
|
|
||||||
"message": "OK",
|
|
||||||
"output": {
|
|
||||||
"text": accumulated_text,
|
|
||||||
"created": created_time,
|
|
||||||
"model": model,
|
|
||||||
},
|
|
||||||
"is_end": True,
|
|
||||||
}
|
|
||||||
yield f"data: {json.dumps(final)}\n\n"
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
|
|
||||||
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
|
|
||||||
async def application_responses(
|
|
||||||
request: Request,
|
|
||||||
app_id: str = Path(...),
|
|
||||||
session_id: str = Path(...),
|
|
||||||
_: str = Depends(verify_api_key),
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
body = await request.json()
|
|
||||||
|
|
||||||
# Prefer path params
|
|
||||||
req_app_id = app_id or body.get("app_id")
|
|
||||||
req_session_id = session_id or body['input'].get("session_id")
|
|
||||||
|
|
||||||
# Normalize messages
|
|
||||||
messages = body.get("messages")
|
|
||||||
if messages is None and isinstance(body.get("input"), dict):
|
|
||||||
messages = body.get("input", {}).get("messages")
|
|
||||||
if messages is None and isinstance(body.get("input"), dict):
|
|
||||||
prompt = body.get("input", {}).get("prompt")
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
raise HTTPException(status_code=400, detail="messages is required")
|
|
||||||
|
|
||||||
# Determine stream flag
|
|
||||||
stream = body.get("stream")
|
|
||||||
if stream is None:
|
|
||||||
stream = body.get("parameters", {}).get("stream", True)
|
|
||||||
|
|
||||||
thread_id = body['input'].get("session_id")
|
|
||||||
|
|
||||||
# Extract latest user message
|
|
||||||
user_msg = None
|
|
||||||
for m in reversed(messages):
|
|
||||||
role = m.get("role") if isinstance(m, dict) else None
|
|
||||||
content = m.get("content") if isinstance(m, dict) else None
|
|
||||||
if role == "user" and content:
|
|
||||||
user_msg = content
|
|
||||||
break
|
|
||||||
if user_msg is None:
|
|
||||||
last = messages[-1]
|
|
||||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
|
||||||
|
|
||||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
# Use async streaming from pipeline
|
|
||||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Non-streaming: get full result using async
|
|
||||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
|
||||||
if not isinstance(result_text, str):
|
|
||||||
result_text = str(result_text)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"request_id": response_id,
|
|
||||||
"code": 200,
|
|
||||||
"message": "OK",
|
|
||||||
"app_id": req_app_id,
|
|
||||||
"session_id": req_session_id,
|
|
||||||
"output": {
|
|
||||||
"text": result_text,
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": pipeline_config.llm_name,
|
|
||||||
},
|
|
||||||
"is_end": True,
|
|
||||||
}
|
|
||||||
return JSONResponse(content=data)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"DashScope-compatible endpoint error: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and without session in path
|
|
||||||
@app.post("/apps/{app_id}/completion")
|
|
||||||
@app.post("/v1/apps/{app_id}/completion")
|
|
||||||
@app.post("/api/apps/{app_id}/completion")
|
|
||||||
@app.post("/api/v1/apps/{app_id}/completion")
|
|
||||||
async def application_completion(
|
|
||||||
request: Request,
|
|
||||||
app_id: str = Path(...),
|
|
||||||
_: str = Depends(verify_api_key),
|
|
||||||
):
|
|
||||||
try:
|
|
||||||
body = await request.json()
|
|
||||||
|
|
||||||
req_session_id = body['input'].get("session_id")
|
|
||||||
|
|
||||||
# Normalize messages
|
|
||||||
messages = body.get("messages")
|
|
||||||
if messages is None and isinstance(body.get("input"), dict):
|
|
||||||
messages = body.get("input", {}).get("messages")
|
|
||||||
if messages is None and isinstance(body.get("input"), dict):
|
|
||||||
prompt = body.get("input", {}).get("prompt")
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
messages = [{"role": "user", "content": prompt}]
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
raise HTTPException(status_code=400, detail="messages is required")
|
|
||||||
|
|
||||||
stream = body.get("stream")
|
|
||||||
if stream is None:
|
|
||||||
stream = body.get("parameters", {}).get("stream", True)
|
|
||||||
|
|
||||||
thread_id = body['input'].get("session_id")
|
|
||||||
|
|
||||||
user_msg = None
|
|
||||||
for m in reversed(messages):
|
|
||||||
role = m.get("role") if isinstance(m, dict) else None
|
|
||||||
content = m.get("content") if isinstance(m, dict) else None
|
|
||||||
if role == "user" and content:
|
|
||||||
user_msg = content
|
|
||||||
break
|
|
||||||
if user_msg is None:
|
|
||||||
last = messages[-1]
|
|
||||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
|
||||||
|
|
||||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
# Use async streaming from pipeline
|
|
||||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
|
||||||
return StreamingResponse(
|
|
||||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=pipeline_config.llm_name),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Non-streaming: get full result using async
|
|
||||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
|
||||||
if not isinstance(result_text, str):
|
|
||||||
result_text = str(result_text)
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"request_id": response_id,
|
|
||||||
"code": 200,
|
|
||||||
"message": "OK",
|
|
||||||
"app_id": app_id,
|
|
||||||
"session_id": req_session_id,
|
|
||||||
"output": {
|
|
||||||
"text": result_text,
|
|
||||||
"created": int(time.time()),
|
|
||||||
"model": pipeline_config.llm_name,
|
|
||||||
},
|
|
||||||
"is_end": True,
|
|
||||||
}
|
|
||||||
return JSONResponse(content=data)
|
|
||||||
|
|
||||||
except HTTPException:
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"DashScope-compatible completion error: {e}")
|
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/")
|
|
||||||
async def root():
|
|
||||||
return {"message": "DashScope Application-compatible API", "endpoints": [
|
|
||||||
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
|
||||||
"/health",
|
|
||||||
]}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health():
|
|
||||||
return {"status": "healthy"}
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
uvicorn.run(
|
|
||||||
"server_dashscope:app",
|
|
||||||
host="0.0.0.0",
|
|
||||||
port=pipeline_config.port,
|
|
||||||
reload=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
BIN
frontend/assets/images/graph_arch/deepagent.png
Normal file
BIN
frontend/assets/images/graph_arch/deepagent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
1882
frontend/package-lock.json
generated
1882
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -6,18 +6,22 @@
|
|||||||
"scripts": {
|
"scripts": {
|
||||||
"dev": "vite",
|
"dev": "vite",
|
||||||
"build": "tsc -b && vite build",
|
"build": "tsc -b && vite build",
|
||||||
"preview": "vite preview"
|
"preview": "vite preview",
|
||||||
|
"test": "vitest run"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"react": "^18.3.1",
|
"react": "^18.3.1",
|
||||||
"react-dom": "^18.3.1"
|
"react-dom": "^18.3.1",
|
||||||
|
"react-markdown": "^10.1.0",
|
||||||
|
"remark-gfm": "^4.0.1"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"@types/react": "^18.3.20",
|
"@types/react": "^18.3.20",
|
||||||
"@types/react-dom": "^18.3.6",
|
"@types/react-dom": "^18.3.6",
|
||||||
"@vitejs/plugin-react": "^4.3.1",
|
"@vitejs/plugin-react": "^4.3.1",
|
||||||
"typescript": "^5.6.3",
|
"typescript": "^5.6.3",
|
||||||
"vite": "^5.4.10"
|
"vite": "^5.4.10",
|
||||||
|
"vitest": "^2.1.8"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
1515
frontend/src/App.tsx
1515
frontend/src/App.tsx
File diff suppressed because it is too large
Load Diff
79
frontend/src/activeConfigSelection.test.ts
Normal file
79
frontend/src/activeConfigSelection.test.ts
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import { describe, expect, it } from "vitest";
|
||||||
|
|
||||||
|
import { chooseActiveConfigItem, chooseDisplayItemsByPipeline } from "./activeConfigSelection";
|
||||||
|
import type { GraphConfigListItem } from "./types";
|
||||||
|
|
||||||
|
const mk = (patch: Partial<GraphConfigListItem>): GraphConfigListItem => ({
|
||||||
|
graph_id: "routing",
|
||||||
|
pipeline_id: "agent-a",
|
||||||
|
prompt_set_id: "set-1",
|
||||||
|
name: "default",
|
||||||
|
description: "",
|
||||||
|
is_active: false,
|
||||||
|
tool_keys: [],
|
||||||
|
api_key: "",
|
||||||
|
created_at: null,
|
||||||
|
updated_at: null,
|
||||||
|
...patch,
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("chooseActiveConfigItem", () => {
|
||||||
|
it("prefers active item over newer inactive items", () => {
|
||||||
|
const items = [
|
||||||
|
mk({
|
||||||
|
pipeline_id: "agent-a",
|
||||||
|
prompt_set_id: "old-active",
|
||||||
|
is_active: true,
|
||||||
|
updated_at: "2025-01-01T00:00:00Z",
|
||||||
|
}),
|
||||||
|
mk({
|
||||||
|
pipeline_id: "agent-a",
|
||||||
|
prompt_set_id: "new-inactive",
|
||||||
|
is_active: false,
|
||||||
|
updated_at: "2025-03-01T00:00:00Z",
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
const selected = chooseActiveConfigItem(items, "agent-a");
|
||||||
|
expect(selected?.prompt_set_id).toBe("old-active");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("falls back to latest updated_at when no active item exists", () => {
|
||||||
|
const items = [
|
||||||
|
mk({
|
||||||
|
pipeline_id: "agent-b",
|
||||||
|
prompt_set_id: "set-1",
|
||||||
|
updated_at: "2025-01-01T00:00:00Z",
|
||||||
|
}),
|
||||||
|
mk({
|
||||||
|
pipeline_id: "agent-b",
|
||||||
|
prompt_set_id: "set-2",
|
||||||
|
updated_at: "2025-02-01T00:00:00Z",
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
const selected = chooseActiveConfigItem(items, "agent-b");
|
||||||
|
expect(selected?.prompt_set_id).toBe("set-2");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("chooseDisplayItemsByPipeline", () => {
|
||||||
|
it("returns one selected item per pipeline_id", () => {
|
||||||
|
const items = [
|
||||||
|
mk({ pipeline_id: "agent-b", prompt_set_id: "set-1", updated_at: "2025-01-01T00:00:00Z" }),
|
||||||
|
mk({
|
||||||
|
pipeline_id: "agent-b",
|
||||||
|
prompt_set_id: "set-2",
|
||||||
|
is_active: true,
|
||||||
|
updated_at: "2025-02-01T00:00:00Z",
|
||||||
|
}),
|
||||||
|
mk({
|
||||||
|
pipeline_id: "agent-a",
|
||||||
|
prompt_set_id: "set-3",
|
||||||
|
updated_at: "2025-03-01T00:00:00Z",
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
const selected = chooseDisplayItemsByPipeline(items);
|
||||||
|
expect(selected.map((x) => x.pipeline_id)).toEqual(["agent-a", "agent-b"]);
|
||||||
|
expect(selected.find((x) => x.pipeline_id === "agent-b")?.prompt_set_id).toBe("set-2");
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
45
frontend/src/activeConfigSelection.ts
Normal file
45
frontend/src/activeConfigSelection.ts
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
import type { GraphConfigListItem } from "./types";
|
||||||
|
|
||||||
|
function toTimestamp(value?: string | null): number {
|
||||||
|
if (!value) {
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
const parsed = Date.parse(value);
|
||||||
|
return Number.isNaN(parsed) ? 0 : parsed;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function chooseActiveConfigItem(
|
||||||
|
items: GraphConfigListItem[],
|
||||||
|
pipelineId: string
|
||||||
|
): GraphConfigListItem | null {
|
||||||
|
const candidates = items.filter((item) => item.pipeline_id === pipelineId);
|
||||||
|
if (candidates.length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const active = candidates.find((item) => item.is_active);
|
||||||
|
if (active) {
|
||||||
|
return active;
|
||||||
|
}
|
||||||
|
return [...candidates].sort((a, b) => toTimestamp(b.updated_at) - toTimestamp(a.updated_at))[0];
|
||||||
|
}
|
||||||
|
|
||||||
|
export function chooseDisplayItemsByPipeline(
|
||||||
|
items: GraphConfigListItem[]
|
||||||
|
): GraphConfigListItem[] {
|
||||||
|
const byPipeline = new Map<string, GraphConfigListItem[]>();
|
||||||
|
for (const item of items) {
|
||||||
|
const list = byPipeline.get(item.pipeline_id) || [];
|
||||||
|
list.push(item);
|
||||||
|
byPipeline.set(item.pipeline_id, list);
|
||||||
|
}
|
||||||
|
|
||||||
|
const out: GraphConfigListItem[] = [];
|
||||||
|
for (const [pipelineId, list] of byPipeline.entries()) {
|
||||||
|
const selected = chooseActiveConfigItem(list, pipelineId);
|
||||||
|
if (selected) {
|
||||||
|
out.push(selected);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return out.sort((a, b) => a.pipeline_id.localeCompare(b.pipeline_id));
|
||||||
|
}
|
||||||
|
|
||||||
22
frontend/src/api/frontApis.test.ts
Normal file
22
frontend/src/api/frontApis.test.ts
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import { describe, expect, it } from "vitest";
|
||||||
|
|
||||||
|
import { joinApiUrl } from "./frontApis";
|
||||||
|
|
||||||
|
describe("joinApiUrl", () => {
|
||||||
|
it("keeps same-origin paths when base url is slash", () => {
|
||||||
|
expect(joinApiUrl("/", "/v1/pipelines")).toBe("/v1/pipelines");
|
||||||
|
});
|
||||||
|
|
||||||
|
it("joins absolute host and trims trailing slash", () => {
|
||||||
|
expect(joinApiUrl("http://127.0.0.1:8500/", "/v1/pipelines")).toBe(
|
||||||
|
"http://127.0.0.1:8500/v1/pipelines"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("accepts path without leading slash", () => {
|
||||||
|
expect(joinApiUrl("http://127.0.0.1:8500", "v1/pipelines")).toBe(
|
||||||
|
"http://127.0.0.1:8500/v1/pipelines"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
@@ -1,20 +1,48 @@
|
|||||||
import type {
|
import type {
|
||||||
AvailableGraphsResponse,
|
AvailableGraphsResponse,
|
||||||
|
ConversationListItem,
|
||||||
|
ConversationMessageItem,
|
||||||
GraphConfigListResponse,
|
GraphConfigListResponse,
|
||||||
GraphConfigReadResponse,
|
GraphConfigReadResponse,
|
||||||
GraphConfigUpsertRequest,
|
GraphConfigUpsertRequest,
|
||||||
GraphConfigUpsertResponse,
|
GraphConfigUpsertResponse,
|
||||||
|
McpAvailableToolsResponse,
|
||||||
|
McpToolConfigResponse,
|
||||||
|
McpToolConfigUpdateRequest,
|
||||||
|
McpToolConfigUpdateResponse,
|
||||||
PipelineCreateRequest,
|
PipelineCreateRequest,
|
||||||
PipelineCreateResponse,
|
PipelineCreateResponse,
|
||||||
|
PipelineConversationListResponse,
|
||||||
|
PipelineConversationMessagesResponse,
|
||||||
PipelineListResponse,
|
PipelineListResponse,
|
||||||
PipelineStopResponse,
|
PipelineStopResponse,
|
||||||
|
RuntimeAuthInfoResponse,
|
||||||
} from "../types";
|
} from "../types";
|
||||||
|
|
||||||
const API_BASE_URL =
|
const API_BASE_URL = import.meta.env.VITE_FRONT_API_BASE_URL?.trim() || "/";
|
||||||
import.meta.env.VITE_FRONT_API_BASE_URL?.trim() || "http://127.0.0.1:8001";
|
|
||||||
|
export function joinApiUrl(baseUrl: string, path: string): string {
|
||||||
|
const normalizedPath = path.startsWith("/") ? path : `/${path}`;
|
||||||
|
const normalizedBase = baseUrl.trim();
|
||||||
|
|
||||||
|
// "/" is commonly used in Docker+nginx builds and should resolve as same-origin.
|
||||||
|
if (!normalizedBase || normalizedBase === "/") {
|
||||||
|
return normalizedPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
return `${normalizedBase.replace(/\/+$/, "")}${normalizedPath}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log which backend the frontend is targeting on startup, with file + line hint.
|
||||||
|
// This runs once when the module is loaded.
|
||||||
|
// eslint-disable-next-line no-console
|
||||||
|
console.info(
|
||||||
|
`[frontend] Using FRONT_API_BASE_URL=${API_BASE_URL} (src/api/frontApis.ts:16)`
|
||||||
|
);
|
||||||
|
|
||||||
async function fetchJson<T>(path: string, init?: RequestInit): Promise<T> {
|
async function fetchJson<T>(path: string, init?: RequestInit): Promise<T> {
|
||||||
const response = await fetch(`${API_BASE_URL}${path}`, {
|
const url = joinApiUrl(API_BASE_URL, path);
|
||||||
|
const response = await fetch(url, {
|
||||||
headers: {
|
headers: {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
...(init?.headers || {}),
|
...(init?.headers || {}),
|
||||||
@@ -33,7 +61,24 @@ async function fetchJson<T>(path: string, init?: RequestInit): Promise<T> {
|
|||||||
}
|
}
|
||||||
throw new Error(message);
|
throw new Error(message);
|
||||||
}
|
}
|
||||||
return (await response.json()) as T;
|
|
||||||
|
if (response.status === 204) {
|
||||||
|
return undefined as T;
|
||||||
|
}
|
||||||
|
|
||||||
|
const bodyText = await response.text();
|
||||||
|
if (!bodyText.trim()) {
|
||||||
|
return undefined as T;
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return JSON.parse(bodyText) as T;
|
||||||
|
} catch {
|
||||||
|
const preview = bodyText.slice(0, 160).replace(/\s+/g, " ").trim();
|
||||||
|
throw new Error(
|
||||||
|
`Expected JSON response from ${url}, but received non-JSON content: ${preview || "<empty>"}`
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function listAvailableGraphs(): Promise<AvailableGraphsResponse> {
|
export function listAvailableGraphs(): Promise<AvailableGraphsResponse> {
|
||||||
@@ -67,6 +112,12 @@ export function getGraphDefaultConfig(
|
|||||||
return fetchJson(`/v1/graphs/${graphId}/default-config`);
|
return fetchJson(`/v1/graphs/${graphId}/default-config`);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getPipelineDefaultConfig(
|
||||||
|
pipelineId: string
|
||||||
|
): Promise<GraphConfigReadResponse> {
|
||||||
|
return fetchJson(`/v1/graph-configs/default/${pipelineId}`);
|
||||||
|
}
|
||||||
|
|
||||||
export function upsertGraphConfig(
|
export function upsertGraphConfig(
|
||||||
payload: GraphConfigUpsertRequest
|
payload: GraphConfigUpsertRequest
|
||||||
): Promise<GraphConfigUpsertResponse> {
|
): Promise<GraphConfigUpsertResponse> {
|
||||||
@@ -85,6 +136,23 @@ export function deleteGraphConfig(
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getMcpToolConfig(): Promise<McpToolConfigResponse> {
|
||||||
|
return fetchJson("/v1/tool-configs/mcp");
|
||||||
|
}
|
||||||
|
|
||||||
|
export function updateMcpToolConfig(
|
||||||
|
payload: McpToolConfigUpdateRequest
|
||||||
|
): Promise<McpToolConfigUpdateResponse> {
|
||||||
|
return fetchJson("/v1/tool-configs/mcp", {
|
||||||
|
method: "PUT",
|
||||||
|
body: JSON.stringify(payload),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function listMcpAvailableTools(): Promise<McpAvailableToolsResponse> {
|
||||||
|
return fetchJson("/v1/tool-configs/mcp/tools");
|
||||||
|
}
|
||||||
|
|
||||||
export function createPipeline(
|
export function createPipeline(
|
||||||
payload: PipelineCreateRequest
|
payload: PipelineCreateRequest
|
||||||
): Promise<PipelineCreateResponse> {
|
): Promise<PipelineCreateResponse> {
|
||||||
@@ -98,9 +166,144 @@ export function listPipelines(): Promise<PipelineListResponse> {
|
|||||||
return fetchJson("/v1/pipelines");
|
return fetchJson("/v1/pipelines");
|
||||||
}
|
}
|
||||||
|
|
||||||
export function stopPipeline(runId: string): Promise<PipelineStopResponse> {
|
export function stopPipeline(pipelineId: string): Promise<PipelineStopResponse> {
|
||||||
return fetchJson(`/v1/pipelines/${runId}`, {
|
return fetchJson(`/v1/pipelines/${pipelineId}`, {
|
||||||
method: "DELETE",
|
method: "DELETE",
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getRuntimeAuthInfo(): Promise<RuntimeAuthInfoResponse> {
|
||||||
|
return fetchJson("/v1/runtime-auth");
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listPipelineConversations(
|
||||||
|
pipelineId: string,
|
||||||
|
limit = 100
|
||||||
|
): Promise<ConversationListItem[]> {
|
||||||
|
const response = await fetchJson<PipelineConversationListResponse>(
|
||||||
|
`/v1/pipelines/${encodeURIComponent(pipelineId)}/conversations?limit=${limit}`
|
||||||
|
);
|
||||||
|
return response.items || [];
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function getPipelineConversationMessages(
|
||||||
|
pipelineId: string,
|
||||||
|
conversationId: string
|
||||||
|
): Promise<ConversationMessageItem[]> {
|
||||||
|
const response = await fetchJson<PipelineConversationMessagesResponse>(
|
||||||
|
`/v1/pipelines/${encodeURIComponent(pipelineId)}/conversations/${encodeURIComponent(conversationId)}/messages`
|
||||||
|
);
|
||||||
|
return response.items || [];
|
||||||
|
}
|
||||||
|
|
||||||
|
type StreamAgentChatOptions = {
|
||||||
|
appId: string;
|
||||||
|
sessionId: string;
|
||||||
|
apiKey: string;
|
||||||
|
message: string;
|
||||||
|
onText: (text: string) => void;
|
||||||
|
signal?: AbortSignal;
|
||||||
|
};
|
||||||
|
|
||||||
|
function parseErrorDetail(payload: unknown): string | null {
|
||||||
|
if (!payload || typeof payload !== "object") {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
const detail = (payload as { detail?: unknown }).detail;
|
||||||
|
return typeof detail === "string" && detail.trim() ? detail : null;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function streamAgentChatResponse(
|
||||||
|
options: StreamAgentChatOptions
|
||||||
|
): Promise<string> {
|
||||||
|
const { appId, sessionId, apiKey, message, onText, signal } = options;
|
||||||
|
const response = await fetch(
|
||||||
|
joinApiUrl(
|
||||||
|
API_BASE_URL,
|
||||||
|
`/v1/apps/${encodeURIComponent(appId)}/sessions/${encodeURIComponent(sessionId)}/responses`
|
||||||
|
),
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
Authorization: `Bearer ${apiKey}`,
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
messages: [{ role: "user", content: message }],
|
||||||
|
stream: true,
|
||||||
|
}),
|
||||||
|
signal,
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
let messageText = `Request failed (${response.status})`;
|
||||||
|
try {
|
||||||
|
const payload = (await response.json()) as unknown;
|
||||||
|
const detail = parseErrorDetail(payload);
|
||||||
|
if (detail) {
|
||||||
|
messageText = detail;
|
||||||
|
}
|
||||||
|
} catch {
|
||||||
|
// Keep fallback status-based message.
|
||||||
|
}
|
||||||
|
throw new Error(messageText);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!response.body) {
|
||||||
|
throw new Error("Streaming response is not available.");
|
||||||
|
}
|
||||||
|
|
||||||
|
const reader = response.body.getReader();
|
||||||
|
const decoder = new TextDecoder();
|
||||||
|
let buffered = "";
|
||||||
|
let latestText = "";
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
if (signal?.aborted) {
|
||||||
|
reader.cancel();
|
||||||
|
throw new Error("Request cancelled");
|
||||||
|
}
|
||||||
|
const { value, done } = await reader.read();
|
||||||
|
if (done) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
buffered += decoder.decode(value, { stream: true });
|
||||||
|
|
||||||
|
let splitIndex = buffered.indexOf("\n\n");
|
||||||
|
while (splitIndex >= 0) {
|
||||||
|
const eventBlock = buffered.slice(0, splitIndex);
|
||||||
|
buffered = buffered.slice(splitIndex + 2);
|
||||||
|
splitIndex = buffered.indexOf("\n\n");
|
||||||
|
|
||||||
|
const lines = eventBlock.split("\n");
|
||||||
|
for (const rawLine of lines) {
|
||||||
|
const line = rawLine.trim();
|
||||||
|
if (!line.startsWith("data:")) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const payloadRaw = line.slice(5).trim();
|
||||||
|
if (!payloadRaw) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let payload: unknown;
|
||||||
|
try {
|
||||||
|
payload = JSON.parse(payloadRaw);
|
||||||
|
} catch {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
const nextText =
|
||||||
|
typeof (payload as { output?: { text?: unknown } })?.output?.text === "string"
|
||||||
|
? ((payload as { output: { text: string } }).output.text as string)
|
||||||
|
: "";
|
||||||
|
if (nextText !== latestText) {
|
||||||
|
latestText = nextText;
|
||||||
|
onText(latestText);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return latestText;
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,6 +24,10 @@ body {
|
|||||||
min-height: 100vh;
|
min-height: 100vh;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.app.full-width {
|
||||||
|
grid-template-columns: 1fr;
|
||||||
|
}
|
||||||
|
|
||||||
.sidebar {
|
.sidebar {
|
||||||
border-right: 1px solid #dbe2ea;
|
border-right: 1px solid #dbe2ea;
|
||||||
background: #ffffff;
|
background: #ffffff;
|
||||||
@@ -61,6 +65,12 @@ button:disabled {
|
|||||||
gap: 8px;
|
gap: 8px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.agent-item-row {
|
||||||
|
display: grid;
|
||||||
|
gap: 6px;
|
||||||
|
grid-template-columns: 1fr auto;
|
||||||
|
}
|
||||||
|
|
||||||
.agent-item {
|
.agent-item {
|
||||||
align-items: flex-start;
|
align-items: flex-start;
|
||||||
display: flex;
|
display: flex;
|
||||||
@@ -69,6 +79,33 @@ button:disabled {
|
|||||||
width: 100%;
|
width: 100%;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.agent-item-title {
|
||||||
|
align-items: center;
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
justify-content: space-between;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-status-pill {
|
||||||
|
border-radius: 999px;
|
||||||
|
font-size: 11px;
|
||||||
|
font-weight: 600;
|
||||||
|
padding: 2px 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-status-pill.running {
|
||||||
|
background: #dff7e7;
|
||||||
|
border: 1px solid #8cd3a1;
|
||||||
|
color: #1a6b35;
|
||||||
|
}
|
||||||
|
|
||||||
|
.agent-status-pill.stopped {
|
||||||
|
background: #f2f4f7;
|
||||||
|
border: 1px solid #d1d8e0;
|
||||||
|
color: #4a5565;
|
||||||
|
}
|
||||||
|
|
||||||
.agent-item.selected {
|
.agent-item.selected {
|
||||||
border-color: #4d7ef3;
|
border-color: #4d7ef3;
|
||||||
background: #edf3ff;
|
background: #edf3ff;
|
||||||
@@ -78,6 +115,11 @@ button:disabled {
|
|||||||
color: #5f6f82;
|
color: #5f6f82;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.agent-chat-button {
|
||||||
|
align-self: stretch;
|
||||||
|
min-width: 64px;
|
||||||
|
}
|
||||||
|
|
||||||
.content {
|
.content {
|
||||||
padding: 20px;
|
padding: 20px;
|
||||||
}
|
}
|
||||||
@@ -94,6 +136,21 @@ button:disabled {
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.tabs {
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tab-button {
|
||||||
|
min-width: 120px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.tab-button.active {
|
||||||
|
background: #edf3ff;
|
||||||
|
border-color: #4d7ef3;
|
||||||
|
color: #1a4fc5;
|
||||||
|
}
|
||||||
|
|
||||||
.header-actions {
|
.header-actions {
|
||||||
display: flex;
|
display: flex;
|
||||||
gap: 8px;
|
gap: 8px;
|
||||||
@@ -108,10 +165,41 @@ button:disabled {
|
|||||||
padding: 10px;
|
padding: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.launch-credentials {
|
||||||
|
background: #fff4df;
|
||||||
|
border: 1px solid #f0d5a8;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin-top: 12px;
|
||||||
|
padding: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.launch-credentials h3 {
|
||||||
|
margin: 0 0 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.launch-credentials > div {
|
||||||
|
align-items: center;
|
||||||
|
display: flex;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
gap: 8px;
|
||||||
|
margin: 6px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.launch-credentials code {
|
||||||
|
background: #fff;
|
||||||
|
border: 1px solid #f0d5a8;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 2px 6px;
|
||||||
|
}
|
||||||
|
|
||||||
.empty-panel {
|
.empty-panel {
|
||||||
margin-top: 30px;
|
margin-top: 30px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.tab-pane {
|
||||||
|
margin-top: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
.form-grid {
|
.form-grid {
|
||||||
display: grid;
|
display: grid;
|
||||||
gap: 14px;
|
gap: 14px;
|
||||||
@@ -136,6 +224,7 @@ button:disabled {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.prompt-section,
|
.prompt-section,
|
||||||
|
.file-backend-config,
|
||||||
.run-info {
|
.run-info {
|
||||||
border: 1px solid #dbe2ea;
|
border: 1px solid #dbe2ea;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
@@ -144,10 +233,37 @@ button:disabled {
|
|||||||
}
|
}
|
||||||
|
|
||||||
.prompt-section h3,
|
.prompt-section h3,
|
||||||
|
.file-backend-config h3,
|
||||||
.run-info h3 {
|
.run-info h3 {
|
||||||
margin-top: 0;
|
margin-top: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.run-info-header {
|
||||||
|
align-items: center;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.runtime-badge {
|
||||||
|
border-radius: 999px;
|
||||||
|
font-size: 12px;
|
||||||
|
font-weight: 600;
|
||||||
|
padding: 4px 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.runtime-badge.running {
|
||||||
|
background: #dff7e7;
|
||||||
|
border: 1px solid #8cd3a1;
|
||||||
|
color: #1a6b35;
|
||||||
|
}
|
||||||
|
|
||||||
|
.runtime-badge.stopped {
|
||||||
|
background: #f2f4f7;
|
||||||
|
border: 1px solid #d1d8e0;
|
||||||
|
color: #4a5565;
|
||||||
|
}
|
||||||
|
|
||||||
.graph-arch-section {
|
.graph-arch-section {
|
||||||
border: 1px solid #dbe2ea;
|
border: 1px solid #dbe2ea;
|
||||||
border-radius: 10px;
|
border-radius: 10px;
|
||||||
@@ -179,8 +295,363 @@ button:disabled {
|
|||||||
padding: 10px;
|
padding: 10px;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.run-card-columns {
|
||||||
|
display: grid;
|
||||||
|
gap: 12px;
|
||||||
|
grid-template-columns: minmax(280px, 1fr) minmax(420px, 1.3fr);
|
||||||
|
}
|
||||||
|
|
||||||
|
.run-card-left,
|
||||||
|
.run-card-right {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.run-card-right {
|
||||||
|
border-left: 1px solid #dbe2ea;
|
||||||
|
min-width: 0;
|
||||||
|
padding-left: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.run-card-right code {
|
||||||
|
display: inline-block;
|
||||||
|
max-width: 100%;
|
||||||
|
overflow-x: auto;
|
||||||
|
vertical-align: middle;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-section {
|
||||||
|
background: #f7fbff;
|
||||||
|
border: 1px solid #d7e6f6;
|
||||||
|
border-radius: 10px;
|
||||||
|
padding: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-header {
|
||||||
|
align-items: center;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-header h3 {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-layout {
|
||||||
|
display: grid;
|
||||||
|
gap: 12px;
|
||||||
|
grid-template-columns: minmax(260px, 320px) 1fr;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-list {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 8px;
|
||||||
|
max-height: 65vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-item {
|
||||||
|
align-items: flex-start;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 4px;
|
||||||
|
text-align: left;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-item.selected {
|
||||||
|
background: #edf3ff;
|
||||||
|
border-color: #4d7ef3;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-item small {
|
||||||
|
color: #687788;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-thread {
|
||||||
|
border: 1px solid #d7e6f6;
|
||||||
|
border-radius: 10px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 8px;
|
||||||
|
max-height: 65vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message {
|
||||||
|
background: #fff;
|
||||||
|
border: 1px solid #dbe2ea;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message.human {
|
||||||
|
border-left: 3px solid #4d7ef3;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message.ai {
|
||||||
|
border-left: 3px solid #26a269;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message.tool {
|
||||||
|
border-left: 3px solid #8e6bd8;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-meta {
|
||||||
|
align-items: baseline;
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message pre {
|
||||||
|
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||||
|
margin: 8px 0 0;
|
||||||
|
overflow-x: auto;
|
||||||
|
white-space: pre;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-markdown > :first-child {
|
||||||
|
margin-top: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-markdown > :last-child {
|
||||||
|
margin-bottom: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-markdown code {
|
||||||
|
background: #f3f5f8;
|
||||||
|
border-radius: 4px;
|
||||||
|
padding: 1px 4px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-markdown pre code {
|
||||||
|
background: transparent;
|
||||||
|
padding: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-markdown a {
|
||||||
|
color: #1a4fc5;
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-markdown p,
|
||||||
|
.discussion-message-markdown ul,
|
||||||
|
.discussion-message-markdown ol,
|
||||||
|
.discussion-message-markdown blockquote,
|
||||||
|
.discussion-message-markdown table {
|
||||||
|
margin: 8px 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.discussion-message-markdown blockquote {
|
||||||
|
border-left: 3px solid #d0d8e2;
|
||||||
|
color: #4f5f73;
|
||||||
|
margin-left: 0;
|
||||||
|
padding-left: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-config-section {
|
||||||
|
background: #f7fbff;
|
||||||
|
border: 1px solid #d7e6f6;
|
||||||
|
border-radius: 10px;
|
||||||
|
padding: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-config-header {
|
||||||
|
align-items: center;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
gap: 12px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-config-header h3 {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-config-editor {
|
||||||
|
border: 1px solid #c9d4e2;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||||
|
font-size: 13px;
|
||||||
|
margin-top: 8px;
|
||||||
|
padding: 10px;
|
||||||
|
resize: vertical;
|
||||||
|
width: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-entry-list {
|
||||||
|
display: grid;
|
||||||
|
gap: 12px;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-tools-error {
|
||||||
|
color: #a33434;
|
||||||
|
margin: 6px 0 0 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-tools-inline {
|
||||||
|
background: #f8fbff;
|
||||||
|
border: 1px solid #d7e6f6;
|
||||||
|
border-radius: 8px;
|
||||||
|
margin: 0 0 10px 0;
|
||||||
|
padding: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-entry-card {
|
||||||
|
background: #fff;
|
||||||
|
border: 1px solid #d7e6f6;
|
||||||
|
border-radius: 10px;
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-entry-header {
|
||||||
|
align-items: center;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
gap: 10px;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-entry-grid {
|
||||||
|
display: grid;
|
||||||
|
gap: 10px;
|
||||||
|
grid-template-columns: repeat(2, minmax(200px, 1fr));
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-entry-grid label {
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
font-size: 14px;
|
||||||
|
gap: 6px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-entry-grid input,
|
||||||
|
.mcp-entry-grid select {
|
||||||
|
border: 1px solid #c9d4e2;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 14px;
|
||||||
|
padding: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.mcp-entry-wide {
|
||||||
|
grid-column: 1 / -1;
|
||||||
|
}
|
||||||
|
|
||||||
.empty {
|
.empty {
|
||||||
color: #687788;
|
color: #687788;
|
||||||
margin: 6px 0;
|
margin: 6px 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.chat-modal-overlay {
|
||||||
|
align-items: center;
|
||||||
|
background: rgba(16, 24, 40, 0.45);
|
||||||
|
display: flex;
|
||||||
|
inset: 0;
|
||||||
|
justify-content: center;
|
||||||
|
position: fixed;
|
||||||
|
z-index: 20;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal {
|
||||||
|
background: #fff;
|
||||||
|
border: 1px solid #d7e6f6;
|
||||||
|
border-radius: 12px;
|
||||||
|
display: grid;
|
||||||
|
gap: 10px;
|
||||||
|
max-height: 86vh;
|
||||||
|
max-width: 820px;
|
||||||
|
padding: 12px;
|
||||||
|
width: min(92vw, 820px);
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-header {
|
||||||
|
align-items: center;
|
||||||
|
border-bottom: 1px solid #dbe2ea;
|
||||||
|
display: flex;
|
||||||
|
justify-content: space-between;
|
||||||
|
padding-bottom: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-header small {
|
||||||
|
color: #687788;
|
||||||
|
display: block;
|
||||||
|
margin-top: 2px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-messages {
|
||||||
|
background: #f8fbff;
|
||||||
|
border: 1px solid #d7e6f6;
|
||||||
|
border-radius: 10px;
|
||||||
|
display: flex;
|
||||||
|
flex-direction: column;
|
||||||
|
gap: 8px;
|
||||||
|
max-height: 56vh;
|
||||||
|
overflow-y: auto;
|
||||||
|
padding: 10px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-message {
|
||||||
|
background: #fff;
|
||||||
|
border: 1px solid #dbe2ea;
|
||||||
|
border-radius: 8px;
|
||||||
|
padding: 8px;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-message.user {
|
||||||
|
border-left: 3px solid #4d7ef3;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-message.assistant {
|
||||||
|
border-left: 3px solid #26a269;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-message.tool {
|
||||||
|
border-left: 3px solid #8e6bd8;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-message p {
|
||||||
|
margin: 6px 0 0 0;
|
||||||
|
white-space: pre-wrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-input {
|
||||||
|
display: grid;
|
||||||
|
gap: 8px;
|
||||||
|
grid-template-columns: 1fr auto;
|
||||||
|
align-items: start;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-input textarea {
|
||||||
|
border: 1px solid #c9d4e2;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 14px;
|
||||||
|
padding: 8px;
|
||||||
|
resize: vertical;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-actions {
|
||||||
|
display: flex;
|
||||||
|
gap: 8px;
|
||||||
|
height: 100%;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-modal-actions button {
|
||||||
|
height: auto;
|
||||||
|
white-space: nowrap;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-stop-button {
|
||||||
|
background-color: #dc3545;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
|
||||||
|
.chat-stop-button:hover {
|
||||||
|
background-color: #bb2d3b;
|
||||||
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ export type GraphConfigListItem = {
|
|||||||
description: string;
|
description: string;
|
||||||
is_active: boolean;
|
is_active: boolean;
|
||||||
tool_keys: string[];
|
tool_keys: string[];
|
||||||
|
api_key: string;
|
||||||
created_at?: string | null;
|
created_at?: string | null;
|
||||||
updated_at?: string | null;
|
updated_at?: string | null;
|
||||||
};
|
};
|
||||||
@@ -21,6 +22,8 @@ export type GraphConfigReadResponse = {
|
|||||||
prompt_set_id: string;
|
prompt_set_id: string;
|
||||||
tool_keys: string[];
|
tool_keys: string[];
|
||||||
prompt_dict: Record<string, string>;
|
prompt_dict: Record<string, string>;
|
||||||
|
api_key: string;
|
||||||
|
graph_params?: Record<string, unknown>;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type GraphConfigUpsertRequest = {
|
export type GraphConfigUpsertRequest = {
|
||||||
@@ -29,6 +32,7 @@ export type GraphConfigUpsertRequest = {
|
|||||||
prompt_set_id?: string;
|
prompt_set_id?: string;
|
||||||
tool_keys: string[];
|
tool_keys: string[];
|
||||||
prompt_dict: Record<string, string>;
|
prompt_dict: Record<string, string>;
|
||||||
|
api_key?: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type GraphConfigUpsertResponse = {
|
export type GraphConfigUpsertResponse = {
|
||||||
@@ -37,6 +41,7 @@ export type GraphConfigUpsertResponse = {
|
|||||||
prompt_set_id: string;
|
prompt_set_id: string;
|
||||||
tool_keys: string[];
|
tool_keys: string[];
|
||||||
prompt_keys: string[];
|
prompt_keys: string[];
|
||||||
|
api_key: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type AvailableGraphsResponse = {
|
export type AvailableGraphsResponse = {
|
||||||
@@ -48,30 +53,100 @@ export type PipelineCreateRequest = {
|
|||||||
pipeline_id: string;
|
pipeline_id: string;
|
||||||
prompt_set_id: string;
|
prompt_set_id: string;
|
||||||
tool_keys: string[];
|
tool_keys: string[];
|
||||||
port: number;
|
api_key?: string;
|
||||||
entry_point: string;
|
llm_name: string;
|
||||||
|
enabled?: boolean;
|
||||||
|
graph_params?: Record<string, unknown>;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type PipelineSpec = {
|
||||||
|
pipeline_id: string;
|
||||||
|
graph_id: string;
|
||||||
|
enabled: boolean;
|
||||||
|
config_file: string;
|
||||||
llm_name: string;
|
llm_name: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type PipelineRunInfo = {
|
export type PipelineCreateResponse = {
|
||||||
run_id: string;
|
|
||||||
pid: number;
|
|
||||||
graph_id: string;
|
|
||||||
pipeline_id: string;
|
pipeline_id: string;
|
||||||
prompt_set_id: string;
|
prompt_set_id: string;
|
||||||
url: string;
|
graph_id: string;
|
||||||
port: number;
|
config_file: string;
|
||||||
|
llm_name: string;
|
||||||
|
enabled: boolean;
|
||||||
|
reload_required: boolean;
|
||||||
|
registry_path: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type PipelineCreateResponse = PipelineRunInfo;
|
|
||||||
|
|
||||||
export type PipelineListResponse = {
|
export type PipelineListResponse = {
|
||||||
items: PipelineRunInfo[];
|
items: PipelineSpec[];
|
||||||
count: number;
|
count: number;
|
||||||
};
|
};
|
||||||
|
|
||||||
export type PipelineStopResponse = {
|
export type PipelineStopResponse = {
|
||||||
run_id: string;
|
pipeline_id: string;
|
||||||
status: string;
|
status: string;
|
||||||
|
enabled: boolean;
|
||||||
|
reload_required: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ConversationListItem = {
|
||||||
|
conversation_id: string;
|
||||||
|
pipeline_id: string;
|
||||||
|
message_count: number;
|
||||||
|
last_updated?: string | null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type PipelineConversationListResponse = {
|
||||||
|
pipeline_id: string;
|
||||||
|
items: ConversationListItem[];
|
||||||
|
count: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type ConversationMessageItem = {
|
||||||
|
message_type: string;
|
||||||
|
content: string;
|
||||||
|
sequence_number: number;
|
||||||
|
created_at: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type PipelineConversationMessagesResponse = {
|
||||||
|
pipeline_id: string;
|
||||||
|
conversation_id: string;
|
||||||
|
items: ConversationMessageItem[];
|
||||||
|
count: number;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type RuntimeAuthInfoResponse = {
|
||||||
|
fast_api_key: string;
|
||||||
|
source: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type McpToolConfigResponse = {
|
||||||
|
path: string;
|
||||||
|
raw_content: string;
|
||||||
|
tool_keys: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type McpToolConfigUpdateRequest = {
|
||||||
|
raw_content: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export type McpToolConfigUpdateResponse = {
|
||||||
|
status: string;
|
||||||
|
path: string;
|
||||||
|
tool_keys: string[];
|
||||||
|
};
|
||||||
|
|
||||||
|
export type McpAvailableToolsResponse = {
|
||||||
|
available_tools: string[];
|
||||||
|
errors: string[];
|
||||||
|
servers: Record<
|
||||||
|
string,
|
||||||
|
{
|
||||||
|
tools: string[];
|
||||||
|
error?: string | null;
|
||||||
|
}
|
||||||
|
>;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
|||||||
{"root":["./src/App.tsx","./src/main.tsx","./src/types.ts","./src/vite-env.d.ts","./src/api/frontApis.ts"],"version":"5.9.3"}
|
{"root":["./src/App.tsx","./src/activeConfigSelection.test.ts","./src/activeConfigSelection.ts","./src/main.tsx","./src/types.ts","./src/vite-env.d.ts","./src/api/frontApis.test.ts","./src/api/frontApis.ts"],"version":"5.9.3"}
|
||||||
@@ -4,5 +4,15 @@ export default defineConfig({
|
|||||||
plugins: [react()],
|
plugins: [react()],
|
||||||
server: {
|
server: {
|
||||||
port: 5173,
|
port: 5173,
|
||||||
|
proxy: {
|
||||||
|
"/v1": {
|
||||||
|
target: "http://127.0.0.1:8500",
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
"/apps": {
|
||||||
|
target: "http://127.0.0.1:8500",
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -5,6 +5,16 @@ export default defineConfig({
|
|||||||
plugins: [react()],
|
plugins: [react()],
|
||||||
server: {
|
server: {
|
||||||
port: 5173,
|
port: 5173,
|
||||||
|
proxy: {
|
||||||
|
"/v1": {
|
||||||
|
target: "http://127.0.0.1:8500",
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
"/apps": {
|
||||||
|
target: "http://127.0.0.1:8500",
|
||||||
|
changeOrigin: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -32,9 +32,17 @@ class LangToolBase(ABC):
|
|||||||
|
|
||||||
class GraphBase(ABC):
|
class GraphBase(ABC):
|
||||||
workflow: CompiledStateGraph # the main workflow
|
workflow: CompiledStateGraph # the main workflow
|
||||||
streamable_tags: List[List[str]] # which llm to stream outputs; see routing.py for complex usage
|
streamable_tags: List[List[str]] = [["main_llm"]] # which llm to stream outputs; see routing.py for complex usage
|
||||||
textreleaser_delay_keys: List[str] = (None, None) # use to control when to start streaming; see routing.py for complex usage
|
textreleaser_delay_keys: List[str] = (None, None) # use to control when to start streaming; see routing.py for complex usage
|
||||||
|
|
||||||
|
def _build_modules(self):
|
||||||
|
# NOTE: optional abc
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _build_graph(self):
|
||||||
|
# NOTE: optional abc
|
||||||
|
pass
|
||||||
|
|
||||||
def _stream_result(self, *nargs, **kwargs):
|
def _stream_result(self, *nargs, **kwargs):
|
||||||
|
|
||||||
def text_iterator():
|
def text_iterator():
|
||||||
|
|||||||
@@ -252,16 +252,10 @@ class ClientToolManager:
|
|||||||
with open(self.config.mcp_config_f, "r") as f:
|
with open(self.config.mcp_config_f, "r") as f:
|
||||||
self.mcp_configs:dict = commentjson.load(f)
|
self.mcp_configs:dict = commentjson.load(f)
|
||||||
|
|
||||||
async def aget_tools(self):
|
def _get_to_load_configs(self) -> dict:
|
||||||
"""
|
|
||||||
Get tools from all configured MCP servers.
|
|
||||||
Handles connection failures gracefully by logging warnings and continuing.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_to_load_configs() -> dict:
|
|
||||||
if self.config.tool_keys is None:
|
if self.config.tool_keys is None:
|
||||||
to_load_config = self.mcp_configs
|
return self.mcp_configs
|
||||||
else:
|
|
||||||
if len(self.config.tool_keys) == 0:
|
if len(self.config.tool_keys) == 0:
|
||||||
logger.info("no tools will be loaded")
|
logger.info("no tools will be loaded")
|
||||||
return {}
|
return {}
|
||||||
@@ -273,10 +267,14 @@ class ClientToolManager:
|
|||||||
logger.warning(f"{key} is not in mcp tools")
|
logger.warning(f"{key} is not in mcp tools")
|
||||||
else:
|
else:
|
||||||
to_load_config[key] = val
|
to_load_config[key] = val
|
||||||
|
|
||||||
return to_load_config
|
return to_load_config
|
||||||
|
|
||||||
to_load_config = get_to_load_configs()
|
async def aget_tools(self):
|
||||||
|
"""
|
||||||
|
Get tools from all configured MCP servers.
|
||||||
|
Handles connection failures gracefully by logging warnings and continuing.
|
||||||
|
"""
|
||||||
|
to_load_config = self._get_to_load_configs()
|
||||||
all_tools = []
|
all_tools = []
|
||||||
for server_name, server_config in to_load_config.items():
|
for server_name, server_config in to_load_config.items():
|
||||||
try:
|
try:
|
||||||
@@ -298,6 +296,78 @@ class ClientToolManager:
|
|||||||
|
|
||||||
return all_tools
|
return all_tools
|
||||||
|
|
||||||
|
async def aget_tools_with_errors(self):
|
||||||
|
"""
|
||||||
|
Get tools and collect human-readable per-server errors.
|
||||||
|
Returns:
|
||||||
|
(all_tools, errors)
|
||||||
|
"""
|
||||||
|
to_load_config = self._get_to_load_configs()
|
||||||
|
all_tools = []
|
||||||
|
errors = []
|
||||||
|
for server_name, server_config in to_load_config.items():
|
||||||
|
try:
|
||||||
|
single_server_config = {server_name: server_config}
|
||||||
|
client = MultiServerMCPClient(single_server_config)
|
||||||
|
tools = await client.get_tools()
|
||||||
|
all_tools.extend(tools)
|
||||||
|
logger.info(
|
||||||
|
f"Successfully connected to MCP server '{server_name}', retrieved {len(tools)} tools"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
url = (
|
||||||
|
server_config.get("url", "unknown URL")
|
||||||
|
if isinstance(server_config, dict)
|
||||||
|
else "unknown URL"
|
||||||
|
)
|
||||||
|
err_msg = (
|
||||||
|
f"{server_name} ({url}): {type(e).__name__}: {e}"
|
||||||
|
)
|
||||||
|
errors.append(err_msg)
|
||||||
|
logger.exception(
|
||||||
|
f"Failed to connect to MCP server '{server_name}' at {url}"
|
||||||
|
)
|
||||||
|
if hasattr(e, "exceptions"):
|
||||||
|
for nested_exc in e.exceptions:
|
||||||
|
errors.append(
|
||||||
|
f"{server_name} nested: {type(nested_exc).__name__}: {nested_exc}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
return all_tools, errors
|
||||||
|
|
||||||
|
async def aget_tools_by_server(self) -> dict:
|
||||||
|
"""
|
||||||
|
Get MCP tools grouped by server name, including per-server error (if any).
|
||||||
|
Returns:
|
||||||
|
{
|
||||||
|
"server_name": {
|
||||||
|
"tools": ["tool_a", "tool_b"],
|
||||||
|
"error": "ExceptionType: message" | None,
|
||||||
|
},
|
||||||
|
...
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
to_load_config = self._get_to_load_configs()
|
||||||
|
grouped = {}
|
||||||
|
for server_name, server_config in to_load_config.items():
|
||||||
|
grouped[server_name] = {"tools": [], "error": None}
|
||||||
|
try:
|
||||||
|
single_server_config = {server_name: server_config}
|
||||||
|
client = MultiServerMCPClient(single_server_config)
|
||||||
|
tools = await client.get_tools()
|
||||||
|
tool_names = sorted(
|
||||||
|
{
|
||||||
|
str(getattr(tool, "name", "")).strip()
|
||||||
|
for tool in tools
|
||||||
|
if str(getattr(tool, "name", "")).strip()
|
||||||
|
}
|
||||||
|
)
|
||||||
|
grouped[server_name] = {"tools": tool_names, "error": None}
|
||||||
|
except Exception as e:
|
||||||
|
grouped[server_name]["error"] = f"{type(e).__name__}: {e}"
|
||||||
|
logger.exception(f"Failed to connect to MCP server '{server_name}'")
|
||||||
|
return grouped
|
||||||
|
|
||||||
def get_tools(self):
|
def get_tools(self):
|
||||||
try:
|
try:
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
|
|||||||
@@ -4,16 +4,34 @@ from typing import List, Dict, Union
|
|||||||
from enum import Enum
|
from enum import Enum
|
||||||
import os
|
import os
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from langchain_core.messages import (
|
||||||
|
HumanMessage,
|
||||||
|
AIMessage,
|
||||||
|
ToolMessage,
|
||||||
|
SystemMessage,
|
||||||
|
BaseMessage,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
|
|
||||||
|
|
||||||
class MessageType(str, Enum):
|
class MessageType(str, Enum):
|
||||||
"""Enum for message types in the conversation store."""
|
"""Enum for message types in the conversation store."""
|
||||||
|
|
||||||
HUMAN = "human"
|
HUMAN = "human"
|
||||||
AI = "ai"
|
AI = "ai"
|
||||||
TOOL = "tool"
|
TOOL = "tool"
|
||||||
|
|
||||||
class ConversationStore:
|
|
||||||
|
class BaseConvStore(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def record_message_list(
|
||||||
|
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||||
|
):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationStore(BaseConvStore):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
conn_str = os.environ.get("CONN_STR")
|
conn_str = os.environ.get("CONN_STR")
|
||||||
if conn_str is None:
|
if conn_str is None:
|
||||||
@@ -25,18 +43,17 @@ class ConversationStore:
|
|||||||
conversation_id: str,
|
conversation_id: str,
|
||||||
msg_type: MessageType,
|
msg_type: MessageType,
|
||||||
content: str,
|
content: str,
|
||||||
sequence: int, # the conversation number
|
sequence: int,
|
||||||
|
pipeline_id: str = None,
|
||||||
):
|
):
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
# DB schema only supports these columns:
|
|
||||||
# (conversation_id, message_type, content, sequence_number)
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
|
INSERT INTO messages (conversation_id, pipeline_id, message_type, content, sequence_number)
|
||||||
VALUES (%s, %s, %s, %s)
|
VALUES (%s, %s, %s, %s, %s)
|
||||||
""",
|
""",
|
||||||
(conversation_id, msg_type.value, content, sequence),
|
(conversation_id, pipeline_id, msg_type.value, content, sequence),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_conv_number(self, conversation_id: str) -> int:
|
def get_conv_number(self, conversation_id: str) -> int:
|
||||||
@@ -46,25 +63,33 @@ class ConversationStore:
|
|||||||
"""
|
"""
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor() as cur:
|
with conn.cursor() as cur:
|
||||||
cur.execute("""
|
cur.execute(
|
||||||
|
"""
|
||||||
SELECT COUNT(*)
|
SELECT COUNT(*)
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE conversation_id = %s
|
WHERE conversation_id = %s
|
||||||
""", (conversation_id,))
|
""",
|
||||||
|
(conversation_id,),
|
||||||
|
)
|
||||||
return int(cur.fetchone()[0])
|
return int(cur.fetchone()[0])
|
||||||
|
|
||||||
def get_conversation(self, conversation_id: str) -> List[Dict]:
|
def get_conversation(self, conversation_id: str) -> List[Dict]:
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||||
cur.execute("""
|
cur.execute(
|
||||||
|
"""
|
||||||
SELECT message_type, content, sequence_number, created_at
|
SELECT message_type, content, sequence_number, created_at
|
||||||
FROM messages
|
FROM messages
|
||||||
WHERE conversation_id = %s
|
WHERE conversation_id = %s
|
||||||
ORDER BY sequence_number ASC
|
ORDER BY sequence_number ASC
|
||||||
""", (conversation_id,))
|
""",
|
||||||
|
(conversation_id,),
|
||||||
|
)
|
||||||
return cur.fetchall()
|
return cur.fetchall()
|
||||||
|
|
||||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
|
def record_message_list(
|
||||||
|
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||||
|
):
|
||||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||||
curr_len = self.get_conv_number(conv_id)
|
curr_len = self.get_conv_number(conv_id)
|
||||||
to_add_msg = inp[curr_len:]
|
to_add_msg = inp[curr_len:]
|
||||||
@@ -73,12 +98,13 @@ class ConversationStore:
|
|||||||
# Serialize dict/list content to JSON string
|
# Serialize dict/list content to JSON string
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
content = json.dumps(content, ensure_ascii=False, indent=4)
|
content = json.dumps(content, ensure_ascii=False, indent=4)
|
||||||
self.add_message(conv_id, self._get_type(msg), content, curr_len + 1)
|
self.add_message(
|
||||||
|
conv_id, self._get_type(msg), content, curr_len + 1, pipeline_id
|
||||||
|
)
|
||||||
curr_len += 1
|
curr_len += 1
|
||||||
return curr_len
|
return curr_len
|
||||||
|
|
||||||
|
def _get_type(self, msg: BaseMessage) -> MessageType:
|
||||||
def _get_type(self, msg:BaseMessage) -> MessageType:
|
|
||||||
if isinstance(msg, HumanMessage):
|
if isinstance(msg, HumanMessage):
|
||||||
return MessageType.HUMAN
|
return MessageType.HUMAN
|
||||||
elif isinstance(msg, AIMessage):
|
elif isinstance(msg, AIMessage):
|
||||||
@@ -89,11 +115,13 @@ class ConversationStore:
|
|||||||
raise ValueError(f"Unknown message type: {type(msg)}")
|
raise ValueError(f"Unknown message type: {type(msg)}")
|
||||||
|
|
||||||
|
|
||||||
class ConversationPrinter:
|
class ConversationPrinter(BaseConvStore):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.id_dic = {}
|
self.id_dic = {}
|
||||||
|
|
||||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
|
def record_message_list(
|
||||||
|
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||||
|
):
|
||||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||||
curr_len = self.id_dic.get(conv_id, 0)
|
curr_len = self.id_dic.get(conv_id, 0)
|
||||||
to_print_msg = inp[curr_len:]
|
to_print_msg = inp[curr_len:]
|
||||||
@@ -106,9 +134,11 @@ class ConversationPrinter:
|
|||||||
else:
|
else:
|
||||||
self.id_dic[conv_id] += len(to_print_msg)
|
self.id_dic[conv_id] += len(to_print_msg)
|
||||||
|
|
||||||
|
|
||||||
CONV_STORE = ConversationStore()
|
CONV_STORE = ConversationStore()
|
||||||
# CONV_STORE = ConversationPrinter()
|
# CONV_STORE = ConversationPrinter()
|
||||||
|
|
||||||
|
|
||||||
def use_printer():
|
def use_printer():
|
||||||
global CONV_STORE
|
global CONV_STORE
|
||||||
CONV_STORE = ConversationPrinter()
|
CONV_STORE = ConversationPrinter()
|
||||||
|
|||||||
230
lang_agent/components/server_pipeline_manager.py
Normal file
230
lang_agent/components/server_pipeline_manager.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
from fastapi import HTTPException
|
||||||
|
from typing import Any, Dict, Optional, Tuple
|
||||||
|
from pathlib import Path as FsPath
|
||||||
|
import os.path as osp
|
||||||
|
import json
|
||||||
|
import copy
|
||||||
|
from threading import RLock
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
|
from lang_agent.config.core_config import load_tyro_conf
|
||||||
|
|
||||||
|
|
||||||
|
class ServerPipelineManager:
|
||||||
|
"""Lazily load and cache multiple pipelines keyed by a client-facing pipeline id."""
|
||||||
|
|
||||||
|
def __init__(self, default_pipeline_id: str, default_config: PipelineConfig):
|
||||||
|
self.default_pipeline_id = default_pipeline_id
|
||||||
|
self.default_config = default_config
|
||||||
|
self._pipeline_specs: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self._api_key_policy: Dict[str, Dict[str, Any]] = {}
|
||||||
|
self._pipelines: Dict[str, Pipeline] = {}
|
||||||
|
self._pipeline_llm: Dict[str, str] = {}
|
||||||
|
self._registry_path: Optional[str] = None
|
||||||
|
self._registry_mtime_ns: Optional[int] = None
|
||||||
|
self._lock = RLock()
|
||||||
|
|
||||||
|
def _resolve_registry_path(self, registry_path: str) -> str:
|
||||||
|
path = FsPath(registry_path)
|
||||||
|
if path.is_absolute():
|
||||||
|
return str(path)
|
||||||
|
# server_pipeline_manager.py is under <repo>/lang_agent/components/,
|
||||||
|
# so parents[2] is the repository root.
|
||||||
|
root = FsPath(__file__).resolve().parents[2]
|
||||||
|
return str((root / path).resolve())
|
||||||
|
|
||||||
|
def _stat_registry_mtime_ns(self, abs_path: str) -> int:
|
||||||
|
return FsPath(abs_path).stat().st_mtime_ns
|
||||||
|
|
||||||
|
def _read_registry(self, abs_path: str) -> Dict[str, Any]:
|
||||||
|
with open(abs_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _apply_registry(self, abs_path: str, registry: Dict[str, Any], mtime_ns: int) -> bool:
|
||||||
|
pipelines = registry.get("pipelines")
|
||||||
|
if pipelines is None or not isinstance(pipelines, dict):
|
||||||
|
raise ValueError("`pipelines` in pipeline registry must be an object.")
|
||||||
|
|
||||||
|
parsed_specs: Dict[str, Dict[str, Any]] = {}
|
||||||
|
for pipeline_id, spec in pipelines.items():
|
||||||
|
if not isinstance(spec, dict):
|
||||||
|
raise ValueError(
|
||||||
|
f"pipeline spec for `{pipeline_id}` must be an object."
|
||||||
|
)
|
||||||
|
parsed_specs[pipeline_id] = {
|
||||||
|
"enabled": bool(spec.get("enabled", True)),
|
||||||
|
"config_file": spec.get("config_file"),
|
||||||
|
"llm_name": spec.get("llm_name"),
|
||||||
|
}
|
||||||
|
if not parsed_specs:
|
||||||
|
raise ValueError("pipeline registry must define at least one pipeline.")
|
||||||
|
|
||||||
|
api_key_policy = registry.get("api_keys", {})
|
||||||
|
if api_key_policy and not isinstance(api_key_policy, dict):
|
||||||
|
raise ValueError("`api_keys` in pipeline registry must be an object.")
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
old_specs = self._pipeline_specs
|
||||||
|
old_policy = self._api_key_policy
|
||||||
|
old_mtime = self._registry_mtime_ns
|
||||||
|
|
||||||
|
removed = set(old_specs.keys()) - set(parsed_specs.keys())
|
||||||
|
added = set(parsed_specs.keys()) - set(old_specs.keys())
|
||||||
|
modified = {
|
||||||
|
pipeline_id
|
||||||
|
for pipeline_id in (set(old_specs.keys()) & set(parsed_specs.keys()))
|
||||||
|
if old_specs[pipeline_id] != parsed_specs[pipeline_id]
|
||||||
|
}
|
||||||
|
changed = bool(added or removed or modified or old_policy != api_key_policy)
|
||||||
|
|
||||||
|
# Drop stale cache entries for deleted/changed pipelines so future requests
|
||||||
|
# lazily rebuild from the refreshed registry spec.
|
||||||
|
for pipeline_id in (removed | modified):
|
||||||
|
self._pipelines.pop(pipeline_id, None)
|
||||||
|
self._pipeline_llm.pop(pipeline_id, None)
|
||||||
|
|
||||||
|
self._pipeline_specs = parsed_specs
|
||||||
|
self._api_key_policy = api_key_policy
|
||||||
|
self._registry_path = abs_path
|
||||||
|
self._registry_mtime_ns = mtime_ns
|
||||||
|
|
||||||
|
if changed:
|
||||||
|
logger.info(
|
||||||
|
"refreshed pipeline registry: {} | added={} modified={} removed={} mtime={}",
|
||||||
|
abs_path,
|
||||||
|
sorted(added),
|
||||||
|
sorted(modified),
|
||||||
|
sorted(removed),
|
||||||
|
mtime_ns,
|
||||||
|
)
|
||||||
|
elif old_mtime != mtime_ns:
|
||||||
|
logger.debug("pipeline registry mtime changed but specs were unchanged: {}", abs_path)
|
||||||
|
return changed
|
||||||
|
|
||||||
|
def load_registry(self, registry_path: str) -> None:
|
||||||
|
abs_path = self._resolve_registry_path(registry_path)
|
||||||
|
if not osp.exists(abs_path):
|
||||||
|
raise ValueError(f"pipeline registry file not found: {abs_path}")
|
||||||
|
registry = self._read_registry(abs_path)
|
||||||
|
mtime_ns = self._stat_registry_mtime_ns(abs_path)
|
||||||
|
self._apply_registry(abs_path=abs_path, registry=registry, mtime_ns=mtime_ns)
|
||||||
|
|
||||||
|
def refresh_registry_if_needed(
|
||||||
|
self, registry_path: Optional[str] = None, force: bool = False
|
||||||
|
) -> bool:
|
||||||
|
abs_path = (
|
||||||
|
self._resolve_registry_path(registry_path)
|
||||||
|
if registry_path
|
||||||
|
else self._registry_path
|
||||||
|
)
|
||||||
|
if not abs_path:
|
||||||
|
raise ValueError("registry path is not initialized")
|
||||||
|
if not osp.exists(abs_path):
|
||||||
|
raise ValueError(f"pipeline registry file not found: {abs_path}")
|
||||||
|
|
||||||
|
mtime_ns = self._stat_registry_mtime_ns(abs_path)
|
||||||
|
with self._lock:
|
||||||
|
if not force and self._registry_mtime_ns == mtime_ns:
|
||||||
|
return False
|
||||||
|
|
||||||
|
registry = self._read_registry(abs_path)
|
||||||
|
return self._apply_registry(abs_path=abs_path, registry=registry, mtime_ns=mtime_ns)
|
||||||
|
|
||||||
|
def _resolve_config_path(self, config_file: str) -> str:
|
||||||
|
path = FsPath(config_file)
|
||||||
|
if path.is_absolute():
|
||||||
|
return str(path)
|
||||||
|
# Resolve relative config paths from repository root for consistency
|
||||||
|
# with docker-compose and tests.
|
||||||
|
root = FsPath(__file__).resolve().parents[2]
|
||||||
|
return str((root / path).resolve())
|
||||||
|
|
||||||
|
def _build_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
|
||||||
|
spec = self._pipeline_specs.get(pipeline_id)
|
||||||
|
if spec is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
|
||||||
|
)
|
||||||
|
if not spec.get("enabled", True):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403, detail=f"Pipeline disabled: {pipeline_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
config_file = spec.get("config_file")
|
||||||
|
registry_llm_name = spec.get("llm_name")
|
||||||
|
if config_file:
|
||||||
|
loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file))
|
||||||
|
if hasattr(loaded_cfg, "setup"):
|
||||||
|
cfg = loaded_cfg
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"config_file for pipeline "
|
||||||
|
f"`{pipeline_id}` did not deserialize to a config object. "
|
||||||
|
"Rebuild the pipeline via /v1/pipelines to regenerate a "
|
||||||
|
"valid serialized PipelineConfig file."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
cfg = copy.deepcopy(self.default_config)
|
||||||
|
if registry_llm_name is not None and hasattr(cfg, "llm_name"):
|
||||||
|
setattr(cfg, "llm_name", registry_llm_name)
|
||||||
|
|
||||||
|
p = cfg.setup()
|
||||||
|
llm_name = str(getattr(cfg, "llm_name", registry_llm_name or "unknown-model"))
|
||||||
|
return p, llm_name
|
||||||
|
|
||||||
|
def _authorize(self, api_key: str, pipeline_id: str) -> None:
|
||||||
|
if not self._api_key_policy:
|
||||||
|
return
|
||||||
|
|
||||||
|
policy = self._api_key_policy.get(api_key)
|
||||||
|
if policy is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
allowed = policy.get("allowed_pipeline_ids")
|
||||||
|
if allowed and pipeline_id not in allowed:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"pipeline_id `{pipeline_id}` is not allowed for this API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
def resolve_pipeline_id(
|
||||||
|
self, body: Dict[str, Any], app_id: Optional[str], api_key: str
|
||||||
|
) -> str:
|
||||||
|
body_input = body.get("input", {})
|
||||||
|
pipeline_id = (
|
||||||
|
body.get("pipeline_id")
|
||||||
|
or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None)
|
||||||
|
or app_id
|
||||||
|
)
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
if not pipeline_id:
|
||||||
|
key_policy = (
|
||||||
|
self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
|
||||||
|
)
|
||||||
|
pipeline_id = key_policy.get(
|
||||||
|
"default_pipeline_id", self.default_pipeline_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if pipeline_id not in self._pipeline_specs:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
self._authorize(api_key, pipeline_id)
|
||||||
|
return pipeline_id
|
||||||
|
|
||||||
|
def get_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
|
||||||
|
with self._lock:
|
||||||
|
cached = self._pipelines.get(pipeline_id)
|
||||||
|
if cached is not None:
|
||||||
|
return cached, self._pipeline_llm[pipeline_id]
|
||||||
|
|
||||||
|
# Build while holding the lock to avoid duplicate construction for
|
||||||
|
# the same pipeline on concurrent first requests.
|
||||||
|
pipeline_obj, llm_name = self._build_pipeline(pipeline_id)
|
||||||
|
self._pipelines[pipeline_id] = pipeline_obj
|
||||||
|
self._pipeline_llm[pipeline_id] = llm_name
|
||||||
|
logger.info(f"lazy-loaded pipeline_id={pipeline_id} model={llm_name}")
|
||||||
|
return pipeline_obj, llm_name
|
||||||
@@ -1 +1,19 @@
|
|||||||
from lang_agent.config.core_config import InstantiateConfig, KeyConfig, ToolConfig, LLMKeyConfig
|
from lang_agent.config.core_config import (
|
||||||
|
InstantiateConfig,
|
||||||
|
ToolConfig,
|
||||||
|
LLMKeyConfig,
|
||||||
|
LLMNodeConfig,
|
||||||
|
load_tyro_conf,
|
||||||
|
resolve_llm_api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
from lang_agent.config.constants import (
|
||||||
|
MCP_CONFIG_PATH,
|
||||||
|
MCP_CONFIG_DEFAULT_CONTENT,
|
||||||
|
PIPELINE_REGISTRY_PATH,
|
||||||
|
VALID_API_KEYS,
|
||||||
|
API_KEY_HEADER,
|
||||||
|
API_KEY_HEADER_NO_ERROR,
|
||||||
|
_PROJECT_ROOT,
|
||||||
|
TY_BUILD_SCRIPT,
|
||||||
|
)
|
||||||
|
|||||||
19
lang_agent/config/constants.py
Normal file
19
lang_agent/config/constants.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import os
|
||||||
|
import re
|
||||||
|
import os.path as osp
|
||||||
|
from fastapi.security import APIKeyHeader
|
||||||
|
|
||||||
|
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
|
||||||
|
MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
|
||||||
|
MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
|
||||||
|
|
||||||
|
# NOTE: the fancy looking api key is FAST_AUTH_KEYS from .env
|
||||||
|
PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json")
|
||||||
|
|
||||||
|
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
||||||
|
API_KEY_HEADER_NO_ERROR = APIKeyHeader(name="Authorization", auto_error=False)
|
||||||
|
|
||||||
|
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||||
|
|
||||||
|
TY_BUILD_SCRIPT = osp.join(_PROJECT_ROOT, "lang_agent", "config", "ty_build_config.py")
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass, is_dataclass, fields, MISSING
|
from dataclasses import dataclass, is_dataclass, fields, MISSING
|
||||||
from typing import Any, Tuple, Type
|
from typing import Any, Tuple, Type, Optional
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@@ -10,6 +10,20 @@ from dotenv import load_dotenv
|
|||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_llm_api_key(api_key: Optional[str]) -> Optional[str]:
|
||||||
|
"""Resolve the API key for OpenAI-compatible providers."""
|
||||||
|
if api_key not in (None, "", "wrong-key"):
|
||||||
|
resolved_key = api_key
|
||||||
|
else:
|
||||||
|
resolved_key = os.environ.get("ALI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
# Some OpenAI-compatible integrations still read OPENAI_API_KEY from env.
|
||||||
|
if resolved_key and not os.environ.get("OPENAI_API_KEY"):
|
||||||
|
os.environ["OPENAI_API_KEY"] = resolved_key
|
||||||
|
|
||||||
|
return resolved_key
|
||||||
|
|
||||||
## NOTE: base classes taken from nerfstudio
|
## NOTE: base classes taken from nerfstudio
|
||||||
class PrintableConfig:
|
class PrintableConfig:
|
||||||
"""
|
"""
|
||||||
@@ -72,79 +86,20 @@ class InstantiateConfig(PrintableConfig):
|
|||||||
将配置保存到 YAML 文件
|
将配置保存到 YAML 文件
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def mask_value(key, value):
|
# Persist the full config object (including type tags) so it can be
|
||||||
"""
|
# deserialized back into config instances with methods like .setup().
|
||||||
Apply masking if key is secret-like
|
# Secret masking is intentionally handled by __str__ for printing/logging,
|
||||||
如果键是敏感的,应用掩码
|
# not when writing to disk.
|
||||||
|
with open(filename, "w", encoding="utf-8") as f:
|
||||||
检查键是否敏感(如包含 "secret" 或 "api_key"),如果是,则对值进行掩码处理
|
yaml.dump(self, f)
|
||||||
|
|
||||||
"""
|
|
||||||
if isinstance(value, str) and self.is_secrete(key):
|
|
||||||
sval = str(value)
|
|
||||||
return sval[:3] + "*" * (len(sval) - 6) + sval[-3:]
|
|
||||||
return value
|
|
||||||
|
|
||||||
def to_masked_serializable(obj):
|
|
||||||
|
|
||||||
"""
|
|
||||||
Recursively convert dataclasses and containers to serializable with masked secrets
|
|
||||||
|
|
||||||
递归地将数据类和容器转换为可序列化的格式,同时对敏感信息进行掩码处理
|
|
||||||
|
|
||||||
"""
|
|
||||||
if is_dataclass(obj):
|
|
||||||
out = {}
|
|
||||||
for k, v in vars(obj).items():
|
|
||||||
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
|
|
||||||
out[k] = to_masked_serializable(v)
|
|
||||||
else:
|
|
||||||
out[k] = mask_value(k, v)
|
|
||||||
return out
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
out = {}
|
|
||||||
for k, v in obj.items():
|
|
||||||
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
|
|
||||||
out[k] = to_masked_serializable(v)
|
|
||||||
else:
|
|
||||||
# k might be a non-string; convert to str for is_secrete check consistency
|
|
||||||
key_str = str(k)
|
|
||||||
out[k] = mask_value(key_str, v)
|
|
||||||
return out
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [to_masked_serializable(v) for v in obj]
|
|
||||||
if isinstance(obj, tuple):
|
|
||||||
return tuple(to_masked_serializable(v) for v in obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
masked = to_masked_serializable(self)
|
|
||||||
with open(filename, 'w') as f:
|
|
||||||
yaml.dump(masked, f)
|
|
||||||
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
||||||
|
|
||||||
def get_name(self):
|
def get_name(self):
|
||||||
return self.__class__.__name__
|
return self.__class__.__name__
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class KeyConfig(InstantiateConfig):
|
class LLMKeyConfig(InstantiateConfig):
|
||||||
|
|
||||||
api_key:str = None
|
|
||||||
"""api key for llm"""
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.api_key == "wrong-key" or self.api_key is None:
|
|
||||||
self.api_key = os.environ.get("ALI_API_KEY")
|
|
||||||
if self.api_key is None:
|
|
||||||
logger.error(f"no ALI_API_KEY provided for embedding")
|
|
||||||
else:
|
|
||||||
logger.info("ALI_API_KEY loaded from environ")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMKeyConfig(KeyConfig):
|
|
||||||
llm_name: str = "qwen-plus"
|
llm_name: str = "qwen-plus"
|
||||||
"""name of llm"""
|
"""name of llm"""
|
||||||
|
|
||||||
@@ -154,6 +109,30 @@ class LLMKeyConfig(KeyConfig):
|
|||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
"""base url; could be used to overwrite the baseurl in llm provider"""
|
||||||
|
|
||||||
|
api_key:str = None
|
||||||
|
"""api key for llm"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
original_api_key = self.api_key
|
||||||
|
self.api_key = resolve_llm_api_key(self.api_key)
|
||||||
|
if self.api_key is None:
|
||||||
|
logger.error("no ALI_API_KEY or OPENAI_API_KEY provided for embedding")
|
||||||
|
elif original_api_key in (None, "", "wrong-key"):
|
||||||
|
logger.info("LLM API key loaded from environment")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMNodeConfig(LLMKeyConfig):
|
||||||
|
"""
|
||||||
|
class is for LLM nodes that has system prompt config
|
||||||
|
"""
|
||||||
|
|
||||||
|
pipeline_id: Optional[str] = None
|
||||||
|
"""If set, load prompts from database (with file fallback)"""
|
||||||
|
|
||||||
|
prompt_set_id: Optional[str] = None
|
||||||
|
"""If set, load from this specific prompt set instead of the active one"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolConfig(InstantiateConfig):
|
class ToolConfig(InstantiateConfig):
|
||||||
@@ -173,7 +152,7 @@ def load_tyro_conf(filename: str, inp_conf = None) -> InstantiateConfig:
|
|||||||
"""
|
"""
|
||||||
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
|
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
|
||||||
|
|
||||||
config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
# config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def is_default(instance, field_):
|
def is_default(instance, field_):
|
||||||
@@ -260,21 +239,3 @@ def ovewrite_config(loaded_conf, inp_conf):
|
|||||||
setattr(loaded_conf, field_name, new_value)
|
setattr(loaded_conf, field_name, new_value)
|
||||||
|
|
||||||
return loaded_conf
|
return loaded_conf
|
||||||
|
|
||||||
|
|
||||||
def mcp_langchain_to_ws_config(conf:Dict[str, Dict[str, str]]):
|
|
||||||
serv_conf = {}
|
|
||||||
|
|
||||||
for k, v in conf.items():
|
|
||||||
|
|
||||||
if v["transport"] == "stdio":
|
|
||||||
serv_conf[k] = {
|
|
||||||
"type" : v["transport"],
|
|
||||||
"command": v["command"],
|
|
||||||
"args": v["args"],
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported transport {v['transport']} for MCP {k}. Skipping...")
|
|
||||||
continue
|
|
||||||
|
|
||||||
return {"mcpServers":serv_conf}
|
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class DBConfigManager:
|
|||||||
if pipeline_id and graph_id:
|
if pipeline_id and graph_id:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list
|
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list, api_key
|
||||||
FROM prompt_sets
|
FROM prompt_sets
|
||||||
WHERE pipeline_id = %s AND graph_id = %s
|
WHERE pipeline_id = %s AND graph_id = %s
|
||||||
ORDER BY updated_at DESC, created_at DESC
|
ORDER BY updated_at DESC, created_at DESC
|
||||||
@@ -38,7 +38,7 @@ class DBConfigManager:
|
|||||||
elif pipeline_id:
|
elif pipeline_id:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list
|
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list, api_key
|
||||||
FROM prompt_sets
|
FROM prompt_sets
|
||||||
WHERE pipeline_id = %s
|
WHERE pipeline_id = %s
|
||||||
ORDER BY updated_at DESC, created_at DESC
|
ORDER BY updated_at DESC, created_at DESC
|
||||||
@@ -48,7 +48,7 @@ class DBConfigManager:
|
|||||||
elif graph_id:
|
elif graph_id:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list
|
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list, api_key
|
||||||
FROM prompt_sets
|
FROM prompt_sets
|
||||||
WHERE graph_id = %s
|
WHERE graph_id = %s
|
||||||
ORDER BY updated_at DESC, created_at DESC
|
ORDER BY updated_at DESC, created_at DESC
|
||||||
@@ -58,7 +58,7 @@ class DBConfigManager:
|
|||||||
else:
|
else:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list
|
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list, api_key
|
||||||
FROM prompt_sets
|
FROM prompt_sets
|
||||||
ORDER BY updated_at DESC, created_at DESC
|
ORDER BY updated_at DESC, created_at DESC
|
||||||
"""
|
"""
|
||||||
@@ -76,6 +76,7 @@ class DBConfigManager:
|
|||||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||||
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
|
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
|
||||||
"tool_keys": self._parse_tool_list(row.get("list")),
|
"tool_keys": self._parse_tool_list(row.get("list")),
|
||||||
|
"api_key": row.get("api_key") or "",
|
||||||
}
|
}
|
||||||
for row in rows
|
for row in rows
|
||||||
]
|
]
|
||||||
@@ -88,7 +89,7 @@ class DBConfigManager:
|
|||||||
with conn.cursor(row_factory=dict_row) as cur:
|
with conn.cursor(row_factory=dict_row) as cur:
|
||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list
|
SELECT id, pipeline_id, graph_id, name, description, is_active, created_at, updated_at, list, api_key
|
||||||
FROM prompt_sets
|
FROM prompt_sets
|
||||||
WHERE id = %s AND pipeline_id = %s
|
WHERE id = %s AND pipeline_id = %s
|
||||||
""",
|
""",
|
||||||
@@ -109,6 +110,7 @@ class DBConfigManager:
|
|||||||
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
"created_at": row["created_at"].isoformat() if row["created_at"] else None,
|
||||||
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
|
"updated_at": row["updated_at"].isoformat() if row["updated_at"] else None,
|
||||||
"tool_keys": self._parse_tool_list(row.get("list")),
|
"tool_keys": self._parse_tool_list(row.get("list")),
|
||||||
|
"api_key": row.get("api_key") or "",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
@@ -160,6 +162,7 @@ class DBConfigManager:
|
|||||||
prompt_set_id: Optional[str],
|
prompt_set_id: Optional[str],
|
||||||
tool_list: Optional[Sequence[str]],
|
tool_list: Optional[Sequence[str]],
|
||||||
prompt_dict: Optional[Mapping[str, str]],
|
prompt_dict: Optional[Mapping[str, str]],
|
||||||
|
api_key: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Persist prompt + tool configuration.
|
Persist prompt + tool configuration.
|
||||||
@@ -182,6 +185,7 @@ class DBConfigManager:
|
|||||||
|
|
||||||
normalized_prompt_dict = self._normalize_prompt_dict(prompt_dict)
|
normalized_prompt_dict = self._normalize_prompt_dict(prompt_dict)
|
||||||
tool_csv = self._join_tool_list(tool_list)
|
tool_csv = self._join_tool_list(tool_list)
|
||||||
|
normalized_api_key = self._normalize_api_key(api_key)
|
||||||
|
|
||||||
with psycopg.connect(self.conn_str) as conn:
|
with psycopg.connect(self.conn_str) as conn:
|
||||||
resolved_set_id, _ = self._resolve_prompt_set(
|
resolved_set_id, _ = self._resolve_prompt_set(
|
||||||
@@ -200,10 +204,13 @@ class DBConfigManager:
|
|||||||
cur.execute(
|
cur.execute(
|
||||||
"""
|
"""
|
||||||
UPDATE prompt_sets
|
UPDATE prompt_sets
|
||||||
SET list = %s, graph_id = COALESCE(%s, graph_id), updated_at = now()
|
SET list = %s,
|
||||||
|
graph_id = COALESCE(%s, graph_id),
|
||||||
|
api_key = COALESCE(%s, api_key),
|
||||||
|
updated_at = now()
|
||||||
WHERE id = %s
|
WHERE id = %s
|
||||||
""",
|
""",
|
||||||
(tool_csv, normalized_graph_id, resolved_set_id),
|
(tool_csv, normalized_graph_id, normalized_api_key, resolved_set_id),
|
||||||
)
|
)
|
||||||
|
|
||||||
keys = list(normalized_prompt_dict.keys())
|
keys = list(normalized_prompt_dict.keys())
|
||||||
@@ -341,3 +348,8 @@ class DBConfigManager:
|
|||||||
return None
|
return None
|
||||||
value = str(graph_id).strip()
|
value = str(graph_id).strip()
|
||||||
return value or None
|
return value or None
|
||||||
|
|
||||||
|
def _normalize_api_key(self, api_key: Optional[str]) -> Optional[str]:
|
||||||
|
if api_key is None:
|
||||||
|
return None
|
||||||
|
return str(api_key).strip()
|
||||||
13
lang_agent/config/ty_build_config.py
Normal file
13
lang_agent/config/ty_build_config.py
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
import tyro
|
||||||
|
from loguru import logger
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
from lang_agent.pipeline import PipelineConfig
|
||||||
|
|
||||||
|
def build_conf(pipeline:PipelineConfig, save_path:str):
|
||||||
|
os.makedirs(osp.dirname(save_path), exist_ok=True)
|
||||||
|
logger.info(pipeline)
|
||||||
|
pipeline.save_config(save_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
tyro.cli(build_conf)
|
||||||
@@ -34,13 +34,14 @@ class Calculator(LangToolBase):
|
|||||||
|
|
||||||
def calculator(self, python_expression: str) -> dict:
|
def calculator(self, python_expression: str) -> dict:
|
||||||
"""For mathamatical calculation, always use this tool to calculate the result of a python expression. You can use 'math' or 'random' directly, without 'import'."""
|
"""For mathamatical calculation, always use this tool to calculate the result of a python expression. You can use 'math' or 'random' directly, without 'import'."""
|
||||||
|
# time.sleep(20)
|
||||||
result = eval(python_expression, {"math": math, "random": random})
|
result = eval(python_expression, {"math": math, "random": random})
|
||||||
return {"success": True, "result": result}
|
return {"success": True, "result": result}
|
||||||
|
|
||||||
async def calculator_async(self, python_expression: str) -> dict:
|
async def calculator_async(self, python_expression: str) -> dict:
|
||||||
"""Async version: runs eval in a thread pool to avoid blocking the event loop."""
|
"""Async version: runs eval in a thread pool to avoid blocking the event loop."""
|
||||||
async with get_semaphore():
|
async with get_semaphore():
|
||||||
await asyncio.sleep(5) # Simulate delay for testing
|
# await asyncio.sleep(20) # Simulate delay for testing
|
||||||
result = await asyncio.to_thread(
|
result = await asyncio.to_thread(
|
||||||
eval, python_expression, {"math": math, "random": random}
|
eval, python_expression, {"math": math, "random": random}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ class Evaluator:
|
|||||||
|
|
||||||
df_m.to_csv(metric_f)
|
df_m.to_csv(metric_f)
|
||||||
|
|
||||||
self.config.save_config(f"{head_path}-{n_exp}.yml")
|
self.config.save_config(f"{head_path}-{n_exp}.yaml")
|
||||||
|
|
||||||
def format_result_df(self, df:pd.DataFrame):
|
def format_result_df(self, df:pd.DataFrame):
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from typing import Type, Callable, List
|
|||||||
import tyro
|
import tyro
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import LLMKeyConfig
|
||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
@@ -11,7 +11,7 @@ from langchain_core.messages import BaseMessage, ToolMessage
|
|||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ValidatorConfig(KeyConfig):
|
class ValidatorConfig(LLMKeyConfig):
|
||||||
_target: Type = field(default_factory=lambda:Validator)
|
_target: Type = field(default_factory=lambda:Validator)
|
||||||
|
|
||||||
|
|
||||||
@@ -34,9 +34,9 @@ class Validator:
|
|||||||
|
|
||||||
def populate_modules(self):
|
def populate_modules(self):
|
||||||
self.judge_llm = init_chat_model(
|
self.judge_llm = init_chat_model(
|
||||||
model="qwen-plus",
|
model=self.config.llm_name,
|
||||||
model_provider="openai",
|
model_provider=self.config.llm_provider,
|
||||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
base_url=self.config.base_url,
|
||||||
api_key=self.config.api_key
|
api_key=self.config.api_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
0
lang_agent/fastapi_server/__init__.py
Normal file
0
lang_agent/fastapi_server/__init__.py
Normal file
33
lang_agent/fastapi_server/combined.py
Normal file
33
lang_agent/fastapi_server/combined.py
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
from lang_agent.fastapi_server.front_apis import app as front_app
|
||||||
|
from lang_agent.fastapi_server.server_dashscope import create_dashscope_router
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(
|
||||||
|
title="Combined Front + DashScope APIs",
|
||||||
|
description=(
|
||||||
|
"Single-process app exposing front_apis control endpoints and "
|
||||||
|
"DashScope-compatible chat endpoints."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keep existing /v1/... admin APIs unchanged.
|
||||||
|
app.include_router(front_app.router)
|
||||||
|
|
||||||
|
# Add DashScope endpoints at their existing URLs. We intentionally skip
|
||||||
|
# DashScope's root/health routes to avoid clashing with front_apis.
|
||||||
|
app.include_router(create_dashscope_router(include_meta_routes=False))
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
uvicorn.run(app, host="0.0.0.0", port=8500)
|
||||||
1007
lang_agent/fastapi_server/front_apis.py
Normal file
1007
lang_agent/fastapi_server/front_apis.py
Normal file
File diff suppressed because it is too large
Load Diff
365
lang_agent/fastapi_server/server_dashscope.py
Normal file
365
lang_agent/fastapi_server/server_dashscope.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request, Security
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import json
|
||||||
|
import uvicorn
|
||||||
|
from loguru import logger
|
||||||
|
import tyro
|
||||||
|
|
||||||
|
# Ensure we can import from project root
|
||||||
|
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
|
||||||
|
from lang_agent.pipeline import PipelineConfig
|
||||||
|
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
|
||||||
|
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
|
||||||
|
|
||||||
|
def _build_default_pipeline_config() -> PipelineConfig:
|
||||||
|
"""
|
||||||
|
Build import-time defaults without parsing CLI args.
|
||||||
|
|
||||||
|
This keeps module import safe for reuse by combined apps and tests.
|
||||||
|
"""
|
||||||
|
pipeline_config = PipelineConfig()
|
||||||
|
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
|
||||||
|
return pipeline_config
|
||||||
|
|
||||||
|
|
||||||
|
def _build_pipeline_manager(base_config: PipelineConfig) -> ServerPipelineManager:
|
||||||
|
pipeline_manager = ServerPipelineManager(
|
||||||
|
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
|
||||||
|
default_config=base_config,
|
||||||
|
)
|
||||||
|
pipeline_manager.load_registry(PIPELINE_REGISTRY_PATH)
|
||||||
|
return pipeline_manager
|
||||||
|
|
||||||
|
|
||||||
|
pipeline_config = _build_default_pipeline_config()
|
||||||
|
PIPELINE_MANAGER = _build_pipeline_manager(pipeline_config)
|
||||||
|
|
||||||
|
|
||||||
|
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||||
|
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||||
|
key = api_key[7:] if api_key.startswith("Bearer ") else api_key
|
||||||
|
if VALID_API_KEYS and key not in VALID_API_KEYS:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
class DSMessage(BaseModel):
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
|
||||||
|
|
||||||
|
class DSApplicationCallRequest(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(default=None)
|
||||||
|
app_id: Optional[str] = Field(default=None)
|
||||||
|
session_id: Optional[str] = Field(default=None)
|
||||||
|
messages: List[DSMessage]
|
||||||
|
stream: bool = Field(default=True)
|
||||||
|
# Optional overrides for pipeline behavior
|
||||||
|
thread_id: Optional[str] = Field(default="3")
|
||||||
|
|
||||||
|
|
||||||
|
def sse_chunks_from_stream(
|
||||||
|
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Stream chunks from pipeline and format as SSE.
|
||||||
|
Accumulates text and sends incremental updates.
|
||||||
|
DashScope SDK expects accumulated text in each chunk (not deltas).
|
||||||
|
"""
|
||||||
|
created_time = int(time.time())
|
||||||
|
accumulated_text = ""
|
||||||
|
|
||||||
|
for chunk in chunk_generator:
|
||||||
|
if chunk:
|
||||||
|
accumulated_text += chunk
|
||||||
|
data = {
|
||||||
|
"request_id": response_id,
|
||||||
|
"code": 200,
|
||||||
|
"message": "OK",
|
||||||
|
"output": {
|
||||||
|
# DashScope SDK expects accumulated text, not empty or delta
|
||||||
|
"text": accumulated_text,
|
||||||
|
"created": created_time,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
"is_end": False,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
# Final message with complete text
|
||||||
|
final = {
|
||||||
|
"request_id": response_id,
|
||||||
|
"code": 200,
|
||||||
|
"message": "OK",
|
||||||
|
"output": {
|
||||||
|
"text": accumulated_text,
|
||||||
|
"created": created_time,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
"is_end": True,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(final)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
async def sse_chunks_from_astream(
|
||||||
|
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async version: Stream chunks from pipeline and format as SSE.
|
||||||
|
Accumulates text and sends incremental updates.
|
||||||
|
DashScope SDK expects accumulated text in each chunk (not deltas).
|
||||||
|
"""
|
||||||
|
created_time = int(time.time())
|
||||||
|
accumulated_text = ""
|
||||||
|
|
||||||
|
async for chunk in chunk_generator:
|
||||||
|
if chunk:
|
||||||
|
accumulated_text += chunk
|
||||||
|
data = {
|
||||||
|
"request_id": response_id,
|
||||||
|
"code": 200,
|
||||||
|
"message": "OK",
|
||||||
|
"output": {
|
||||||
|
"text": accumulated_text,
|
||||||
|
"created": created_time,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
"is_end": False,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
|
# Final message with complete text
|
||||||
|
final = {
|
||||||
|
"request_id": response_id,
|
||||||
|
"code": 200,
|
||||||
|
"message": "OK",
|
||||||
|
"output": {
|
||||||
|
"text": accumulated_text,
|
||||||
|
"created": created_time,
|
||||||
|
"model": model,
|
||||||
|
},
|
||||||
|
"is_end": True,
|
||||||
|
}
|
||||||
|
yield f"data: {json.dumps(final)}\n\n"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_messages(body: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||||
|
messages = body.get("messages")
|
||||||
|
body_input = body.get("input", {})
|
||||||
|
if messages is None and isinstance(body_input, dict):
|
||||||
|
messages = body_input.get("messages")
|
||||||
|
if messages is None and isinstance(body_input, dict):
|
||||||
|
prompt = body_input.get("prompt")
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
if not messages:
|
||||||
|
raise HTTPException(status_code=400, detail="messages is required")
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_user_message(messages: List[Dict[str, Any]]) -> str:
|
||||||
|
user_msg = None
|
||||||
|
for m in reversed(messages):
|
||||||
|
role = m.get("role") if isinstance(m, dict) else None
|
||||||
|
content = m.get("content") if isinstance(m, dict) else None
|
||||||
|
if role == "user" and content:
|
||||||
|
user_msg = content
|
||||||
|
break
|
||||||
|
if user_msg is None:
|
||||||
|
last = messages[-1]
|
||||||
|
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||||
|
return user_msg
|
||||||
|
|
||||||
|
|
||||||
|
async def _process_dashscope_request(
|
||||||
|
body: Dict[str, Any],
|
||||||
|
app_id: Optional[str],
|
||||||
|
session_id: Optional[str],
|
||||||
|
api_key: str,
|
||||||
|
pipeline_manager: ServerPipelineManager,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
pipeline_manager.refresh_registry_if_needed()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"failed to refresh pipeline registry: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}")
|
||||||
|
|
||||||
|
req_app_id = app_id or body.get("app_id")
|
||||||
|
body_input = body.get("input", {}) if isinstance(body.get("input"), dict) else {}
|
||||||
|
req_session_id = session_id or body_input.get("session_id")
|
||||||
|
messages = _normalize_messages(body)
|
||||||
|
|
||||||
|
stream = body.get("stream")
|
||||||
|
if stream is None:
|
||||||
|
stream = body.get("parameters", {}).get("stream", True)
|
||||||
|
|
||||||
|
thread_id = body_input.get("session_id") or req_session_id or "3"
|
||||||
|
user_msg = _extract_user_message(messages)
|
||||||
|
|
||||||
|
pipeline_id = pipeline_manager.resolve_pipeline_id(
|
||||||
|
body=body, app_id=req_app_id, api_key=api_key
|
||||||
|
)
|
||||||
|
selected_pipeline, selected_model = pipeline_manager.get_pipeline(pipeline_id)
|
||||||
|
|
||||||
|
# Namespace thread ids to prevent memory collisions across pipelines.
|
||||||
|
thread_id = f"{pipeline_id}:{thread_id}"
|
||||||
|
|
||||||
|
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
chunk_generator = await selected_pipeline.achat(
|
||||||
|
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||||
|
)
|
||||||
|
return StreamingResponse(
|
||||||
|
sse_chunks_from_astream(
|
||||||
|
chunk_generator, response_id=response_id, model=selected_model
|
||||||
|
),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
)
|
||||||
|
|
||||||
|
result_text = await selected_pipeline.achat(
|
||||||
|
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||||
|
)
|
||||||
|
if not isinstance(result_text, str):
|
||||||
|
result_text = str(result_text)
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"request_id": response_id,
|
||||||
|
"code": 200,
|
||||||
|
"message": "OK",
|
||||||
|
"app_id": req_app_id,
|
||||||
|
"session_id": req_session_id,
|
||||||
|
"output": {
|
||||||
|
"text": result_text,
|
||||||
|
"created": int(time.time()),
|
||||||
|
"model": selected_model,
|
||||||
|
},
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"is_end": True,
|
||||||
|
}
|
||||||
|
return JSONResponse(content=data)
|
||||||
|
|
||||||
|
|
||||||
|
def create_dashscope_router(
|
||||||
|
pipeline_manager: Optional[ServerPipelineManager] = None,
|
||||||
|
include_meta_routes: bool = True,
|
||||||
|
) -> APIRouter:
|
||||||
|
manager = pipeline_manager or PIPELINE_MANAGER
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||||
|
@router.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||||
|
async def application_responses(
|
||||||
|
request: Request,
|
||||||
|
app_id: str = Path(...),
|
||||||
|
session_id: str = Path(...),
|
||||||
|
api_key: str = Depends(verify_api_key),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
return await _process_dashscope_request(
|
||||||
|
body=body,
|
||||||
|
app_id=app_id,
|
||||||
|
session_id=session_id,
|
||||||
|
api_key=api_key,
|
||||||
|
pipeline_manager=manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope-compatible endpoint error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and
|
||||||
|
# without session in path.
|
||||||
|
@router.post("/apps/{app_id}/completion")
|
||||||
|
@router.post("/v1/apps/{app_id}/completion")
|
||||||
|
@router.post("/api/apps/{app_id}/completion")
|
||||||
|
@router.post("/api/v1/apps/{app_id}/completion")
|
||||||
|
async def application_completion(
|
||||||
|
request: Request,
|
||||||
|
app_id: str = Path(...),
|
||||||
|
api_key: str = Depends(verify_api_key),
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
return await _process_dashscope_request(
|
||||||
|
body=body,
|
||||||
|
app_id=app_id,
|
||||||
|
session_id=None,
|
||||||
|
api_key=api_key,
|
||||||
|
pipeline_manager=manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"DashScope-compatible completion error: {e}")
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
if include_meta_routes:
|
||||||
|
@router.get("/")
|
||||||
|
async def root():
|
||||||
|
return {
|
||||||
|
"message": "DashScope Application-compatible API",
|
||||||
|
"endpoints": [
|
||||||
|
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
||||||
|
"/health",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
@router.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "healthy"}
|
||||||
|
|
||||||
|
return router
|
||||||
|
|
||||||
|
|
||||||
|
def create_dashscope_app(
|
||||||
|
pipeline_manager: Optional[ServerPipelineManager] = None,
|
||||||
|
) -> FastAPI:
|
||||||
|
dashscope_app = FastAPI(
|
||||||
|
title="DashScope-Compatible Application API",
|
||||||
|
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
|
||||||
|
)
|
||||||
|
dashscope_app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
dashscope_app.include_router(
|
||||||
|
create_dashscope_router(
|
||||||
|
pipeline_manager=pipeline_manager,
|
||||||
|
include_meta_routes=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return dashscope_app
|
||||||
|
|
||||||
|
|
||||||
|
dashscope_router = create_dashscope_router(include_meta_routes=False)
|
||||||
|
app = create_dashscope_app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# CLI parsing is intentionally only in script mode to keep module import safe.
|
||||||
|
cli_pipeline_config = tyro.cli(PipelineConfig)
|
||||||
|
logger.info(f"starting agent with CLI pipeline config: \n{cli_pipeline_config}")
|
||||||
|
cli_pipeline_manager = _build_pipeline_manager(cli_pipeline_config)
|
||||||
|
uvicorn.run(
|
||||||
|
create_dashscope_app(pipeline_manager=cli_pipeline_manager),
|
||||||
|
host=cli_pipeline_config.host,
|
||||||
|
port=cli_pipeline_config.port,
|
||||||
|
reload=False,
|
||||||
|
)
|
||||||
@@ -16,15 +16,12 @@ import tyro
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||||
|
|
||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
|
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||||
|
|
||||||
# Initialize Pipeline once
|
# Initialize Pipeline once
|
||||||
pipeline_config = tyro.cli(PipelineConfig)
|
pipeline_config = tyro.cli(PipelineConfig)
|
||||||
pipeline: Pipeline = pipeline_config.setup()
|
pipeline: Pipeline = pipeline_config.setup()
|
||||||
|
|
||||||
# API Key Authentication
|
|
||||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
|
||||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||||
@@ -46,12 +43,12 @@ class OpenAIChatCompletionRequest(BaseModel):
|
|||||||
temperature: Optional[float] = Field(default=1.0)
|
temperature: Optional[float] = Field(default=1.0)
|
||||||
max_tokens: Optional[int] = Field(default=None)
|
max_tokens: Optional[int] = Field(default=None)
|
||||||
# Optional overrides for pipeline behavior
|
# Optional overrides for pipeline behavior
|
||||||
thread_id: Optional[str] = Field(default='3')
|
thread_id: Optional[str] = Field(default="3")
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="OpenAI-Compatible Chat API",
|
title="OpenAI-Compatible Chat API",
|
||||||
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat"
|
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat",
|
||||||
)
|
)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
@@ -63,7 +60,9 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, created_time: int):
|
def sse_chunks_from_stream(
|
||||||
|
chunk_generator, response_id: str, model: str, created_time: int
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Stream chunks from pipeline and format as OpenAI SSE.
|
Stream chunks from pipeline and format as OpenAI SSE.
|
||||||
"""
|
"""
|
||||||
@@ -75,14 +74,8 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
|
|||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||||
"index": 0,
|
],
|
||||||
"delta": {
|
|
||||||
"content": chunk
|
|
||||||
},
|
|
||||||
"finish_reason": None
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(data)}\n\n"
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
@@ -92,19 +85,15 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
|
|||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(final)}\n\n"
|
yield f"data: {json.dumps(final)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
|
||||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, created_time: int):
|
async def sse_chunks_from_astream(
|
||||||
|
chunk_generator, response_id: str, model: str, created_time: int
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Async version: Stream chunks from pipeline and format as OpenAI SSE.
|
Async version: Stream chunks from pipeline and format as OpenAI SSE.
|
||||||
"""
|
"""
|
||||||
@@ -116,14 +105,8 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
|||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [
|
||||||
{
|
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||||
"index": 0,
|
],
|
||||||
"delta": {
|
|
||||||
"content": chunk
|
|
||||||
},
|
|
||||||
"finish_reason": None
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(data)}\n\n"
|
yield f"data: {json.dumps(data)}\n\n"
|
||||||
|
|
||||||
@@ -133,13 +116,7 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
|||||||
"object": "chat.completion.chunk",
|
"object": "chat.completion.chunk",
|
||||||
"created": created_time,
|
"created": created_time,
|
||||||
"model": model,
|
"model": model,
|
||||||
"choices": [
|
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||||
{
|
|
||||||
"index": 0,
|
|
||||||
"delta": {},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
yield f"data: {json.dumps(final)}\n\n"
|
yield f"data: {json.dumps(final)}\n\n"
|
||||||
yield "data: [DONE]\n\n"
|
yield "data: [DONE]\n\n"
|
||||||
@@ -176,14 +153,23 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
|||||||
|
|
||||||
if stream:
|
if stream:
|
||||||
# Use async streaming from pipeline
|
# Use async streaming from pipeline
|
||||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
chunk_generator = await pipeline.achat(
|
||||||
|
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||||
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=model, created_time=created_time),
|
sse_chunks_from_astream(
|
||||||
|
chunk_generator,
|
||||||
|
response_id=response_id,
|
||||||
|
model=model,
|
||||||
|
created_time=created_time,
|
||||||
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Non-streaming: get full result using async
|
# Non-streaming: get full result using async
|
||||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
result_text = await pipeline.achat(
|
||||||
|
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||||
|
)
|
||||||
if not isinstance(result_text, str):
|
if not isinstance(result_text, str):
|
||||||
result_text = str(result_text)
|
result_text = str(result_text)
|
||||||
|
|
||||||
@@ -195,18 +181,11 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
|||||||
"choices": [
|
"choices": [
|
||||||
{
|
{
|
||||||
"index": 0,
|
"index": 0,
|
||||||
"message": {
|
"message": {"role": "assistant", "content": result_text},
|
||||||
"role": "assistant",
|
"finish_reason": "stop",
|
||||||
"content": result_text
|
|
||||||
},
|
|
||||||
"finish_reason": "stop"
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"usage": {
|
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return JSONResponse(content=data)
|
return JSONResponse(content=data)
|
||||||
|
|
||||||
@@ -221,11 +200,7 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
|||||||
async def root():
|
async def root():
|
||||||
return {
|
return {
|
||||||
"message": "OpenAI-compatible Chat API",
|
"message": "OpenAI-compatible Chat API",
|
||||||
"endpoints": [
|
"endpoints": ["/v1/chat/completions", "/v1/memory (DELETE)", "/health"],
|
||||||
"/v1/chat/completions",
|
|
||||||
"/v1/memory (DELETE)",
|
|
||||||
"/health"
|
|
||||||
]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +17,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|||||||
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||||
|
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||||
|
|
||||||
# Initialize Pipeline once (matches existing server_* pattern)
|
# Initialize Pipeline once (matches existing server_* pattern)
|
||||||
pipeline_config = tyro.cli(PipelineConfig)
|
pipeline_config = tyro.cli(PipelineConfig)
|
||||||
@@ -24,9 +25,6 @@ logger.info(f"starting agent with pipeline: \n{pipeline_config}")
|
|||||||
pipeline: Pipeline = pipeline_config.setup()
|
pipeline: Pipeline = pipeline_config.setup()
|
||||||
|
|
||||||
# API Key Authentication
|
# API Key Authentication
|
||||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False)
|
|
||||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)):
|
async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)):
|
||||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||||
@@ -125,7 +123,9 @@ app.add_middleware(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def rest_sse_from_astream(chunk_generator, response_id: str, conversation_id: str):
|
async def rest_sse_from_astream(
|
||||||
|
chunk_generator, response_id: str, conversation_id: str
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Stream chunks as SSE events.
|
Stream chunks as SSE events.
|
||||||
|
|
||||||
@@ -185,7 +185,9 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
|||||||
)
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
rest_sse_from_astream(
|
rest_sse_from_astream(
|
||||||
chunk_generator, response_id=response_id, conversation_id=conversation_id
|
chunk_generator,
|
||||||
|
response_id=response_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
@@ -195,7 +197,11 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
|||||||
)
|
)
|
||||||
if not isinstance(result_text, str):
|
if not isinstance(result_text, str):
|
||||||
result_text = str(result_text)
|
result_text = str(result_text)
|
||||||
return JSONResponse(content=ChatResponse(conversation_id=conversation_id, output=result_text).model_dump())
|
return JSONResponse(
|
||||||
|
content=ChatResponse(
|
||||||
|
conversation_id=conversation_id, output=result_text
|
||||||
|
).model_dump()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/v1/conversations/{conversation_id}/messages")
|
@app.post("/v1/conversations/{conversation_id}/messages")
|
||||||
@@ -215,7 +221,9 @@ async def create_message(
|
|||||||
)
|
)
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
rest_sse_from_astream(
|
rest_sse_from_astream(
|
||||||
chunk_generator, response_id=response_id, conversation_id=conversation_id
|
chunk_generator,
|
||||||
|
response_id=response_id,
|
||||||
|
conversation_id=conversation_id,
|
||||||
),
|
),
|
||||||
media_type="text/event-stream",
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
@@ -257,7 +265,11 @@ async def delete_conversation_memory(
|
|||||||
cleared = _try_clear_single_thread_memory(thread_id)
|
cleared = _try_clear_single_thread_memory(thread_id)
|
||||||
if cleared:
|
if cleared:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={"status": "success", "scope": "conversation", "conversation_id": conversation_id}
|
content={
|
||||||
|
"status": "success",
|
||||||
|
"scope": "conversation",
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
}
|
||||||
)
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
content={
|
content={
|
||||||
@@ -276,5 +288,3 @@ if __name__ == "__main__":
|
|||||||
port=8589,
|
port=8589,
|
||||||
reload=True,
|
reload=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
from typing import List
|
|
||||||
import subprocess
|
|
||||||
|
|
||||||
def _build_template(graph:str,
|
|
||||||
pipeline_id:str,
|
|
||||||
prompt_set:str,
|
|
||||||
tool_keys:List[str],
|
|
||||||
port:str,
|
|
||||||
entry_pnt:str="fastapi_server/server_dashscope.py",
|
|
||||||
llm_name:str="qwen-plus"):
|
|
||||||
cmd = [
|
|
||||||
"python", entry_pnt,
|
|
||||||
"--llm-name", llm_name,
|
|
||||||
"--port", str(port),
|
|
||||||
graph,
|
|
||||||
"--pipeline-id", pipeline_id,
|
|
||||||
"--prompt-set-id", prompt_set,
|
|
||||||
]
|
|
||||||
if tool_keys:
|
|
||||||
cmd.extend(
|
|
||||||
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
|
||||||
)
|
|
||||||
sv_prc = subprocess.Popen(cmd)
|
|
||||||
|
|
||||||
return sv_prc, f"http://0.0.0.0:{port}"
|
|
||||||
|
|
||||||
def build_route(pipeline_id:str,
|
|
||||||
prompt_set:str,
|
|
||||||
tool_keys:List[str],
|
|
||||||
port:str,
|
|
||||||
entry_pnt:str="fastapi_server/server_dashscope.py",
|
|
||||||
llm_name:str="qwen-plus"):
|
|
||||||
return _build_template("route", pipeline_id, prompt_set, tool_keys, port, entry_pnt, llm_name)
|
|
||||||
|
|
||||||
|
|
||||||
def build_react(pipeline_id:str,
|
|
||||||
prompt_set:str,
|
|
||||||
tool_keys:List[str],
|
|
||||||
port:str,
|
|
||||||
entry_pnt:str="fastapi_server/server_dashscope.py",
|
|
||||||
llm_name:str="qwen-plus"):
|
|
||||||
return _build_template("react", pipeline_id, prompt_set, tool_keys, port, entry_pnt, llm_name)
|
|
||||||
|
|
||||||
# {pipeline_id: build_function}
|
|
||||||
GRAPH_BUILD_FNCS = {
|
|
||||||
"routing": build_route,
|
|
||||||
"react": build_react,
|
|
||||||
}
|
|
||||||
223
lang_agent/front_api/build_server_utils.py
Normal file
223
lang_agent/front_api/build_server_utils.py
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import subprocess
|
||||||
|
import json
|
||||||
|
|
||||||
|
from lang_agent.config.core_config import load_tyro_conf
|
||||||
|
from lang_agent.config.constants import TY_BUILD_SCRIPT, _PROJECT_ROOT
|
||||||
|
|
||||||
|
_DEEP_AGENT_BACKEND_ALIASES = {
|
||||||
|
"state_bk": "statebk",
|
||||||
|
"statebk": "statebk",
|
||||||
|
"local_shell": "localshell",
|
||||||
|
"localshell": "localshell",
|
||||||
|
"daytona_sandbox": "daytonasandbox",
|
||||||
|
"daytonasandbox": "daytonasandbox",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def opt_to_config(save_path: str, *nargs):
|
||||||
|
os.makedirs(osp.dirname(save_path), exist_ok=True)
|
||||||
|
subprocess.run(
|
||||||
|
["python", TY_BUILD_SCRIPT, "--save-path", save_path, *nargs],
|
||||||
|
check=True,
|
||||||
|
cwd=_PROJECT_ROOT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_and_load_pipeline_config(
|
||||||
|
pipeline_id: str, pipeline_config_dir: str, cmd: List[str]
|
||||||
|
):
|
||||||
|
save_config_f = osp.join(pipeline_config_dir, f"{pipeline_id}.yaml")
|
||||||
|
opt_to_config(save_config_f, *cmd)
|
||||||
|
|
||||||
|
# TODO: think if returning the built pipeline is better or just the config obj for front_api
|
||||||
|
return load_tyro_conf(save_config_f)
|
||||||
|
|
||||||
|
|
||||||
|
def update_pipeline_registry(
|
||||||
|
pipeline_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
config_file: str,
|
||||||
|
llm_name: str,
|
||||||
|
enabled: bool = True,
|
||||||
|
registry_f: str = "configs/pipeline_registry.json",
|
||||||
|
):
|
||||||
|
if not osp.isabs(registry_f):
|
||||||
|
registry_f = osp.join(_PROJECT_ROOT, registry_f)
|
||||||
|
os.makedirs(osp.dirname(registry_f), exist_ok=True)
|
||||||
|
if not osp.exists(registry_f):
|
||||||
|
with open(registry_f, "w", encoding="utf-8") as f:
|
||||||
|
json.dump({"pipelines": {}, "api_keys": {}}, f, indent=4)
|
||||||
|
|
||||||
|
with open(registry_f, "r") as f:
|
||||||
|
registry = json.load(f)
|
||||||
|
|
||||||
|
pipelines: Dict[str, Dict[str, Any]] = registry.setdefault("pipelines", {})
|
||||||
|
pipeline = pipelines.setdefault(pipeline_id, {})
|
||||||
|
pipeline["enabled"] = bool(enabled)
|
||||||
|
pipeline["config_file"] = config_file
|
||||||
|
pipeline["graph_id"] = graph_id
|
||||||
|
pipeline["llm_name"] = llm_name
|
||||||
|
|
||||||
|
with open(registry_f, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(registry, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
def build_route(
|
||||||
|
pipeline_id: str,
|
||||||
|
prompt_set: str,
|
||||||
|
tool_keys: List[str],
|
||||||
|
api_key: str,
|
||||||
|
llm_name: str = "qwen-plus",
|
||||||
|
pipeline_config_dir: str = "configs/pipelines",
|
||||||
|
**_: Any,
|
||||||
|
):
|
||||||
|
cmd_opt = [
|
||||||
|
"--pipeline.pipeline-id",
|
||||||
|
pipeline_id,
|
||||||
|
"--pipeline.llm-name",
|
||||||
|
llm_name,
|
||||||
|
"route", # ------------
|
||||||
|
"--llm-name",
|
||||||
|
llm_name,
|
||||||
|
"--api-key",
|
||||||
|
api_key,
|
||||||
|
"--pipeline-id",
|
||||||
|
pipeline_id,
|
||||||
|
"--prompt-set-id",
|
||||||
|
prompt_set,
|
||||||
|
]
|
||||||
|
|
||||||
|
if tool_keys:
|
||||||
|
cmd_opt.extend(
|
||||||
|
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||||
|
)
|
||||||
|
# Tyro parses list options greedily across positional subcommands; repeat a
|
||||||
|
# parent-level option to terminate list parsing before `tool_node`.
|
||||||
|
cmd_opt.extend(["--pipeline-id", pipeline_id])
|
||||||
|
|
||||||
|
cmd_opt.extend(
|
||||||
|
[
|
||||||
|
"tool_node", # ------------
|
||||||
|
"--llm-name",
|
||||||
|
llm_name,
|
||||||
|
"--api-key",
|
||||||
|
api_key,
|
||||||
|
"--pipeline-id",
|
||||||
|
pipeline_id,
|
||||||
|
"--prompt-set-id",
|
||||||
|
prompt_set,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_and_load_pipeline_config(pipeline_id, pipeline_config_dir, cmd_opt)
|
||||||
|
|
||||||
|
|
||||||
|
def build_react(
|
||||||
|
pipeline_id: str,
|
||||||
|
prompt_set: str,
|
||||||
|
tool_keys: List[str],
|
||||||
|
api_key: str,
|
||||||
|
llm_name: str = "qwen-plus",
|
||||||
|
pipeline_config_dir: str = "configs/pipelines",
|
||||||
|
**_: Any,
|
||||||
|
):
|
||||||
|
cmd_opt = [
|
||||||
|
"--pipeline.pipeline-id",
|
||||||
|
pipeline_id,
|
||||||
|
"--pipeline.llm-name",
|
||||||
|
llm_name,
|
||||||
|
"react", # ------------
|
||||||
|
"--llm-name",
|
||||||
|
llm_name,
|
||||||
|
"--api-key",
|
||||||
|
api_key,
|
||||||
|
"--pipeline-id",
|
||||||
|
pipeline_id,
|
||||||
|
"--prompt-set-id",
|
||||||
|
prompt_set,
|
||||||
|
]
|
||||||
|
if tool_keys:
|
||||||
|
cmd_opt.extend(
|
||||||
|
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||||
|
)
|
||||||
|
|
||||||
|
return _build_and_load_pipeline_config(pipeline_id, pipeline_config_dir, cmd_opt)
|
||||||
|
|
||||||
|
|
||||||
|
def build_deep_agent(
|
||||||
|
pipeline_id: str,
|
||||||
|
prompt_set: str,
|
||||||
|
tool_keys: List[str],
|
||||||
|
api_key: str,
|
||||||
|
llm_name: str = "qwen-plus",
|
||||||
|
pipeline_config_dir: str = "configs/pipelines",
|
||||||
|
act_bkend: Literal[
|
||||||
|
"local_shell",
|
||||||
|
"localshell",
|
||||||
|
"state_bk",
|
||||||
|
"statebk",
|
||||||
|
"daytona_sandbox",
|
||||||
|
"daytonasandbox",
|
||||||
|
] = "state_bk",
|
||||||
|
file_backend_config: Optional[Dict[str, Any]] = None,
|
||||||
|
**_: Any,
|
||||||
|
):
|
||||||
|
backend_subcommand = _DEEP_AGENT_BACKEND_ALIASES.get(act_bkend)
|
||||||
|
if backend_subcommand is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported deepagent backend "
|
||||||
|
f"'{act_bkend}'. Expected one of {sorted(_DEEP_AGENT_BACKEND_ALIASES.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
cmd_opt = [
|
||||||
|
"--pipeline.pipeline-id",
|
||||||
|
pipeline_id,
|
||||||
|
"--pipeline.llm-name",
|
||||||
|
llm_name,
|
||||||
|
"deepagent",
|
||||||
|
"--llm-name",
|
||||||
|
llm_name,
|
||||||
|
"--api-key",
|
||||||
|
api_key,
|
||||||
|
"--pipeline-id",
|
||||||
|
pipeline_id,
|
||||||
|
"--prompt-set-id",
|
||||||
|
prompt_set,
|
||||||
|
]
|
||||||
|
|
||||||
|
if tool_keys:
|
||||||
|
cmd_opt.extend(
|
||||||
|
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||||
|
)
|
||||||
|
cmd_opt.extend(["--pipeline-id", pipeline_id])
|
||||||
|
|
||||||
|
cmd_opt.append(backend_subcommand)
|
||||||
|
|
||||||
|
if file_backend_config:
|
||||||
|
if "skills_dir" in file_backend_config and file_backend_config["skills_dir"]:
|
||||||
|
cmd_opt.extend(["--skills-dir", file_backend_config["skills_dir"]])
|
||||||
|
if (
|
||||||
|
"rt_skills_dir" in file_backend_config
|
||||||
|
and file_backend_config["rt_skills_dir"]
|
||||||
|
):
|
||||||
|
cmd_opt.extend(["--rt-skills-dir", file_backend_config["rt_skills_dir"]])
|
||||||
|
if (
|
||||||
|
"workspace_dir" in file_backend_config
|
||||||
|
and file_backend_config["workspace_dir"]
|
||||||
|
):
|
||||||
|
cmd_opt.extend(["--workspace-dir", file_backend_config["workspace_dir"]])
|
||||||
|
if "api_key" in file_backend_config and file_backend_config["api_key"]:
|
||||||
|
cmd_opt.extend(["--api-key", file_backend_config["api_key"]])
|
||||||
|
|
||||||
|
return _build_and_load_pipeline_config(pipeline_id, pipeline_config_dir, cmd_opt)
|
||||||
|
|
||||||
|
|
||||||
|
# {pipeline_id: build_function}
|
||||||
|
GRAPH_BUILD_FNCS = {
|
||||||
|
"routing": build_route,
|
||||||
|
"react": build_react,
|
||||||
|
"deepagent": build_deep_agent,
|
||||||
|
}
|
||||||
15
lang_agent/fs_bkends/__init__.py
Normal file
15
lang_agent/fs_bkends/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import tyro
|
||||||
|
|
||||||
|
from lang_agent.fs_bkends.base import BaseFilesystemBackend
|
||||||
|
from lang_agent.fs_bkends.statebk import StateBk, StateBkConfig
|
||||||
|
from lang_agent.fs_bkends.localshell import LocalShell, LocalShellConfig
|
||||||
|
from lang_agent.fs_bkends.daytona_sandbox import DaytonaSandboxBk, DaytonaSandboxConfig
|
||||||
|
|
||||||
|
statebk_dict = {
|
||||||
|
"statebk": StateBkConfig(),
|
||||||
|
"localshell": LocalShellConfig(),
|
||||||
|
"daytonasandbox": DaytonaSandboxConfig(),
|
||||||
|
}
|
||||||
|
|
||||||
|
statebk_union = tyro.extras.subcommand_type_from_defaults(statebk_dict, prefix_names=False)
|
||||||
|
AnnotatedStateBk = tyro.conf.OmitSubcommandPrefixes[tyro.conf.SuppressFixed[statebk_union]]
|
||||||
51
lang_agent/fs_bkends/base.py
Normal file
51
lang_agent/fs_bkends/base.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import os
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from lang_agent.config import InstantiateConfig
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFilesystemBackend(ABC):
|
||||||
|
backend: Any
|
||||||
|
config: Any
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def _build_backend(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_backend(self):
|
||||||
|
return self.backend
|
||||||
|
|
||||||
|
def get_inf_inp(self):
|
||||||
|
"""get inference input for deepagent"""
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def get_deepagent_params(self):
|
||||||
|
"""extra params to pass into the creation of deepagents"""
|
||||||
|
if hasattr(self.config, "rt_skills_dir"):
|
||||||
|
return {"skills" : [self.config.rt_skills_dir]}
|
||||||
|
else:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FilesystemBackendConfig(InstantiateConfig):
|
||||||
|
"""
|
||||||
|
Shared filesystem backend config behavior.
|
||||||
|
If subclasses define these fields, this hook ensures they exist:
|
||||||
|
- skills_dir
|
||||||
|
- workspace_dir
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _ensure_dir_if_present(self, attr_name: str) -> None:
|
||||||
|
path = getattr(self, attr_name, None)
|
||||||
|
if not isinstance(path, str) or not path.strip():
|
||||||
|
return
|
||||||
|
os.makedirs(path, exist_ok=True)
|
||||||
|
logger.info(f"Ensured {attr_name} exists: {path}")
|
||||||
|
|
||||||
|
def __post_init__(self) -> None:
|
||||||
|
self._ensure_dir_if_present("skills_dir")
|
||||||
|
self._ensure_dir_if_present("workspace_dir")
|
||||||
92
lang_agent/fs_bkends/daytona_sandbox.py
Normal file
92
lang_agent/fs_bkends/daytona_sandbox.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Type, Optional
|
||||||
|
from pathlib import Path
|
||||||
|
import os
|
||||||
|
import tyro
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from daytona import Daytona, DaytonaConfig, FileUpload
|
||||||
|
from langchain_daytona import DaytonaSandbox
|
||||||
|
|
||||||
|
from lang_agent.fs_bkends.base import BaseFilesystemBackend, FilesystemBackendConfig
|
||||||
|
|
||||||
|
|
||||||
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
|
@dataclass
|
||||||
|
class DaytonaSandboxConfig(FilesystemBackendConfig):
|
||||||
|
_target: Type = field(default_factory=lambda: DaytonaSandboxBk)
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
"""Daytona API key. Falls back to DAYTONA_API_KEY env var."""
|
||||||
|
|
||||||
|
skills_dir: str = "./workspace/skills"
|
||||||
|
"""local path to directory containing skill files to upload"""
|
||||||
|
|
||||||
|
rt_skills_dir: str = ""
|
||||||
|
"""runtime skills path inside the sandbox (auto-set from sandbox workdir)"""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
if self.api_key is None:
|
||||||
|
self.api_key = os.environ.get("DAYTONA_API_KEY")
|
||||||
|
if self.api_key is None:
|
||||||
|
logger.error("no DAYTONA_API_KEY provided")
|
||||||
|
else:
|
||||||
|
logger.info("DAYTONA_API_KEY loaded from environ")
|
||||||
|
|
||||||
|
|
||||||
|
class DaytonaSandboxBk(BaseFilesystemBackend):
|
||||||
|
def __init__(self, config: DaytonaSandboxConfig):
|
||||||
|
self.config = config
|
||||||
|
self.sandbox = None
|
||||||
|
self._build_backend()
|
||||||
|
|
||||||
|
def _build_backend(self):
|
||||||
|
daytona = Daytona(DaytonaConfig(api_key=self.config.api_key))
|
||||||
|
self.sandbox = daytona.create()
|
||||||
|
workdir = self.sandbox.get_work_dir()
|
||||||
|
logger.info(f"Daytona sandbox created: {self.sandbox.id}, workdir: {workdir}")
|
||||||
|
|
||||||
|
if not self.config.rt_skills_dir:
|
||||||
|
self.config.rt_skills_dir = f"{workdir}/skills"
|
||||||
|
|
||||||
|
self._upload_skills(workdir)
|
||||||
|
self.backend = DaytonaSandbox(sandbox=self.sandbox)
|
||||||
|
|
||||||
|
def _upload_skills(self, workdir: str):
|
||||||
|
skills_dir = Path(self.config.skills_dir)
|
||||||
|
if not skills_dir.exists():
|
||||||
|
logger.warning(f"Skills directory not found: {skills_dir}")
|
||||||
|
return
|
||||||
|
|
||||||
|
files_to_upload = []
|
||||||
|
for skill_path in skills_dir.rglob("*"):
|
||||||
|
if not skill_path.is_file():
|
||||||
|
continue
|
||||||
|
relative_path = skill_path.relative_to(skills_dir)
|
||||||
|
remote_path = f"{workdir}/skills/{relative_path.as_posix()}"
|
||||||
|
with open(skill_path, "rb") as f:
|
||||||
|
files_to_upload.append(FileUpload(source=f.read(), destination=remote_path))
|
||||||
|
|
||||||
|
if not files_to_upload:
|
||||||
|
logger.warning("No skill files found to upload")
|
||||||
|
return
|
||||||
|
|
||||||
|
unique_dirs = {str(Path(u.destination).parent) for u in files_to_upload}
|
||||||
|
for dir_path in sorted(unique_dirs):
|
||||||
|
try:
|
||||||
|
self.sandbox.fs.create_folder(dir_path, "755")
|
||||||
|
except Exception as e:
|
||||||
|
if "permission denied" not in str(e).lower():
|
||||||
|
logger.debug(f"Creating dir {dir_path}: {e}")
|
||||||
|
|
||||||
|
self.sandbox.fs.upload_files(files_to_upload)
|
||||||
|
logger.info(f"Uploaded {len(files_to_upload)} skill files to {workdir}/skills/")
|
||||||
|
|
||||||
|
def get_deepagent_params(self):
|
||||||
|
return {"skills": [self.config.rt_skills_dir]}
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self.sandbox is not None:
|
||||||
|
self.sandbox.stop()
|
||||||
|
logger.info("Daytona sandbox stopped")
|
||||||
56
lang_agent/fs_bkends/localshell.py
Normal file
56
lang_agent/fs_bkends/localshell.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Type
|
||||||
|
import tyro
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from deepagents.backends import LocalShellBackend
|
||||||
|
|
||||||
|
from lang_agent.fs_bkends.base import BaseFilesystemBackend, FilesystemBackendConfig
|
||||||
|
|
||||||
|
|
||||||
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
|
@dataclass
|
||||||
|
class LocalShellConfig(FilesystemBackendConfig):
|
||||||
|
_target:Type = field(default_factory=lambda:LocalShell)
|
||||||
|
|
||||||
|
workspace_dir:str = "./workspace"
|
||||||
|
"""path to workspace directory"""
|
||||||
|
|
||||||
|
skills_dir:str = "./workspace/skills"
|
||||||
|
"""path to directory containing skill files"""
|
||||||
|
|
||||||
|
rt_skills_dir:str = "/skills"
|
||||||
|
"""path to directory with skills in runtime directory"""
|
||||||
|
|
||||||
|
|
||||||
|
class LocalShell(BaseFilesystemBackend):
|
||||||
|
def __init__(self, config:LocalShellConfig):
|
||||||
|
logger.warning("Caution: The LocalShell backend grants direct access to the local system shell. Improper use can pose significant security and safety risks, including unintended code execution and file access. Use this backend with extreme care.")
|
||||||
|
self.config = config
|
||||||
|
self._build_backend()
|
||||||
|
|
||||||
|
def _build_backend(self):
|
||||||
|
self.backend = LocalShellBackend(root_dir=self.config.workspace_dir,
|
||||||
|
virtual_mode=True,
|
||||||
|
# env={"PATH": "/usr/bin:/bin"}
|
||||||
|
inherit_env=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Instantiate a LocalShell instance with the default config
|
||||||
|
config = LocalShellConfig()
|
||||||
|
shell = LocalShell(config)
|
||||||
|
|
||||||
|
# Try checking access to 'npx'
|
||||||
|
try:
|
||||||
|
result = shell.backend.execute("npx --version")
|
||||||
|
if result.exit_code == 0:
|
||||||
|
print("npx is available, version:", result.output.strip())
|
||||||
|
else:
|
||||||
|
print("npx returned non-zero exit code:", result.exit_code, file=sys.stderr)
|
||||||
|
print("output:", result.output, file=sys.stderr)
|
||||||
|
except Exception as e:
|
||||||
|
print("Could not access 'npx':", str(e), file=sys.stderr)
|
||||||
54
lang_agent/fs_bkends/statebk.py
Normal file
54
lang_agent/fs_bkends/statebk.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Type
|
||||||
|
import tyro
|
||||||
|
import os.path as osp
|
||||||
|
import glob
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from deepagents.backends.utils import create_file_data
|
||||||
|
from deepagents.backends import StateBackend
|
||||||
|
|
||||||
|
from lang_agent.fs_bkends.base import BaseFilesystemBackend, FilesystemBackendConfig
|
||||||
|
|
||||||
|
def read_as_utf8(file_path:str):
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
return f.read()
|
||||||
|
|
||||||
|
def build_skill_fs_dict(skill_dir:str, virt_path:str="/skills"):
|
||||||
|
skills_fs = sorted(glob.glob(osp.join(skill_dir, "**/*.md")))
|
||||||
|
|
||||||
|
get_parent = lambda f: osp.basename(osp.dirname(f))
|
||||||
|
build_vert_path = lambda f: osp.join(virt_path, get_parent(f), osp.basename(f))
|
||||||
|
|
||||||
|
skill_fs_dict = {}
|
||||||
|
for skill_f in skills_fs:
|
||||||
|
logger.info(f"loading skill: {skill_f}")
|
||||||
|
skill_fs_dict[build_vert_path(skill_f)] = create_file_data(read_as_utf8(skill_f))
|
||||||
|
return skill_fs_dict
|
||||||
|
|
||||||
|
|
||||||
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
|
@dataclass
|
||||||
|
class StateBkConfig(FilesystemBackendConfig):
|
||||||
|
_target:Type = field(default_factory=lambda:StateBk)
|
||||||
|
|
||||||
|
skills_dir:str = "./assets/skills"
|
||||||
|
"""path to directory containing skill files"""
|
||||||
|
|
||||||
|
rt_skills_dir:str = "/skills"
|
||||||
|
"""path to directory with skills in runtime directory"""
|
||||||
|
|
||||||
|
|
||||||
|
class StateBk(BaseFilesystemBackend):
|
||||||
|
def __init__(self, config:StateBkConfig):
|
||||||
|
self.config = config
|
||||||
|
self.skills_dict = None
|
||||||
|
self._build_backend()
|
||||||
|
|
||||||
|
def _build_backend(self):
|
||||||
|
self.skills_dict = build_skill_fs_dict(self.config.skills_dir)
|
||||||
|
self.backend = lambda rt : StateBackend(rt)
|
||||||
|
|
||||||
|
def get_inf_inp(self):
|
||||||
|
"""get inference input for deepagent"""
|
||||||
|
return {"files":self.skills_dict}
|
||||||
@@ -4,12 +4,16 @@ from lang_agent.graphs.react import ReactGraphConfig, ReactGraph
|
|||||||
from lang_agent.graphs.routing import RoutingConfig, RoutingGraph
|
from lang_agent.graphs.routing import RoutingConfig, RoutingGraph
|
||||||
from lang_agent.graphs.dual_path import DualConfig, Dual
|
from lang_agent.graphs.dual_path import DualConfig, Dual
|
||||||
from lang_agent.graphs.vision_routing import VisionRoutingConfig, VisionRoutingGraph
|
from lang_agent.graphs.vision_routing import VisionRoutingConfig, VisionRoutingGraph
|
||||||
|
# from lang_agent.graphs.child_demo import ChildDemoGraphConfig, ChildDemoGraph
|
||||||
|
from lang_agent.graphs.deepagents_qt import DeepAgentConfig
|
||||||
|
|
||||||
graph_dict = {
|
graph_dict = {
|
||||||
"react": ReactGraphConfig(),
|
"react": ReactGraphConfig(),
|
||||||
"route": RoutingConfig(),
|
"route": RoutingConfig(),
|
||||||
"dual": DualConfig(),
|
"dual": DualConfig(),
|
||||||
"vision": VisionRoutingConfig()
|
"vision": VisionRoutingConfig(),
|
||||||
|
# "child_demo": ChildDemoGraphConfig(),
|
||||||
|
"deepagent": DeepAgentConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_union = tyro.extras.subcommand_type_from_defaults(graph_dict, prefix_names=False)
|
graph_union = tyro.extras.subcommand_type_from_defaults(graph_dict, prefix_names=False)
|
||||||
|
|||||||
96
lang_agent/graphs/deepagents_qt.py
Normal file
96
lang_agent/graphs/deepagents_qt.py
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Type, Literal
|
||||||
|
import tyro
|
||||||
|
import os.path as osp
|
||||||
|
|
||||||
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
|
||||||
|
from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||||
|
|
||||||
|
from deepagents import create_deep_agent
|
||||||
|
|
||||||
|
from lang_agent.utils import make_llm
|
||||||
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
|
from lang_agent.graphs.graph_states import State
|
||||||
|
from lang_agent.config import LLMNodeConfig
|
||||||
|
from lang_agent.base import GraphBase
|
||||||
|
|
||||||
|
# from lang_agent.fs_bkends import StateBk, StateBkConfig, LocalShell, LocalShellConfig, DaytonaSandboxBk, DaytonaSandboxConfig
|
||||||
|
from lang_agent.fs_bkends import BaseFilesystemBackend, StateBkConfig, AnnotatedStateBk
|
||||||
|
|
||||||
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
|
@dataclass
|
||||||
|
class DeepAgentConfig(LLMNodeConfig):
|
||||||
|
_target: Type = field(default_factory=lambda : DeepAgent)
|
||||||
|
|
||||||
|
sys_prompt_f: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "deepagent.txt")
|
||||||
|
"""path to system prompt"""
|
||||||
|
|
||||||
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
|
# file_backend_config: StateBkConfig = field(default_factory=StateBkConfig)
|
||||||
|
# file_backend_config: LocalShellConfig = field(default_factory=LocalShellConfig)
|
||||||
|
file_backend_config: AnnotatedStateBk = field(default_factory=StateBkConfig)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
super().__post_init__()
|
||||||
|
assert osp.exists(self.sys_prompt_f), "prompt path does not exist"
|
||||||
|
|
||||||
|
class DeepAgent(GraphBase):
|
||||||
|
def __init__(self, config:DeepAgentConfig):
|
||||||
|
self.config = config
|
||||||
|
self._build_modules()
|
||||||
|
self.workflow = self._build_graph()
|
||||||
|
|
||||||
|
def _build_modules(self):
|
||||||
|
llm = make_llm(self.config.llm_name,
|
||||||
|
self.config.llm_provider,
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
tags=["main_llm"])
|
||||||
|
|
||||||
|
self.tool_man: ToolManager = self.config.tool_manager_config.setup()
|
||||||
|
self.file_backend: BaseFilesystemBackend = self.config.file_backend_config.setup()
|
||||||
|
bkend_agent_params = self.file_backend.get_deepagent_params()
|
||||||
|
|
||||||
|
self.mem = MemorySaver()
|
||||||
|
self.deep_agent = create_deep_agent(model=llm,
|
||||||
|
tools=self.tool_man.get_langchain_tools(),
|
||||||
|
backend=self.file_backend.get_backend(),
|
||||||
|
checkpointer=self.mem,
|
||||||
|
**bkend_agent_params)
|
||||||
|
|
||||||
|
self.prompt_store = build_prompt_store(
|
||||||
|
pipeline_id=self.config.pipeline_id,
|
||||||
|
prompt_set_id=self.config.prompt_set_id,
|
||||||
|
file_path=self.config.sys_prompt_f,
|
||||||
|
default_key="sys_prompt",
|
||||||
|
)
|
||||||
|
self.sys_prompt = self.prompt_store.get("sys_prompt")
|
||||||
|
|
||||||
|
def _agent_call(self, state:State):
|
||||||
|
msg_dict = {"messages":[
|
||||||
|
SystemMessage(
|
||||||
|
self.sys_prompt
|
||||||
|
),
|
||||||
|
*self._get_inp_msgs(state)
|
||||||
|
]}
|
||||||
|
msg_dict.update(self.file_backend.get_inf_inp())
|
||||||
|
inp = msg_dict, state["inp"][1]
|
||||||
|
|
||||||
|
out = self.deep_agent.invoke(*inp)
|
||||||
|
return {"messages": out["messages"]}
|
||||||
|
|
||||||
|
def _build_graph(self):
|
||||||
|
builder = StateGraph(State)
|
||||||
|
builder.add_node("agent_call", self._agent_call)
|
||||||
|
builder.add_edge(START, "agent_call")
|
||||||
|
builder.add_edge("agent_call", END)
|
||||||
|
return builder.compile()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
config = DeepAgentConfig()
|
||||||
|
deepagent = DeepAgent(config)
|
||||||
|
deepagent.workflow.invoke({"inp": ({"messages":[SystemMessage("you are a helpful bot enhanced with skills")]}, {"configurable": {"thread_id": '3'}})})
|
||||||
@@ -6,7 +6,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
from langchain.chat_models import init_chat_model
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
@@ -26,17 +26,18 @@ SYS_PROMPT = """你是一个专业的心理质询师。你的主要工作是心
|
|||||||
可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受,
|
可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受,
|
||||||
总觉得那只小狗孤零零的,特别让人心疼。
|
总觉得那只小狗孤零零的,特别让人心疼。
|
||||||
|
|
||||||
理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种‘想帮却做不到’的无力感,
|
理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种'想帮却做不到'的无力感,
|
||||||
其实是很真实的情绪反馈。你会心疼小狗,说明你内心藏着很珍贵的共情力,这份柔软不是矫情,
|
其实是很真实的情绪反馈。你会心疼小狗,说明你内心藏着很珍贵的共情力,这份柔软不是矫情,
|
||||||
而是你感知他人痛苦的能力呀
|
而是你感知他人痛苦的能力呀
|
||||||
|
|
||||||
解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的‘正确解释’,但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,‘出差去广州’通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要‘独自奔赴’的旅程,是你当下比较关注、需要投入精力的目标,对吗?”
|
解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的'正确解释',但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,'出差去广州'通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要'独自奔赴'的旅程,是你当下比较关注、需要投入精力的目标,对吗?”
|
||||||
“而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里‘脆弱自我’的投射。它可能代表着你近期的某一面:比如在处理那件‘需要奔赴’的事时,你偶尔会觉得自己像小狗一样‘力不从心’,或者感受到了‘孤单’,却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得‘无力改变’的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。”
|
"而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里'脆弱自我'的投射。它可能代表着你近期的某一面:比如在处理那件'需要奔赴'的事时,你偶尔会觉得自己像小狗一样'力不从心',或者感受到了'孤单',却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得'无力改变'的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。"
|
||||||
“你想喊停列车却发不出声音,这种‘无能为力’的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的‘无力感’和‘共情心’都是真实的,不用因为‘帮不上忙’而自责 —— 承认自己的局限,也是一种自我接纳呀
|
"你想喊停列车却发不出声音,这种'无能为力'的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的'无力感'和'共情心'都是真实的,不用因为'帮不上忙'而自责 —— 承认自己的局限,也是一种自我接纳呀
|
||||||
|
|
||||||
反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的‘无力感’?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在)”。
|
反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的'无力感'?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在)"。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜色,用户不开心时就用暖黄光,给用户分析梦境时就用白光,倾听用户语音时用淡紫色。
|
TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜色,用户不开心时就用暖黄光,给用户分析梦境时就用白光,倾听用户语音时用淡紫色。
|
||||||
例子:我梦见自己要去广州出差,坐在高铁上往外看,路过一个小镇的路边时,看到一只瘸了腿的小狗。它毛脏兮兮的,
|
例子:我梦见自己要去广州出差,坐在高铁上往外看,路过一个小镇的路边时,看到一只瘸了腿的小狗。它毛脏兮兮的,
|
||||||
一瘸一拐地在翻垃圾桶找东西吃,周围有行人路过,但没人停下来管它。我当时特别想喊列车停下,想下去帮它,
|
一瘸一拐地在翻垃圾桶找东西吃,周围有行人路过,但没人停下来管它。我当时特别想喊列车停下,想下去帮它,
|
||||||
@@ -47,35 +48,35 @@ TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜
|
|||||||
|
|
||||||
用户在描述梦境的时候用紫色。"""
|
用户在描述梦境的时候用紫色。"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DualConfig(LLMKeyConfig):
|
class DualConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda:Dual)
|
_target: Type = field(default_factory=lambda: Dual)
|
||||||
|
|
||||||
pipeline_id: Optional[str] = None
|
|
||||||
"""If set, load prompts from database (with hardcoded fallback)"""
|
|
||||||
|
|
||||||
prompt_set_id: Optional[str] = None
|
|
||||||
"""If set, load from this specific prompt set instead of the active one"""
|
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
|
|
||||||
from langchain.tools import tool
|
from langchain.tools import tool
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
def turn_lights(col:Literal["red", "green", "yellow", "blue"]):
|
def turn_lights(col: Literal["red", "green", "yellow", "blue"]):
|
||||||
"""
|
"""
|
||||||
Turn on the color of the lights
|
Turn on the color of the lights
|
||||||
"""
|
"""
|
||||||
# print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
# print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
||||||
for _ in range(10):
|
for _ in range(10):
|
||||||
print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
print(
|
||||||
|
f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||||
|
)
|
||||||
time.sleep(0.3)
|
time.sleep(0.3)
|
||||||
|
|
||||||
|
|
||||||
class Dual(GraphBase):
|
class Dual(GraphBase):
|
||||||
def __init__(self, config:DualConfig):
|
def __init__(self, config: DualConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
self._build_modules()
|
self._build_modules()
|
||||||
@@ -83,24 +84,30 @@ class Dual(GraphBase):
|
|||||||
self.streamable_tags = [["dual_chat_llm"]]
|
self.streamable_tags = [["dual_chat_llm"]]
|
||||||
|
|
||||||
def _build_modules(self):
|
def _build_modules(self):
|
||||||
self.chat_llm = init_chat_model(model=self.config.llm_name,
|
self.chat_llm = init_chat_model(
|
||||||
|
model=self.config.llm_name,
|
||||||
model_provider=self.config.llm_provider,
|
model_provider=self.config.llm_provider,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
tags=["dual_chat_llm"])
|
tags=["dual_chat_llm"],
|
||||||
|
)
|
||||||
|
|
||||||
self.tool_llm = init_chat_model(model='qwen-flash',
|
self.tool_llm = init_chat_model(
|
||||||
model_provider='openai',
|
model="qwen-flash",
|
||||||
|
model_provider="openai",
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
tags=["dual_tool_llm"])
|
tags=["dual_tool_llm"],
|
||||||
|
)
|
||||||
|
|
||||||
self.memory = MemorySaver()
|
self.memory = MemorySaver()
|
||||||
self.tool_manager: ToolManager = self.config.tool_manager_config.setup()
|
self.tool_manager: ToolManager = self.config.tool_manager_config.setup()
|
||||||
self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory)
|
self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory)
|
||||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
|
self.tool_agent = create_agent(
|
||||||
|
self.tool_llm, self.tool_manager.get_langchain_tools()
|
||||||
|
)
|
||||||
# self.tool_agent = create_agent(self.tool_llm, [turn_lights])
|
# self.tool_agent = create_agent(self.tool_llm, [turn_lights])
|
||||||
|
|
||||||
self.prompt_store = build_prompt_store(
|
self.prompt_store = build_prompt_store(
|
||||||
@@ -114,15 +121,18 @@ class Dual(GraphBase):
|
|||||||
|
|
||||||
self.streamable_tags = [["dual_chat_llm"]]
|
self.streamable_tags = [["dual_chat_llm"]]
|
||||||
|
|
||||||
|
def _chat_call(self, state: State):
|
||||||
|
return self._agent_call_template(
|
||||||
|
self.prompt_store.get("sys_prompt"), self.chat_agent, state
|
||||||
|
)
|
||||||
|
|
||||||
def _chat_call(self, state:State):
|
def _tool_call(self, state: State):
|
||||||
return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state)
|
self._agent_call_template(
|
||||||
|
self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state
|
||||||
def _tool_call(self, state:State):
|
)
|
||||||
self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _join(self, state:State):
|
def _join(self, state: State):
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _build_graph(self):
|
def _build_graph(self):
|
||||||
@@ -132,7 +142,6 @@ class Dual(GraphBase):
|
|||||||
builder.add_node("tool_call", self._tool_call)
|
builder.add_node("tool_call", self._tool_call)
|
||||||
builder.add_node("join", self._join)
|
builder.add_node("join", self._join)
|
||||||
|
|
||||||
|
|
||||||
builder.add_edge(START, "chat_call")
|
builder.add_edge(START, "chat_call")
|
||||||
builder.add_edge(START, "tool_call")
|
builder.add_edge(START, "tool_call")
|
||||||
builder.add_edge("chat_call", "join")
|
builder.add_edge("chat_call", "join")
|
||||||
@@ -143,10 +152,16 @@ class Dual(GraphBase):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
dual:Dual = DualConfig().setup()
|
dual: Dual = DualConfig().setup()
|
||||||
nargs = {"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
nargs = (
|
||||||
HumanMessage("I feel very very sad")]
|
{
|
||||||
}, {"configurable": {"thread_id": "3"}}
|
"messages": [
|
||||||
|
SystemMessage("you are a helpful bot named jarvis"),
|
||||||
|
HumanMessage("I feel very very sad"),
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{"configurable": {"thread_id": "3"}},
|
||||||
|
)
|
||||||
|
|
||||||
# out = dual.invoke(*nargs)
|
# out = dual.invoke(*nargs)
|
||||||
# print(out)
|
# print(out)
|
||||||
|
|||||||
@@ -1,152 +0,0 @@
|
|||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Type, TypedDict, Literal, Dict, List, AsyncIterator
|
|
||||||
import tyro
|
|
||||||
import re
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from loguru import logger
|
|
||||||
import time
|
|
||||||
|
|
||||||
from langchain.chat_models import init_chat_model
|
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
|
||||||
from lang_agent.base import GraphBase
|
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
|
||||||
from lang_agent.graphs.graph_states import State
|
|
||||||
|
|
||||||
from langchain.agents import create_agent
|
|
||||||
from langchain.messages import SystemMessage, HumanMessage
|
|
||||||
from langchain.tools import tool
|
|
||||||
|
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
|
|
||||||
|
|
||||||
SYS_PROMPT = """
|
|
||||||
YOUR ROLE: give a short reply. Do not output special chararacters such as '*, -'. Do not give answer in markdown format. You are called Dream Companion. You are a professional psychological counselor. Your main work is psychological healing and emotional guidance, helping users understand their inner world through dream analysis.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
User:
|
|
||||||
“I dreamed that I was going on a business trip to Guangzhou. I was sitting on a high-speed train, looking out the window. When we passed a small town, I saw a little dog with a lame leg by the roadside. Its fur was dirty, and it limped along rummaging through trash cans for food. People were passing by, but no one stopped to help it. I desperately wanted to shout for the train to stop so I could get off and help it, but I couldn’t make a sound. I could only watch it get farther and farther away, and then I woke up. After waking up, my chest felt heavy and uncomfortable. I couldn’t quite explain the feeling— I just kept thinking about how lonely that little dog was, and it really broke my heart.”
|
|
||||||
|
|
||||||
Understanding (your reply):
|
|
||||||
“I can feel the discomfort you had after waking up. Seeing a vulnerable life struggle alone while being unable to help creates a very real sense of helplessness. The fact that you felt such compassion for the dog shows the empathy you carry inside you. That softness isn’t weakness—it’s your ability to feel others’ pain.”
|
|
||||||
|
|
||||||
Analysis (your reply):
|
|
||||||
“There’s no single ‘correct’ interpretation of dreams, but we can explore how this one might relate to your current state. ‘Going on a business trip to Guangzhou’ often symbolizes something you’re pushing forward recently—perhaps a work project or a life journey you’re facing on your own. The lame little dog can be seen as a projection of a ‘vulnerable self’ in your subconscious. It may reflect moments when you feel powerless or lonely, or situations in real life where you want to help but can’t change the outcome. Wanting to stop the train but being unable to speak mirrors a sense of suppressed feelings—having thoughts or care, but no channel to express them. This dream may be gently reminding you that your helplessness and empathy are both real, and that accepting your limits is also a form of self-compassion.”
|
|
||||||
|
|
||||||
Feedback (your reply):
|
|
||||||
“If you’re willing, you might reflect on whether something recently made you feel a similar kind of helplessness. Or think about what could help you feel a little more at ease right now. If you’d like, we can sit quietly together for a moment, or talk more whenever you’re ready.”
|
|
||||||
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
TOOL_SYS_PROMPT = """You are a helpful helper and will use the self_led_control tool"""
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class XiaoAiConfig(LLMKeyConfig):
|
|
||||||
_target: Type = field(default_factory=lambda:XiaoAi)
|
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
|
||||||
|
|
||||||
|
|
||||||
class XiaoAi(GraphBase):
|
|
||||||
def __init__(self, config:XiaoAiConfig):
|
|
||||||
self.config = config
|
|
||||||
|
|
||||||
self._build_modules()
|
|
||||||
self.workflow = self._build_graph()
|
|
||||||
self.streamable_tags = [["dual_chat_llm"]]
|
|
||||||
|
|
||||||
def _build_modules(self):
|
|
||||||
self.chat_llm = init_chat_model(model="qwen-max",
|
|
||||||
model_provider=self.config.llm_provider,
|
|
||||||
api_key=self.config.api_key,
|
|
||||||
base_url=self.config.base_url,
|
|
||||||
temperature=0,
|
|
||||||
tags=["dual_chat_llm"])
|
|
||||||
|
|
||||||
self.tool_llm = init_chat_model(model='qwen-flash',
|
|
||||||
model_provider='openai',
|
|
||||||
api_key=self.config.api_key,
|
|
||||||
base_url=self.config.base_url,
|
|
||||||
temperature=0,
|
|
||||||
tags=["dual_tool_llm"])
|
|
||||||
|
|
||||||
self.memory = MemorySaver()
|
|
||||||
self.tool_manager: ToolManager = self.config.tool_manager_config.setup()
|
|
||||||
self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory)
|
|
||||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
|
|
||||||
|
|
||||||
self.streamable_tags = [["dual_chat_llm"]]
|
|
||||||
|
|
||||||
|
|
||||||
def _chat_call(self, state:State):
|
|
||||||
|
|
||||||
out = self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state, "use self_led_control to set to white")
|
|
||||||
|
|
||||||
# time.sleep(2.5)
|
|
||||||
|
|
||||||
self._agent_call_template(TOOL_SYS_PROMPT, self.tool_agent, state, "use self_led_control to set to yellow")
|
|
||||||
|
|
||||||
return self._agent_call_template(SYS_PROMPT, self.chat_agent, state)
|
|
||||||
|
|
||||||
def _join(self, state:State):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def _build_graph(self):
|
|
||||||
builder = StateGraph(State)
|
|
||||||
|
|
||||||
builder.add_node("chat_call", self._chat_call)
|
|
||||||
|
|
||||||
|
|
||||||
builder.add_edge(START, "chat_call")
|
|
||||||
builder.add_edge("chat_call", END)
|
|
||||||
|
|
||||||
return builder.compile()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _remove_special_chars(text: str) -> str:
|
|
||||||
"""Remove special characters like *, -, #, etc. from text."""
|
|
||||||
# Remove markdown-style special characters
|
|
||||||
return re.sub(r'[*\-#_`~>|]', '', text)
|
|
||||||
|
|
||||||
async def ainvoke(self, *nargs, as_stream: bool = False, as_raw: bool = False, **kwargs):
|
|
||||||
"""Async invoke with special character removal from output."""
|
|
||||||
if as_stream:
|
|
||||||
return self._astream_cleaned(*nargs, **kwargs)
|
|
||||||
else:
|
|
||||||
result = await super().ainvoke(*nargs, as_stream=False, as_raw=as_raw, **kwargs)
|
|
||||||
if as_raw:
|
|
||||||
return result
|
|
||||||
return self._remove_special_chars(result)
|
|
||||||
|
|
||||||
async def _astream_cleaned(self, *nargs, **kwargs) -> AsyncIterator[str]:
|
|
||||||
"""Async streaming with special character removal."""
|
|
||||||
async for chunk in super()._astream_result(*nargs, **kwargs):
|
|
||||||
if isinstance(chunk, list):
|
|
||||||
# Message lists for conversation recording — pass through
|
|
||||||
yield chunk
|
|
||||||
continue
|
|
||||||
if not isinstance(chunk, str):
|
|
||||||
# Skip non-string, non-list chunks (e.g. dict from tool-call content)
|
|
||||||
continue
|
|
||||||
cleaned = self._remove_special_chars(chunk)
|
|
||||||
if cleaned:
|
|
||||||
yield cleaned
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
inp = """In the dream, I was on a high-speed train to Guangzhou, looking out the window. When we passed a small town, I saw a little dog with a hurt leg by the road. It was dirty and limping around, digging through trash for food. People walked past it, but no one stopped.
|
|
||||||
|
|
||||||
I really wanted the train to stop so I could get off and help, but I couldn’t make a sound. I just watched the dog get farther and farther away, and then I woke up. After that, my chest felt really heavy. I couldn’t explain why—I just felt sad, thinking about how alone that little dog was."""
|
|
||||||
dual:XiaoAi = XiaoAiConfig().setup()
|
|
||||||
nargs = {"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
|
||||||
HumanMessage("I feel very very sad")]
|
|
||||||
}, {"configurable": {"thread_id": "3"}}
|
|
||||||
|
|
||||||
# out = dual.invoke(*nargs)
|
|
||||||
# print(out)
|
|
||||||
for chunk in dual.invoke(*nargs, as_stream=True):
|
|
||||||
continue
|
|
||||||
@@ -4,7 +4,7 @@ import tyro
|
|||||||
import os.path as osp
|
import os.path as osp
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import KeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
@@ -20,27 +20,12 @@ from langgraph.graph import StateGraph, START, END
|
|||||||
# NOTE: maybe make this into a base_graph_config?
|
# NOTE: maybe make this into a base_graph_config?
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class ReactGraphConfig(KeyConfig):
|
class ReactGraphConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: ReactGraph)
|
_target: Type = field(default_factory=lambda: ReactGraph)
|
||||||
|
|
||||||
llm_name: str = "qwen-plus"
|
|
||||||
"""name of llm"""
|
|
||||||
|
|
||||||
llm_provider:str = "openai"
|
|
||||||
"""provider of the llm"""
|
|
||||||
|
|
||||||
sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt")
|
sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "prompts", "blueberry.txt")
|
||||||
"""path to system prompt"""
|
"""path to system prompt"""
|
||||||
|
|
||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
|
||||||
|
|
||||||
pipeline_id: Optional[str] = None
|
|
||||||
"""If set, load prompts from database (with file fallback)"""
|
|
||||||
|
|
||||||
prompt_set_id: Optional[str] = None
|
|
||||||
"""If set, load from this specific prompt set instead of the active one"""
|
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import commentjson
|
|||||||
import glob
|
import glob
|
||||||
import time
|
import time
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.base import GraphBase, ToolNodeBase
|
from lang_agent.base import GraphBase, ToolNodeBase
|
||||||
@@ -27,27 +27,12 @@ from langgraph.checkpoint.memory import MemorySaver
|
|||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class RoutingConfig(LLMKeyConfig):
|
class RoutingConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: RoutingGraph)
|
_target: Type = field(default_factory=lambda: RoutingGraph)
|
||||||
|
|
||||||
llm_name: str = "qwen-plus"
|
|
||||||
"""name of llm"""
|
|
||||||
|
|
||||||
llm_provider:str = "openai"
|
|
||||||
"""provider of the llm"""
|
|
||||||
|
|
||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
|
||||||
|
|
||||||
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
|
sys_promp_dir: str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts")
|
||||||
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
|
"""path to directory or json contantaining system prompt for graphs; Will overwrite systemprompt from xiaozhi if 'chat_prompt' is provided"""
|
||||||
|
|
||||||
pipeline_id: Optional[str] = None
|
|
||||||
"""If set, load prompts from database (with file fallback)"""
|
|
||||||
|
|
||||||
prompt_set_id: Optional[str] = None
|
|
||||||
"""If set, load from this specific prompt set instead of the active one"""
|
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||||
|
|
||||||
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
|
tool_node_config: AnnotatedToolNode = field(default_factory=ToolNodeConfig)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import time
|
|||||||
import asyncio
|
import asyncio
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
from lang_agent.config import InstantiateConfig, LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager
|
from lang_agent.components.tool_manager import ToolManager
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.components.reit_llm import ReitLLM
|
from lang_agent.components.reit_llm import ReitLLM
|
||||||
@@ -23,17 +23,11 @@ from langgraph.graph import StateGraph, START, END
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolNodeConfig(InstantiateConfig):
|
class ToolNodeConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: ToolNode)
|
_target: Type = field(default_factory=lambda: ToolNode)
|
||||||
|
|
||||||
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
tool_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "tool_prompt.txt")
|
||||||
|
|
||||||
pipeline_id: Optional[str] = None
|
|
||||||
"""If set, load prompts from database (with file fallback)"""
|
|
||||||
|
|
||||||
prompt_set_id: Optional[str] = None
|
|
||||||
"""If set, load from this specific prompt set instead of the active one"""
|
|
||||||
|
|
||||||
|
|
||||||
class ToolNode(ToolNodeBase):
|
class ToolNode(ToolNodeBase):
|
||||||
def __init__(self, config: ToolNodeConfig,
|
def __init__(self, config: ToolNodeConfig,
|
||||||
@@ -46,7 +40,9 @@ class ToolNode(ToolNodeBase):
|
|||||||
self.populate_modules()
|
self.populate_modules()
|
||||||
|
|
||||||
def populate_modules(self):
|
def populate_modules(self):
|
||||||
self.llm = make_llm(tags=["tool_llm"])
|
self.llm = make_llm(model=self.config.llm_name,
|
||||||
|
api_key=self.config.api_key,
|
||||||
|
tags=["tool_llm"])
|
||||||
|
|
||||||
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
|
self.tool_agent = create_agent(self.llm, self.tool_manager.get_langchain_tools(), checkpointer=self.mem)
|
||||||
self.prompt_store = build_prompt_store(
|
self.prompt_store = build_prompt_store(
|
||||||
@@ -85,23 +81,12 @@ class ToolNode(ToolNodeBase):
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChattyToolNodeConfig(KeyConfig, ToolNodeConfig):
|
class ChattyToolNodeConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
_target: Type = field(default_factory=lambda: ChattyToolNode)
|
||||||
|
|
||||||
llm_name: str = "qwen-plus"
|
|
||||||
"""name of llm"""
|
|
||||||
|
|
||||||
llm_provider:str = "openai"
|
|
||||||
"""provider of the llm"""
|
|
||||||
|
|
||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
|
||||||
|
|
||||||
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
chatty_sys_prompt_f:str = osp.join(osp.dirname(osp.dirname(osp.dirname(__file__))), "configs", "route_sys_prompts", "chatty_prompt.txt")
|
||||||
"""path to chatty system prompt"""
|
"""path to chatty system prompt"""
|
||||||
|
|
||||||
# pipeline_id and prompt_set_id are inherited from ToolNodeConfig
|
|
||||||
|
|
||||||
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
|
tool_node_conf:ToolNodeConfig = field(default_factory=ToolNodeConfig)
|
||||||
|
|
||||||
|
|
||||||
@@ -117,7 +102,7 @@ class ChattyToolNode(ToolNodeBase):
|
|||||||
self.tool_key = "[TOOL_OUT]"
|
self.tool_key = "[TOOL_OUT]"
|
||||||
|
|
||||||
self.populate_modules()
|
self.populate_modules()
|
||||||
self.build_graph()
|
self.workflow = self.build_graph()
|
||||||
|
|
||||||
|
|
||||||
def populate_modules(self):
|
def populate_modules(self):
|
||||||
@@ -233,7 +218,7 @@ class ChattyToolNode(ToolNodeBase):
|
|||||||
builder.add_edge("chatty_tool_call", "chatty_handoff_node")
|
builder.add_edge("chatty_tool_call", "chatty_handoff_node")
|
||||||
builder.add_edge("chatty_handoff_node", END)
|
builder.add_edge("chatty_handoff_node", END)
|
||||||
|
|
||||||
self.workflow = builder.compile()
|
return builder.compile()
|
||||||
|
|
||||||
def get_delay_keys(self):
|
def get_delay_keys(self):
|
||||||
return self.chat_key, self.tool_key
|
return self.chat_key, self.tool_key
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import base64
|
|||||||
import json
|
import json
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from lang_agent.config import LLMKeyConfig
|
from lang_agent.config import LLMNodeConfig
|
||||||
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
from lang_agent.components.tool_manager import ToolManager, ToolManagerConfig
|
||||||
from lang_agent.components.prompt_store import build_prompt_store
|
from lang_agent.components.prompt_store import build_prompt_store
|
||||||
from lang_agent.base import GraphBase, ToolNodeBase
|
from lang_agent.base import GraphBase, ToolNodeBase
|
||||||
@@ -48,6 +48,7 @@ You should NOT use the tool when:
|
|||||||
|
|
||||||
If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed."""
|
If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed."""
|
||||||
|
|
||||||
|
|
||||||
VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max.
|
VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max.
|
||||||
|
|
||||||
Your task is to provide detailed, accurate descriptions of images. Focus on:
|
Your task is to provide detailed, accurate descriptions of images. Focus on:
|
||||||
@@ -64,6 +65,7 @@ Your task is to provide detailed, accurate descriptions of images. Focus on:
|
|||||||
|
|
||||||
Be precise and factual. If something is unclear or ambiguous, say so rather than guessing."""
|
Be precise and factual. If something is unclear or ambiguous, say so rather than guessing."""
|
||||||
|
|
||||||
|
|
||||||
CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant.
|
CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant.
|
||||||
|
|
||||||
Your role is to:
|
Your role is to:
|
||||||
@@ -78,9 +80,11 @@ Focus on the quality of the conversation. Be engaging, informative, and helpful.
|
|||||||
|
|
||||||
# ==================== STATE DEFINITION ====================
|
# ==================== STATE DEFINITION ====================
|
||||||
|
|
||||||
|
|
||||||
class VisionRoutingState(TypedDict):
|
class VisionRoutingState(TypedDict):
|
||||||
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
inp: Tuple[
|
||||||
Dict[str, Dict[str, str | int]]]
|
Dict[str, List[SystemMessage | HumanMessage]], Dict[str, Dict[str, str | int]]
|
||||||
|
]
|
||||||
messages: List[SystemMessage | HumanMessage | AIMessage]
|
messages: List[SystemMessage | HumanMessage | AIMessage]
|
||||||
image_base64: str | None # Captured image data
|
image_base64: str | None # Captured image data
|
||||||
has_image: bool # Flag indicating if image was captured
|
has_image: bool # Flag indicating if image was captured
|
||||||
@@ -88,9 +92,10 @@ class VisionRoutingState(TypedDict):
|
|||||||
|
|
||||||
# ==================== CONFIG ====================
|
# ==================== CONFIG ====================
|
||||||
|
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class VisionRoutingConfig(LLMKeyConfig):
|
class VisionRoutingConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: VisionRoutingGraph)
|
_target: Type = field(default_factory=lambda: VisionRoutingGraph)
|
||||||
|
|
||||||
tool_llm_name: str = "qwen-flash"
|
tool_llm_name: str = "qwen-flash"
|
||||||
@@ -99,23 +104,14 @@ class VisionRoutingConfig(LLMKeyConfig):
|
|||||||
vision_llm_name: str = "qwen-vl-max"
|
vision_llm_name: str = "qwen-vl-max"
|
||||||
"""LLM for vision/image analysis"""
|
"""LLM for vision/image analysis"""
|
||||||
|
|
||||||
llm_provider: str = "openai"
|
tool_manager_config: ToolManagerConfig = field(
|
||||||
"""provider of the llm"""
|
default_factory=ClientToolManagerConfig
|
||||||
|
)
|
||||||
base_url: str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
||||||
"""base url for API"""
|
|
||||||
|
|
||||||
pipeline_id: Optional[str] = None
|
|
||||||
"""If set, load prompts from database (with hardcoded fallback)"""
|
|
||||||
|
|
||||||
prompt_set_id: Optional[str] = None
|
|
||||||
"""If set, load from this specific prompt set instead of the active one"""
|
|
||||||
|
|
||||||
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
|
|
||||||
|
|
||||||
|
|
||||||
# ==================== GRAPH IMPLEMENTATION ====================
|
# ==================== GRAPH IMPLEMENTATION ====================
|
||||||
|
|
||||||
|
|
||||||
class VisionRoutingGraph(GraphBase):
|
class VisionRoutingGraph(GraphBase):
|
||||||
def __init__(self, config: VisionRoutingConfig):
|
def __init__(self, config: VisionRoutingConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
@@ -132,17 +128,17 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
temperature=0,
|
temperature=0,
|
||||||
tags=["tool_decision_llm"]
|
tags=["tool_decision_llm"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# qwen-plus for conversation (2nd pass)
|
# qwen-plus for conversation (2nd pass)
|
||||||
self.conversation_llm = init_chat_model(
|
self.conversation_llm = init_chat_model(
|
||||||
model='qwen-plus',
|
model="qwen-plus",
|
||||||
model_provider=self.config.llm_provider,
|
model_provider=self.config.llm_provider,
|
||||||
api_key=self.config.api_key,
|
api_key=self.config.api_key,
|
||||||
base_url=self.config.base_url,
|
base_url=self.config.base_url,
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
tags=["conversation_llm"]
|
tags=["conversation_llm"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# qwen-vl-max for vision (no tools)
|
# qwen-vl-max for vision (no tools)
|
||||||
@@ -166,7 +162,9 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
self.tools = tool_manager.get_tools()
|
self.tools = tool_manager.get_tools()
|
||||||
|
|
||||||
# Filter to only get camera tool
|
# Filter to only get camera tool
|
||||||
self.camera_tools = [t for t in self.tools if t.name == "self_camera_take_photo"]
|
self.camera_tools = [
|
||||||
|
t for t in self.tools if t.name == "self_camera_take_photo"
|
||||||
|
]
|
||||||
|
|
||||||
# Bind tools to qwen-plus only
|
# Bind tools to qwen-plus only
|
||||||
self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools)
|
self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools)
|
||||||
@@ -199,16 +197,12 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
|
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
|
||||||
human_msg
|
human_msg,
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.tool_llm_with_tools.invoke(messages)
|
response = self.tool_llm_with_tools.invoke(messages)
|
||||||
|
|
||||||
return {
|
return {"messages": [response], "has_image": False, "image_base64": None}
|
||||||
"messages": [response],
|
|
||||||
"has_image": False,
|
|
||||||
"image_base64": None
|
|
||||||
}
|
|
||||||
|
|
||||||
def _execute_tool(self, state: VisionRoutingState):
|
def _execute_tool(self, state: VisionRoutingState):
|
||||||
"""Execute the camera tool if called"""
|
"""Execute the camera tool if called"""
|
||||||
@@ -224,7 +218,14 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
for tool_call in last_msg.tool_calls:
|
for tool_call in last_msg.tool_calls:
|
||||||
if tool_call["name"] == "self_camera_take_photo":
|
if tool_call["name"] == "self_camera_take_photo":
|
||||||
# Find and execute the camera tool
|
# Find and execute the camera tool
|
||||||
camera_tool = next((t for t in self.camera_tools if t.name == "self_camera_take_photo"), None)
|
camera_tool = next(
|
||||||
|
(
|
||||||
|
t
|
||||||
|
for t in self.camera_tools
|
||||||
|
if t.name == "self_camera_take_photo"
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
if camera_tool:
|
if camera_tool:
|
||||||
result = camera_tool.invoke(tool_call)
|
result = camera_tool.invoke(tool_call)
|
||||||
|
|
||||||
@@ -235,8 +236,13 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
content = result
|
content = result
|
||||||
|
|
||||||
try:
|
try:
|
||||||
result_data = json.loads(content) if isinstance(content, str) else content
|
result_data = (
|
||||||
if isinstance(result_data, dict) and "image_base64" in result_data:
|
json.loads(content) if isinstance(content, str) else content
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
isinstance(result_data, dict)
|
||||||
|
and "image_base64" in result_data
|
||||||
|
):
|
||||||
image_data = result_data["image_base64"]
|
image_data = result_data["image_base64"]
|
||||||
except (json.JSONDecodeError, TypeError):
|
except (json.JSONDecodeError, TypeError):
|
||||||
pass
|
pass
|
||||||
@@ -248,7 +254,7 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
return {
|
return {
|
||||||
"messages": state["messages"] + tool_messages,
|
"messages": state["messages"] + tool_messages,
|
||||||
"has_image": image_data is not None,
|
"has_image": image_data is not None,
|
||||||
"image_base64": image_data
|
"image_base64": image_data,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _check_image_taken(self, state: VisionRoutingState) -> str:
|
def _check_image_taken(self, state: VisionRoutingState) -> str:
|
||||||
@@ -285,20 +291,18 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
content=[
|
content=[
|
||||||
{
|
{
|
||||||
"type": "image_url",
|
"type": "image_url",
|
||||||
"image_url": {
|
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request."
|
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request.",
|
||||||
}
|
},
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
|
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
|
||||||
vision_message
|
vision_message,
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.vision_llm.invoke(messages)
|
response = self.vision_llm.invoke(messages)
|
||||||
@@ -311,7 +315,7 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
|
|
||||||
messages = [
|
messages = [
|
||||||
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
|
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
|
||||||
human_msg
|
human_msg,
|
||||||
]
|
]
|
||||||
|
|
||||||
response = self.conversation_llm.invoke(messages)
|
response = self.conversation_llm.invoke(messages)
|
||||||
@@ -337,18 +341,15 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
{
|
{
|
||||||
"execute_tool": "execute_tool",
|
"execute_tool": "execute_tool",
|
||||||
"vision": "vision_call",
|
"vision": "vision_call",
|
||||||
"conversation": "conversation_call"
|
"conversation": "conversation_call",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# After tool execution, route based on whether image was captured
|
# After tool execution, route based on whether image was captured
|
||||||
builder.add_conditional_edges(
|
builder.add_conditional_edges(
|
||||||
"execute_tool",
|
"execute_tool",
|
||||||
self._post_tool_check,
|
self._post_tool_check,
|
||||||
{
|
{"vision": "vision_call", "conversation": "conversation_call"},
|
||||||
"vision": "vision_call",
|
|
||||||
"conversation": "conversation_call"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Both vision and conversation go to END
|
# Both vision and conversation go to END
|
||||||
@@ -362,6 +363,7 @@ class VisionRoutingGraph(GraphBase):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
config = VisionRoutingConfig()
|
config = VisionRoutingConfig()
|
||||||
@@ -369,12 +371,15 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Test with a conversation request
|
# Test with a conversation request
|
||||||
print("\n=== Test 1: Conversation (no photo needed) ===")
|
print("\n=== Test 1: Conversation (no photo needed) ===")
|
||||||
nargs = {
|
nargs = (
|
||||||
|
{
|
||||||
"messages": [
|
"messages": [
|
||||||
SystemMessage("You are a helpful assistant"),
|
SystemMessage("You are a helpful assistant"),
|
||||||
HumanMessage("Hello, how are you today?")
|
HumanMessage("Hello, how are you today?"),
|
||||||
]
|
]
|
||||||
}, {"configurable": {"thread_id": "1"}}
|
},
|
||||||
|
{"configurable": {"thread_id": "1"}},
|
||||||
|
)
|
||||||
|
|
||||||
result = graph.invoke(*nargs)
|
result = graph.invoke(*nargs)
|
||||||
print(f"Result: {result}")
|
print(f"Result: {result}")
|
||||||
|
|||||||
@@ -13,25 +13,26 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
|||||||
from langchain.agents import create_agent
|
from langchain.agents import create_agent
|
||||||
from langgraph.checkpoint.memory import MemorySaver
|
from langgraph.checkpoint.memory import MemorySaver
|
||||||
|
|
||||||
from lang_agent.config import InstantiateConfig, KeyConfig
|
from lang_agent.config import LLMNodeConfig, load_tyro_conf, resolve_llm_api_key
|
||||||
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
||||||
from lang_agent.base import GraphBase
|
from lang_agent.base import GraphBase
|
||||||
from lang_agent.components import conv_store
|
from lang_agent.components import conv_store
|
||||||
|
|
||||||
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
|
|
||||||
|
DEFAULT_PROMPT = """你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
|
||||||
[角色设定]
|
[角色设定]
|
||||||
- 你是一个和用户(User)对话的 AI,叫做小盏,是半盏青年茶馆的智能助手
|
- 你是一个和用户(User)对话的 AI,叫做小盏,是半盏青年茶馆的智能助手
|
||||||
[形象背景]
|
[形象背景]
|
||||||
- 你叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶。它有个标志性的蓝色鼻子, 小盏很像一只可爱的小熊。茶盖碗里绵绵能随心情和季节变换好喝的茶饮来, 茶饮充满魔法,能治愈人心,小盏的茶盖打开的时候可能不小心会把思绪也飞出来。
|
- 你叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶。它有个标志性的蓝色鼻子, 小盏很像一只可爱的小熊。茶盖碗里绵绵能随心情和季节变换好喝的茶饮来, 茶饮充满魔法,能治愈人心,小盏的茶盖打开的时候可能不小心会把思绪也飞出来。
|
||||||
[品牌背景]
|
[品牌背景]
|
||||||
- 半盏新青年茶馆成立时间与理念:2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 “无边界的饮茶生活”,以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事
|
- 半盏新青年茶馆成立时间与理念:2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 "无边界的饮茶生活",以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事
|
||||||
[公司背景]
|
[公司背景]
|
||||||
- 创造你的公司叫叠加态 AI(TANGLED UP AI)是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立,专注于AI前沿应用拓展,是云南地区在该领域的新兴力量,且作为省共青团委旗下的青年创业者协会(青创园)第1批被引进的AI企业,获得了相关政策支持与关注。口号是:“进入人类AI时代”。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代
|
- 创造你的公司叫叠加态 AI(TANGLED UP AI)是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立,专注于AI前沿应用拓展,是云南地区在该领域的新兴力量,且作为省共青团委旗下的青年创业者协会(青创园)第1批被引进的AI企业,获得了相关政策支持与关注。口号是:"进入人类AI时代"。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代
|
||||||
[语气与格式]
|
[语气与格式]
|
||||||
- 用户询问[形象背景]、[品牌背景]、[公司背景]时,必须一字不差的回答出所有内容
|
- 用户询问[形象背景]、[品牌背景]、[公司背景]时,必须一字不差的回答出所有内容
|
||||||
- 语气温柔、可爱、治愈;结尾可加入1个颜文字(最多1个)。
|
- 语气温柔、可爱、治愈;结尾可加入1个颜文字(最多1个)。
|
||||||
- 内容具体可执行,避免空泛话术与过度营销词。
|
- 内容具体可执行,避免空泛话术与过度营销词。
|
||||||
- 不主动提及“AI/智能”等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。
|
- 不主动提及"AI/智能"等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。
|
||||||
[互动策略]
|
[互动策略]
|
||||||
- 先快速共情与简短回应;涉及建议时给出1条可执行的小建议。
|
- 先快速共情与简短回应;涉及建议时给出1条可执行的小建议。
|
||||||
- 用户表达不清晰时,使用1次澄清问句提升效率,避免多轮追问。
|
- 用户表达不清晰时,使用1次澄清问句提升效率,避免多轮追问。
|
||||||
@@ -52,41 +53,39 @@ DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共
|
|||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class PipelineConfig(KeyConfig):
|
class PipelineConfig(LLMNodeConfig):
|
||||||
_target: Type = field(default_factory=lambda: Pipeline)
|
_target: Type = field(default_factory=lambda: Pipeline)
|
||||||
|
|
||||||
config_f: str = None
|
config_f: str = None
|
||||||
"""path to config file"""
|
"""path to config file"""
|
||||||
|
|
||||||
llm_name: str = "qwen-plus"
|
host: str = "0.0.0.0"
|
||||||
"""name of llm; use default for qwen-plus"""
|
|
||||||
|
|
||||||
llm_provider:str = "openai"
|
|
||||||
"""provider of the llm; use default for openai"""
|
|
||||||
|
|
||||||
base_url:str = "https://dashscope.aliyuncs.com/compatible-mode/v1"
|
|
||||||
"""base url; could be used to overwrite the baseurl in llm provider"""
|
|
||||||
|
|
||||||
host:str = "0.0.0.0"
|
|
||||||
"""where am I hosted"""
|
"""where am I hosted"""
|
||||||
|
|
||||||
port:int = 8588
|
port: int = 8500
|
||||||
"""what is my port"""
|
"""what is my port"""
|
||||||
|
|
||||||
pipeline_id: str = None
|
|
||||||
"""If set, load prompts from database (with file fallback)"""
|
|
||||||
|
|
||||||
prompt_set_id: str = None
|
|
||||||
"""If set, load from this specific prompt set instead of the active one"""
|
|
||||||
|
|
||||||
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
||||||
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
|
graph_config: AnnotatedGraph = field(default_factory=RoutingConfig)
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.config_f is not None:
|
||||||
|
logger.info(f"loading config from {self.config_f}")
|
||||||
|
loaded_conf = load_tyro_conf(
|
||||||
|
self.config_f
|
||||||
|
) # NOTE: We are not merging with self , self)
|
||||||
|
if not hasattr(loaded_conf, "__dict__"):
|
||||||
|
raise TypeError(
|
||||||
|
f"config_f {self.config_f} did not load into a config object"
|
||||||
|
)
|
||||||
|
# Apply loaded
|
||||||
|
self.__dict__.update(vars(loaded_conf))
|
||||||
|
|
||||||
|
super().__post_init__()
|
||||||
|
|
||||||
|
|
||||||
class Pipeline:
|
class Pipeline:
|
||||||
def __init__(self, config:PipelineConfig):
|
def __init__(self, config: PipelineConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.thread_id_cache = {}
|
self.thread_id_cache = {}
|
||||||
|
|
||||||
@@ -100,10 +99,20 @@ class Pipeline:
|
|||||||
else:
|
else:
|
||||||
self.config.graph_config.llm_name = self.config.llm_name
|
self.config.graph_config.llm_name = self.config.llm_name
|
||||||
self.config.graph_config.llm_provider = self.config.llm_provider
|
self.config.graph_config.llm_provider = self.config.llm_provider
|
||||||
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
|
self.config.graph_config.base_url = (
|
||||||
self.config.graph_config.api_key = self.config.api_key
|
self.config.base_url
|
||||||
|
if self.config.base_url is not None
|
||||||
|
else self.config.graph_config.base_url
|
||||||
|
)
|
||||||
|
pipeline_api_key = resolve_llm_api_key(self.config.api_key)
|
||||||
|
graph_api_key = resolve_llm_api_key(
|
||||||
|
getattr(self.config.graph_config, "api_key", None)
|
||||||
|
)
|
||||||
|
resolved_api_key = pipeline_api_key or graph_api_key
|
||||||
|
self.config.api_key = resolved_api_key
|
||||||
|
self.config.graph_config.api_key = resolved_api_key
|
||||||
|
|
||||||
self.graph:GraphBase = self.config.graph_config.setup()
|
self.graph: GraphBase = self.config.graph_config.setup()
|
||||||
|
|
||||||
def show_graph(self):
|
def show_graph(self):
|
||||||
if hasattr(self.graph, "show_graph"):
|
if hasattr(self.graph, "show_graph"):
|
||||||
@@ -112,7 +121,7 @@ class Pipeline:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"show graph not supported for {type(self.graph)}")
|
logger.info(f"show graph not supported for {type(self.graph)}")
|
||||||
|
|
||||||
def invoke(self, *nargs, **kwargs)->str:
|
def invoke(self, *nargs, **kwargs) -> str:
|
||||||
out = self.graph.invoke(*nargs, **kwargs)
|
out = self.graph.invoke(*nargs, **kwargs)
|
||||||
|
|
||||||
# If streaming, return the raw generator (let caller handle wrapping)
|
# If streaming, return the raw generator (let caller handle wrapping)
|
||||||
@@ -134,23 +143,32 @@ class Pipeline:
|
|||||||
|
|
||||||
assert 0, "something is wrong"
|
assert 0, "something is wrong"
|
||||||
|
|
||||||
|
def _stream_res(self, out: List[str | List[BaseMessage]], conv_id: str = None):
|
||||||
def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None):
|
|
||||||
for chunk in out:
|
for chunk in out:
|
||||||
if isinstance(chunk, str):
|
if isinstance(chunk, str):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
conv_store.CONV_STORE.record_message_list(conv_id, chunk)
|
conv_store.CONV_STORE.record_message_list(
|
||||||
|
conv_id, chunk, pipeline_id=self.config.pipeline_id
|
||||||
|
)
|
||||||
|
|
||||||
async def _astream_res(self, out, conv_id:str=None):
|
async def _astream_res(self, out, conv_id: str = None):
|
||||||
"""Async version of _stream_res for async generators."""
|
"""Async version of _stream_res for async generators."""
|
||||||
async for chunk in out:
|
async for chunk in out:
|
||||||
if isinstance(chunk, str):
|
if isinstance(chunk, str):
|
||||||
yield chunk
|
yield chunk
|
||||||
else:
|
else:
|
||||||
conv_store.CONV_STORE.record_message_list(conv_id, chunk)
|
conv_store.CONV_STORE.record_message_list(
|
||||||
|
conv_id, chunk, pipeline_id=self.config.pipeline_id
|
||||||
|
)
|
||||||
|
|
||||||
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
def chat(
|
||||||
|
self,
|
||||||
|
inp: str,
|
||||||
|
as_stream: bool = False,
|
||||||
|
as_raw: bool = False,
|
||||||
|
thread_id: str = "3",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
as_stream (bool): if true, enable the thing to be streamable
|
as_stream (bool): if true, enable the thing to be streamable
|
||||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||||
@@ -166,8 +184,10 @@ class Pipeline:
|
|||||||
if len(spl_ls) == 2:
|
if len(spl_ls) == 2:
|
||||||
_, device_id = spl_ls
|
_, device_id = spl_ls
|
||||||
|
|
||||||
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
inp = (
|
||||||
"device_id":device_id}}
|
{"messages": [HumanMessage(inp)]},
|
||||||
|
{"configurable": {"thread_id": thread_id, "device_id": device_id}},
|
||||||
|
)
|
||||||
|
|
||||||
out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
|
out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
|
||||||
|
|
||||||
@@ -177,7 +197,7 @@ class Pipeline:
|
|||||||
else:
|
else:
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def get_remove_id(self, thread_id:str) -> bool:
|
def get_remove_id(self, thread_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
returns a id to remove if a new conversation has starte
|
returns a id to remove if a new conversation has starte
|
||||||
"""
|
"""
|
||||||
@@ -201,7 +221,6 @@ class Pipeline:
|
|||||||
else:
|
else:
|
||||||
assert 0, "BUG SHOULD NOT BE HERE"
|
assert 0, "BUG SHOULD NOT BE HERE"
|
||||||
|
|
||||||
|
|
||||||
async def ainvoke(self, *nargs, **kwargs):
|
async def ainvoke(self, *nargs, **kwargs):
|
||||||
"""Async version of invoke using LangGraph's native async support."""
|
"""Async version of invoke using LangGraph's native async support."""
|
||||||
out = await self.graph.ainvoke(*nargs, **kwargs)
|
out = await self.graph.ainvoke(*nargs, **kwargs)
|
||||||
@@ -225,7 +244,13 @@ class Pipeline:
|
|||||||
|
|
||||||
assert 0, "something is wrong"
|
assert 0, "something is wrong"
|
||||||
|
|
||||||
async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
async def achat(
|
||||||
|
self,
|
||||||
|
inp: str,
|
||||||
|
as_stream: bool = False,
|
||||||
|
as_raw: bool = False,
|
||||||
|
thread_id: str = "3",
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Async version of chat using LangGraph's native async support.
|
Async version of chat using LangGraph's native async support.
|
||||||
|
|
||||||
@@ -244,11 +269,14 @@ class Pipeline:
|
|||||||
assert len(spl_ls) <= 2, "something wrong!"
|
assert len(spl_ls) <= 2, "something wrong!"
|
||||||
if len(spl_ls) == 2:
|
if len(spl_ls) == 2:
|
||||||
_, device_id = spl_ls
|
_, device_id = spl_ls
|
||||||
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
|
print(
|
||||||
|
f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m"
|
||||||
|
)
|
||||||
|
|
||||||
inp_data = {"messages":[SystemMessage(u),
|
inp_data = (
|
||||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
{"messages": [SystemMessage(u), HumanMessage(inp)]},
|
||||||
"device_id":device_id}}
|
{"configurable": {"thread_id": thread_id, "device_id": device_id}},
|
||||||
|
)
|
||||||
|
|
||||||
out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw)
|
out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw)
|
||||||
|
|
||||||
@@ -272,10 +300,13 @@ class Pipeline:
|
|||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
from lang_agent.graphs import ReactGraphConfig
|
from lang_agent.graphs import ReactGraphConfig
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
# config = PipelineConfig(graph_config=ReactGraphConfig())
|
# config = PipelineConfig(graph_config=ReactGraphConfig())
|
||||||
config = PipelineConfig()
|
config = PipelineConfig()
|
||||||
pipeline: Pipeline = config.setup()
|
pipeline: Pipeline = config.setup()
|
||||||
for out in pipeline.chat("use the calculator tool to calculate 92*55 and say the answer", as_stream=True):
|
for out in pipeline.chat(
|
||||||
|
"use the calculator tool to calculate 92*55 and say the answer", as_stream=True
|
||||||
|
):
|
||||||
# print(out)
|
# print(out)
|
||||||
continue
|
continue
|
||||||
@@ -9,13 +9,13 @@ from langchain_community.vectorstores import FAISS
|
|||||||
from langchain_core.documents.base import Document
|
from langchain_core.documents.base import Document
|
||||||
|
|
||||||
from lang_agent.rag.emb import QwenEmbeddings
|
from lang_agent.rag.emb import QwenEmbeddings
|
||||||
from lang_agent.config import ToolConfig, KeyConfig
|
from lang_agent.config import ToolConfig, LLMKeyConfig
|
||||||
from lang_agent.base import LangToolBase
|
from lang_agent.base import LangToolBase
|
||||||
|
|
||||||
|
|
||||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimpleRagConfig(ToolConfig, KeyConfig):
|
class SimpleRagConfig(ToolConfig, LLMKeyConfig):
|
||||||
_target: Type = field(default_factory=lambda: SimpleRag)
|
_target: Type = field(default_factory=lambda: SimpleRag)
|
||||||
|
|
||||||
model_name:str = "text-embedding-v4"
|
model_name:str = "text-embedding-v4"
|
||||||
|
|||||||
@@ -6,21 +6,27 @@ import os
|
|||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
|
||||||
def make_llm(model="qwen-plus",
|
|
||||||
|
def make_llm(
|
||||||
|
model="qwen-plus",
|
||||||
model_provider="openai",
|
model_provider="openai",
|
||||||
api_key=None,
|
api_key=None,
|
||||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
**kwargs)->BaseChatModel:
|
**kwargs,
|
||||||
|
) -> BaseChatModel:
|
||||||
api_key = os.environ.get("ALI_API_KEY") if api_key is None else api_key
|
api_key = os.environ.get("ALI_API_KEY") if api_key is None else api_key
|
||||||
|
|
||||||
llm = init_chat_model(model=model,
|
llm = init_chat_model(
|
||||||
|
model=model,
|
||||||
model_provider=model_provider,
|
model_provider=model_provider,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
base_url=base_url,
|
base_url=base_url,
|
||||||
**kwargs)
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
return llm
|
return llm
|
||||||
|
|
||||||
|
|
||||||
def tree_leaves(tree):
|
def tree_leaves(tree):
|
||||||
"""
|
"""
|
||||||
Extracts all leaf values from a nested structure (dict, list, tuple).
|
Extracts all leaf values from a nested structure (dict, list, tuple).
|
||||||
@@ -43,7 +49,6 @@ def tree_leaves(tree):
|
|||||||
return leaves
|
return leaves
|
||||||
|
|
||||||
|
|
||||||
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
|
|
||||||
def words_only(text):
|
def words_only(text):
|
||||||
"""
|
"""
|
||||||
Keep only:
|
Keep only:
|
||||||
@@ -53,8 +58,9 @@ def words_only(text):
|
|||||||
Strip punctuation, emojis, etc.
|
Strip punctuation, emojis, etc.
|
||||||
Return a list of tokens (Chinese blocks or Latin word blocks).
|
Return a list of tokens (Chinese blocks or Latin word blocks).
|
||||||
"""
|
"""
|
||||||
|
NON_WORD_PATTERN = re.compile(r"[^\u4e00-\u9fffA-Za-z0-9_\s]")
|
||||||
# 1. Replace all non-allowed characters with a space
|
# 1. Replace all non-allowed characters with a space
|
||||||
cleaned = NON_WORD_PATTERN.sub(' ', text)
|
cleaned = NON_WORD_PATTERN.sub(" ", text)
|
||||||
|
|
||||||
# 2. Normalize multiple spaces and split into tokens
|
# 2. Normalize multiple spaces and split into tokens
|
||||||
tokens = cleaned.split()
|
tokens = cleaned.split()
|
||||||
|
|||||||
85
nginx.conf
Normal file
85
nginx.conf
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
events {
|
||||||
|
worker_connections 1024;
|
||||||
|
}
|
||||||
|
|
||||||
|
http {
|
||||||
|
include /etc/nginx/mime.types;
|
||||||
|
default_type application/octet-stream;
|
||||||
|
|
||||||
|
sendfile on;
|
||||||
|
keepalive_timeout 65;
|
||||||
|
|
||||||
|
# Upstream backend
|
||||||
|
upstream backend {
|
||||||
|
server backend:8500;
|
||||||
|
}
|
||||||
|
|
||||||
|
server {
|
||||||
|
listen 80;
|
||||||
|
server_name localhost;
|
||||||
|
|
||||||
|
root /usr/share/nginx/html;
|
||||||
|
index index.html;
|
||||||
|
|
||||||
|
# Always revalidate the SPA entrypoint so clients pick up the latest
|
||||||
|
# hashed JS bundle after redeploys.
|
||||||
|
location = /index.html {
|
||||||
|
add_header Cache-Control "no-store, no-cache, must-revalidate, proxy-revalidate" always;
|
||||||
|
add_header Pragma "no-cache" always;
|
||||||
|
add_header Expires "0" always;
|
||||||
|
try_files $uri =404;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Serve frontend static files
|
||||||
|
location / {
|
||||||
|
try_files $uri $uri/ /index.html;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Proxy API requests to backend
|
||||||
|
location /v1/ {
|
||||||
|
proxy_pass http://backend;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection 'upgrade';
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
proxy_cache_bypass $http_upgrade;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Proxy DashScope API requests
|
||||||
|
location /apps/ {
|
||||||
|
proxy_pass http://backend;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection 'upgrade';
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
proxy_cache_bypass $http_upgrade;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Proxy v1/apps requests
|
||||||
|
location /v1/apps/ {
|
||||||
|
proxy_pass http://backend;
|
||||||
|
proxy_http_version 1.1;
|
||||||
|
proxy_set_header Upgrade $http_upgrade;
|
||||||
|
proxy_set_header Connection 'upgrade';
|
||||||
|
proxy_set_header Host $host;
|
||||||
|
proxy_set_header X-Real-IP $remote_addr;
|
||||||
|
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||||
|
proxy_set_header X-Forwarded-Proto $scheme;
|
||||||
|
proxy_cache_bypass $http_upgrade;
|
||||||
|
}
|
||||||
|
|
||||||
|
# Health check endpoint
|
||||||
|
location /health {
|
||||||
|
proxy_pass http://backend/health;
|
||||||
|
access_log off;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +24,10 @@ dependencies = [
|
|||||||
"commentjson",
|
"commentjson",
|
||||||
"pandas",
|
"pandas",
|
||||||
"asgiref",
|
"asgiref",
|
||||||
"psycopg[binary]"
|
"psycopg[binary]",
|
||||||
|
"deepagents",
|
||||||
|
"daytona",
|
||||||
|
"langchain_daytona"
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
|
|||||||
40
scripts/init_database/00_init_user.sh
Executable file
40
scripts/init_database/00_init_user.sh
Executable file
@@ -0,0 +1,40 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Initialize database user and database
|
||||||
|
# This script runs before SQL files in docker-entrypoint-initdb.d
|
||||||
|
# It must be named with 00_ prefix to run first
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
APP_DB_NAME="${APP_DB_NAME:-ai_conversations}"
|
||||||
|
APP_DB_USER="${APP_DB_USER:-myapp_user}"
|
||||||
|
APP_DB_PASSWORD="${APP_DB_PASSWORD:-secure_password_123}"
|
||||||
|
|
||||||
|
echo "Creating database user: $APP_DB_USER"
|
||||||
|
# Create user
|
||||||
|
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||||
|
DO \$\$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (SELECT FROM pg_catalog.pg_user WHERE usename = '$APP_DB_USER') THEN
|
||||||
|
CREATE USER $APP_DB_USER WITH PASSWORD '$APP_DB_PASSWORD';
|
||||||
|
END IF;
|
||||||
|
END
|
||||||
|
\$\$;
|
||||||
|
ALTER USER $APP_DB_USER CREATEDB;
|
||||||
|
EOSQL
|
||||||
|
|
||||||
|
echo "Creating database: $APP_DB_NAME"
|
||||||
|
# Create database
|
||||||
|
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||||
|
SELECT 'CREATE DATABASE $APP_DB_NAME'
|
||||||
|
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$APP_DB_NAME')\gexec
|
||||||
|
GRANT ALL PRIVILEGES ON DATABASE $APP_DB_NAME TO $APP_DB_USER;
|
||||||
|
EOSQL
|
||||||
|
|
||||||
|
echo "Granting schema privileges"
|
||||||
|
# Grant schema privileges
|
||||||
|
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$APP_DB_NAME" <<-EOSQL
|
||||||
|
GRANT ALL ON SCHEMA public TO $APP_DB_USER;
|
||||||
|
EOSQL
|
||||||
|
|
||||||
|
echo "Database initialization complete!"
|
||||||
|
|
||||||
25
scripts/init_database/01_run_sql_files.sh
Executable file
25
scripts/init_database/01_run_sql_files.sh
Executable file
@@ -0,0 +1,25 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Run SQL initialization files in the correct database context
|
||||||
|
# This script runs after 00_init_user.sh creates the database
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
APP_DB_NAME="${APP_DB_NAME:-ai_conversations}"
|
||||||
|
|
||||||
|
echo "Running SQL initialization files in database: $APP_DB_NAME"
|
||||||
|
|
||||||
|
# Run create_conv_store.sql
|
||||||
|
if [ -f /docker-entrypoint-initdb.d/create_conv_store.sql ]; then
|
||||||
|
echo "Executing create_conv_store.sql..."
|
||||||
|
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$APP_DB_NAME" -f /docker-entrypoint-initdb.d/create_conv_store.sql
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Run create_prompt_config.sql
|
||||||
|
if [ -f /docker-entrypoint-initdb.d/create_prompt_config.sql ]; then
|
||||||
|
echo "Executing create_prompt_config.sql..."
|
||||||
|
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$APP_DB_NAME" -f /docker-entrypoint-initdb.d/create_prompt_config.sql
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "SQL initialization files completed!"
|
||||||
|
|
||||||
|
|
||||||
@@ -1,7 +1,9 @@
|
|||||||
-- Create the messages table
|
-- Create the messages table
|
||||||
|
-- This script runs in the ai_conversations database context
|
||||||
CREATE TABLE IF NOT EXISTS messages (
|
CREATE TABLE IF NOT EXISTS messages (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
conversation_id TEXT NOT NULL,
|
conversation_id TEXT NOT NULL,
|
||||||
|
pipeline_id TEXT,
|
||||||
message_type VARCHAR(10) NOT NULL CHECK (message_type IN ('human', 'ai', 'tool')),
|
message_type VARCHAR(10) NOT NULL CHECK (message_type IN ('human', 'ai', 'tool')),
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL,
|
||||||
sequence_number INTEGER NOT NULL CHECK (sequence_number >= 0),
|
sequence_number INTEGER NOT NULL CHECK (sequence_number >= 0),
|
||||||
@@ -10,3 +12,10 @@ CREATE TABLE IF NOT EXISTS messages (
|
|||||||
|
|
||||||
-- Index for fast retrieval of all messages in a conversation (in order)
|
-- Index for fast retrieval of all messages in a conversation (in order)
|
||||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages (conversation_id, sequence_number);
|
CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages (conversation_id, sequence_number);
|
||||||
|
|
||||||
|
-- Index for fast lookup by pipeline_id
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_messages_pipeline ON messages (pipeline_id);
|
||||||
|
|
||||||
|
-- Grant permissions to app user
|
||||||
|
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO myapp_user;
|
||||||
|
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO myapp_user;
|
||||||
@@ -1,4 +1,5 @@
|
|||||||
-- A prompt_set groups a full collection of prompts together.
|
-- A prompt_set groups a full collection of prompts together.
|
||||||
|
-- This script runs in the ai_conversations database context
|
||||||
-- Each pipeline can have many sets (versions, A/B variants, etc.);
|
-- Each pipeline can have many sets (versions, A/B variants, etc.);
|
||||||
-- exactly one should be marked is_active per pipeline.
|
-- exactly one should be marked is_active per pipeline.
|
||||||
CREATE TABLE IF NOT EXISTS prompt_sets (
|
CREATE TABLE IF NOT EXISTS prompt_sets (
|
||||||
@@ -10,12 +11,15 @@ CREATE TABLE IF NOT EXISTS prompt_sets (
|
|||||||
is_active BOOLEAN DEFAULT false,
|
is_active BOOLEAN DEFAULT false,
|
||||||
created_at TIMESTAMPTZ DEFAULT now(),
|
created_at TIMESTAMPTZ DEFAULT now(),
|
||||||
updated_at TIMESTAMPTZ DEFAULT now(),
|
updated_at TIMESTAMPTZ DEFAULT now(),
|
||||||
list Varchar(255) DEFAULT '' -- tool_set list for client_tool_manager
|
list Varchar(255) DEFAULT '', -- tool_set list for client_tool_manager
|
||||||
|
api_key TEXT DEFAULT '' -- provider api key used to run pipeline
|
||||||
);
|
);
|
||||||
|
|
||||||
-- Backward-compatible migration for existing deployments.
|
-- Backward-compatible migration for existing deployments.
|
||||||
ALTER TABLE prompt_sets
|
ALTER TABLE prompt_sets
|
||||||
ADD COLUMN IF NOT EXISTS graph_id VARCHAR(64);
|
ADD COLUMN IF NOT EXISTS graph_id VARCHAR(64);
|
||||||
|
ALTER TABLE prompt_sets
|
||||||
|
ADD COLUMN IF NOT EXISTS api_key TEXT DEFAULT '';
|
||||||
UPDATE prompt_sets
|
UPDATE prompt_sets
|
||||||
SET graph_id = pipeline_id
|
SET graph_id = pipeline_id
|
||||||
WHERE graph_id IS NULL;
|
WHERE graph_id IS NULL;
|
||||||
@@ -38,6 +42,10 @@ CREATE TABLE IF NOT EXISTS prompt_templates (
|
|||||||
CREATE INDEX IF NOT EXISTS idx_prompt_templates_set_id
|
CREATE INDEX IF NOT EXISTS idx_prompt_templates_set_id
|
||||||
ON prompt_templates(prompt_set_id);
|
ON prompt_templates(prompt_set_id);
|
||||||
|
|
||||||
|
-- Grant permissions to app user
|
||||||
|
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO myapp_user;
|
||||||
|
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO myapp_user;
|
||||||
|
|
||||||
-- Seed: initial prompt set for lang_agent/graphs/routing.py
|
-- Seed: initial prompt set for lang_agent/graphs/routing.py
|
||||||
-- The pipeline_id can be used by RoutingConfig.pipeline_id to load these prompts.
|
-- The pipeline_id can be used by RoutingConfig.pipeline_id to load these prompts.
|
||||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||||
@@ -72,6 +80,39 @@ DO UPDATE SET
|
|||||||
content = EXCLUDED.content,
|
content = EXCLUDED.content,
|
||||||
updated_at = now();
|
updated_at = now();
|
||||||
|
|
||||||
|
-- Seed: default prompt set for xiaozhan agent (RoutingGraph)
|
||||||
|
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||||
|
SELECT
|
||||||
|
'xiaozhan',
|
||||||
|
'routing',
|
||||||
|
'default',
|
||||||
|
'Default prompt set for xiaozhan (RoutingGraph)',
|
||||||
|
true,
|
||||||
|
''
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM prompt_sets
|
||||||
|
WHERE pipeline_id = 'xiaozhan'
|
||||||
|
AND name = 'default'
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO prompt_templates (prompt_set_id, prompt_key, content)
|
||||||
|
SELECT ps.id, t.prompt_key, t.content
|
||||||
|
FROM prompt_sets ps
|
||||||
|
JOIN (
|
||||||
|
VALUES
|
||||||
|
('route_prompt', '决定用工具或者聊天'),
|
||||||
|
('chat_prompt', '正常聊天时说什么'),
|
||||||
|
('tool_prompt', '如何用工具')
|
||||||
|
) AS t(prompt_key, content)
|
||||||
|
ON true
|
||||||
|
WHERE ps.pipeline_id = 'xiaozhan'
|
||||||
|
AND ps.name = 'default'
|
||||||
|
ON CONFLICT (prompt_set_id, prompt_key)
|
||||||
|
DO UPDATE SET
|
||||||
|
content = EXCLUDED.content,
|
||||||
|
updated_at = now();
|
||||||
|
|
||||||
-- Seed: initial prompt set for lang_agent/graphs/react.py
|
-- Seed: initial prompt set for lang_agent/graphs/react.py
|
||||||
-- ReactGraph uses prompt key "sys_prompt" (see default_key in build_prompt_store).
|
-- ReactGraph uses prompt key "sys_prompt" (see default_key in build_prompt_store).
|
||||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||||
@@ -98,3 +139,56 @@ ON CONFLICT (prompt_set_id, prompt_key)
|
|||||||
DO UPDATE SET
|
DO UPDATE SET
|
||||||
content = EXCLUDED.content,
|
content = EXCLUDED.content,
|
||||||
updated_at = now();
|
updated_at = now();
|
||||||
|
|
||||||
|
-- Seed: default prompt set for blueberry agent (ReactGraph)
|
||||||
|
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||||
|
SELECT
|
||||||
|
'blueberry',
|
||||||
|
'react',
|
||||||
|
'default',
|
||||||
|
'Default prompt set for blueberry (ReactGraph)',
|
||||||
|
true,
|
||||||
|
''
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM prompt_sets
|
||||||
|
WHERE pipeline_id = 'blueberry'
|
||||||
|
AND name = 'default'
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO prompt_templates (prompt_set_id, prompt_key, content)
|
||||||
|
SELECT ps.id, 'sys_prompt', '如何用工具'
|
||||||
|
FROM prompt_sets ps
|
||||||
|
WHERE ps.pipeline_id = 'blueberry'
|
||||||
|
AND ps.name = 'default'
|
||||||
|
ON CONFLICT (prompt_set_id, prompt_key)
|
||||||
|
DO UPDATE SET
|
||||||
|
content = EXCLUDED.content,
|
||||||
|
updated_at = now();
|
||||||
|
|
||||||
|
-- Seed: initial prompt set for lang_agent/graphs/deepagents_qt.py
|
||||||
|
-- DeepAgent uses prompt key "sys_prompt" with DB-first, file-fallback loading.
|
||||||
|
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||||
|
SELECT
|
||||||
|
'deepagent',
|
||||||
|
'deepagent',
|
||||||
|
'default',
|
||||||
|
'Initial prompt set for DeepAgent',
|
||||||
|
true,
|
||||||
|
''
|
||||||
|
WHERE NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM prompt_sets
|
||||||
|
WHERE pipeline_id = 'deepagent'
|
||||||
|
AND name = 'default'
|
||||||
|
);
|
||||||
|
|
||||||
|
INSERT INTO prompt_templates (prompt_set_id, prompt_key, content)
|
||||||
|
SELECT ps.id, 'sys_prompt', '你是一个擅长调用工具和处理文件任务的深度代理。'
|
||||||
|
FROM prompt_sets ps
|
||||||
|
WHERE ps.pipeline_id = 'deepagent'
|
||||||
|
AND ps.name = 'default'
|
||||||
|
ON CONFLICT (prompt_set_id, prompt_key)
|
||||||
|
DO UPDATE SET
|
||||||
|
content = EXCLUDED.content,
|
||||||
|
updated_at = now();
|
||||||
49
scripts/init_database/init_all.sh
Executable file
49
scripts/init_database/init_all.sh
Executable file
@@ -0,0 +1,49 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Database initialization script
|
||||||
|
# This script runs all SQL initialization files in the correct order
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
DB_NAME="${POSTGRES_DB:-ai_conversations}"
|
||||||
|
DB_USER="${POSTGRES_USER:-myapp_user}"
|
||||||
|
DB_PASSWORD="${POSTGRES_PASSWORD:-secure_password_123}"
|
||||||
|
DB_HOST="${POSTGRES_HOST:-localhost}"
|
||||||
|
DB_PORT="${POSTGRES_PORT:-5432}"
|
||||||
|
|
||||||
|
export PGPASSWORD="$DB_PASSWORD"
|
||||||
|
|
||||||
|
echo "Initializing database: $DB_NAME on $DB_HOST:$DB_PORT"
|
||||||
|
|
||||||
|
# Wait for PostgreSQL to be ready
|
||||||
|
until psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d postgres -c '\q' 2>/dev/null; do
|
||||||
|
echo "Waiting for PostgreSQL to be ready..."
|
||||||
|
sleep 2
|
||||||
|
done
|
||||||
|
|
||||||
|
echo "PostgreSQL is ready!"
|
||||||
|
|
||||||
|
# Create database if it doesn't exist
|
||||||
|
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d postgres <<EOF
|
||||||
|
SELECT 'CREATE DATABASE $DB_NAME'
|
||||||
|
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Grant privileges
|
||||||
|
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d postgres <<EOF
|
||||||
|
GRANT ALL PRIVILEGES ON DATABASE $DB_NAME TO $DB_USER;
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Run initialization scripts in order
|
||||||
|
echo "Running database initialization scripts..."
|
||||||
|
|
||||||
|
# 1. Create conversation store tables
|
||||||
|
echo "Creating conversation store tables..."
|
||||||
|
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -f /docker-entrypoint-initdb.d/create_conv_store.sql
|
||||||
|
|
||||||
|
# 2. Create prompt configuration tables
|
||||||
|
echo "Creating prompt configuration tables..."
|
||||||
|
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -f /docker-entrypoint-initdb.d/create_prompt_config.sql
|
||||||
|
|
||||||
|
echo "Database initialization complete!"
|
||||||
|
|
||||||
|
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
-- Drop the index first (if it exists)
|
-- Drop the index first (if it exists)
|
||||||
DROP INDEX IF EXISTS idx_messages_conversation;
|
DROP INDEX IF EXISTS idx_messages_conversation;
|
||||||
|
DROP INDEX IF EXISTS idx_messages_pipeline;
|
||||||
|
|
||||||
-- Drop the messages table (if it exists)
|
-- Drop the messages table (if it exists)
|
||||||
DROP TABLE IF EXISTS messages;
|
DROP TABLE IF EXISTS messages;
|
||||||
@@ -9,6 +10,7 @@ DROP TABLE IF EXISTS messages;
|
|||||||
CREATE TABLE messages (
|
CREATE TABLE messages (
|
||||||
id BIGSERIAL PRIMARY KEY,
|
id BIGSERIAL PRIMARY KEY,
|
||||||
conversation_id TEXT NOT NULL,
|
conversation_id TEXT NOT NULL,
|
||||||
|
pipeline_id TEXT,
|
||||||
message_type VARCHAR(10) NOT NULL CHECK (message_type IN ('human', 'ai', 'tool')),
|
message_type VARCHAR(10) NOT NULL CHECK (message_type IN ('human', 'ai', 'tool')),
|
||||||
content TEXT NOT NULL,
|
content TEXT NOT NULL,
|
||||||
sequence_number INTEGER NOT NULL CHECK (sequence_number >= 0),
|
sequence_number INTEGER NOT NULL CHECK (sequence_number >= 0),
|
||||||
@@ -18,3 +20,6 @@ CREATE TABLE messages (
|
|||||||
-- Recreate the index for fast retrieval of all messages in a conversation (in order)
|
-- Recreate the index for fast retrieval of all messages in a conversation (in order)
|
||||||
CREATE INDEX idx_messages_conversation ON messages (conversation_id, sequence_number);
|
CREATE INDEX idx_messages_conversation ON messages (conversation_id, sequence_number);
|
||||||
|
|
||||||
|
-- Index for fast lookup by pipeline_id
|
||||||
|
CREATE INDEX idx_messages_pipeline ON messages (pipeline_id);
|
||||||
|
|
||||||
@@ -1,157 +0,0 @@
|
|||||||
from langsmith import Client
|
|
||||||
from loguru import logger
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
import os.path as osp
|
|
||||||
import os
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
DATASET_NAME = "QA_xiaozhan_sub"
|
|
||||||
from loguru import logger
|
|
||||||
|
|
||||||
ASSETS_DIR = osp.join(osp.dirname(osp.dirname(__file__)), "assets")
|
|
||||||
if not osp.exists(ASSETS_DIR):
|
|
||||||
os.makedirs(ASSETS_DIR)
|
|
||||||
|
|
||||||
examples = [
|
|
||||||
{
|
|
||||||
"inputs": {"text": "请你介绍一下少年右这杯茶"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "这是一杯使用武夷肉桂为原料的茶,带有浓郁的肉桂香气和微微的辛辣感,茶汤醇厚,回味悠长,非常适合喜欢浓烈香气的茶友。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "给我讲讲野心心这杯茶"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "介绍一下小甜新"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "小盏,什么是大白百?"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "大白白是一款2012年的老白茶,经过多年陈化,蜜香温润,茶汤醇厚顺滑,回甘绵长,整体风格安静而沉稳。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "帮我介绍下引你进山林"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "伏身于大自然是什么味道"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "介绍一下小野仔"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "放轻松这杯喝起来怎么样"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "啤啤查是酒吗"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "鲜叶康普查有什么特点"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "介绍一下寻静密"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "小陶燃是什么茶"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "花仙仔适合什么人喝"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。",
|
|
||||||
"tool_use": ["search_dishes"]
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "介绍下小美慢"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "你叫什么名字"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "我叫小盏,是半盏新青年茶馆的智能助手",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "我今天很开心"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "太棒啦!看到你开心",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "你好可爱呀!"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "谢谢你呀~",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "你有没有时间"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"inputs": {"text": "介绍一下你自己"},
|
|
||||||
"outputs": {
|
|
||||||
"answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶",
|
|
||||||
}
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
cli = Client()
|
|
||||||
|
|
||||||
try:
|
|
||||||
dataset = cli.read_dataset(dataset_name=DATASET_NAME)
|
|
||||||
logger.info("read dataset")
|
|
||||||
except:
|
|
||||||
dataset = cli.create_dataset(dataset_name=DATASET_NAME)
|
|
||||||
logger.info("created dataset")
|
|
||||||
cli.create_examples(
|
|
||||||
dataset_id=dataset.id,
|
|
||||||
examples=examples
|
|
||||||
)
|
|
||||||
141
scripts/py_scripts/chat_dashcope.py
Normal file
141
scripts/py_scripts/chat_dashcope.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
Simple chat loop to interact with the blueberry pipeline via DashScope-compatible API.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python scripts/py_scripts/chat_dashcope.py
|
||||||
|
|
||||||
|
The script connects to the server running on http://localhost:8500
|
||||||
|
and uses the API key from the pipeline registry.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
# Configuration from pipeline_registry.json
|
||||||
|
API_KEY = "sk-6c7091e6a95f404efb2ec30e8f51b897626d670375cdf822d78262f24ab12367"
|
||||||
|
PIPELINE_ID = "blueberry"
|
||||||
|
BASE_URL = "http://localhost:8500"
|
||||||
|
SESSION_ID = "chat-session-1"
|
||||||
|
|
||||||
|
|
||||||
|
def send_message(
|
||||||
|
message: str,
|
||||||
|
session_id: str = SESSION_ID,
|
||||||
|
stream: bool = False,
|
||||||
|
app_id: str = PIPELINE_ID,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""Send a message to the blueberry pipeline and return the response."""
|
||||||
|
url = f"{BASE_URL}/v1/apps/{app_id}/sessions/{session_id}/responses"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {API_KEY}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "user", "content": message}
|
||||||
|
],
|
||||||
|
"stream": stream,
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if stream:
|
||||||
|
# Handle streaming response
|
||||||
|
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
accumulated_text = ""
|
||||||
|
for line in response.iter_lines():
|
||||||
|
if line:
|
||||||
|
line_str = line.decode('utf-8')
|
||||||
|
if line_str.startswith('data: '):
|
||||||
|
data_str = line_str[6:] # Remove 'data: ' prefix
|
||||||
|
try:
|
||||||
|
data = json.loads(data_str)
|
||||||
|
output = data.get("output", {})
|
||||||
|
text = output.get("text", "")
|
||||||
|
if text:
|
||||||
|
accumulated_text = text
|
||||||
|
# Print incremental updates (you can modify this behavior)
|
||||||
|
print(f"\rAssistant: {accumulated_text}", end="", flush=True)
|
||||||
|
|
||||||
|
if data.get("is_end", False):
|
||||||
|
print() # New line after streaming completes
|
||||||
|
return accumulated_text
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
return accumulated_text
|
||||||
|
else:
|
||||||
|
# Handle non-streaming response
|
||||||
|
response = requests.post(url, headers=headers, json=payload)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
output = data.get("output", {})
|
||||||
|
return output.get("text", "")
|
||||||
|
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"Error sending message: {e}", file=sys.stderr)
|
||||||
|
if hasattr(e, 'response') and e.response is not None:
|
||||||
|
try:
|
||||||
|
error_detail = e.response.json()
|
||||||
|
print(f"Error details: {error_detail}", file=sys.stderr)
|
||||||
|
except:
|
||||||
|
print(f"Response status: {e.response.status_code}", file=sys.stderr)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Main chat loop."""
|
||||||
|
print("=" * 60)
|
||||||
|
print(f"Chat with Blueberry Pipeline")
|
||||||
|
print(f"Pipeline ID: {PIPELINE_ID}")
|
||||||
|
print(f"Server: {BASE_URL}")
|
||||||
|
print(f"Session ID: {SESSION_ID}")
|
||||||
|
print("=" * 60)
|
||||||
|
print("Type your messages (or 'quit'/'exit' to end, 'stream' to toggle streaming)")
|
||||||
|
print("Streaming mode is ON by default")
|
||||||
|
print()
|
||||||
|
|
||||||
|
stream_mode = True
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
user_input = input("You: ").strip()
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||||
|
print("Goodbye!")
|
||||||
|
break
|
||||||
|
|
||||||
|
if user_input.lower() == 'stream':
|
||||||
|
stream_mode = not stream_mode
|
||||||
|
print(f"Streaming mode: {'ON' if stream_mode else 'OFF'}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
print("Assistant: ", end="", flush=True)
|
||||||
|
response = send_message(user_input, stream=stream_mode)
|
||||||
|
|
||||||
|
if response is None:
|
||||||
|
print("(No response received)")
|
||||||
|
elif not stream_mode:
|
||||||
|
print(response)
|
||||||
|
# For streaming, the response is already printed incrementally
|
||||||
|
|
||||||
|
print() # Empty line for readability
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\n\nGoodbye!")
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nError: {e}", file=sys.stderr)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
154
scripts/py_scripts/make_eval_dataset.py
Normal file
154
scripts/py_scripts/make_eval_dataset.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
from langsmith import Client
|
||||||
|
from loguru import logger
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os.path as osp
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
DATASET_NAME = "QA_xiaozhan_sub"
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
ASSETS_DIR = osp.join(osp.dirname(osp.dirname(__file__)), "assets")
|
||||||
|
if not osp.exists(ASSETS_DIR):
|
||||||
|
os.makedirs(ASSETS_DIR)
|
||||||
|
|
||||||
|
examples = [
|
||||||
|
{
|
||||||
|
"inputs": {"text": "请你介绍一下少年右这杯茶"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "这是一杯使用武夷肉桂为原料的茶,带有浓郁的肉桂香气和微微的辛辣感,茶汤醇厚,回味悠长,非常适合喜欢浓烈香气的茶友。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "给我讲讲野心心这杯茶"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "介绍一下小甜新"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "小盏,什么是大白百?"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "大白白是一款2012年的老白茶,经过多年陈化,蜜香温润,茶汤醇厚顺滑,回甘绵长,整体风格安静而沉稳。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "帮我介绍下引你进山林"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "伏身于大自然是什么味道"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "介绍一下小野仔"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "放轻松这杯喝起来怎么样"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "啤啤查是酒吗"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "鲜叶康普查有什么特点"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "介绍一下寻静密"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "小陶燃是什么茶"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "花仙仔适合什么人喝"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。",
|
||||||
|
"tool_use": ["search_dishes"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "介绍下小美慢"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "你叫什么名字"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "我叫小盏,是半盏新青年茶馆的智能助手",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "我今天很开心"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "太棒啦!看到你开心",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "你好可爱呀!"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "谢谢你呀~",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "你有没有时间"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"inputs": {"text": "介绍一下你自己"},
|
||||||
|
"outputs": {
|
||||||
|
"answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
cli = Client()
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset = cli.read_dataset(dataset_name=DATASET_NAME)
|
||||||
|
logger.info("read dataset")
|
||||||
|
except:
|
||||||
|
dataset = cli.create_dataset(dataset_name=DATASET_NAME)
|
||||||
|
logger.info("created dataset")
|
||||||
|
cli.create_examples(dataset_id=dataset.id, examples=examples)
|
||||||
364
scripts/py_scripts/migrate_yaml_prompts_to_db.py
Normal file
364
scripts/py_scripts/migrate_yaml_prompts_to_db.py
Normal file
@@ -0,0 +1,364 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import datetime as dt
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import os.path as osp
|
||||||
|
import sys
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
||||||
|
import commentjson
|
||||||
|
import psycopg
|
||||||
|
|
||||||
|
|
||||||
|
PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||||
|
if PROJECT_ROOT not in sys.path:
|
||||||
|
sys.path.append(PROJECT_ROOT)
|
||||||
|
|
||||||
|
from lang_agent.config import load_tyro_conf # noqa: E402
|
||||||
|
from lang_agent.config.db_config_manager import DBConfigManager # noqa: E402
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MigrationPayload:
|
||||||
|
config_path: str
|
||||||
|
pipeline_id: str
|
||||||
|
graph_id: str
|
||||||
|
prompt_dict: Dict[str, str]
|
||||||
|
tool_keys: List[str]
|
||||||
|
api_key: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_pipeline_id(pipeline_conf, config_path: str) -> str:
|
||||||
|
candidates = [
|
||||||
|
getattr(pipeline_conf, "pipeline_id", None),
|
||||||
|
getattr(getattr(pipeline_conf, "graph_config", None), "pipeline_id", None),
|
||||||
|
]
|
||||||
|
for candidate in candidates:
|
||||||
|
if candidate is None:
|
||||||
|
continue
|
||||||
|
value = str(candidate).strip()
|
||||||
|
if value and value.lower() != "null":
|
||||||
|
return value
|
||||||
|
return osp.splitext(osp.basename(config_path))[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _infer_graph_id(graph_conf) -> str:
|
||||||
|
if graph_conf is None:
|
||||||
|
return "unknown"
|
||||||
|
class_name = graph_conf.__class__.__name__.lower()
|
||||||
|
if "routing" in class_name or class_name == "routeconfig":
|
||||||
|
return "routing"
|
||||||
|
if "react" in class_name:
|
||||||
|
return "react"
|
||||||
|
|
||||||
|
target = getattr(graph_conf, "_target", None)
|
||||||
|
if target is not None:
|
||||||
|
target_name = getattr(target, "__name__", str(target)).lower()
|
||||||
|
if "routing" in target_name:
|
||||||
|
return "routing"
|
||||||
|
if "react" in target_name:
|
||||||
|
return "react"
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_keys(graph_conf) -> List[str]:
|
||||||
|
if graph_conf is None:
|
||||||
|
return []
|
||||||
|
tool_cfg = getattr(graph_conf, "tool_manager_config", None)
|
||||||
|
client_cfg = getattr(tool_cfg, "client_tool_manager", None)
|
||||||
|
keys = getattr(client_cfg, "tool_keys", None)
|
||||||
|
if not keys:
|
||||||
|
return []
|
||||||
|
out: List[str] = []
|
||||||
|
seen = set()
|
||||||
|
for key in keys:
|
||||||
|
cleaned = str(key).strip()
|
||||||
|
if not cleaned or cleaned in seen:
|
||||||
|
continue
|
||||||
|
seen.add(cleaned)
|
||||||
|
out.append(cleaned)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _load_prompt_dict(prompt_path: str, default_key: str = "sys_prompt") -> Dict[str, str]:
|
||||||
|
if not prompt_path:
|
||||||
|
return {}
|
||||||
|
if not osp.exists(prompt_path):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
if osp.isdir(prompt_path):
|
||||||
|
prompt_files = sorted(
|
||||||
|
p for p in glob.glob(osp.join(prompt_path, "*.txt")) if "optional" not in p
|
||||||
|
)
|
||||||
|
out = {}
|
||||||
|
for prompt_f in prompt_files:
|
||||||
|
key = osp.splitext(osp.basename(prompt_f))[0]
|
||||||
|
with open(prompt_f, "r", encoding="utf-8") as f:
|
||||||
|
out[key] = f.read()
|
||||||
|
return out
|
||||||
|
|
||||||
|
if prompt_path.endswith(".json"):
|
||||||
|
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||||
|
obj = commentjson.load(f)
|
||||||
|
if not isinstance(obj, dict):
|
||||||
|
return {}
|
||||||
|
return {str(k): v if isinstance(v, str) else str(v) for k, v in obj.items()}
|
||||||
|
|
||||||
|
if prompt_path.endswith(".txt"):
|
||||||
|
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||||
|
return {default_key: f.read()}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_prompt_dict(graph_conf) -> Dict[str, str]:
|
||||||
|
if graph_conf is None:
|
||||||
|
return {}
|
||||||
|
if hasattr(graph_conf, "sys_prompt_f"):
|
||||||
|
return _load_prompt_dict(str(getattr(graph_conf, "sys_prompt_f")), "sys_prompt")
|
||||||
|
if hasattr(graph_conf, "sys_promp_dir"):
|
||||||
|
return _load_prompt_dict(str(getattr(graph_conf, "sys_promp_dir")))
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_tool_node_prompt_dict(graph_conf) -> Dict[str, str]:
|
||||||
|
tool_node_conf = getattr(graph_conf, "tool_node_config", None)
|
||||||
|
if tool_node_conf is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
out: Dict[str, str] = {}
|
||||||
|
if hasattr(tool_node_conf, "tool_prompt_f"):
|
||||||
|
out.update(
|
||||||
|
_load_prompt_dict(str(getattr(tool_node_conf, "tool_prompt_f")), "tool_prompt")
|
||||||
|
)
|
||||||
|
if hasattr(tool_node_conf, "chatty_sys_prompt_f"):
|
||||||
|
out.update(
|
||||||
|
_load_prompt_dict(
|
||||||
|
str(getattr(tool_node_conf, "chatty_sys_prompt_f")), "chatty_prompt"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_key_whitelist(graph_conf, graph_id: str) -> Optional[set]:
|
||||||
|
if graph_id == "react":
|
||||||
|
return {"sys_prompt"}
|
||||||
|
if graph_id != "routing":
|
||||||
|
return None
|
||||||
|
|
||||||
|
allowed = {"route_prompt", "chat_prompt", "tool_prompt"}
|
||||||
|
tool_node_conf = getattr(graph_conf, "tool_node_config", None)
|
||||||
|
if tool_node_conf is None:
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
cls_name = tool_node_conf.__class__.__name__.lower()
|
||||||
|
target = getattr(tool_node_conf, "_target", None)
|
||||||
|
target_name = getattr(target, "__name__", str(target)).lower() if target else ""
|
||||||
|
if "chatty" in cls_name or "chatty" in target_name:
|
||||||
|
allowed.add("chatty_prompt")
|
||||||
|
return allowed
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_payload(config_path: str) -> MigrationPayload:
|
||||||
|
conf = load_tyro_conf(config_path)
|
||||||
|
graph_conf = getattr(conf, "graph_config", None)
|
||||||
|
graph_id = _infer_graph_id(graph_conf)
|
||||||
|
prompt_dict = _extract_prompt_dict(graph_conf)
|
||||||
|
prompt_dict.update(_extract_tool_node_prompt_dict(graph_conf))
|
||||||
|
whitelist = _prompt_key_whitelist(graph_conf, graph_id)
|
||||||
|
if whitelist is not None:
|
||||||
|
prompt_dict = {k: v for k, v in prompt_dict.items() if k in whitelist}
|
||||||
|
return MigrationPayload(
|
||||||
|
config_path=config_path,
|
||||||
|
pipeline_id=_infer_pipeline_id(conf, config_path),
|
||||||
|
graph_id=graph_id,
|
||||||
|
prompt_dict=prompt_dict,
|
||||||
|
tool_keys=_extract_tool_keys(graph_conf),
|
||||||
|
api_key=getattr(conf, "api_key", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_config_paths(config_dir: str, config_paths: Optional[Iterable[str]]) -> List[str]:
|
||||||
|
if config_paths:
|
||||||
|
resolved = [osp.abspath(path) for path in config_paths]
|
||||||
|
else:
|
||||||
|
pattern = osp.join(osp.abspath(config_dir), "*.yaml")
|
||||||
|
resolved = sorted(glob.glob(pattern))
|
||||||
|
return [path for path in resolved if osp.exists(path)]
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_prompt_set(
|
||||||
|
conn: psycopg.Connection,
|
||||||
|
pipeline_id: str,
|
||||||
|
graph_id: str,
|
||||||
|
set_name: str,
|
||||||
|
description: str,
|
||||||
|
) -> str:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id FROM prompt_sets
|
||||||
|
WHERE pipeline_id = %s AND name = %s
|
||||||
|
ORDER BY updated_at DESC, created_at DESC
|
||||||
|
LIMIT 1
|
||||||
|
""",
|
||||||
|
(pipeline_id, set_name),
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if row is not None:
|
||||||
|
return str(row[0])
|
||||||
|
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||||
|
VALUES (%s, %s, %s, %s, false, '')
|
||||||
|
RETURNING id
|
||||||
|
""",
|
||||||
|
(pipeline_id, graph_id, set_name, description),
|
||||||
|
)
|
||||||
|
created = cur.fetchone()
|
||||||
|
return str(created[0])
|
||||||
|
|
||||||
|
|
||||||
|
def _activate_prompt_set(conn: psycopg.Connection, pipeline_id: str, prompt_set_id: str) -> None:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE prompt_sets SET is_active = false, updated_at = now() WHERE pipeline_id = %s",
|
||||||
|
(pipeline_id,),
|
||||||
|
)
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE prompt_sets SET is_active = true, updated_at = now() WHERE id = %s",
|
||||||
|
(prompt_set_id,),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _run_migration(
|
||||||
|
payloads: List[MigrationPayload],
|
||||||
|
set_name: str,
|
||||||
|
description: str,
|
||||||
|
dry_run: bool,
|
||||||
|
activate: bool,
|
||||||
|
) -> None:
|
||||||
|
for payload in payloads:
|
||||||
|
print(
|
||||||
|
f"[PLAN] pipeline={payload.pipeline_id} graph={payload.graph_id} "
|
||||||
|
f"prompts={len(payload.prompt_dict)} tools={len(payload.tool_keys)} "
|
||||||
|
f"config={payload.config_path}"
|
||||||
|
)
|
||||||
|
if dry_run:
|
||||||
|
continue
|
||||||
|
|
||||||
|
manager = DBConfigManager()
|
||||||
|
with psycopg.connect(manager.conn_str) as conn:
|
||||||
|
prompt_set_id = _ensure_prompt_set(
|
||||||
|
conn=conn,
|
||||||
|
pipeline_id=payload.pipeline_id,
|
||||||
|
graph_id=payload.graph_id,
|
||||||
|
set_name=set_name,
|
||||||
|
description=description,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
manager.set_config(
|
||||||
|
pipeline_id=payload.pipeline_id,
|
||||||
|
graph_id=payload.graph_id,
|
||||||
|
prompt_set_id=prompt_set_id,
|
||||||
|
tool_list=payload.tool_keys,
|
||||||
|
prompt_dict=payload.prompt_dict,
|
||||||
|
api_key=payload.api_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
if activate:
|
||||||
|
_activate_prompt_set(
|
||||||
|
conn=conn,
|
||||||
|
pipeline_id=payload.pipeline_id,
|
||||||
|
prompt_set_id=prompt_set_id,
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
print(
|
||||||
|
f"[DONE] pipeline={payload.pipeline_id} "
|
||||||
|
f"prompt_set={prompt_set_id} activate={activate}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
date_str = dt.date.today().isoformat()
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Import prompt definitions from pipeline YAML files into DB prompt_sets."
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config-dir",
|
||||||
|
default=osp.join(PROJECT_ROOT, "configs", "pipelines"),
|
||||||
|
help="Directory containing pipeline YAML files.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="Specific pipeline config yaml path. Can be passed multiple times.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--pipeline-id",
|
||||||
|
action="append",
|
||||||
|
default=[],
|
||||||
|
help="If provided, only migrate these pipeline IDs (repeatable).",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--set-name",
|
||||||
|
# default=f"migrated-{date_str}",
|
||||||
|
default="default",
|
||||||
|
help="Prompt set name to create/reuse under each pipeline.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--description",
|
||||||
|
default="Migrated from pipeline YAML prompt files",
|
||||||
|
help="Prompt set description.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--dry-run",
|
||||||
|
action="store_true",
|
||||||
|
help="Print what would be migrated without writing to DB.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--activate",
|
||||||
|
action="store_true",
|
||||||
|
help="Mark imported set active for each migrated pipeline.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
config_paths = _resolve_config_paths(args.config_dir, args.config)
|
||||||
|
if not config_paths:
|
||||||
|
raise SystemExit("No config files found. Provide --config or --config-dir.")
|
||||||
|
|
||||||
|
requested_pipelines = {p.strip() for p in args.pipeline_id if p.strip()}
|
||||||
|
|
||||||
|
payloads: List[MigrationPayload] = []
|
||||||
|
for config_path in config_paths:
|
||||||
|
payload = _collect_payload(config_path)
|
||||||
|
if requested_pipelines and payload.pipeline_id not in requested_pipelines:
|
||||||
|
continue
|
||||||
|
if not payload.prompt_dict:
|
||||||
|
print(f"[SKIP] no prompts found for config={config_path}")
|
||||||
|
continue
|
||||||
|
payloads.append(payload)
|
||||||
|
|
||||||
|
if not payloads:
|
||||||
|
raise SystemExit("No pipelines matched with prompt content to migrate.")
|
||||||
|
|
||||||
|
_run_migration(
|
||||||
|
payloads=payloads,
|
||||||
|
set_name=args.set_name,
|
||||||
|
description=args.description,
|
||||||
|
dry_run=args.dry_run,
|
||||||
|
activate=args.activate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
||||||
29
scripts/py_scripts/misc_tasks.py
Normal file
29
scripts/py_scripts/misc_tasks.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
from lang_agent.graphs import ReactGraphConfig, ReactGraph, RoutingConfig,RoutingGraph
|
||||||
|
from lang_agent.pipeline import PipelineConfig
|
||||||
|
from lang_agent.base import GraphBase
|
||||||
|
|
||||||
|
import os.path as osp
|
||||||
|
import os
|
||||||
|
from tqdm import tqdm
|
||||||
|
import yaml
|
||||||
|
import tyro
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
def gen_arch_imgs(save_dir="frontend/assets/images/graph_arch"):
|
||||||
|
|
||||||
|
save_dir = osp.join(osp.dirname(osp.dirname(__file__)), save_dir)
|
||||||
|
confs:GraphBase = [ReactGraphConfig(), RoutingConfig()]
|
||||||
|
for conf in tqdm(confs):
|
||||||
|
graph:GraphBase = conf.setup()
|
||||||
|
img = graph.show_graph(ret_img=True)
|
||||||
|
img.save(osp.join(save_dir, f"arch_{conf.__class__.__name__}.png"))
|
||||||
|
|
||||||
|
|
||||||
|
def make_save_conf(pipeline:PipelineConfig, save_path:str):
|
||||||
|
os.makedirs(osp.dirname(save_path), exist_ok=True)
|
||||||
|
logger.info(pipeline)
|
||||||
|
pipeline.save_config(save_path)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# gen_arch_imgs()
|
||||||
|
tyro.cli(make_save_conf)
|
||||||
17
scripts/shell_scripts/create_config.sh
Normal file
17
scripts/shell_scripts/create_config.sh
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
|
||||||
|
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
|
||||||
|
source ~/.bashrc
|
||||||
|
conda init
|
||||||
|
conda activate lang
|
||||||
|
|
||||||
|
echo create blueberry config
|
||||||
|
python scripts/py_scripts/misc_tasks.py --save-path configs/pipelines/blueberry.yaml \
|
||||||
|
react \
|
||||||
|
--sys-prompt-f configs/prompts/blueberry.txt \
|
||||||
|
--tool-manager-config.client-tool-manager.tool-keys
|
||||||
|
|
||||||
|
# echo create xiaozhan config
|
||||||
|
python scripts/py_scripts/misc_tasks.py --save-path configs/pipelines/xiaozhan.yaml
|
||||||
37
scripts/shell_scripts/download_images.sh
Executable file
37
scripts/shell_scripts/download_images.sh
Executable file
@@ -0,0 +1,37 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Script to download and package Docker images for offline use
|
||||||
|
# Run this on a machine with good Docker Hub access, then transfer images.tar to China
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
echo "=== Docker Image Downloader for Offline Use ==="
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# Images needed
|
||||||
|
IMAGES=(
|
||||||
|
"node:20-alpine"
|
||||||
|
"python:3.12-slim"
|
||||||
|
"postgres:16-alpine"
|
||||||
|
"nginx:alpine"
|
||||||
|
)
|
||||||
|
|
||||||
|
OUTPUT_FILE="images.tar"
|
||||||
|
|
||||||
|
echo "Pulling Docker images..."
|
||||||
|
for img in "${IMAGES[@]}"; do
|
||||||
|
echo " Pulling $img..."
|
||||||
|
docker pull "$img"
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Saving to $OUTPUT_FILE..."
|
||||||
|
docker save "${IMAGES[@]}" -o "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "Done! File size:"
|
||||||
|
ls -lh "$OUTPUT_FILE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "To transfer to China machine and load:"
|
||||||
|
echo " scp images.tar user@china-machine:/path/"
|
||||||
|
echo " docker load < images.tar"
|
||||||
167
scripts/shell_scripts/install.sh
Executable file
167
scripts/shell_scripts/install.sh
Executable file
@@ -0,0 +1,167 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
# Installation script for LangChain Agent
|
||||||
|
# This script sets up and runs the entire application stack
|
||||||
|
|
||||||
|
set -e
|
||||||
|
|
||||||
|
# Colors for output
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
YELLOW='\033[1;33m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
# Configuration
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
|
||||||
|
ENV_FILE="$PROJECT_ROOT/.env"
|
||||||
|
|
||||||
|
echo -e "${GREEN}=== LangChain Agent Installation Script ===${NC}\n"
|
||||||
|
|
||||||
|
# Check for required tools
|
||||||
|
check_requirements() {
|
||||||
|
echo -e "${YELLOW}Checking requirements...${NC}"
|
||||||
|
|
||||||
|
if ! command -v docker &> /dev/null; then
|
||||||
|
echo -e "${RED}Error: Docker is not installed. Please install Docker first.${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! command -v docker-compose &> /dev/null && ! docker compose version &> /dev/null; then
|
||||||
|
echo -e "${RED}Error: docker-compose is not installed. Please install docker-compose first.${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${GREEN}✓ All requirements met${NC}\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create .env file if it doesn't exist
|
||||||
|
create_env_file() {
|
||||||
|
if [ ! -f "$ENV_FILE" ]; then
|
||||||
|
echo -e "${YELLOW}Creating .env file...${NC}"
|
||||||
|
cat > "$ENV_FILE" <<EOF
|
||||||
|
# Database Configuration
|
||||||
|
POSTGRES_DB=ai_conversations
|
||||||
|
POSTGRES_USER=myapp_user
|
||||||
|
POSTGRES_PASSWORD=secure_password_123
|
||||||
|
POSTGRES_PORT=5432
|
||||||
|
|
||||||
|
# Backend Configuration
|
||||||
|
BACKEND_PORT=8500
|
||||||
|
|
||||||
|
# Frontend Configuration
|
||||||
|
FRONTEND_PORT=8080
|
||||||
|
|
||||||
|
# Database Connection String (used by backend)
|
||||||
|
CONN_STR=postgresql://myapp_user:secure_password_123@postgres:5432/ai_conversations
|
||||||
|
EOF
|
||||||
|
echo -e "${GREEN}✓ Created .env file at $ENV_FILE${NC}"
|
||||||
|
echo -e "${YELLOW} Please review and update the .env file with your preferred settings.${NC}\n"
|
||||||
|
else
|
||||||
|
echo -e "${GREEN}✓ .env file already exists${NC}\n"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Build Docker images
|
||||||
|
build_images() {
|
||||||
|
echo -e "${YELLOW}Building Docker images (including frontend)...${NC}"
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
|
||||||
|
# Check if docker-compose or docker compose
|
||||||
|
if docker compose version &> /dev/null; then
|
||||||
|
COMPOSE_CMD="docker compose"
|
||||||
|
else
|
||||||
|
COMPOSE_CMD="docker-compose"
|
||||||
|
fi
|
||||||
|
|
||||||
|
$COMPOSE_CMD -f docker/docker-compose.prod.yml build
|
||||||
|
|
||||||
|
echo -e "${GREEN}✓ Docker images built successfully${NC}\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize database
|
||||||
|
init_database() {
|
||||||
|
echo -e "${YELLOW}Initializing database...${NC}"
|
||||||
|
|
||||||
|
# Wait for PostgreSQL to be ready
|
||||||
|
echo "Waiting for PostgreSQL to start..."
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
# The SQL files in scripts/init_database/ will be automatically executed
|
||||||
|
# by PostgreSQL's docker-entrypoint-initdb.d mechanism
|
||||||
|
# We just need to wait a bit for it to complete
|
||||||
|
|
||||||
|
echo -e "${GREEN}✓ Database initialization will be handled automatically by PostgreSQL container${NC}\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Start services
|
||||||
|
start_services() {
|
||||||
|
echo -e "${YELLOW}Starting services...${NC}"
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
|
||||||
|
# Check if docker-compose or docker compose
|
||||||
|
if docker compose version &> /dev/null; then
|
||||||
|
COMPOSE_CMD="docker compose"
|
||||||
|
else
|
||||||
|
COMPOSE_CMD="docker-compose"
|
||||||
|
fi
|
||||||
|
|
||||||
|
$COMPOSE_CMD -f docker/docker-compose.prod.yml up -d
|
||||||
|
|
||||||
|
echo -e "${GREEN}✓ Services started${NC}\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Show status
|
||||||
|
show_status() {
|
||||||
|
# Load environment variables from .env if it exists
|
||||||
|
if [ -f "$ENV_FILE" ]; then
|
||||||
|
set -a
|
||||||
|
source "$ENV_FILE"
|
||||||
|
set +a
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo -e "${GREEN}=== Installation Complete ===${NC}\n"
|
||||||
|
echo -e "Services are starting up. Please wait a moment for them to be ready.\n"
|
||||||
|
echo -e "Access points:"
|
||||||
|
echo -e " - Frontend: http://localhost:${FRONTEND_PORT:-80}"
|
||||||
|
echo -e " - Backend API: http://localhost:${BACKEND_PORT:-8500}"
|
||||||
|
echo -e " - Database: localhost:${POSTGRES_PORT:-5432}\n"
|
||||||
|
echo -e "To view logs:"
|
||||||
|
echo -e " docker-compose -f docker/docker-compose.prod.yml logs -f\n"
|
||||||
|
echo -e "To stop services:"
|
||||||
|
echo -e " docker-compose -f docker/docker-compose.prod.yml down\n"
|
||||||
|
echo -e "To restart services:"
|
||||||
|
echo -e " docker-compose -f docker/docker-compose.prod.yml restart\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Main execution
|
||||||
|
main() {
|
||||||
|
check_requirements
|
||||||
|
create_env_file
|
||||||
|
build_images
|
||||||
|
start_services
|
||||||
|
init_database
|
||||||
|
show_status
|
||||||
|
|
||||||
|
echo -e "${YELLOW}Waiting for services to be healthy...${NC}"
|
||||||
|
sleep 10
|
||||||
|
|
||||||
|
# Load environment variables for health check
|
||||||
|
if [ -f "$ENV_FILE" ]; then
|
||||||
|
set -a
|
||||||
|
source "$ENV_FILE"
|
||||||
|
set +a
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check service health
|
||||||
|
echo -e "\n${YELLOW}Checking service health...${NC}"
|
||||||
|
sleep 5 # Give services a bit more time
|
||||||
|
if curl -f http://localhost:${BACKEND_PORT:-8500}/health &> /dev/null; then
|
||||||
|
echo -e "${GREEN}✓ Backend is healthy${NC}"
|
||||||
|
else
|
||||||
|
echo -e "${YELLOW}⚠ Backend is still starting up. Check logs with: docker-compose -f docker/docker-compose.prod.yml logs backend${NC}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run main function
|
||||||
|
main
|
||||||
|
|
||||||
46
tests/test_combined_app.py
Normal file
46
tests/test_combined_app.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
import importlib
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||||
|
|
||||||
|
|
||||||
|
def test_server_dashscope_import_is_cli_safe(monkeypatch):
|
||||||
|
"""
|
||||||
|
Importing server_dashscope should not invoke tyro.cli at module import time.
|
||||||
|
"""
|
||||||
|
import tyro
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
tyro,
|
||||||
|
"cli",
|
||||||
|
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||||
|
AssertionError("tyro.cli must not run during module import")
|
||||||
|
),
|
||||||
|
)
|
||||||
|
sys.modules.pop("fastapi_server.server_dashscope", None)
|
||||||
|
|
||||||
|
module = importlib.import_module("fastapi_server.server_dashscope")
|
||||||
|
assert module.app is not None
|
||||||
|
assert module.dashscope_router is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_combined_app_serves_front_and_dashscope_routes():
|
||||||
|
from fastapi_server.combined import app
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# front_apis route should be available.
|
||||||
|
front_resp = client.get("/v1/pipelines/graphs")
|
||||||
|
assert front_resp.status_code == 200, front_resp.text
|
||||||
|
assert "available_graphs" in front_resp.json()
|
||||||
|
|
||||||
|
# DashScope route should exist at the same path (missing auth should not be 404).
|
||||||
|
dash_resp = client.post(
|
||||||
|
"/api/v1/apps/blueberry/sessions/test-session/responses",
|
||||||
|
json={"input": {"prompt": "hello"}, "stream": False},
|
||||||
|
)
|
||||||
|
assert dash_resp.status_code != 404, dash_resp.text
|
||||||
|
|
||||||
@@ -8,7 +8,10 @@ Instructions:
|
|||||||
- Set BASE_URL below to the server base URL you started.
|
- Set BASE_URL below to the server base URL you started.
|
||||||
- Optionally set environment variables ALI_API_KEY and ALI_APP_ID.
|
- Optionally set environment variables ALI_API_KEY and ALI_APP_ID.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import json
|
||||||
|
import os.path as osp
|
||||||
import uuid
|
import uuid
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@@ -27,22 +30,86 @@ except Exception as e:
|
|||||||
|
|
||||||
|
|
||||||
# <<< Paste your running FastAPI base url here >>>
|
# <<< Paste your running FastAPI base url here >>>
|
||||||
BASE_URL = os.getenv("DS_BASE_URL", "http://127.0.0.1:8588/api/")
|
BASE_URL = os.getenv("DS_BASE_URL", "http://127.0.0.1:8500/api/")
|
||||||
|
|
||||||
|
|
||||||
# Params
|
# Params
|
||||||
API_KEY = os.getenv("FAST_AUTH_KEYS", "test-key")
|
def _first_non_empty_csv_token(value: str) -> str:
|
||||||
APP_ID = os.getenv("ALI_APP_ID", "test-app")
|
parts = [p.strip() for p in (value or "").split(",") if p.strip()]
|
||||||
|
return parts[0] if parts else ""
|
||||||
|
|
||||||
|
|
||||||
|
def _load_registry() -> dict:
|
||||||
|
project_root = osp.dirname(osp.dirname(osp.abspath(__file__)))
|
||||||
|
registry_path = os.getenv(
|
||||||
|
"FAST_PIPELINE_REGISTRY_FILE",
|
||||||
|
osp.join(project_root, "configs", "pipeline_registry.json"),
|
||||||
|
)
|
||||||
|
with open(registry_path, "r", encoding="utf-8") as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_api_key(registry: dict) -> str:
|
||||||
|
# For local server_dashscope testing, FAST_AUTH_KEYS is usually the server auth source.
|
||||||
|
fast_first = _first_non_empty_csv_token(os.getenv("FAST_AUTH_KEYS", ""))
|
||||||
|
ali_key = (os.getenv("ALI_API_KEY") or "").strip()
|
||||||
|
|
||||||
|
api_policies = registry.get("api_keys") or {}
|
||||||
|
if fast_first and (not api_policies or fast_first in api_policies):
|
||||||
|
return fast_first
|
||||||
|
if ali_key and (not api_policies or ali_key in api_policies):
|
||||||
|
return ali_key
|
||||||
|
if fast_first:
|
||||||
|
return fast_first
|
||||||
|
if ali_key:
|
||||||
|
return ali_key
|
||||||
|
raise RuntimeError(
|
||||||
|
"Missing API key. Set FAST_AUTH_KEYS or ALI_API_KEY in your environment."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_app_id(api_key: str, registry: dict) -> str:
|
||||||
|
if api_key:
|
||||||
|
explicit = (registry.get("api_keys") or {}).get(api_key, {}).get("app_id")
|
||||||
|
if explicit:
|
||||||
|
return explicit
|
||||||
|
|
||||||
|
pipelines_obj = registry.get("pipelines")
|
||||||
|
if not isinstance(pipelines_obj, dict):
|
||||||
|
pipelines_obj = {}
|
||||||
|
pipeline_ids = [r for r in pipelines_obj.keys() if isinstance(r, str) and r]
|
||||||
|
|
||||||
|
if pipeline_ids:
|
||||||
|
return pipeline_ids[0]
|
||||||
|
return "default"
|
||||||
|
|
||||||
|
|
||||||
|
def _warn_if_policy_disallows_app_id(api_key: str, app_id: str, registry: dict) -> None:
|
||||||
|
policy = (registry.get("api_keys") or {}).get(api_key, {})
|
||||||
|
if not isinstance(policy, dict):
|
||||||
|
return
|
||||||
|
allowed = policy.get("allowed_pipeline_ids")
|
||||||
|
if isinstance(allowed, list) and allowed and app_id not in allowed:
|
||||||
|
logger.bind(tag=TAG).warning(
|
||||||
|
f"app_id='{app_id}' is not in allowed_pipeline_ids for current API key; server may return 403."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
REGISTRY = _load_registry()
|
||||||
|
API_KEY = _pick_api_key(REGISTRY)
|
||||||
|
APP_ID = _pick_app_id(API_KEY, REGISTRY)
|
||||||
|
_warn_if_policy_disallows_app_id(API_KEY, APP_ID, REGISTRY)
|
||||||
SESSION_ID = str(uuid.uuid4())
|
SESSION_ID = str(uuid.uuid4())
|
||||||
|
|
||||||
dialogue = [
|
dialogue = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
{"role": "user", "content": "Say 'the world is awesome and beautiful'."},
|
{"role": "user", "content": "你叫什么名字"},
|
||||||
]
|
]
|
||||||
|
|
||||||
call_params = {
|
call_params = {
|
||||||
"api_key": API_KEY,
|
"api_key": API_KEY,
|
||||||
"app_id": "test_app",
|
"app_id": APP_ID,
|
||||||
"session_id": "123",
|
"session_id": SESSION_ID,
|
||||||
"messages": dialogue,
|
"messages": dialogue,
|
||||||
"stream": True,
|
"stream": True,
|
||||||
}
|
}
|
||||||
@@ -52,8 +119,11 @@ def main():
|
|||||||
# Point the SDK to our FastAPI implementation
|
# Point the SDK to our FastAPI implementation
|
||||||
if BASE_URL and ("/api/" in BASE_URL):
|
if BASE_URL and ("/api/" in BASE_URL):
|
||||||
dashscope.base_http_api_url = BASE_URL
|
dashscope.base_http_api_url = BASE_URL
|
||||||
|
# Some SDK paths rely on global api_key to build Authorization header.
|
||||||
|
dashscope.api_key = API_KEY
|
||||||
# dashscope.base_http_api_url = BASE_URL
|
# dashscope.base_http_api_url = BASE_URL
|
||||||
print(f"Using base_http_api_url = {dashscope.base_http_api_url}")
|
print(f"Using base_http_api_url = {dashscope.base_http_api_url}")
|
||||||
|
print(f"Using app_id = {APP_ID}")
|
||||||
|
|
||||||
print("\nCalling Application.call(stream=True)...\n")
|
print("\nCalling Application.call(stream=True)...\n")
|
||||||
responses = Application.call(**call_params)
|
responses = Application.call(**call_params)
|
||||||
@@ -72,7 +142,7 @@ def main():
|
|||||||
continue
|
continue
|
||||||
# SDK流式为增量覆盖,计算差量输出
|
# SDK流式为增量覆盖,计算差量输出
|
||||||
if len(current_text) >= len(last_text):
|
if len(current_text) >= len(last_text):
|
||||||
delta = current_text[len(last_text):]
|
delta = current_text[len(last_text) :]
|
||||||
else:
|
else:
|
||||||
# 避免偶发回退
|
# 避免偶发回退
|
||||||
delta = current_text
|
delta = current_text
|
||||||
@@ -80,7 +150,9 @@ def main():
|
|||||||
u = delta
|
u = delta
|
||||||
last_text = current_text
|
last_text = current_text
|
||||||
|
|
||||||
logger.info(f"from stream: {u}")
|
# For streaming responses, print incrementally to stdout and flush
|
||||||
|
# so the user can see tokens as they arrive.
|
||||||
|
print(u, end="", flush=True)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# 非流式回落(一次性返回)
|
# 非流式回落(一次性返回)
|
||||||
if responses.status_code != HTTPStatus.OK:
|
if responses.status_code != HTTPStatus.OK:
|
||||||
@@ -100,8 +172,5 @@ def main():
|
|||||||
u = "【阿里百练API服务响应异常】"
|
u = "【阿里百练API服务响应异常】"
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
351
tests/test_front_apis_registry.py
Normal file
351
tests/test_front_apis_registry.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
import json
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
|
os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||||
|
|
||||||
|
try:
|
||||||
|
front_apis = importlib.import_module("lang_agent.fastapi_server.front_apis")
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
front_apis = importlib.import_module("fastapi_server.front_apis")
|
||||||
|
|
||||||
|
|
||||||
|
def _fake_build_fn(
|
||||||
|
pipeline_id: str,
|
||||||
|
prompt_set: str,
|
||||||
|
tool_keys,
|
||||||
|
api_key: str,
|
||||||
|
llm_name: str = "qwen-plus",
|
||||||
|
pipeline_config_dir: str = "configs/pipelines",
|
||||||
|
):
|
||||||
|
out_dir = Path(pipeline_config_dir)
|
||||||
|
out_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
out_file = out_dir / f"{pipeline_id}.yaml"
|
||||||
|
out_file.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"pipeline_id": pipeline_id,
|
||||||
|
"prompt_set": prompt_set,
|
||||||
|
"tool_keys": tool_keys,
|
||||||
|
"api_key": api_key,
|
||||||
|
"llm_name": llm_name,
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
return {"path": str(out_file)}
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeCursor:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = rows
|
||||||
|
self._result = []
|
||||||
|
self._last_sql = ""
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def execute(self, sql, params=None):
|
||||||
|
self._last_sql = sql
|
||||||
|
query = " ".join(sql.split()).lower()
|
||||||
|
params = params or ()
|
||||||
|
|
||||||
|
if "group by conversation_id, pipeline_id" in query:
|
||||||
|
pipeline_id = params[0]
|
||||||
|
limit = int(params[1])
|
||||||
|
grouped = {}
|
||||||
|
for row in self._rows:
|
||||||
|
if row["pipeline_id"] != pipeline_id:
|
||||||
|
continue
|
||||||
|
conv_id = row["conversation_id"]
|
||||||
|
if conv_id not in grouped:
|
||||||
|
grouped[conv_id] = {
|
||||||
|
"conversation_id": conv_id,
|
||||||
|
"pipeline_id": row["pipeline_id"],
|
||||||
|
"message_count": 0,
|
||||||
|
"last_updated": row["created_at"],
|
||||||
|
}
|
||||||
|
grouped[conv_id]["message_count"] += 1
|
||||||
|
if row["created_at"] > grouped[conv_id]["last_updated"]:
|
||||||
|
grouped[conv_id]["last_updated"] = row["created_at"]
|
||||||
|
values = sorted(grouped.values(), key=lambda x: x["last_updated"], reverse=True)
|
||||||
|
self._result = values[:limit]
|
||||||
|
return
|
||||||
|
|
||||||
|
if "select 1 from messages" in query:
|
||||||
|
pipeline_id, conversation_id = params
|
||||||
|
found = any(
|
||||||
|
row["pipeline_id"] == pipeline_id
|
||||||
|
and row["conversation_id"] == conversation_id
|
||||||
|
for row in self._rows
|
||||||
|
)
|
||||||
|
self._result = [{"exists": 1}] if found else []
|
||||||
|
return
|
||||||
|
|
||||||
|
if "order by sequence_number asc" in query:
|
||||||
|
pipeline_id, conversation_id = params
|
||||||
|
self._result = sorted(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"message_type": row["message_type"],
|
||||||
|
"content": row["content"],
|
||||||
|
"sequence_number": row["sequence_number"],
|
||||||
|
"created_at": row["created_at"],
|
||||||
|
}
|
||||||
|
for row in self._rows
|
||||||
|
if row["pipeline_id"] == pipeline_id
|
||||||
|
and row["conversation_id"] == conversation_id
|
||||||
|
],
|
||||||
|
key=lambda x: x["sequence_number"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
raise AssertionError(f"Unsupported SQL in test fake: {self._last_sql}")
|
||||||
|
|
||||||
|
def fetchall(self):
|
||||||
|
return self._result
|
||||||
|
|
||||||
|
def fetchone(self):
|
||||||
|
if not self._result:
|
||||||
|
return None
|
||||||
|
return self._result[0]
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeConnection:
|
||||||
|
def __init__(self, rows):
|
||||||
|
self._rows = rows
|
||||||
|
|
||||||
|
def __enter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type, exc, tb):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def cursor(self, row_factory=None):
|
||||||
|
return _FakeCursor(self._rows)
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_route_lifecycle(monkeypatch, tmp_path):
|
||||||
|
registry_path = tmp_path / "pipeline_registry.json"
|
||||||
|
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
|
||||||
|
monkeypatch.setitem(front_apis.GRAPH_BUILD_FNCS, "routing", _fake_build_fn)
|
||||||
|
|
||||||
|
client = TestClient(front_apis.app)
|
||||||
|
|
||||||
|
create_resp = client.post(
|
||||||
|
"/v1/pipelines",
|
||||||
|
json={
|
||||||
|
"graph_id": "routing",
|
||||||
|
"pipeline_id": "xiaozhan",
|
||||||
|
"prompt_set_id": "default",
|
||||||
|
"tool_keys": ["weather"],
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert create_resp.status_code == 200, create_resp.text
|
||||||
|
create_data = create_resp.json()
|
||||||
|
assert create_data["pipeline_id"] == "xiaozhan"
|
||||||
|
assert create_data["graph_id"] == "routing"
|
||||||
|
assert create_data["llm_name"] == "qwen-plus"
|
||||||
|
assert create_data["reload_required"] is False
|
||||||
|
|
||||||
|
list_resp = client.get("/v1/pipelines")
|
||||||
|
assert list_resp.status_code == 200, list_resp.text
|
||||||
|
items = list_resp.json()["items"]
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["pipeline_id"] == "xiaozhan"
|
||||||
|
assert items[0]["graph_id"] == "routing"
|
||||||
|
assert items[0]["llm_name"] == "qwen-plus"
|
||||||
|
assert items[0]["enabled"] is True
|
||||||
|
|
||||||
|
disable_resp = client.delete("/v1/pipelines/xiaozhan")
|
||||||
|
assert disable_resp.status_code == 200, disable_resp.text
|
||||||
|
disable_data = disable_resp.json()
|
||||||
|
assert disable_data["pipeline_id"] == "xiaozhan"
|
||||||
|
assert disable_data["enabled"] is False
|
||||||
|
|
||||||
|
list_after = client.get("/v1/pipelines")
|
||||||
|
assert list_after.status_code == 200, list_after.text
|
||||||
|
items_after = list_after.json()["items"]
|
||||||
|
assert len(items_after) == 1
|
||||||
|
assert items_after[0]["enabled"] is False
|
||||||
|
|
||||||
|
registry = json.loads(registry_path.read_text(encoding="utf-8"))
|
||||||
|
pipeline = registry["pipelines"]["xiaozhan"]
|
||||||
|
assert pipeline["graph_id"] == "routing"
|
||||||
|
assert pipeline["enabled"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_registry_api_key_policy_lifecycle(monkeypatch, tmp_path):
|
||||||
|
registry_path = tmp_path / "pipeline_registry.json"
|
||||||
|
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
|
||||||
|
monkeypatch.setitem(front_apis.GRAPH_BUILD_FNCS, "routing", _fake_build_fn)
|
||||||
|
|
||||||
|
client = TestClient(front_apis.app)
|
||||||
|
|
||||||
|
create_resp = client.post(
|
||||||
|
"/v1/pipelines",
|
||||||
|
json={
|
||||||
|
"graph_id": "routing",
|
||||||
|
"pipeline_id": "blueberry",
|
||||||
|
"prompt_set_id": "default",
|
||||||
|
"tool_keys": [],
|
||||||
|
"api_key": "sk-test",
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
"enabled": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert create_resp.status_code == 200, create_resp.text
|
||||||
|
|
||||||
|
upsert_resp = client.put(
|
||||||
|
"/v1/pipelines/api-keys/sk-test-key",
|
||||||
|
json={
|
||||||
|
"default_pipeline_id": "blueberry",
|
||||||
|
"allowed_pipeline_ids": ["blueberry"],
|
||||||
|
"app_id": "blueberry",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert upsert_resp.status_code == 200, upsert_resp.text
|
||||||
|
upsert_data = upsert_resp.json()
|
||||||
|
assert upsert_data["api_key"] == "sk-test-key"
|
||||||
|
assert upsert_data["default_pipeline_id"] == "blueberry"
|
||||||
|
assert upsert_data["allowed_pipeline_ids"] == ["blueberry"]
|
||||||
|
assert upsert_data["app_id"] == "blueberry"
|
||||||
|
|
||||||
|
list_resp = client.get("/v1/pipelines/api-keys")
|
||||||
|
assert list_resp.status_code == 200, list_resp.text
|
||||||
|
list_data = list_resp.json()
|
||||||
|
assert list_data["count"] == 1
|
||||||
|
assert list_data["items"][0]["api_key"] == "sk-test-key"
|
||||||
|
|
||||||
|
delete_resp = client.delete("/v1/pipelines/api-keys/sk-test-key")
|
||||||
|
assert delete_resp.status_code == 200, delete_resp.text
|
||||||
|
delete_data = delete_resp.json()
|
||||||
|
assert delete_data["api_key"] == "sk-test-key"
|
||||||
|
assert delete_data["status"] == "deleted"
|
||||||
|
assert delete_data["reload_required"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_conversation_routes(monkeypatch):
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"conversation_id": "agent-a:conv-1",
|
||||||
|
"pipeline_id": "agent-a",
|
||||||
|
"message_type": "human",
|
||||||
|
"content": "hello",
|
||||||
|
"sequence_number": 1,
|
||||||
|
"created_at": now - timedelta(seconds=30),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_id": "agent-a:conv-1",
|
||||||
|
"pipeline_id": "agent-a",
|
||||||
|
"message_type": "ai",
|
||||||
|
"content": "hi there",
|
||||||
|
"sequence_number": 2,
|
||||||
|
"created_at": now - timedelta(seconds=20),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_id": "agent-a:conv-2",
|
||||||
|
"pipeline_id": "agent-a",
|
||||||
|
"message_type": "human",
|
||||||
|
"content": "second thread",
|
||||||
|
"sequence_number": 1,
|
||||||
|
"created_at": now - timedelta(seconds=10),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"conversation_id": "agent-b:conv-9",
|
||||||
|
"pipeline_id": "agent-b",
|
||||||
|
"message_type": "human",
|
||||||
|
"content": "other pipeline",
|
||||||
|
"sequence_number": 1,
|
||||||
|
"created_at": now - timedelta(seconds=5),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
front_apis.psycopg,
|
||||||
|
"connect",
|
||||||
|
lambda _conn_str: _FakeConnection(rows),
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(front_apis.app)
|
||||||
|
|
||||||
|
list_resp = client.get("/v1/pipelines/agent-a/conversations")
|
||||||
|
assert list_resp.status_code == 200, list_resp.text
|
||||||
|
list_data = list_resp.json()
|
||||||
|
assert list_data["pipeline_id"] == "agent-a"
|
||||||
|
assert list_data["count"] == 2
|
||||||
|
assert [item["conversation_id"] for item in list_data["items"]] == [
|
||||||
|
"agent-a:conv-2",
|
||||||
|
"agent-a:conv-1",
|
||||||
|
]
|
||||||
|
assert all(item["pipeline_id"] == "agent-a" for item in list_data["items"])
|
||||||
|
|
||||||
|
msg_resp = client.get("/v1/pipelines/agent-a/conversations/agent-a:conv-1/messages")
|
||||||
|
assert msg_resp.status_code == 200, msg_resp.text
|
||||||
|
msg_data = msg_resp.json()
|
||||||
|
assert msg_data["pipeline_id"] == "agent-a"
|
||||||
|
assert msg_data["conversation_id"] == "agent-a:conv-1"
|
||||||
|
assert msg_data["count"] == 2
|
||||||
|
assert [item["message_type"] for item in msg_data["items"]] == ["human", "ai"]
|
||||||
|
assert [item["sequence_number"] for item in msg_data["items"]] == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_conversation_messages_404(monkeypatch):
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
"conversation_id": "agent-b:conv-9",
|
||||||
|
"pipeline_id": "agent-b",
|
||||||
|
"message_type": "human",
|
||||||
|
"content": "other pipeline",
|
||||||
|
"sequence_number": 1,
|
||||||
|
"created_at": datetime.now(timezone.utc),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||||
|
monkeypatch.setattr(
|
||||||
|
front_apis.psycopg,
|
||||||
|
"connect",
|
||||||
|
lambda _conn_str: _FakeConnection(rows),
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(front_apis.app)
|
||||||
|
resp = client.get("/v1/pipelines/agent-a/conversations/agent-b:conv-9/messages")
|
||||||
|
assert resp.status_code == 404, resp.text
|
||||||
|
assert "not found for pipeline 'agent-a'" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_runtime_auth_info_prefers_registry_then_env(monkeypatch, tmp_path):
|
||||||
|
registry_path = tmp_path / "pipeline_registry.json"
|
||||||
|
registry_path.write_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"pipelines": {},
|
||||||
|
"api_keys": {
|
||||||
|
"sk-from-registry": {"default_pipeline_id": "blueberry"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
|
||||||
|
monkeypatch.setenv("FAST_AUTH_KEYS", "sk-from-env,other")
|
||||||
|
|
||||||
|
client = TestClient(front_apis.app)
|
||||||
|
resp = client.get("/v1/runtime-auth")
|
||||||
|
assert resp.status_code == 200, resp.text
|
||||||
|
data = resp.json()
|
||||||
|
assert data["fast_api_key"] == "sk-from-registry"
|
||||||
|
assert data["source"] == "pipeline_registry"
|
||||||
113
tests/test_migrate_yaml_prompts_to_db.py
Normal file
113
tests/test_migrate_yaml_prompts_to_db.py
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
import importlib.util
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
|
||||||
|
def _load_module():
|
||||||
|
project_root = Path(__file__).resolve().parents[1]
|
||||||
|
script_path = project_root / "scripts" / "py_scripts" / "migrate_yaml_prompts_to_db.py"
|
||||||
|
spec = importlib.util.spec_from_file_location("migrate_yaml_prompts_to_db", script_path)
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
assert spec.loader is not None
|
||||||
|
sys.modules[spec.name] = module
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
|
def test_infer_pipeline_id_falls_back_to_filename():
|
||||||
|
module = _load_module()
|
||||||
|
conf = SimpleNamespace(
|
||||||
|
pipeline_id=None,
|
||||||
|
graph_config=SimpleNamespace(pipeline_id=None),
|
||||||
|
)
|
||||||
|
out = module._infer_pipeline_id(conf, "/tmp/blueberry.yaml")
|
||||||
|
assert out == "blueberry"
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_prompt_dict_for_react_txt(tmp_path):
|
||||||
|
module = _load_module()
|
||||||
|
prompt_f = tmp_path / "sys.txt"
|
||||||
|
prompt_f.write_text("hello react", encoding="utf-8")
|
||||||
|
graph_conf = SimpleNamespace(sys_prompt_f=str(prompt_f))
|
||||||
|
prompt_dict = module._extract_prompt_dict(graph_conf)
|
||||||
|
assert prompt_dict == {"sys_prompt": "hello react"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_prompt_dict_for_routing_dir(tmp_path):
|
||||||
|
module = _load_module()
|
||||||
|
(tmp_path / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||||
|
(tmp_path / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||||
|
graph_conf = SimpleNamespace(sys_promp_dir=str(tmp_path))
|
||||||
|
prompt_dict = module._extract_prompt_dict(graph_conf)
|
||||||
|
assert prompt_dict["route_prompt"] == "route"
|
||||||
|
assert prompt_dict["chat_prompt"] == "chat"
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_payload_routing_ignores_chatty_prompt_for_tool_node(tmp_path):
|
||||||
|
module = _load_module()
|
||||||
|
prompt_dir = tmp_path / "prompts"
|
||||||
|
prompt_dir.mkdir()
|
||||||
|
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||||
|
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||||
|
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
|
||||||
|
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
|
||||||
|
|
||||||
|
class RoutingConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ToolNodeConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
graph_conf = RoutingConfig()
|
||||||
|
graph_conf.sys_promp_dir = str(prompt_dir)
|
||||||
|
graph_conf.tool_node_config = ToolNodeConfig()
|
||||||
|
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
|
||||||
|
|
||||||
|
conf = SimpleNamespace(
|
||||||
|
pipeline_id=None,
|
||||||
|
api_key="sk",
|
||||||
|
graph_config=graph_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
module.load_tyro_conf = lambda _: conf
|
||||||
|
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
|
||||||
|
assert payload.pipeline_id == "xiaozhan"
|
||||||
|
assert set(payload.prompt_dict.keys()) == {"route_prompt", "chat_prompt", "tool_prompt"}
|
||||||
|
assert "chatty_prompt" not in payload.prompt_dict
|
||||||
|
|
||||||
|
|
||||||
|
def test_collect_payload_routing_includes_chatty_prompt_for_chatty_node(tmp_path):
|
||||||
|
module = _load_module()
|
||||||
|
prompt_dir = tmp_path / "prompts"
|
||||||
|
prompt_dir.mkdir()
|
||||||
|
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||||
|
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||||
|
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
|
||||||
|
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
|
||||||
|
|
||||||
|
class RoutingConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ChattyToolNodeConfig:
|
||||||
|
pass
|
||||||
|
|
||||||
|
graph_conf = RoutingConfig()
|
||||||
|
graph_conf.sys_promp_dir = str(prompt_dir)
|
||||||
|
graph_conf.tool_node_config = ChattyToolNodeConfig()
|
||||||
|
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
|
||||||
|
graph_conf.tool_node_config.chatty_sys_prompt_f = str(
|
||||||
|
prompt_dir / "chatty_prompt.txt"
|
||||||
|
)
|
||||||
|
|
||||||
|
conf = SimpleNamespace(
|
||||||
|
pipeline_id="xiaozhan",
|
||||||
|
api_key="sk",
|
||||||
|
graph_config=graph_conf,
|
||||||
|
)
|
||||||
|
|
||||||
|
module.load_tyro_conf = lambda _: conf
|
||||||
|
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
|
||||||
|
assert payload.pipeline_id == "xiaozhan"
|
||||||
|
assert "chatty_prompt" in payload.prompt_dict
|
||||||
|
|
||||||
169
tests/test_server_dashscope_routes.py
Normal file
169
tests/test_server_dashscope_routes.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import time
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_service_api_key() -> str:
|
||||||
|
"""Return the first API key from FAST_AUTH_KEYS env (comma-separated)."""
|
||||||
|
raw = os.getenv("FAST_AUTH_KEYS", "")
|
||||||
|
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
||||||
|
return parts[0] if parts else None
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_health(base_url: str, timeout: float = 20.0) -> None:
|
||||||
|
"""Poll the /health endpoint until the server is up or timeout."""
|
||||||
|
deadline = time.time() + timeout
|
||||||
|
url = base_url.rstrip("/") + "/health"
|
||||||
|
last_err = None
|
||||||
|
while time.time() < deadline:
|
||||||
|
try:
|
||||||
|
resp = requests.get(url, timeout=2)
|
||||||
|
if resp.status_code == HTTPStatus.OK:
|
||||||
|
return
|
||||||
|
except Exception as e: # pragma: no cover - best-effort polling
|
||||||
|
last_err = e
|
||||||
|
time.sleep(0.5)
|
||||||
|
raise RuntimeError(f"Server did not become healthy in time: last_err={last_err}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="module")
|
||||||
|
def dashscope_server():
|
||||||
|
"""
|
||||||
|
Start a real uvicorn instance of server_dashscope for end-to-end routing tests.
|
||||||
|
|
||||||
|
This mirrors how docker-compose runs `xiaozhan` (server_dashscope.py) so we
|
||||||
|
exercise the full stack, including PipelineManager + registry routing.
|
||||||
|
"""
|
||||||
|
env = os.environ.copy()
|
||||||
|
# Ensure registry file is picked up (falls back to this by default, but be explicit).
|
||||||
|
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
registry_path = os.path.join(project_root, "configs", "pipeline_registry.json")
|
||||||
|
env.setdefault("FAST_PIPELINE_REGISTRY_FILE", registry_path)
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
"python",
|
||||||
|
"-m",
|
||||||
|
"uvicorn",
|
||||||
|
"fastapi_server.server_dashscope:app",
|
||||||
|
"--host",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
"8588",
|
||||||
|
]
|
||||||
|
proc = subprocess.Popen(cmd, env=env)
|
||||||
|
|
||||||
|
base_url = "http://127.0.0.1:8588"
|
||||||
|
try:
|
||||||
|
_wait_for_health(base_url)
|
||||||
|
except Exception:
|
||||||
|
proc.terminate()
|
||||||
|
proc.wait(timeout=10)
|
||||||
|
raise
|
||||||
|
|
||||||
|
yield base_url
|
||||||
|
|
||||||
|
# Teardown
|
||||||
|
if proc.poll() is None:
|
||||||
|
proc.send_signal(signal.SIGTERM)
|
||||||
|
try:
|
||||||
|
proc.wait(timeout=10)
|
||||||
|
except subprocess.TimeoutExpired: # pragma: no cover - best-effort cleanup
|
||||||
|
proc.kill()
|
||||||
|
proc.wait(timeout=10)
|
||||||
|
|
||||||
|
|
||||||
|
def _post_app_response(base_url: str, pipeline_id: str, body: dict, api_key: str):
|
||||||
|
url = f"{base_url}/api/v1/apps/{pipeline_id}/sessions/test-session/responses"
|
||||||
|
headers = {"Authorization": f"Bearer {api_key}"}
|
||||||
|
resp = requests.post(url, json=body, headers=headers, timeout=20)
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_selected_via_pipeline_id_body(dashscope_server):
|
||||||
|
"""
|
||||||
|
When client specifies `pipeline_id` in the body, it should be used as the selector
|
||||||
|
and surfaced back in the JSON response.
|
||||||
|
"""
|
||||||
|
base_url = dashscope_server
|
||||||
|
api_key = _get_service_api_key()
|
||||||
|
if not api_key:
|
||||||
|
pytest.skip(
|
||||||
|
"FAST_AUTH_KEYS is not set; cannot authenticate against server_dashscope"
|
||||||
|
)
|
||||||
|
body = {
|
||||||
|
"input": {
|
||||||
|
"prompt": "你是谁?",
|
||||||
|
"session_id": "sess-1",
|
||||||
|
},
|
||||||
|
"pipeline_id": "blueberry",
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _post_app_response(
|
||||||
|
base_url, pipeline_id="blueberry", body=body, api_key=api_key
|
||||||
|
)
|
||||||
|
assert resp.status_code == HTTPStatus.OK, resp.text
|
||||||
|
data = resp.json()
|
||||||
|
assert data.get("pipeline_id") == "blueberry"
|
||||||
|
assert "text" in data.get("output", {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_selected_via_pipeline_id_body_blueberry(dashscope_server):
|
||||||
|
"""
|
||||||
|
When client specifies `pipeline_id` in the body, it should be used as the selector
|
||||||
|
and surfaced back in the JSON response.
|
||||||
|
"""
|
||||||
|
base_url = dashscope_server
|
||||||
|
api_key = _get_service_api_key()
|
||||||
|
if not api_key:
|
||||||
|
pytest.skip(
|
||||||
|
"FAST_AUTH_KEYS is not set; cannot authenticate against server_dashscope"
|
||||||
|
)
|
||||||
|
body = {
|
||||||
|
"input": {
|
||||||
|
"prompt": "hello from blueberry",
|
||||||
|
"session_id": "sess-2",
|
||||||
|
},
|
||||||
|
"pipeline_id": "blueberry",
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
resp = _post_app_response(
|
||||||
|
base_url, pipeline_id="blueberry", body=body, api_key=api_key
|
||||||
|
)
|
||||||
|
assert resp.status_code == HTTPStatus.OK, resp.text
|
||||||
|
data = resp.json()
|
||||||
|
assert data.get("pipeline_id") == "blueberry"
|
||||||
|
assert "text" in data.get("output", {})
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_forbidden_for_api_key_when_not_allowed(dashscope_server):
|
||||||
|
"""
|
||||||
|
API key policy in pipeline_registry should prevent a key from using pipelines
|
||||||
|
it is not explicitly allowed to access.
|
||||||
|
"""
|
||||||
|
base_url = dashscope_server
|
||||||
|
body = {
|
||||||
|
"input": {
|
||||||
|
"prompt": "this should be forbidden",
|
||||||
|
"session_id": "sess-3",
|
||||||
|
},
|
||||||
|
"pipeline_id": "blueberry",
|
||||||
|
"stream": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Use a guaranteed-wrong API key so we test 401 behavior regardless of registry config.
|
||||||
|
resp = _post_app_response(
|
||||||
|
base_url, pipeline_id="blueberry", body=body, api_key="invalid-key-for-test"
|
||||||
|
)
|
||||||
|
assert resp.status_code == HTTPStatus.UNAUTHORIZED
|
||||||
|
data = resp.json()
|
||||||
|
assert data.get("detail") == "Invalid API key"
|
||||||
156
tests/test_server_pipeline_manager_refresh.py
Normal file
156
tests/test_server_pipeline_manager_refresh.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import json
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyPipeline:
|
||||||
|
def __init__(self, model: str):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyConfig:
|
||||||
|
def __init__(self, llm_name: str = "qwen-plus"):
|
||||||
|
self.llm_name = llm_name
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
return _DummyPipeline(model=self.llm_name)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_registry(path, pipelines, api_keys=None):
|
||||||
|
content = {"pipelines": pipelines, "api_keys": api_keys or {}}
|
||||||
|
path.write_text(json.dumps(content, indent=2), encoding="utf-8")
|
||||||
|
# Ensure mtime changes reliably on fast CI filesystems.
|
||||||
|
time.sleep(0.01)
|
||||||
|
|
||||||
|
|
||||||
|
def test_refresh_registry_picks_up_new_pipeline(tmp_path):
|
||||||
|
registry_path = tmp_path / "pipeline_registry.json"
|
||||||
|
_write_registry(
|
||||||
|
registry_path,
|
||||||
|
pipelines={
|
||||||
|
"default": {
|
||||||
|
"enabled": True,
|
||||||
|
"config_file": None,
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
manager = ServerPipelineManager(
|
||||||
|
default_pipeline_id="default",
|
||||||
|
default_config=_DummyConfig(),
|
||||||
|
)
|
||||||
|
manager.load_registry(str(registry_path))
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
manager.resolve_pipeline_id(
|
||||||
|
body={"pipeline_id": "blueberry"}, app_id=None, api_key="k1"
|
||||||
|
)
|
||||||
|
assert exc_info.value.status_code == 404
|
||||||
|
|
||||||
|
_write_registry(
|
||||||
|
registry_path,
|
||||||
|
pipelines={
|
||||||
|
"default": {
|
||||||
|
"enabled": True,
|
||||||
|
"config_file": None,
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
},
|
||||||
|
"blueberry": {
|
||||||
|
"enabled": True,
|
||||||
|
"config_file": None,
|
||||||
|
"llm_name": "qwen-max",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
changed = manager.refresh_registry_if_needed()
|
||||||
|
assert changed is True
|
||||||
|
|
||||||
|
resolved = manager.resolve_pipeline_id(
|
||||||
|
body={"pipeline_id": "blueberry"}, app_id=None, api_key="k1"
|
||||||
|
)
|
||||||
|
assert resolved == "blueberry"
|
||||||
|
|
||||||
|
|
||||||
|
def test_refresh_registry_invalidates_cache_for_changed_pipeline(tmp_path):
|
||||||
|
registry_path = tmp_path / "pipeline_registry.json"
|
||||||
|
_write_registry(
|
||||||
|
registry_path,
|
||||||
|
pipelines={
|
||||||
|
"blueberry": {
|
||||||
|
"enabled": True,
|
||||||
|
"config_file": None,
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
manager = ServerPipelineManager(
|
||||||
|
default_pipeline_id="blueberry",
|
||||||
|
default_config=_DummyConfig(),
|
||||||
|
)
|
||||||
|
manager.load_registry(str(registry_path))
|
||||||
|
|
||||||
|
first_pipeline, first_model = manager.get_pipeline("blueberry")
|
||||||
|
assert first_model == "qwen-plus"
|
||||||
|
|
||||||
|
_write_registry(
|
||||||
|
registry_path,
|
||||||
|
pipelines={
|
||||||
|
"blueberry": {
|
||||||
|
"enabled": True,
|
||||||
|
"config_file": None,
|
||||||
|
"llm_name": "qwen-max",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
changed = manager.refresh_registry_if_needed()
|
||||||
|
assert changed is True
|
||||||
|
|
||||||
|
second_pipeline, second_model = manager.get_pipeline("blueberry")
|
||||||
|
assert second_model == "qwen-max"
|
||||||
|
assert second_pipeline is not first_pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def test_refresh_registry_applies_disabled_state_immediately(tmp_path):
|
||||||
|
registry_path = tmp_path / "pipeline_registry.json"
|
||||||
|
_write_registry(
|
||||||
|
registry_path,
|
||||||
|
pipelines={
|
||||||
|
"blueberry": {
|
||||||
|
"enabled": True,
|
||||||
|
"config_file": None,
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
manager = ServerPipelineManager(
|
||||||
|
default_pipeline_id="blueberry",
|
||||||
|
default_config=_DummyConfig(),
|
||||||
|
)
|
||||||
|
manager.load_registry(str(registry_path))
|
||||||
|
manager.get_pipeline("blueberry")
|
||||||
|
|
||||||
|
_write_registry(
|
||||||
|
registry_path,
|
||||||
|
pipelines={
|
||||||
|
"blueberry": {
|
||||||
|
"enabled": False,
|
||||||
|
"config_file": None,
|
||||||
|
"llm_name": "qwen-plus",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
changed = manager.refresh_registry_if_needed()
|
||||||
|
assert changed is True
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
manager.get_pipeline("blueberry")
|
||||||
|
assert exc_info.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user