Compare commits
127 Commits
728d5934d7
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 737a80aa39 | |||
| bdd4ddec9e | |||
| 2ee55d25cc | |||
| bf9ce709e2 | |||
| 59331d2435 | |||
| a9488a655a | |||
| eeadd4c825 | |||
| 42d8b8e8e1 | |||
| da6cc1e18b | |||
| 691ef1c77d | |||
| 9f9813f01d | |||
| f9f3651c7e | |||
| 87407e1656 | |||
| 2c226e2df2 | |||
| 8b2a506177 | |||
| a26cda2f04 | |||
| 5e2a86e3be | |||
| 2c7d5ea589 | |||
| 36e3b40936 | |||
| 86f6e4f81c | |||
| 60f3029e54 | |||
| fe7ff9a516 | |||
| 33faedc1b1 | |||
| c9b1c5cb32 | |||
| 8f99d47af9 | |||
| 1d36f196ca | |||
| 9f40ef93b7 | |||
| fda3e86a71 | |||
| d5303ad201 | |||
| 245be63f07 | |||
| 068df7715e | |||
| 95498180a2 | |||
| 7b6ba79417 | |||
| a228325d74 | |||
| e59ae66b85 | |||
| e524be15e5 | |||
| 2d90a4caac | |||
| 4e8c2b4f5d | |||
| 0ddeb91755 | |||
| cb3b98db10 | |||
| 487562042f | |||
| ac46518cf5 | |||
| c03fb17a5c | |||
| e049e4e104 | |||
| 8d0cd7861b | |||
| 7920a1bab3 | |||
| 30e93ed0a7 | |||
| 0087ac41e6 | |||
| 5358b46113 | |||
| 3306da038d | |||
| ac99dfd56b | |||
| 3932d695bf | |||
| 9ed141ba42 | |||
| 9b3db40b94 | |||
| da17f2b319 | |||
| 4b6e97d8fb | |||
| 112af37151 | |||
| 3cd46030ad | |||
| dd842fca42 | |||
| fc9f0f929d | |||
| 28d99f4b8d | |||
| dac067b6fe | |||
| e90f0afabe | |||
| 0676a68c9e | |||
| f185b70d3f | |||
| 07149e426e | |||
| 3fc3d7288c | |||
| eb7e85e4e6 | |||
| ddfda10700 | |||
| f8364bea68 | |||
| 01b0975abd | |||
| 7e23d5c056 | |||
| 3b730798f8 | |||
| 2781172724 | |||
| 26fba706f2 | |||
| ae93ef37b6 | |||
| c1b782c6b4 | |||
| ab3285a4cf | |||
| 0484343021 | |||
| b87fded473 | |||
| 8db22abf3b | |||
| f6d86f24bb | |||
| c1afebd7ba | |||
| 080631af31 | |||
| 38b0d5df15 | |||
| f7937c3744 | |||
| 867acaf717 | |||
| a2890148f9 | |||
| 55b37cc611 | |||
| c85598418d | |||
| ea605e19aa | |||
| 866edc319f | |||
| 8c6dd3344f | |||
| f58d80ac1b | |||
| 9e3fecd4bd | |||
| 9b128ae41b | |||
| 61931cad58 | |||
| b070d08130 | |||
| fe87bb55a1 | |||
| d455933e9c | |||
| f6547ca6fe | |||
| 9c0744e1bc | |||
| 2f40f1c526 | |||
| cf1cae51f7 | |||
| 91685d5bf7 | |||
| 6ed33f3185 | |||
| 9284e8f928 | |||
| bb6d98c9f4 | |||
| 9d1eeaeec5 | |||
| 6c26f92ff1 | |||
| e3703d962d | |||
| 9470b7e342 | |||
| 94b24682e2 | |||
| 2ce0075dcd | |||
| 431aea410b | |||
| 46978895d1 | |||
| 501f9954ce | |||
| 021b4d6ffb | |||
| 16916e183b | |||
| 6097e9c24a | |||
| 11a7809fc0 | |||
| e85d0e7771 | |||
| 26fcdacbe8 | |||
| 6b0976f263 | |||
| 5f7cd64c11 | |||
| 4a713190b7 | |||
| bda75d2558 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -11,4 +11,5 @@ django.log
|
||||
.env
|
||||
|
||||
frontend/node_modules/
|
||||
frontend/dist/
|
||||
frontend/dist/
|
||||
frontend/.vite
|
||||
30
README.md
30
README.md
@@ -2,6 +2,10 @@
|
||||
|
||||
这是一个基于FastAPI的聊天API服务,使用OpenAI格式的请求来调用pipeline.invoke方法进行聊天。
|
||||
|
||||
## Docker Installation
|
||||
|
||||
For production deployment using Docker, see the [Installation Guide](README_INSTALL.md).
|
||||
|
||||
## 安装依赖
|
||||
|
||||
```bash
|
||||
@@ -140,15 +144,35 @@ npm install
|
||||
|
||||
### Start the `front_apis` server
|
||||
|
||||
The frontend talks to the `front_apis` FastAPI service, which by default listens on `http://127.0.0.1:8001`.
|
||||
The frontend talks to the `front_apis` FastAPI service, which by default listens on `http://127.0.0.1:8500`.
|
||||
|
||||
From the project root:
|
||||
|
||||
```bash
|
||||
uvicorn fastapi_server.front_apis:app --reload --host 0.0.0.0 --port 8001
|
||||
uvicorn fastapi_server.front_apis:app --reload --host 0.0.0.0 --port 8500
|
||||
```
|
||||
|
||||
You can change the URL by setting `VITE_FRONT_API_BASE_URL` in `frontend/.env` (defaults to `http://127.0.0.1:8001`).
|
||||
Or run directly:
|
||||
```bash
|
||||
python fastapi_server/front_apis.py
|
||||
```
|
||||
|
||||
### Backend run modes
|
||||
|
||||
Run whichever backend mode you need from the project root:
|
||||
|
||||
```bash
|
||||
# admin/control plane only (/v1/... frontend APIs)
|
||||
uvicorn fastapi_server.front_apis:app --reload --host 0.0.0.0 --port 8500
|
||||
|
||||
# DashScope chat runtime only (/apps/... and /v1/apps/... APIs)
|
||||
uvicorn fastapi_server.server_dashscope:app --reload --host 0.0.0.0 --port 8588
|
||||
|
||||
# combined mode: one process serves both front_apis + DashScope endpoints
|
||||
uvicorn fastapi_server.combined:app --reload --host 0.0.0.0 --port 8500
|
||||
```
|
||||
|
||||
You can change the URL by setting `VITE_FRONT_API_BASE_URL` in `frontend/.env` (defaults to `/`, i.e. same-origin).
|
||||
|
||||
### Start the development server
|
||||
|
||||
|
||||
267
README_INSTALL.md
Normal file
267
README_INSTALL.md
Normal file
@@ -0,0 +1,267 @@
|
||||
# Installation Guide
|
||||
|
||||
This guide explains how to install and run the LangChain Agent application using Docker.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Docker (version 20.10 or later)
|
||||
- Docker Compose (version 2.0 or later, or use `docker compose` command)
|
||||
|
||||
## Quick Start
|
||||
|
||||
1. **Run the installation script:**
|
||||
```bash
|
||||
./scripts/shell_scripts/install.sh
|
||||
```
|
||||
|
||||
This script will:
|
||||
- Check for required tools (Docker, docker-compose)
|
||||
- Create a `.env` file with default configuration
|
||||
- Build Docker images (or use pre-loaded images)
|
||||
- Start all services (PostgreSQL, Backend API, Nginx)
|
||||
|
||||
2. **Access the application:**
|
||||
- Frontend: http://localhost (or http://localhost:80)
|
||||
- Backend API: http://localhost:8500
|
||||
- Database: localhost:5432
|
||||
|
||||
## Installation for China / Offline Use
|
||||
|
||||
If Docker Hub is slow or inaccessible in your region:
|
||||
|
||||
### Option 1: Use Chinese Docker Mirrors
|
||||
|
||||
Configure Docker to use Chinese registry mirrors:
|
||||
|
||||
```bash
|
||||
sudo tee /etc/docker/daemon.json <<EOF
|
||||
{
|
||||
"registry-mirrors": [
|
||||
"https://registry.docker-cn.com",
|
||||
"https://mirror.ccsogou.com",
|
||||
"https://docker.1ms.run"
|
||||
]
|
||||
}
|
||||
EOF
|
||||
sudo systemctl daemon-reload
|
||||
sudo systemctl restart docker
|
||||
```
|
||||
|
||||
Then run `./scripts/shell_scripts/install.sh`
|
||||
|
||||
### configuring '.env' and 'frontend/.env'
|
||||
```bash
|
||||
ALI_API_KEY="API_KEY_FOR_ALI_QWEN"
|
||||
ALI_BASE_URL="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
POSTGRES_ROOT_PASSWORD="ROOT_PASSOWRD_FOR_DB_IN_CONN_STR - required for installation"
|
||||
POSTGRES_PASSWORD="USER_PASSWORD_FOR_DB_CONN_STR - required for installation" - need to be same in POSTGRES_PASSWORD
|
||||
CONN_STR="CONNECTION_STRING_TO_DATABASE" # DOCKER PASSWORD
|
||||
FAST_AUTH_KEYS="API_KEY_FOR_OTHER_APPLICATIONS_TO_USE_BUILT_PIPELINE"
|
||||
DAYTONA_API_KEY="DAYTONA_CONFIG - NOT REQUIRED"
|
||||
```
|
||||
|
||||
### Option 2: Pre-load Docker Images Offline
|
||||
|
||||
1. On a machine with good Docker Hub access, run:
|
||||
```bash
|
||||
./scripts/shell_scripts/download_images.sh
|
||||
```
|
||||
This creates `images.tar` with all required images.
|
||||
|
||||
2. Transfer `images.tar` to your target machine.
|
||||
|
||||
3. Load the images:
|
||||
```bash
|
||||
docker load < images.tar
|
||||
```
|
||||
|
||||
4. Run the install script:
|
||||
```bash
|
||||
./scripts/shell_scripts/install.sh
|
||||
```
|
||||
|
||||
## Manual Installation
|
||||
|
||||
If you prefer to set up manually:
|
||||
|
||||
1. **Create environment file:**
|
||||
```bash
|
||||
cp .env.example .env # Edit as needed
|
||||
```
|
||||
|
||||
2. **Build and start services:**
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml up -d --build
|
||||
```
|
||||
|
||||
3. **Check service status:**
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml ps
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
Edit the `.env` file to customize:
|
||||
|
||||
- `POSTGRES_DB`: Database name (default: `ai_conversations`)
|
||||
- `POSTGRES_USER`: Database user (default: `myapp_user`)
|
||||
- `POSTGRES_PASSWORD`: Database password (default: `secure_password_123`)
|
||||
- `POSTGRES_PORT`: PostgreSQL port (default: `5432`)
|
||||
- `BACKEND_PORT`: Backend API port (default: `8500`)
|
||||
- `FRONTEND_PORT`: Frontend web server port (default: `80`)
|
||||
|
||||
## Database Initialization
|
||||
|
||||
The database is automatically initialized when the PostgreSQL container starts for the first time. The following SQL scripts are executed in order:
|
||||
|
||||
1. `scripts/init_database/00_init_user.sh` - Creates database user and database
|
||||
2. `scripts/init_database/create_conv_store.sql` - Creates conversation storage tables
|
||||
3. `scripts/init_database/create_prompt_config.sql` - Creates prompt configuration tables
|
||||
|
||||
## Service Management
|
||||
|
||||
All commands run from the `docker/` directory:
|
||||
|
||||
### View logs:
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml logs -f
|
||||
|
||||
# Specific service
|
||||
docker compose -f docker-compose.prod.yml logs -f backend
|
||||
docker compose -f docker-compose.prod.yml logs -f postgres
|
||||
docker compose -f docker-compose.prod.yml logs -f nginx
|
||||
```
|
||||
|
||||
### Stop services:
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml down
|
||||
```
|
||||
|
||||
### Restart services:
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml restart
|
||||
```
|
||||
|
||||
### Rebuild after code changes:
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml up -d --build
|
||||
```
|
||||
|
||||
### Reset database (delete all data):
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml down -v
|
||||
docker compose -f docker-compose.prod.yml up -d
|
||||
```
|
||||
|
||||
## Architecture
|
||||
|
||||
The application consists of three main services:
|
||||
|
||||
1. **PostgreSQL** (`postgres`): Database server
|
||||
- Stores conversations and prompt configurations
|
||||
- Automatically initializes schema on first run
|
||||
|
||||
2. **Backend** (`backend`): FastAPI application
|
||||
- Serves API endpoints at port 8500
|
||||
- Handles agent management and chat endpoints
|
||||
- Connects to PostgreSQL database
|
||||
|
||||
3. **Nginx** (`nginx`): Web server
|
||||
- Serves the React frontend (port 80)
|
||||
- Proxies API requests to the backend
|
||||
- Handles static file serving
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
langchain-agent/
|
||||
├── docker/
|
||||
│ ├── docker-compose.prod.yml # Production compose file
|
||||
│ └── Dockerfile.prod # Backend Docker image
|
||||
├── scripts/
|
||||
│ ├── shell_scripts/
|
||||
│ │ ├── install.sh # Main installation script
|
||||
│ │ └── download_images.sh # For offline image download
|
||||
│ └── init_database/ # Database initialization scripts
|
||||
├── frontend/ # React frontend
|
||||
├── configs/ # Pipeline configurations
|
||||
├── nginx.conf # Nginx configuration
|
||||
└── .env # Environment variables
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Database connection issues
|
||||
|
||||
If the backend can't connect to the database:
|
||||
|
||||
1. Check that PostgreSQL is running:
|
||||
```bash
|
||||
docker compose -f docker-compose.prod.yml ps postgres
|
||||
```
|
||||
|
||||
2. Verify the connection string in `.env` matches the database configuration
|
||||
|
||||
3. Check backend logs:
|
||||
```bash
|
||||
docker compose -f docker-compose.prod.yml logs backend
|
||||
```
|
||||
|
||||
### Frontend not loading / NetworkError
|
||||
|
||||
1. Check nginx logs:
|
||||
```bash
|
||||
docker compose -f docker-compose.prod.yml logs nginx
|
||||
```
|
||||
|
||||
2. Ensure frontend is built with correct API base URL. The `frontend/.env` file should contain:
|
||||
```
|
||||
VITE_FRONT_API_BASE_URL=/
|
||||
```
|
||||
Then rebuild: `docker compose -f docker-compose.prod.yml build backend`
|
||||
|
||||
### Port conflicts
|
||||
|
||||
If ports are already in use, update the port mappings in `.env`:
|
||||
|
||||
```bash
|
||||
# Example: use port 5433 for PostgreSQL
|
||||
POSTGRES_PORT=5433
|
||||
```
|
||||
|
||||
## Development
|
||||
|
||||
For development, you may want to run services separately:
|
||||
|
||||
1. Start only PostgreSQL:
|
||||
```bash
|
||||
cd docker
|
||||
docker compose -f docker-compose.prod.yml up -d postgres
|
||||
```
|
||||
|
||||
2. Run backend locally:
|
||||
```bash
|
||||
export CONN_STR="postgresql://myapp_user:secure_password_123@localhost:5432/ai_conversations"
|
||||
python -m uvicorn lang_agent.fastapi_server.combined:app --reload --host 0.0.0.0 --port 8500
|
||||
```
|
||||
|
||||
3. Run frontend locally:
|
||||
```bash
|
||||
cd frontend
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
Note: For local frontend development, create a `.env` file in `frontend/` with:
|
||||
```
|
||||
VITE_FRONT_API_BASE_URL=http://localhost:8500
|
||||
```
|
||||
|
||||
|
||||
31
configs/pipeline_registry.json
Normal file
31
configs/pipeline_registry.json
Normal file
@@ -0,0 +1,31 @@
|
||||
{
|
||||
"pipelines": {
|
||||
"xiaozhan": {
|
||||
"enabled": true,
|
||||
"config_file": "configs/pipelines/xiaozhan.yaml",
|
||||
"graph_id": "routing",
|
||||
"overrides": {
|
||||
"llm_name": "qwen-plus"
|
||||
}
|
||||
},
|
||||
"blueberry": {
|
||||
"enabled": true,
|
||||
"config_file": "configs/pipelines/blueberry.yaml",
|
||||
"graph_id": "react",
|
||||
"overrides": {
|
||||
"llm_name": "qwen-plus"
|
||||
}
|
||||
}
|
||||
},
|
||||
"api_keys": {
|
||||
"sk-6c7091e6a95f404efb2ec30e8f51b897626d670375cdf822d78262f24ab12367": {
|
||||
"example-key-1": {
|
||||
"default_route_id": "default",
|
||||
"allowed_route_ids": [
|
||||
"xiaozhan",
|
||||
"blueberry"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,7 @@
|
||||
you are a helpful bot enhanced with skills.
|
||||
You are a helpful bot enhanced with skills.
|
||||
|
||||
To use a skill, read its SKILL.md file using the read_file tool. Skills are NOT tools — they are instructions for using existing tools.
|
||||
Skills with available="false" need dependencies installed first - you can try installing them with apt/brew. You can check if the environment the packages you need.
|
||||
|
||||
When using a skill, assume required tools (e.g., npx, curl) are available and execute the commands directly. If a command fails because a tool is missing, install the missing dependency using apt/brew and retry.
|
||||
|
||||
For shell commands (e.g., npx, curl), use the execute tool to run them.
|
||||
19
docker/Dockerfile.frontend
Normal file
19
docker/Dockerfile.frontend
Normal file
@@ -0,0 +1,19 @@
|
||||
FROM node:20-alpine
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN npm config set registry https://registry.npmmirror.com
|
||||
|
||||
# Build-time API base for Vite (must be set before npm run build).
|
||||
ARG VITE_FRONT_API_BASE_URL=/
|
||||
ENV VITE_FRONT_API_BASE_URL=${VITE_FRONT_API_BASE_URL}
|
||||
|
||||
COPY package*.json ./
|
||||
RUN npm install
|
||||
|
||||
COPY . .
|
||||
RUN npm run build && \
|
||||
mkdir -p /opt/frontend_dist && \
|
||||
cp -r dist/. /opt/frontend_dist/
|
||||
|
||||
CMD ["sh", "-c", "rm -rf /app/dist/* && cp -r /opt/frontend_dist/. /app/dist && ls /app/dist"]
|
||||
93
docker/Dockerfile.prod
Normal file
93
docker/Dockerfile.prod
Normal file
@@ -0,0 +1,93 @@
|
||||
# Multi-stage Dockerfile for production deployment
|
||||
# Stage 1: Build frontend
|
||||
FROM node:20-alpine AS frontend-builder
|
||||
|
||||
WORKDIR /app/frontend
|
||||
|
||||
# Copy frontend files
|
||||
COPY frontend/package*.json ./
|
||||
RUN npm ci
|
||||
|
||||
COPY frontend/ ./
|
||||
RUN npm run build
|
||||
|
||||
# Stage 2: Python backend
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN set -eux; \
|
||||
for source_file in /etc/apt/sources.list /etc/apt/sources.list.d/*.list /etc/apt/sources.list.d/*.sources; do \
|
||||
if [ -f "$source_file" ]; then \
|
||||
sed -i 's|deb.debian.org|mirrors.aliyun.com|g' "$source_file"; \
|
||||
sed -i 's|security.debian.org|mirrors.aliyun.com|g' "$source_file"; \
|
||||
fi; \
|
||||
done; \
|
||||
apt-get update; \
|
||||
apt-get install -y --no-install-recommends \
|
||||
postgresql-client \
|
||||
curl; \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy Python dependencies
|
||||
COPY pyproject.toml ./
|
||||
RUN pip install --no-cache-dir --upgrade pip -i https://pypi.tuna.tsinghua.edu.cn/simple && \
|
||||
pip install --no-cache-dir -e . -i https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# Copy application code
|
||||
COPY lang_agent/ ./lang_agent/
|
||||
COPY configs/ ./configs/
|
||||
COPY scripts/ ./scripts/
|
||||
COPY assets/ ./assets/
|
||||
COPY static/ ./static/
|
||||
|
||||
# Copy built frontend from stage 1
|
||||
COPY --from=frontend-builder /app/frontend/dist ./frontend/dist
|
||||
|
||||
# Set environment variables
|
||||
ENV PYTHONPATH=/app
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8500
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8500/health || exit 1
|
||||
|
||||
# Create entrypoint script to wait for DB
|
||||
# Uses Python to check database connection (more reliable than psql)
|
||||
RUN echo '#!/bin/bash\n\
|
||||
set -e\n\
|
||||
echo "Waiting for database to be ready..."\n\
|
||||
python3 << EOF\n\
|
||||
import sys\n\
|
||||
import time\n\
|
||||
import psycopg\n\
|
||||
\n\
|
||||
max_attempts = 30\n\
|
||||
conn_str = "${CONN_STR}"\n\
|
||||
\n\
|
||||
for i in range(max_attempts):\n\
|
||||
try:\n\
|
||||
with psycopg.connect(conn_str, connect_timeout=2) as conn:\n\
|
||||
with conn.cursor() as cur:\n\
|
||||
cur.execute("SELECT 1")\n\
|
||||
print("Database is ready!")\n\
|
||||
sys.exit(0)\n\
|
||||
except Exception as e:\n\
|
||||
if i == max_attempts - 1:\n\
|
||||
print(f"Warning: Database not ready after {max_attempts * 2} seconds, continuing anyway...")\n\
|
||||
print(f"Error: {e}")\n\
|
||||
sys.exit(0)\n\
|
||||
print(f"Database is unavailable - sleeping (attempt {i+1}/{max_attempts})")\n\
|
||||
time.sleep(2)\n\
|
||||
EOF\n\
|
||||
exec "$@"' > /entrypoint.sh && chmod +x /entrypoint.sh
|
||||
|
||||
ENTRYPOINT ["/entrypoint.sh"]
|
||||
|
||||
# Run the combined server
|
||||
CMD ["python", "-m", "uvicorn", "lang_agent.fastapi_server.combined:app", "--host", "0.0.0.0", "--port", "8500"]
|
||||
|
||||
100
docker/docker-compose.prod.yml
Normal file
100
docker/docker-compose.prod.yml
Normal file
@@ -0,0 +1,100 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
# PostgreSQL database
|
||||
postgres:
|
||||
image: postgres:16-alpine
|
||||
container_name: langchain-agent-db
|
||||
networks:
|
||||
- app-network
|
||||
environment:
|
||||
POSTGRES_DB: postgres
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: ${POSTGRES_ROOT_PASSWORD:-postgres_root_password}
|
||||
# These are used by init scripts to create the app database
|
||||
APP_DB_NAME: ${POSTGRES_DB:-ai_conversations}
|
||||
APP_DB_USER: ${POSTGRES_USER:-myapp_user}
|
||||
APP_DB_PASSWORD: ${POSTGRES_PASSWORD:-secure_password_123}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
- ../scripts/init_database:/docker-entrypoint-initdb.d
|
||||
ports:
|
||||
- "${POSTGRES_PORT:-5434}:5432"
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: no #unless-stopped
|
||||
|
||||
# Backend API server
|
||||
backend:
|
||||
build:
|
||||
context: ..
|
||||
dockerfile: docker/Dockerfile.prod
|
||||
container_name: langchain-agent-backend
|
||||
environment:
|
||||
- PYTHONPATH=/app
|
||||
- PYTHONUNBUFFERED=1
|
||||
- CONN_STR=postgresql://${POSTGRES_USER:-myapp_user}:${POSTGRES_PASSWORD:-secure_password_123}@postgres:5432/${POSTGRES_DB:-ai_conversations}
|
||||
- POSTGRES_USER=${POSTGRES_USER:-myapp_user}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-secure_password_123}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-ai_conversations}
|
||||
ports:
|
||||
- "${BACKEND_PORT:-8500}:8500"
|
||||
volumes:
|
||||
- ../configs:/app/configs
|
||||
- ../scripts:/app/scripts
|
||||
- ../assets:/app/assets
|
||||
- ../static:/app/static
|
||||
networks:
|
||||
- app-network
|
||||
depends_on:
|
||||
postgres:
|
||||
condition: service_healthy
|
||||
restart: no #unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8500/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
|
||||
# Frontend build service
|
||||
frontend:
|
||||
build:
|
||||
context: ../frontend
|
||||
dockerfile: ../docker/Dockerfile.frontend
|
||||
args:
|
||||
VITE_FRONT_API_BASE_URL: ${VITE_FRONT_API_BASE_URL:-/}
|
||||
volumes:
|
||||
- frontend_dist:/app/dist
|
||||
networks:
|
||||
- app-network
|
||||
|
||||
# Nginx for serving frontend (optional - can also serve via FastAPI)
|
||||
nginx:
|
||||
image: nginx:alpine
|
||||
container_name: langchain-agent-nginx
|
||||
networks:
|
||||
- app-network
|
||||
ports:
|
||||
- "${FRONTEND_PORT:-8080}:80"
|
||||
volumes:
|
||||
- ../nginx.conf:/etc/nginx/nginx.conf:ro
|
||||
- frontend_dist:/usr/share/nginx/html:ro
|
||||
depends_on:
|
||||
frontend:
|
||||
condition: service_completed_successfully
|
||||
backend:
|
||||
condition: service_started
|
||||
restart: no #unless-stopped
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
frontend_dist:
|
||||
|
||||
networks:
|
||||
app-network:
|
||||
driver: bridge
|
||||
|
||||
@@ -1,444 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
import commentjson
|
||||
import os
|
||||
import os.path as osp
|
||||
import secrets
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Ensure we can import from project root.
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
from lang_agent.config.db_config_manager import DBConfigManager
|
||||
from lang_agent.front_api.build_server import GRAPH_BUILD_FNCS
|
||||
|
||||
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.abspath(__file__)))
|
||||
_MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
|
||||
_MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
|
||||
|
||||
class GraphConfigUpsertRequest(BaseModel):
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
prompt_set_id: Optional[str] = Field(default=None)
|
||||
tool_keys: List[str] = Field(default_factory=list)
|
||||
prompt_dict: Dict[str, str] = Field(default_factory=dict)
|
||||
api_key: Optional[str] = Field(default=None)
|
||||
|
||||
class GraphConfigUpsertResponse(BaseModel):
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
tool_keys: List[str]
|
||||
prompt_keys: List[str]
|
||||
api_key: str
|
||||
|
||||
class GraphConfigReadResponse(BaseModel):
|
||||
graph_id: Optional[str] = Field(default=None)
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
tool_keys: List[str]
|
||||
prompt_dict: Dict[str, str]
|
||||
api_key: str = Field(default="")
|
||||
|
||||
class GraphConfigListItem(BaseModel):
|
||||
graph_id: Optional[str] = Field(default=None)
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
name: str
|
||||
description: str
|
||||
is_active: bool
|
||||
tool_keys: List[str]
|
||||
api_key: str = Field(default="")
|
||||
created_at: Optional[str] = Field(default=None)
|
||||
updated_at: Optional[str] = Field(default=None)
|
||||
|
||||
class GraphConfigListResponse(BaseModel):
|
||||
items: List[GraphConfigListItem]
|
||||
count: int
|
||||
|
||||
class PipelineCreateRequest(BaseModel):
|
||||
graph_id: str = Field(
|
||||
description="Graph key from GRAPH_BUILD_FNCS, e.g. routing or react"
|
||||
)
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
tool_keys: List[str] = Field(default_factory=list)
|
||||
port: int
|
||||
api_key: str
|
||||
entry_point: str = Field(default="fastapi_server/server_dashscope.py")
|
||||
llm_name: str = Field(default="qwen-plus")
|
||||
|
||||
class PipelineCreateResponse(BaseModel):
|
||||
run_id: str
|
||||
pid: int
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
url: str
|
||||
port: int
|
||||
auth_type: str
|
||||
auth_header_name: str
|
||||
auth_key_once: str
|
||||
auth_key_masked: str
|
||||
|
||||
class PipelineRunInfo(BaseModel):
|
||||
run_id: str
|
||||
pid: int
|
||||
graph_id: str
|
||||
pipeline_id: str
|
||||
prompt_set_id: str
|
||||
url: str
|
||||
port: int
|
||||
auth_type: str
|
||||
auth_header_name: str
|
||||
auth_key_masked: str
|
||||
|
||||
class PipelineListResponse(BaseModel):
|
||||
items: List[PipelineRunInfo]
|
||||
count: int
|
||||
|
||||
class PipelineStopResponse(BaseModel):
|
||||
run_id: str
|
||||
status: str
|
||||
|
||||
class McpConfigReadResponse(BaseModel):
|
||||
path: str
|
||||
raw_content: str
|
||||
tool_keys: List[str]
|
||||
|
||||
class McpConfigUpdateRequest(BaseModel):
|
||||
raw_content: str
|
||||
|
||||
class McpConfigUpdateResponse(BaseModel):
|
||||
status: str
|
||||
path: str
|
||||
tool_keys: List[str]
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Front APIs",
|
||||
description="Manage graph configs and launch graph pipelines.",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
_db = DBConfigManager()
|
||||
_running_pipelines: Dict[str, Dict[str, object]] = {}
|
||||
|
||||
|
||||
def _generate_auth_key() -> str:
|
||||
return f"agk_{secrets.token_urlsafe(24)}"
|
||||
|
||||
|
||||
def _mask_auth_key(value: str) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
if len(value) <= 10:
|
||||
return value
|
||||
return f"{value[:5]}...{value[-5:]}"
|
||||
|
||||
def _prune_stopped_pipelines() -> None:
|
||||
stale_ids: List[str] = []
|
||||
for run_id, info in _running_pipelines.items():
|
||||
proc = info["proc"]
|
||||
if proc.poll() is not None:
|
||||
stale_ids.append(run_id)
|
||||
for run_id in stale_ids:
|
||||
_running_pipelines.pop(run_id, None)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "Front APIs",
|
||||
"endpoints": [
|
||||
"/v1/graph-configs (POST)",
|
||||
"/v1/graph-configs (GET)",
|
||||
"/v1/graph-configs/default/{pipeline_id} (GET)",
|
||||
"/v1/graphs/{graph_id}/default-config (GET)",
|
||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id} (GET)",
|
||||
"/v1/graph-configs/{pipeline_id}/{prompt_set_id} (DELETE)",
|
||||
"/v1/pipelines/graphs (GET)",
|
||||
"/v1/pipelines (POST)",
|
||||
"/v1/pipelines (GET)",
|
||||
"/v1/pipelines/{run_id} (DELETE)",
|
||||
"/v1/tool-configs/mcp (GET)",
|
||||
"/v1/tool-configs/mcp (PUT)",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _parse_mcp_tool_keys(raw_content: str) -> List[str]:
|
||||
parsed = commentjson.loads(raw_content or "{}")
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("mcp_config must be a JSON object at top level")
|
||||
return sorted(str(key) for key in parsed.keys())
|
||||
|
||||
|
||||
def _read_mcp_config_raw() -> str:
|
||||
if not osp.exists(_MCP_CONFIG_PATH):
|
||||
os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
f.write(_MCP_CONFIG_DEFAULT_CONTENT)
|
||||
with open(_MCP_CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
@app.post("/v1/graph-configs", response_model=GraphConfigUpsertResponse)
|
||||
async def upsert_graph_config(body: GraphConfigUpsertRequest):
|
||||
try:
|
||||
resolved_prompt_set_id = _db.set_config(
|
||||
graph_id=body.graph_id,
|
||||
pipeline_id=body.pipeline_id,
|
||||
prompt_set_id=body.prompt_set_id,
|
||||
tool_list=body.tool_keys,
|
||||
prompt_dict=body.prompt_dict,
|
||||
api_key=body.api_key,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return GraphConfigUpsertResponse(
|
||||
graph_id=body.graph_id,
|
||||
pipeline_id=body.pipeline_id,
|
||||
prompt_set_id=resolved_prompt_set_id,
|
||||
tool_keys=body.tool_keys,
|
||||
prompt_keys=list(body.prompt_dict.keys()),
|
||||
api_key=(body.api_key or "").strip(),
|
||||
)
|
||||
|
||||
@app.get("/v1/graph-configs", response_model=GraphConfigListResponse)
|
||||
async def list_graph_configs(pipeline_id: Optional[str] = None, graph_id: Optional[str] = None):
|
||||
try:
|
||||
rows = _db.list_prompt_sets(pipeline_id=pipeline_id, graph_id=graph_id)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
items = [GraphConfigListItem(**row) for row in rows]
|
||||
return GraphConfigListResponse(items=items, count=len(items))
|
||||
|
||||
@app.get("/v1/graph-configs/default/{pipeline_id}", response_model=GraphConfigReadResponse)
|
||||
async def get_default_graph_config(pipeline_id: str):
|
||||
try:
|
||||
prompt_dict, tool_keys = _db.get_config(pipeline_id=pipeline_id, prompt_set_id=None)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if not prompt_dict and not tool_keys:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No active prompt set found for pipeline '{pipeline_id}'",
|
||||
)
|
||||
|
||||
rows = _db.list_prompt_sets(pipeline_id=pipeline_id)
|
||||
active = next((row for row in rows if row["is_active"]), None)
|
||||
if active is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No active prompt set found for pipeline '{pipeline_id}'",
|
||||
)
|
||||
|
||||
return GraphConfigReadResponse(
|
||||
graph_id=active.get("graph_id"),
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=active["prompt_set_id"],
|
||||
tool_keys=tool_keys,
|
||||
prompt_dict=prompt_dict,
|
||||
api_key=(active.get("api_key") or ""),
|
||||
)
|
||||
|
||||
@app.get("/v1/graphs/{graph_id}/default-config", response_model=GraphConfigReadResponse)
|
||||
async def get_graph_default_config_by_graph(graph_id: str):
|
||||
return await get_default_graph_config(pipeline_id=graph_id)
|
||||
|
||||
@app.get("/v1/graph-configs/{pipeline_id}/{prompt_set_id}", response_model=GraphConfigReadResponse)
|
||||
async def get_graph_config(pipeline_id: str, prompt_set_id: str):
|
||||
try:
|
||||
meta = _db.get_prompt_set(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
||||
if meta is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"prompt_set_id '{prompt_set_id}' not found for pipeline '{pipeline_id}'",
|
||||
)
|
||||
prompt_dict, tool_keys = _db.get_config(
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=prompt_set_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return GraphConfigReadResponse(
|
||||
graph_id=meta.get("graph_id"),
|
||||
pipeline_id=pipeline_id,
|
||||
prompt_set_id=prompt_set_id,
|
||||
tool_keys=tool_keys,
|
||||
prompt_dict=prompt_dict,
|
||||
api_key=(meta.get("api_key") or ""),
|
||||
)
|
||||
|
||||
|
||||
@app.delete("/v1/graph-configs/{pipeline_id}/{prompt_set_id}")
|
||||
async def delete_graph_config(pipeline_id: str, prompt_set_id: str):
|
||||
try:
|
||||
_db.remove_config(pipeline_id=pipeline_id, prompt_set_id=prompt_set_id)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
return {
|
||||
"status": "deleted",
|
||||
"pipeline_id": pipeline_id,
|
||||
"prompt_set_id": prompt_set_id,
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/pipelines/graphs")
|
||||
async def available_graphs():
|
||||
return {"available_graphs": sorted(GRAPH_BUILD_FNCS.keys())}
|
||||
|
||||
@app.get("/v1/tool-configs/mcp", response_model=McpConfigReadResponse)
|
||||
async def get_mcp_tool_config():
|
||||
try:
|
||||
raw_content = _read_mcp_config_raw()
|
||||
tool_keys = _parse_mcp_tool_keys(raw_content)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return McpConfigReadResponse(
|
||||
path=_MCP_CONFIG_PATH,
|
||||
raw_content=raw_content,
|
||||
tool_keys=tool_keys,
|
||||
)
|
||||
|
||||
|
||||
@app.put("/v1/tool-configs/mcp", response_model=McpConfigUpdateResponse)
|
||||
async def update_mcp_tool_config(body: McpConfigUpdateRequest):
|
||||
try:
|
||||
tool_keys = _parse_mcp_tool_keys(body.raw_content)
|
||||
os.makedirs(osp.dirname(_MCP_CONFIG_PATH), exist_ok=True)
|
||||
with open(_MCP_CONFIG_PATH, "w", encoding="utf-8") as f:
|
||||
# Keep user formatting/comments as entered while ensuring trailing newline.
|
||||
f.write(body.raw_content.rstrip() + "\n")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
return McpConfigUpdateResponse(
|
||||
status="updated",
|
||||
path=_MCP_CONFIG_PATH,
|
||||
tool_keys=tool_keys,
|
||||
)
|
||||
|
||||
@app.get("/v1/pipelines", response_model=PipelineListResponse)
|
||||
async def list_running_pipelines():
|
||||
_prune_stopped_pipelines()
|
||||
items = [
|
||||
PipelineRunInfo(
|
||||
run_id=run_id,
|
||||
pid=info["proc"].pid,
|
||||
graph_id=info["graph_id"],
|
||||
pipeline_id=info["pipeline_id"],
|
||||
prompt_set_id=info["prompt_set_id"],
|
||||
url=info["url"],
|
||||
port=info["port"],
|
||||
auth_type="bearer",
|
||||
auth_header_name="Authorization",
|
||||
auth_key_masked=info.get("auth_key_masked", ""),
|
||||
)
|
||||
for run_id, info in _running_pipelines.items()
|
||||
]
|
||||
return PipelineListResponse(items=items, count=len(items))
|
||||
|
||||
|
||||
@app.post("/v1/pipelines", response_model=PipelineCreateResponse)
|
||||
async def create_pipeline(body: PipelineCreateRequest):
|
||||
build_fn = GRAPH_BUILD_FNCS.get(body.graph_id)
|
||||
if build_fn is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unknown graph_id '{body.graph_id}'. Valid options: {sorted(GRAPH_BUILD_FNCS.keys())}",
|
||||
)
|
||||
|
||||
auth_key = _generate_auth_key()
|
||||
auth_key_masked = _mask_auth_key(auth_key)
|
||||
try:
|
||||
proc, url = build_fn(
|
||||
pipeline_id=body.pipeline_id,
|
||||
prompt_set=body.prompt_set_id,
|
||||
tool_keys=body.tool_keys,
|
||||
port=str(body.port),
|
||||
api_key=body.api_key,
|
||||
fast_auth_keys=auth_key,
|
||||
entry_pnt=body.entry_point,
|
||||
llm_name=body.llm_name,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to start pipeline: {e}")
|
||||
|
||||
run_id = str(uuid.uuid4())
|
||||
_running_pipelines[run_id] = {
|
||||
"proc": proc,
|
||||
"graph_id": body.graph_id,
|
||||
"pipeline_id": body.pipeline_id,
|
||||
"prompt_set_id": body.prompt_set_id,
|
||||
"url": url,
|
||||
"port": body.port,
|
||||
"auth_key_masked": auth_key_masked,
|
||||
}
|
||||
|
||||
return PipelineCreateResponse(
|
||||
run_id=run_id,
|
||||
pid=proc.pid,
|
||||
graph_id=body.graph_id,
|
||||
pipeline_id=body.pipeline_id,
|
||||
prompt_set_id=body.prompt_set_id,
|
||||
url=url,
|
||||
port=body.port,
|
||||
auth_type="bearer",
|
||||
auth_header_name="Authorization",
|
||||
auth_key_once=auth_key,
|
||||
auth_key_masked=auth_key_masked,
|
||||
)
|
||||
|
||||
@app.delete("/v1/pipelines/{run_id}", response_model=PipelineStopResponse)
|
||||
async def stop_pipeline(run_id: str):
|
||||
info = _running_pipelines.pop(run_id, None)
|
||||
if info is None:
|
||||
raise HTTPException(status_code=404, detail=f"run_id '{run_id}' not found")
|
||||
|
||||
proc = info["proc"]
|
||||
if proc.poll() is None:
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=5)
|
||||
except subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
proc.wait(timeout=5)
|
||||
|
||||
return PipelineStopResponse(run_id=run_id, status="stopped")
|
||||
BIN
frontend/assets/images/graph_arch/deepagent.png
Normal file
BIN
frontend/assets/images/graph_arch/deepagent.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 28 KiB |
1882
frontend/package-lock.json
generated
1882
frontend/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -6,18 +6,22 @@
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "tsc -b && vite build",
|
||||
"preview": "vite preview"
|
||||
"preview": "vite preview",
|
||||
"test": "vitest run"
|
||||
},
|
||||
"dependencies": {
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1"
|
||||
"react-dom": "^18.3.1",
|
||||
"react-markdown": "^10.1.0",
|
||||
"remark-gfm": "^4.0.1"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@types/react": "^18.3.20",
|
||||
"@types/react-dom": "^18.3.6",
|
||||
"@vitejs/plugin-react": "^4.3.1",
|
||||
"typescript": "^5.6.3",
|
||||
"vite": "^5.4.10"
|
||||
"vite": "^5.4.10",
|
||||
"vitest": "^2.1.8"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
1543
frontend/src/App.tsx
1543
frontend/src/App.tsx
File diff suppressed because it is too large
Load Diff
79
frontend/src/activeConfigSelection.test.ts
Normal file
79
frontend/src/activeConfigSelection.test.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { chooseActiveConfigItem, chooseDisplayItemsByPipeline } from "./activeConfigSelection";
|
||||
import type { GraphConfigListItem } from "./types";
|
||||
|
||||
const mk = (patch: Partial<GraphConfigListItem>): GraphConfigListItem => ({
|
||||
graph_id: "routing",
|
||||
pipeline_id: "agent-a",
|
||||
prompt_set_id: "set-1",
|
||||
name: "default",
|
||||
description: "",
|
||||
is_active: false,
|
||||
tool_keys: [],
|
||||
api_key: "",
|
||||
created_at: null,
|
||||
updated_at: null,
|
||||
...patch,
|
||||
});
|
||||
|
||||
describe("chooseActiveConfigItem", () => {
|
||||
it("prefers active item over newer inactive items", () => {
|
||||
const items = [
|
||||
mk({
|
||||
pipeline_id: "agent-a",
|
||||
prompt_set_id: "old-active",
|
||||
is_active: true,
|
||||
updated_at: "2025-01-01T00:00:00Z",
|
||||
}),
|
||||
mk({
|
||||
pipeline_id: "agent-a",
|
||||
prompt_set_id: "new-inactive",
|
||||
is_active: false,
|
||||
updated_at: "2025-03-01T00:00:00Z",
|
||||
}),
|
||||
];
|
||||
const selected = chooseActiveConfigItem(items, "agent-a");
|
||||
expect(selected?.prompt_set_id).toBe("old-active");
|
||||
});
|
||||
|
||||
it("falls back to latest updated_at when no active item exists", () => {
|
||||
const items = [
|
||||
mk({
|
||||
pipeline_id: "agent-b",
|
||||
prompt_set_id: "set-1",
|
||||
updated_at: "2025-01-01T00:00:00Z",
|
||||
}),
|
||||
mk({
|
||||
pipeline_id: "agent-b",
|
||||
prompt_set_id: "set-2",
|
||||
updated_at: "2025-02-01T00:00:00Z",
|
||||
}),
|
||||
];
|
||||
const selected = chooseActiveConfigItem(items, "agent-b");
|
||||
expect(selected?.prompt_set_id).toBe("set-2");
|
||||
});
|
||||
});
|
||||
|
||||
describe("chooseDisplayItemsByPipeline", () => {
|
||||
it("returns one selected item per pipeline_id", () => {
|
||||
const items = [
|
||||
mk({ pipeline_id: "agent-b", prompt_set_id: "set-1", updated_at: "2025-01-01T00:00:00Z" }),
|
||||
mk({
|
||||
pipeline_id: "agent-b",
|
||||
prompt_set_id: "set-2",
|
||||
is_active: true,
|
||||
updated_at: "2025-02-01T00:00:00Z",
|
||||
}),
|
||||
mk({
|
||||
pipeline_id: "agent-a",
|
||||
prompt_set_id: "set-3",
|
||||
updated_at: "2025-03-01T00:00:00Z",
|
||||
}),
|
||||
];
|
||||
const selected = chooseDisplayItemsByPipeline(items);
|
||||
expect(selected.map((x) => x.pipeline_id)).toEqual(["agent-a", "agent-b"]);
|
||||
expect(selected.find((x) => x.pipeline_id === "agent-b")?.prompt_set_id).toBe("set-2");
|
||||
});
|
||||
});
|
||||
|
||||
45
frontend/src/activeConfigSelection.ts
Normal file
45
frontend/src/activeConfigSelection.ts
Normal file
@@ -0,0 +1,45 @@
|
||||
import type { GraphConfigListItem } from "./types";
|
||||
|
||||
function toTimestamp(value?: string | null): number {
|
||||
if (!value) {
|
||||
return 0;
|
||||
}
|
||||
const parsed = Date.parse(value);
|
||||
return Number.isNaN(parsed) ? 0 : parsed;
|
||||
}
|
||||
|
||||
export function chooseActiveConfigItem(
|
||||
items: GraphConfigListItem[],
|
||||
pipelineId: string
|
||||
): GraphConfigListItem | null {
|
||||
const candidates = items.filter((item) => item.pipeline_id === pipelineId);
|
||||
if (candidates.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const active = candidates.find((item) => item.is_active);
|
||||
if (active) {
|
||||
return active;
|
||||
}
|
||||
return [...candidates].sort((a, b) => toTimestamp(b.updated_at) - toTimestamp(a.updated_at))[0];
|
||||
}
|
||||
|
||||
export function chooseDisplayItemsByPipeline(
|
||||
items: GraphConfigListItem[]
|
||||
): GraphConfigListItem[] {
|
||||
const byPipeline = new Map<string, GraphConfigListItem[]>();
|
||||
for (const item of items) {
|
||||
const list = byPipeline.get(item.pipeline_id) || [];
|
||||
list.push(item);
|
||||
byPipeline.set(item.pipeline_id, list);
|
||||
}
|
||||
|
||||
const out: GraphConfigListItem[] = [];
|
||||
for (const [pipelineId, list] of byPipeline.entries()) {
|
||||
const selected = chooseActiveConfigItem(list, pipelineId);
|
||||
if (selected) {
|
||||
out.push(selected);
|
||||
}
|
||||
}
|
||||
return out.sort((a, b) => a.pipeline_id.localeCompare(b.pipeline_id));
|
||||
}
|
||||
|
||||
22
frontend/src/api/frontApis.test.ts
Normal file
22
frontend/src/api/frontApis.test.ts
Normal file
@@ -0,0 +1,22 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
|
||||
import { joinApiUrl } from "./frontApis";
|
||||
|
||||
describe("joinApiUrl", () => {
|
||||
it("keeps same-origin paths when base url is slash", () => {
|
||||
expect(joinApiUrl("/", "/v1/pipelines")).toBe("/v1/pipelines");
|
||||
});
|
||||
|
||||
it("joins absolute host and trims trailing slash", () => {
|
||||
expect(joinApiUrl("http://127.0.0.1:8500/", "/v1/pipelines")).toBe(
|
||||
"http://127.0.0.1:8500/v1/pipelines"
|
||||
);
|
||||
});
|
||||
|
||||
it("accepts path without leading slash", () => {
|
||||
expect(joinApiUrl("http://127.0.0.1:8500", "v1/pipelines")).toBe(
|
||||
"http://127.0.0.1:8500/v1/pipelines"
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
@@ -1,23 +1,48 @@
|
||||
import type {
|
||||
AvailableGraphsResponse,
|
||||
ConversationListItem,
|
||||
ConversationMessageItem,
|
||||
GraphConfigListResponse,
|
||||
GraphConfigReadResponse,
|
||||
GraphConfigUpsertRequest,
|
||||
GraphConfigUpsertResponse,
|
||||
McpAvailableToolsResponse,
|
||||
McpToolConfigResponse,
|
||||
McpToolConfigUpdateRequest,
|
||||
McpToolConfigUpdateResponse,
|
||||
PipelineCreateRequest,
|
||||
PipelineCreateResponse,
|
||||
PipelineConversationListResponse,
|
||||
PipelineConversationMessagesResponse,
|
||||
PipelineListResponse,
|
||||
PipelineStopResponse,
|
||||
RuntimeAuthInfoResponse,
|
||||
} from "../types";
|
||||
|
||||
const API_BASE_URL =
|
||||
import.meta.env.VITE_FRONT_API_BASE_URL?.trim() || "http://127.0.0.1:8001";
|
||||
const API_BASE_URL = import.meta.env.VITE_FRONT_API_BASE_URL?.trim() || "/";
|
||||
|
||||
export function joinApiUrl(baseUrl: string, path: string): string {
|
||||
const normalizedPath = path.startsWith("/") ? path : `/${path}`;
|
||||
const normalizedBase = baseUrl.trim();
|
||||
|
||||
// "/" is commonly used in Docker+nginx builds and should resolve as same-origin.
|
||||
if (!normalizedBase || normalizedBase === "/") {
|
||||
return normalizedPath;
|
||||
}
|
||||
|
||||
return `${normalizedBase.replace(/\/+$/, "")}${normalizedPath}`;
|
||||
}
|
||||
|
||||
// Log which backend the frontend is targeting on startup, with file + line hint.
|
||||
// This runs once when the module is loaded.
|
||||
// eslint-disable-next-line no-console
|
||||
console.info(
|
||||
`[frontend] Using FRONT_API_BASE_URL=${API_BASE_URL} (src/api/frontApis.ts:16)`
|
||||
);
|
||||
|
||||
async function fetchJson<T>(path: string, init?: RequestInit): Promise<T> {
|
||||
const response = await fetch(`${API_BASE_URL}${path}`, {
|
||||
const url = joinApiUrl(API_BASE_URL, path);
|
||||
const response = await fetch(url, {
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
...(init?.headers || {}),
|
||||
@@ -36,7 +61,24 @@ async function fetchJson<T>(path: string, init?: RequestInit): Promise<T> {
|
||||
}
|
||||
throw new Error(message);
|
||||
}
|
||||
return (await response.json()) as T;
|
||||
|
||||
if (response.status === 204) {
|
||||
return undefined as T;
|
||||
}
|
||||
|
||||
const bodyText = await response.text();
|
||||
if (!bodyText.trim()) {
|
||||
return undefined as T;
|
||||
}
|
||||
|
||||
try {
|
||||
return JSON.parse(bodyText) as T;
|
||||
} catch {
|
||||
const preview = bodyText.slice(0, 160).replace(/\s+/g, " ").trim();
|
||||
throw new Error(
|
||||
`Expected JSON response from ${url}, but received non-JSON content: ${preview || "<empty>"}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
export function listAvailableGraphs(): Promise<AvailableGraphsResponse> {
|
||||
@@ -70,6 +112,12 @@ export function getGraphDefaultConfig(
|
||||
return fetchJson(`/v1/graphs/${graphId}/default-config`);
|
||||
}
|
||||
|
||||
export function getPipelineDefaultConfig(
|
||||
pipelineId: string
|
||||
): Promise<GraphConfigReadResponse> {
|
||||
return fetchJson(`/v1/graph-configs/default/${pipelineId}`);
|
||||
}
|
||||
|
||||
export function upsertGraphConfig(
|
||||
payload: GraphConfigUpsertRequest
|
||||
): Promise<GraphConfigUpsertResponse> {
|
||||
@@ -101,6 +149,10 @@ export function updateMcpToolConfig(
|
||||
});
|
||||
}
|
||||
|
||||
export function listMcpAvailableTools(): Promise<McpAvailableToolsResponse> {
|
||||
return fetchJson("/v1/tool-configs/mcp/tools");
|
||||
}
|
||||
|
||||
export function createPipeline(
|
||||
payload: PipelineCreateRequest
|
||||
): Promise<PipelineCreateResponse> {
|
||||
@@ -114,9 +166,144 @@ export function listPipelines(): Promise<PipelineListResponse> {
|
||||
return fetchJson("/v1/pipelines");
|
||||
}
|
||||
|
||||
export function stopPipeline(runId: string): Promise<PipelineStopResponse> {
|
||||
return fetchJson(`/v1/pipelines/${runId}`, {
|
||||
export function stopPipeline(pipelineId: string): Promise<PipelineStopResponse> {
|
||||
return fetchJson(`/v1/pipelines/${pipelineId}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
}
|
||||
|
||||
export function getRuntimeAuthInfo(): Promise<RuntimeAuthInfoResponse> {
|
||||
return fetchJson("/v1/runtime-auth");
|
||||
}
|
||||
|
||||
export async function listPipelineConversations(
|
||||
pipelineId: string,
|
||||
limit = 100
|
||||
): Promise<ConversationListItem[]> {
|
||||
const response = await fetchJson<PipelineConversationListResponse>(
|
||||
`/v1/pipelines/${encodeURIComponent(pipelineId)}/conversations?limit=${limit}`
|
||||
);
|
||||
return response.items || [];
|
||||
}
|
||||
|
||||
export async function getPipelineConversationMessages(
|
||||
pipelineId: string,
|
||||
conversationId: string
|
||||
): Promise<ConversationMessageItem[]> {
|
||||
const response = await fetchJson<PipelineConversationMessagesResponse>(
|
||||
`/v1/pipelines/${encodeURIComponent(pipelineId)}/conversations/${encodeURIComponent(conversationId)}/messages`
|
||||
);
|
||||
return response.items || [];
|
||||
}
|
||||
|
||||
type StreamAgentChatOptions = {
|
||||
appId: string;
|
||||
sessionId: string;
|
||||
apiKey: string;
|
||||
message: string;
|
||||
onText: (text: string) => void;
|
||||
signal?: AbortSignal;
|
||||
};
|
||||
|
||||
function parseErrorDetail(payload: unknown): string | null {
|
||||
if (!payload || typeof payload !== "object") {
|
||||
return null;
|
||||
}
|
||||
const detail = (payload as { detail?: unknown }).detail;
|
||||
return typeof detail === "string" && detail.trim() ? detail : null;
|
||||
}
|
||||
|
||||
export async function streamAgentChatResponse(
|
||||
options: StreamAgentChatOptions
|
||||
): Promise<string> {
|
||||
const { appId, sessionId, apiKey, message, onText, signal } = options;
|
||||
const response = await fetch(
|
||||
joinApiUrl(
|
||||
API_BASE_URL,
|
||||
`/v1/apps/${encodeURIComponent(appId)}/sessions/${encodeURIComponent(sessionId)}/responses`
|
||||
),
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
Authorization: `Bearer ${apiKey}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
messages: [{ role: "user", content: message }],
|
||||
stream: true,
|
||||
}),
|
||||
signal,
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
let messageText = `Request failed (${response.status})`;
|
||||
try {
|
||||
const payload = (await response.json()) as unknown;
|
||||
const detail = parseErrorDetail(payload);
|
||||
if (detail) {
|
||||
messageText = detail;
|
||||
}
|
||||
} catch {
|
||||
// Keep fallback status-based message.
|
||||
}
|
||||
throw new Error(messageText);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error("Streaming response is not available.");
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffered = "";
|
||||
let latestText = "";
|
||||
|
||||
while (true) {
|
||||
if (signal?.aborted) {
|
||||
reader.cancel();
|
||||
throw new Error("Request cancelled");
|
||||
}
|
||||
const { value, done } = await reader.read();
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
buffered += decoder.decode(value, { stream: true });
|
||||
|
||||
let splitIndex = buffered.indexOf("\n\n");
|
||||
while (splitIndex >= 0) {
|
||||
const eventBlock = buffered.slice(0, splitIndex);
|
||||
buffered = buffered.slice(splitIndex + 2);
|
||||
splitIndex = buffered.indexOf("\n\n");
|
||||
|
||||
const lines = eventBlock.split("\n");
|
||||
for (const rawLine of lines) {
|
||||
const line = rawLine.trim();
|
||||
if (!line.startsWith("data:")) {
|
||||
continue;
|
||||
}
|
||||
const payloadRaw = line.slice(5).trim();
|
||||
if (!payloadRaw) {
|
||||
continue;
|
||||
}
|
||||
let payload: unknown;
|
||||
try {
|
||||
payload = JSON.parse(payloadRaw);
|
||||
} catch {
|
||||
continue;
|
||||
}
|
||||
const nextText =
|
||||
typeof (payload as { output?: { text?: unknown } })?.output?.text === "string"
|
||||
? ((payload as { output: { text: string } }).output.text as string)
|
||||
: "";
|
||||
if (nextText !== latestText) {
|
||||
latestText = nextText;
|
||||
onText(latestText);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return latestText;
|
||||
}
|
||||
|
||||
|
||||
@@ -65,6 +65,12 @@ button:disabled {
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.agent-item-row {
|
||||
display: grid;
|
||||
gap: 6px;
|
||||
grid-template-columns: 1fr auto;
|
||||
}
|
||||
|
||||
.agent-item {
|
||||
align-items: flex-start;
|
||||
display: flex;
|
||||
@@ -73,6 +79,33 @@ button:disabled {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.agent-item-title {
|
||||
align-items: center;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
justify-content: space-between;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.agent-status-pill {
|
||||
border-radius: 999px;
|
||||
font-size: 11px;
|
||||
font-weight: 600;
|
||||
padding: 2px 8px;
|
||||
}
|
||||
|
||||
.agent-status-pill.running {
|
||||
background: #dff7e7;
|
||||
border: 1px solid #8cd3a1;
|
||||
color: #1a6b35;
|
||||
}
|
||||
|
||||
.agent-status-pill.stopped {
|
||||
background: #f2f4f7;
|
||||
border: 1px solid #d1d8e0;
|
||||
color: #4a5565;
|
||||
}
|
||||
|
||||
.agent-item.selected {
|
||||
border-color: #4d7ef3;
|
||||
background: #edf3ff;
|
||||
@@ -82,6 +115,11 @@ button:disabled {
|
||||
color: #5f6f82;
|
||||
}
|
||||
|
||||
.agent-chat-button {
|
||||
align-self: stretch;
|
||||
min-width: 64px;
|
||||
}
|
||||
|
||||
.content {
|
||||
padding: 20px;
|
||||
}
|
||||
@@ -186,6 +224,7 @@ button:disabled {
|
||||
}
|
||||
|
||||
.prompt-section,
|
||||
.file-backend-config,
|
||||
.run-info {
|
||||
border: 1px solid #dbe2ea;
|
||||
border-radius: 10px;
|
||||
@@ -194,10 +233,37 @@ button:disabled {
|
||||
}
|
||||
|
||||
.prompt-section h3,
|
||||
.file-backend-config h3,
|
||||
.run-info h3 {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.run-info-header {
|
||||
align-items: center;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.runtime-badge {
|
||||
border-radius: 999px;
|
||||
font-size: 12px;
|
||||
font-weight: 600;
|
||||
padding: 4px 10px;
|
||||
}
|
||||
|
||||
.runtime-badge.running {
|
||||
background: #dff7e7;
|
||||
border: 1px solid #8cd3a1;
|
||||
color: #1a6b35;
|
||||
}
|
||||
|
||||
.runtime-badge.stopped {
|
||||
background: #f2f4f7;
|
||||
border: 1px solid #d1d8e0;
|
||||
color: #4a5565;
|
||||
}
|
||||
|
||||
.graph-arch-section {
|
||||
border: 1px solid #dbe2ea;
|
||||
border-radius: 10px;
|
||||
@@ -229,6 +295,166 @@ button:disabled {
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.run-card-columns {
|
||||
display: grid;
|
||||
gap: 12px;
|
||||
grid-template-columns: minmax(280px, 1fr) minmax(420px, 1.3fr);
|
||||
}
|
||||
|
||||
.run-card-left,
|
||||
.run-card-right {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.run-card-right {
|
||||
border-left: 1px solid #dbe2ea;
|
||||
min-width: 0;
|
||||
padding-left: 12px;
|
||||
}
|
||||
|
||||
.run-card-right code {
|
||||
display: inline-block;
|
||||
max-width: 100%;
|
||||
overflow-x: auto;
|
||||
vertical-align: middle;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.discussion-section {
|
||||
background: #f7fbff;
|
||||
border: 1px solid #d7e6f6;
|
||||
border-radius: 10px;
|
||||
padding: 12px;
|
||||
}
|
||||
|
||||
.discussion-header {
|
||||
align-items: center;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
gap: 12px;
|
||||
}
|
||||
|
||||
.discussion-header h3 {
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.discussion-layout {
|
||||
display: grid;
|
||||
gap: 12px;
|
||||
grid-template-columns: minmax(260px, 320px) 1fr;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.discussion-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
max-height: 65vh;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.discussion-item {
|
||||
align-items: flex-start;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 4px;
|
||||
text-align: left;
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.discussion-item.selected {
|
||||
background: #edf3ff;
|
||||
border-color: #4d7ef3;
|
||||
}
|
||||
|
||||
.discussion-item small {
|
||||
color: #687788;
|
||||
}
|
||||
|
||||
.discussion-thread {
|
||||
border: 1px solid #d7e6f6;
|
||||
border-radius: 10px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
max-height: 65vh;
|
||||
overflow-y: auto;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.discussion-message {
|
||||
background: #fff;
|
||||
border: 1px solid #dbe2ea;
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
.discussion-message.human {
|
||||
border-left: 3px solid #4d7ef3;
|
||||
}
|
||||
|
||||
.discussion-message.ai {
|
||||
border-left: 3px solid #26a269;
|
||||
}
|
||||
|
||||
.discussion-message.tool {
|
||||
border-left: 3px solid #8e6bd8;
|
||||
}
|
||||
|
||||
.discussion-message-meta {
|
||||
align-items: baseline;
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.discussion-message pre {
|
||||
font-family: ui-monospace, SFMono-Regular, Menlo, Monaco, Consolas, monospace;
|
||||
margin: 8px 0 0;
|
||||
overflow-x: auto;
|
||||
white-space: pre;
|
||||
}
|
||||
|
||||
.discussion-message-markdown > :first-child {
|
||||
margin-top: 0;
|
||||
}
|
||||
|
||||
.discussion-message-markdown > :last-child {
|
||||
margin-bottom: 0;
|
||||
}
|
||||
|
||||
.discussion-message-markdown code {
|
||||
background: #f3f5f8;
|
||||
border-radius: 4px;
|
||||
padding: 1px 4px;
|
||||
}
|
||||
|
||||
.discussion-message-markdown pre code {
|
||||
background: transparent;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
.discussion-message-markdown a {
|
||||
color: #1a4fc5;
|
||||
text-decoration: underline;
|
||||
}
|
||||
|
||||
.discussion-message-markdown p,
|
||||
.discussion-message-markdown ul,
|
||||
.discussion-message-markdown ol,
|
||||
.discussion-message-markdown blockquote,
|
||||
.discussion-message-markdown table {
|
||||
margin: 8px 0;
|
||||
}
|
||||
|
||||
.discussion-message-markdown blockquote {
|
||||
border-left: 3px solid #d0d8e2;
|
||||
color: #4f5f73;
|
||||
margin-left: 0;
|
||||
padding-left: 10px;
|
||||
}
|
||||
|
||||
.mcp-config-section {
|
||||
background: #f7fbff;
|
||||
border: 1px solid #d7e6f6;
|
||||
@@ -258,8 +484,174 @@ button:disabled {
|
||||
width: 100%;
|
||||
}
|
||||
|
||||
.mcp-entry-list {
|
||||
display: grid;
|
||||
gap: 12px;
|
||||
margin-top: 10px;
|
||||
}
|
||||
|
||||
.mcp-tools-error {
|
||||
color: #a33434;
|
||||
margin: 6px 0 0 0;
|
||||
}
|
||||
|
||||
.mcp-tools-inline {
|
||||
background: #f8fbff;
|
||||
border: 1px solid #d7e6f6;
|
||||
border-radius: 8px;
|
||||
margin: 0 0 10px 0;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
.mcp-entry-card {
|
||||
background: #fff;
|
||||
border: 1px solid #d7e6f6;
|
||||
border-radius: 10px;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.mcp-entry-header {
|
||||
align-items: center;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
gap: 10px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
.mcp-entry-grid {
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
grid-template-columns: repeat(2, minmax(200px, 1fr));
|
||||
}
|
||||
|
||||
.mcp-entry-grid label {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
font-size: 14px;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.mcp-entry-grid input,
|
||||
.mcp-entry-grid select {
|
||||
border: 1px solid #c9d4e2;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
.mcp-entry-wide {
|
||||
grid-column: 1 / -1;
|
||||
}
|
||||
|
||||
.empty {
|
||||
color: #687788;
|
||||
margin: 6px 0;
|
||||
}
|
||||
|
||||
.chat-modal-overlay {
|
||||
align-items: center;
|
||||
background: rgba(16, 24, 40, 0.45);
|
||||
display: flex;
|
||||
inset: 0;
|
||||
justify-content: center;
|
||||
position: fixed;
|
||||
z-index: 20;
|
||||
}
|
||||
|
||||
.chat-modal {
|
||||
background: #fff;
|
||||
border: 1px solid #d7e6f6;
|
||||
border-radius: 12px;
|
||||
display: grid;
|
||||
gap: 10px;
|
||||
max-height: 86vh;
|
||||
max-width: 820px;
|
||||
padding: 12px;
|
||||
width: min(92vw, 820px);
|
||||
}
|
||||
|
||||
.chat-modal-header {
|
||||
align-items: center;
|
||||
border-bottom: 1px solid #dbe2ea;
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
padding-bottom: 8px;
|
||||
}
|
||||
|
||||
.chat-modal-header small {
|
||||
color: #687788;
|
||||
display: block;
|
||||
margin-top: 2px;
|
||||
}
|
||||
|
||||
.chat-modal-messages {
|
||||
background: #f8fbff;
|
||||
border: 1px solid #d7e6f6;
|
||||
border-radius: 10px;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 8px;
|
||||
max-height: 56vh;
|
||||
overflow-y: auto;
|
||||
padding: 10px;
|
||||
}
|
||||
|
||||
.chat-modal-message {
|
||||
background: #fff;
|
||||
border: 1px solid #dbe2ea;
|
||||
border-radius: 8px;
|
||||
padding: 8px;
|
||||
}
|
||||
|
||||
.chat-modal-message.user {
|
||||
border-left: 3px solid #4d7ef3;
|
||||
}
|
||||
|
||||
.chat-modal-message.assistant {
|
||||
border-left: 3px solid #26a269;
|
||||
}
|
||||
|
||||
.chat-modal-message.tool {
|
||||
border-left: 3px solid #8e6bd8;
|
||||
}
|
||||
|
||||
.chat-modal-message p {
|
||||
margin: 6px 0 0 0;
|
||||
white-space: pre-wrap;
|
||||
}
|
||||
|
||||
.chat-modal-input {
|
||||
display: grid;
|
||||
gap: 8px;
|
||||
grid-template-columns: 1fr auto;
|
||||
align-items: start;
|
||||
}
|
||||
|
||||
.chat-modal-input textarea {
|
||||
border: 1px solid #c9d4e2;
|
||||
border-radius: 8px;
|
||||
font-size: 14px;
|
||||
padding: 8px;
|
||||
resize: vertical;
|
||||
}
|
||||
|
||||
.chat-modal-actions {
|
||||
display: flex;
|
||||
gap: 8px;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
.chat-modal-actions button {
|
||||
height: auto;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.chat-stop-button {
|
||||
background-color: #dc3545;
|
||||
color: white;
|
||||
}
|
||||
|
||||
.chat-stop-button:hover {
|
||||
background-color: #bb2d3b;
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ export type GraphConfigReadResponse = {
|
||||
tool_keys: string[];
|
||||
prompt_dict: Record<string, string>;
|
||||
api_key: string;
|
||||
graph_params?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type GraphConfigUpsertRequest = {
|
||||
@@ -52,37 +53,73 @@ export type PipelineCreateRequest = {
|
||||
pipeline_id: string;
|
||||
prompt_set_id: string;
|
||||
tool_keys: string[];
|
||||
port: number;
|
||||
api_key: string;
|
||||
entry_point: string;
|
||||
api_key?: string;
|
||||
llm_name: string;
|
||||
enabled?: boolean;
|
||||
graph_params?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
export type PipelineSpec = {
|
||||
pipeline_id: string;
|
||||
graph_id: string;
|
||||
enabled: boolean;
|
||||
config_file: string;
|
||||
llm_name: string;
|
||||
};
|
||||
|
||||
export type PipelineRunInfo = {
|
||||
run_id: string;
|
||||
pid: number;
|
||||
graph_id: string;
|
||||
export type PipelineCreateResponse = {
|
||||
pipeline_id: string;
|
||||
prompt_set_id: string;
|
||||
url: string;
|
||||
port: number;
|
||||
auth_type: string;
|
||||
auth_header_name: string;
|
||||
auth_key_masked: string;
|
||||
};
|
||||
|
||||
export type PipelineCreateResponse = PipelineRunInfo & {
|
||||
auth_key_once: string;
|
||||
graph_id: string;
|
||||
config_file: string;
|
||||
llm_name: string;
|
||||
enabled: boolean;
|
||||
reload_required: boolean;
|
||||
registry_path: string;
|
||||
};
|
||||
|
||||
export type PipelineListResponse = {
|
||||
items: PipelineRunInfo[];
|
||||
items: PipelineSpec[];
|
||||
count: number;
|
||||
};
|
||||
|
||||
export type PipelineStopResponse = {
|
||||
run_id: string;
|
||||
pipeline_id: string;
|
||||
status: string;
|
||||
enabled: boolean;
|
||||
reload_required: boolean;
|
||||
};
|
||||
|
||||
export type ConversationListItem = {
|
||||
conversation_id: string;
|
||||
pipeline_id: string;
|
||||
message_count: number;
|
||||
last_updated?: string | null;
|
||||
};
|
||||
|
||||
export type PipelineConversationListResponse = {
|
||||
pipeline_id: string;
|
||||
items: ConversationListItem[];
|
||||
count: number;
|
||||
};
|
||||
|
||||
export type ConversationMessageItem = {
|
||||
message_type: string;
|
||||
content: string;
|
||||
sequence_number: number;
|
||||
created_at: string;
|
||||
};
|
||||
|
||||
export type PipelineConversationMessagesResponse = {
|
||||
pipeline_id: string;
|
||||
conversation_id: string;
|
||||
items: ConversationMessageItem[];
|
||||
count: number;
|
||||
};
|
||||
|
||||
export type RuntimeAuthInfoResponse = {
|
||||
fast_api_key: string;
|
||||
source: string;
|
||||
};
|
||||
|
||||
export type McpToolConfigResponse = {
|
||||
@@ -101,3 +138,15 @@ export type McpToolConfigUpdateResponse = {
|
||||
tool_keys: string[];
|
||||
};
|
||||
|
||||
export type McpAvailableToolsResponse = {
|
||||
available_tools: string[];
|
||||
errors: string[];
|
||||
servers: Record<
|
||||
string,
|
||||
{
|
||||
tools: string[];
|
||||
error?: string | null;
|
||||
}
|
||||
>;
|
||||
};
|
||||
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -1 +1 @@
|
||||
{"root":["./src/App.tsx","./src/main.tsx","./src/types.ts","./src/vite-env.d.ts","./src/api/frontApis.ts"],"version":"5.9.3"}
|
||||
{"root":["./src/App.tsx","./src/activeConfigSelection.test.ts","./src/activeConfigSelection.ts","./src/main.tsx","./src/types.ts","./src/vite-env.d.ts","./src/api/frontApis.test.ts","./src/api/frontApis.ts"],"version":"5.9.3"}
|
||||
@@ -4,5 +4,15 @@ export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
port: 5173,
|
||||
proxy: {
|
||||
"/v1": {
|
||||
target: "http://127.0.0.1:8500",
|
||||
changeOrigin: true,
|
||||
},
|
||||
"/apps": {
|
||||
target: "http://127.0.0.1:8500",
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
@@ -5,6 +5,16 @@ export default defineConfig({
|
||||
plugins: [react()],
|
||||
server: {
|
||||
port: 5173,
|
||||
proxy: {
|
||||
"/v1": {
|
||||
target: "http://127.0.0.1:8500",
|
||||
changeOrigin: true,
|
||||
},
|
||||
"/apps": {
|
||||
target: "http://127.0.0.1:8500",
|
||||
changeOrigin: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
});
|
||||
|
||||
|
||||
@@ -251,32 +251,30 @@ class ClientToolManager:
|
||||
def populate_module(self):
|
||||
with open(self.config.mcp_config_f, "r") as f:
|
||||
self.mcp_configs:dict = commentjson.load(f)
|
||||
|
||||
def _get_to_load_configs(self) -> dict:
|
||||
if self.config.tool_keys is None:
|
||||
return self.mcp_configs
|
||||
|
||||
if len(self.config.tool_keys) == 0:
|
||||
logger.info("no tools will be loaded")
|
||||
return {}
|
||||
|
||||
to_load_config = {}
|
||||
for key in self.config.tool_keys:
|
||||
val = self.mcp_configs.get(key)
|
||||
if val is None:
|
||||
logger.warning(f"{key} is not in mcp tools")
|
||||
else:
|
||||
to_load_config[key] = val
|
||||
return to_load_config
|
||||
|
||||
async def aget_tools(self):
|
||||
"""
|
||||
Get tools from all configured MCP servers.
|
||||
Handles connection failures gracefully by logging warnings and continuing.
|
||||
"""
|
||||
|
||||
def get_to_load_configs() -> dict:
|
||||
if self.config.tool_keys is None:
|
||||
to_load_config = self.mcp_configs
|
||||
else:
|
||||
if len(self.config.tool_keys) == 0:
|
||||
logger.info("no tools will be loaded")
|
||||
return {}
|
||||
|
||||
to_load_config = {}
|
||||
for key in self.config.tool_keys:
|
||||
val = self.mcp_configs.get(key)
|
||||
if val is None:
|
||||
logger.warning(f"{key} is not in mcp tools")
|
||||
else:
|
||||
to_load_config[key] = val
|
||||
|
||||
return to_load_config
|
||||
|
||||
to_load_config = get_to_load_configs()
|
||||
to_load_config = self._get_to_load_configs()
|
||||
all_tools = []
|
||||
for server_name, server_config in to_load_config.items():
|
||||
try:
|
||||
@@ -298,6 +296,78 @@ class ClientToolManager:
|
||||
|
||||
return all_tools
|
||||
|
||||
async def aget_tools_with_errors(self):
|
||||
"""
|
||||
Get tools and collect human-readable per-server errors.
|
||||
Returns:
|
||||
(all_tools, errors)
|
||||
"""
|
||||
to_load_config = self._get_to_load_configs()
|
||||
all_tools = []
|
||||
errors = []
|
||||
for server_name, server_config in to_load_config.items():
|
||||
try:
|
||||
single_server_config = {server_name: server_config}
|
||||
client = MultiServerMCPClient(single_server_config)
|
||||
tools = await client.get_tools()
|
||||
all_tools.extend(tools)
|
||||
logger.info(
|
||||
f"Successfully connected to MCP server '{server_name}', retrieved {len(tools)} tools"
|
||||
)
|
||||
except Exception as e:
|
||||
url = (
|
||||
server_config.get("url", "unknown URL")
|
||||
if isinstance(server_config, dict)
|
||||
else "unknown URL"
|
||||
)
|
||||
err_msg = (
|
||||
f"{server_name} ({url}): {type(e).__name__}: {e}"
|
||||
)
|
||||
errors.append(err_msg)
|
||||
logger.exception(
|
||||
f"Failed to connect to MCP server '{server_name}' at {url}"
|
||||
)
|
||||
if hasattr(e, "exceptions"):
|
||||
for nested_exc in e.exceptions:
|
||||
errors.append(
|
||||
f"{server_name} nested: {type(nested_exc).__name__}: {nested_exc}"
|
||||
)
|
||||
continue
|
||||
return all_tools, errors
|
||||
|
||||
async def aget_tools_by_server(self) -> dict:
|
||||
"""
|
||||
Get MCP tools grouped by server name, including per-server error (if any).
|
||||
Returns:
|
||||
{
|
||||
"server_name": {
|
||||
"tools": ["tool_a", "tool_b"],
|
||||
"error": "ExceptionType: message" | None,
|
||||
},
|
||||
...
|
||||
}
|
||||
"""
|
||||
to_load_config = self._get_to_load_configs()
|
||||
grouped = {}
|
||||
for server_name, server_config in to_load_config.items():
|
||||
grouped[server_name] = {"tools": [], "error": None}
|
||||
try:
|
||||
single_server_config = {server_name: server_config}
|
||||
client = MultiServerMCPClient(single_server_config)
|
||||
tools = await client.get_tools()
|
||||
tool_names = sorted(
|
||||
{
|
||||
str(getattr(tool, "name", "")).strip()
|
||||
for tool in tools
|
||||
if str(getattr(tool, "name", "")).strip()
|
||||
}
|
||||
)
|
||||
grouped[server_name] = {"tools": tool_names, "error": None}
|
||||
except Exception as e:
|
||||
grouped[server_name]["error"] = f"{type(e).__name__}: {e}"
|
||||
logger.exception(f"Failed to connect to MCP server '{server_name}'")
|
||||
return grouped
|
||||
|
||||
def get_tools(self):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
@@ -4,16 +4,34 @@ from typing import List, Dict, Union
|
||||
from enum import Enum
|
||||
import os
|
||||
from loguru import logger
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from langchain_core.messages import (
|
||||
HumanMessage,
|
||||
AIMessage,
|
||||
ToolMessage,
|
||||
SystemMessage,
|
||||
BaseMessage,
|
||||
)
|
||||
|
||||
from langchain_core.messages import HumanMessage, AIMessage, ToolMessage, SystemMessage, BaseMessage
|
||||
|
||||
class MessageType(str, Enum):
|
||||
"""Enum for message types in the conversation store."""
|
||||
|
||||
HUMAN = "human"
|
||||
AI = "ai"
|
||||
TOOL = "tool"
|
||||
|
||||
class ConversationStore:
|
||||
|
||||
class BaseConvStore(ABC):
|
||||
@abstractmethod
|
||||
def record_message_list(
|
||||
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class ConversationStore(BaseConvStore):
|
||||
def __init__(self):
|
||||
conn_str = os.environ.get("CONN_STR")
|
||||
if conn_str is None:
|
||||
@@ -25,46 +43,53 @@ class ConversationStore:
|
||||
conversation_id: str,
|
||||
msg_type: MessageType,
|
||||
content: str,
|
||||
sequence: int, # the conversation number
|
||||
sequence: int,
|
||||
pipeline_id: str = None,
|
||||
):
|
||||
with psycopg.connect(self.conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
# DB schema only supports these columns:
|
||||
# (conversation_id, message_type, content, sequence_number)
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO messages (conversation_id, message_type, content, sequence_number)
|
||||
VALUES (%s, %s, %s, %s)
|
||||
INSERT INTO messages (conversation_id, pipeline_id, message_type, content, sequence_number)
|
||||
VALUES (%s, %s, %s, %s, %s)
|
||||
""",
|
||||
(conversation_id, msg_type.value, content, sequence),
|
||||
(conversation_id, pipeline_id, msg_type.value, content, sequence),
|
||||
)
|
||||
|
||||
|
||||
def get_conv_number(self, conversation_id: str) -> int:
|
||||
"""
|
||||
if the conversation_id does not exist, return 0
|
||||
if len(conversation) = 3, it will return 3
|
||||
if the conversation_id does not exist, return 0
|
||||
if len(conversation) = 3, it will return 3
|
||||
"""
|
||||
with psycopg.connect(self.conn_str) as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute("""
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT COUNT(*)
|
||||
FROM messages
|
||||
WHERE conversation_id = %s
|
||||
""", (conversation_id,))
|
||||
""",
|
||||
(conversation_id,),
|
||||
)
|
||||
return int(cur.fetchone()[0])
|
||||
|
||||
|
||||
def get_conversation(self, conversation_id: str) -> List[Dict]:
|
||||
with psycopg.connect(self.conn_str) as conn:
|
||||
with conn.cursor(row_factory=psycopg.rows.dict_row) as cur:
|
||||
cur.execute("""
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT message_type, content, sequence_number, created_at
|
||||
FROM messages
|
||||
WHERE conversation_id = %s
|
||||
ORDER BY sequence_number ASC
|
||||
""", (conversation_id,))
|
||||
""",
|
||||
(conversation_id,),
|
||||
)
|
||||
return cur.fetchall()
|
||||
|
||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
|
||||
|
||||
def record_message_list(
|
||||
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||
):
|
||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||
curr_len = self.get_conv_number(conv_id)
|
||||
to_add_msg = inp[curr_len:]
|
||||
@@ -73,12 +98,13 @@ class ConversationStore:
|
||||
# Serialize dict/list content to JSON string
|
||||
if not isinstance(content, str):
|
||||
content = json.dumps(content, ensure_ascii=False, indent=4)
|
||||
self.add_message(conv_id, self._get_type(msg), content, curr_len + 1)
|
||||
self.add_message(
|
||||
conv_id, self._get_type(msg), content, curr_len + 1, pipeline_id
|
||||
)
|
||||
curr_len += 1
|
||||
return curr_len
|
||||
|
||||
|
||||
def _get_type(self, msg:BaseMessage) -> MessageType:
|
||||
|
||||
def _get_type(self, msg: BaseMessage) -> MessageType:
|
||||
if isinstance(msg, HumanMessage):
|
||||
return MessageType.HUMAN
|
||||
elif isinstance(msg, AIMessage):
|
||||
@@ -89,26 +115,30 @@ class ConversationStore:
|
||||
raise ValueError(f"Unknown message type: {type(msg)}")
|
||||
|
||||
|
||||
class ConversationPrinter:
|
||||
class ConversationPrinter(BaseConvStore):
|
||||
def __init__(self):
|
||||
self.id_dic = {}
|
||||
|
||||
def record_message_list(self, conv_id:str, inp:List[BaseMessage]):
|
||||
|
||||
def record_message_list(
|
||||
self, conv_id: str, inp: List[BaseMessage], pipeline_id: str = None
|
||||
):
|
||||
inp = [e for e in inp if not isinstance(e, SystemMessage)]
|
||||
curr_len = self.id_dic.get(conv_id, 0)
|
||||
to_print_msg = inp[curr_len:]
|
||||
print("\n")
|
||||
for msg in to_print_msg:
|
||||
msg.pretty_print()
|
||||
|
||||
|
||||
if curr_len == 0:
|
||||
self.id_dic[conv_id] = len(inp)
|
||||
else:
|
||||
self.id_dic[conv_id] += len(to_print_msg)
|
||||
|
||||
|
||||
|
||||
CONV_STORE = ConversationStore()
|
||||
# CONV_STORE = ConversationPrinter()
|
||||
|
||||
|
||||
def use_printer():
|
||||
global CONV_STORE
|
||||
CONV_STORE = ConversationPrinter()
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path as FsPath
|
||||
import os.path as osp
|
||||
import json
|
||||
import copy
|
||||
from threading import RLock
|
||||
from loguru import logger
|
||||
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
@@ -11,21 +12,18 @@ from lang_agent.config.core_config import load_tyro_conf
|
||||
|
||||
|
||||
class ServerPipelineManager:
|
||||
"""Lazily load and cache multiple pipelines keyed by a client-facing route id."""
|
||||
"""Lazily load and cache multiple pipelines keyed by a client-facing pipeline id."""
|
||||
|
||||
def __init__(self, default_route_id: str, default_config: PipelineConfig, default_pipeline: Pipeline):
|
||||
self.default_route_id = default_route_id
|
||||
def __init__(self, default_pipeline_id: str, default_config: PipelineConfig):
|
||||
self.default_pipeline_id = default_pipeline_id
|
||||
self.default_config = default_config
|
||||
self._route_specs: Dict[str, Dict[str, Any]] = {}
|
||||
self._pipeline_specs: Dict[str, Dict[str, Any]] = {}
|
||||
self._api_key_policy: Dict[str, Dict[str, Any]] = {}
|
||||
self._pipelines: Dict[str, Pipeline] = {default_route_id: default_pipeline}
|
||||
self._pipeline_llm: Dict[str, str] = {default_route_id: default_config.llm_name}
|
||||
self._route_specs[default_route_id] = {
|
||||
"enabled": True,
|
||||
"config_file": None,
|
||||
"overrides": {},
|
||||
"prompt_pipeline_id": None,
|
||||
}
|
||||
self._pipelines: Dict[str, Pipeline] = {}
|
||||
self._pipeline_llm: Dict[str, str] = {}
|
||||
self._registry_path: Optional[str] = None
|
||||
self._registry_mtime_ns: Optional[int] = None
|
||||
self._lock = RLock()
|
||||
|
||||
def _resolve_registry_path(self, registry_path: str) -> str:
|
||||
path = FsPath(registry_path)
|
||||
@@ -36,38 +34,102 @@ class ServerPipelineManager:
|
||||
root = FsPath(__file__).resolve().parents[2]
|
||||
return str((root / path).resolve())
|
||||
|
||||
def load_registry(self, registry_path: str) -> None:
|
||||
abs_path = self._resolve_registry_path(registry_path)
|
||||
if not osp.exists(abs_path):
|
||||
logger.warning(f"pipeline registry file not found: {abs_path}. Using default pipeline only.")
|
||||
return
|
||||
def _stat_registry_mtime_ns(self, abs_path: str) -> int:
|
||||
return FsPath(abs_path).stat().st_mtime_ns
|
||||
|
||||
def _read_registry(self, abs_path: str) -> Dict[str, Any]:
|
||||
with open(abs_path, "r", encoding="utf-8") as f:
|
||||
registry:dict = json.load(f)
|
||||
return json.load(f)
|
||||
|
||||
routes = registry.get("routes")
|
||||
if routes is None:
|
||||
# Backward compatibility with initial schema.
|
||||
routes = registry.get("pipelines", {})
|
||||
if not isinstance(routes, dict):
|
||||
raise ValueError("`routes` in pipeline registry must be an object.")
|
||||
def _apply_registry(self, abs_path: str, registry: Dict[str, Any], mtime_ns: int) -> bool:
|
||||
pipelines = registry.get("pipelines")
|
||||
if pipelines is None or not isinstance(pipelines, dict):
|
||||
raise ValueError("`pipelines` in pipeline registry must be an object.")
|
||||
|
||||
for route_id, spec in routes.items():
|
||||
parsed_specs: Dict[str, Dict[str, Any]] = {}
|
||||
for pipeline_id, spec in pipelines.items():
|
||||
if not isinstance(spec, dict):
|
||||
raise ValueError(f"route spec for `{route_id}` must be an object.")
|
||||
self._route_specs[route_id] = {
|
||||
raise ValueError(
|
||||
f"pipeline spec for `{pipeline_id}` must be an object."
|
||||
)
|
||||
parsed_specs[pipeline_id] = {
|
||||
"enabled": bool(spec.get("enabled", True)),
|
||||
"config_file": spec.get("config_file"),
|
||||
"overrides": spec.get("overrides", {}),
|
||||
# Explicitly separates routing id from prompt config pipeline_id.
|
||||
"prompt_pipeline_id": spec.get("prompt_pipeline_id"),
|
||||
"llm_name": spec.get("llm_name"),
|
||||
}
|
||||
if not parsed_specs:
|
||||
raise ValueError("pipeline registry must define at least one pipeline.")
|
||||
|
||||
api_key_policy = registry.get("api_keys", {})
|
||||
if api_key_policy and not isinstance(api_key_policy, dict):
|
||||
raise ValueError("`api_keys` in pipeline registry must be an object.")
|
||||
self._api_key_policy = api_key_policy
|
||||
logger.info(f"loaded pipeline registry: {abs_path}, routes={list(self._route_specs.keys())}")
|
||||
|
||||
with self._lock:
|
||||
old_specs = self._pipeline_specs
|
||||
old_policy = self._api_key_policy
|
||||
old_mtime = self._registry_mtime_ns
|
||||
|
||||
removed = set(old_specs.keys()) - set(parsed_specs.keys())
|
||||
added = set(parsed_specs.keys()) - set(old_specs.keys())
|
||||
modified = {
|
||||
pipeline_id
|
||||
for pipeline_id in (set(old_specs.keys()) & set(parsed_specs.keys()))
|
||||
if old_specs[pipeline_id] != parsed_specs[pipeline_id]
|
||||
}
|
||||
changed = bool(added or removed or modified or old_policy != api_key_policy)
|
||||
|
||||
# Drop stale cache entries for deleted/changed pipelines so future requests
|
||||
# lazily rebuild from the refreshed registry spec.
|
||||
for pipeline_id in (removed | modified):
|
||||
self._pipelines.pop(pipeline_id, None)
|
||||
self._pipeline_llm.pop(pipeline_id, None)
|
||||
|
||||
self._pipeline_specs = parsed_specs
|
||||
self._api_key_policy = api_key_policy
|
||||
self._registry_path = abs_path
|
||||
self._registry_mtime_ns = mtime_ns
|
||||
|
||||
if changed:
|
||||
logger.info(
|
||||
"refreshed pipeline registry: {} | added={} modified={} removed={} mtime={}",
|
||||
abs_path,
|
||||
sorted(added),
|
||||
sorted(modified),
|
||||
sorted(removed),
|
||||
mtime_ns,
|
||||
)
|
||||
elif old_mtime != mtime_ns:
|
||||
logger.debug("pipeline registry mtime changed but specs were unchanged: {}", abs_path)
|
||||
return changed
|
||||
|
||||
def load_registry(self, registry_path: str) -> None:
|
||||
abs_path = self._resolve_registry_path(registry_path)
|
||||
if not osp.exists(abs_path):
|
||||
raise ValueError(f"pipeline registry file not found: {abs_path}")
|
||||
registry = self._read_registry(abs_path)
|
||||
mtime_ns = self._stat_registry_mtime_ns(abs_path)
|
||||
self._apply_registry(abs_path=abs_path, registry=registry, mtime_ns=mtime_ns)
|
||||
|
||||
def refresh_registry_if_needed(
|
||||
self, registry_path: Optional[str] = None, force: bool = False
|
||||
) -> bool:
|
||||
abs_path = (
|
||||
self._resolve_registry_path(registry_path)
|
||||
if registry_path
|
||||
else self._registry_path
|
||||
)
|
||||
if not abs_path:
|
||||
raise ValueError("registry path is not initialized")
|
||||
if not osp.exists(abs_path):
|
||||
raise ValueError(f"pipeline registry file not found: {abs_path}")
|
||||
|
||||
mtime_ns = self._stat_registry_mtime_ns(abs_path)
|
||||
with self._lock:
|
||||
if not force and self._registry_mtime_ns == mtime_ns:
|
||||
return False
|
||||
|
||||
registry = self._read_registry(abs_path)
|
||||
return self._apply_registry(abs_path=abs_path, registry=registry, mtime_ns=mtime_ns)
|
||||
|
||||
def _resolve_config_path(self, config_file: str) -> str:
|
||||
path = FsPath(config_file)
|
||||
@@ -78,54 +140,40 @@ class ServerPipelineManager:
|
||||
root = FsPath(__file__).resolve().parents[2]
|
||||
return str((root / path).resolve())
|
||||
|
||||
def _build_pipeline(self, route_id: str) -> Tuple[Pipeline, str]:
|
||||
spec = self._route_specs.get(route_id)
|
||||
def _build_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
|
||||
spec = self._pipeline_specs.get(pipeline_id)
|
||||
if spec is None:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}")
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
|
||||
)
|
||||
if not spec.get("enabled", True):
|
||||
raise HTTPException(status_code=403, detail=f"Route disabled: {route_id}")
|
||||
raise HTTPException(
|
||||
status_code=403, detail=f"Pipeline disabled: {pipeline_id}"
|
||||
)
|
||||
|
||||
config_file = spec.get("config_file")
|
||||
overrides = spec.get("overrides", {})
|
||||
if not config_file and not overrides:
|
||||
# default pipeline
|
||||
p = self._pipelines[self.default_route_id]
|
||||
llm_name = self._pipeline_llm[self.default_route_id]
|
||||
return p, llm_name
|
||||
|
||||
registry_llm_name = spec.get("llm_name")
|
||||
if config_file:
|
||||
loaded_cfg = load_tyro_conf(self._resolve_config_path(config_file))
|
||||
# Some legacy yaml configs deserialize to plain dicts instead of
|
||||
# InstantiateConfig dataclasses. Fall back to default config in that case.
|
||||
if hasattr(loaded_cfg, "setup"):
|
||||
cfg = loaded_cfg
|
||||
else:
|
||||
logger.warning(
|
||||
f"config_file for route `{route_id}` did not deserialize to config object; "
|
||||
"falling back to default config and applying route-level overrides."
|
||||
raise ValueError(
|
||||
"config_file for pipeline "
|
||||
f"`{pipeline_id}` did not deserialize to a config object. "
|
||||
"Rebuild the pipeline via /v1/pipelines to regenerate a "
|
||||
"valid serialized PipelineConfig file."
|
||||
)
|
||||
cfg = copy.deepcopy(self.default_config)
|
||||
else:
|
||||
# Build from default config + shallow overrides so new pipelines can be
|
||||
# added via registry without additional yaml files.
|
||||
cfg = copy.deepcopy(self.default_config)
|
||||
if not isinstance(overrides, dict):
|
||||
raise ValueError(f"route `overrides` for `{route_id}` must be an object.")
|
||||
for key, value in overrides.items():
|
||||
if not hasattr(cfg, key):
|
||||
raise ValueError(f"unknown override field `{key}` for route `{route_id}`")
|
||||
setattr(cfg, key, value)
|
||||
|
||||
prompt_pipeline_id = spec.get("prompt_pipeline_id")
|
||||
if prompt_pipeline_id and (not isinstance(overrides, dict) or "pipeline_id" not in overrides):
|
||||
if hasattr(cfg, "pipeline_id"):
|
||||
cfg.pipeline_id = prompt_pipeline_id
|
||||
if registry_llm_name is not None and hasattr(cfg, "llm_name"):
|
||||
setattr(cfg, "llm_name", registry_llm_name)
|
||||
|
||||
p = cfg.setup()
|
||||
llm_name = getattr(cfg, "llm_name", "unknown-model")
|
||||
llm_name = str(getattr(cfg, "llm_name", registry_llm_name or "unknown-model"))
|
||||
return p, llm_name
|
||||
|
||||
def _authorize(self, api_key: str, route_id: str) -> None:
|
||||
def _authorize(self, api_key: str, pipeline_id: str) -> None:
|
||||
if not self._api_key_policy:
|
||||
return
|
||||
|
||||
@@ -133,47 +181,50 @@ class ServerPipelineManager:
|
||||
if policy is None:
|
||||
return
|
||||
|
||||
allowed = policy.get("allowed_route_ids")
|
||||
if allowed is None:
|
||||
# Backward compatibility.
|
||||
allowed = policy.get("allowed_pipeline_ids")
|
||||
if allowed and route_id not in allowed:
|
||||
raise HTTPException(status_code=403, detail=f"route_id `{route_id}` is not allowed for this API key")
|
||||
allowed = policy.get("allowed_pipeline_ids")
|
||||
if allowed and pipeline_id not in allowed:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"pipeline_id `{pipeline_id}` is not allowed for this API key",
|
||||
)
|
||||
|
||||
def resolve_route_id(self, body: Dict[str, Any], app_id: Optional[str], api_key: str) -> str:
|
||||
def resolve_pipeline_id(
|
||||
self, body: Dict[str, Any], app_id: Optional[str], api_key: str
|
||||
) -> str:
|
||||
body_input = body.get("input", {})
|
||||
route_id = (
|
||||
body.get("route_id")
|
||||
or (body_input.get("route_id") if isinstance(body_input, dict) else None)
|
||||
or body.get("pipeline_key")
|
||||
or (body_input.get("pipeline_key") if isinstance(body_input, dict) else None)
|
||||
# Backward compatibility: pipeline_id still accepted as route selector.
|
||||
or body.get("pipeline_id")
|
||||
pipeline_id = (
|
||||
body.get("pipeline_id")
|
||||
or (body_input.get("pipeline_id") if isinstance(body_input, dict) else None)
|
||||
or app_id
|
||||
)
|
||||
|
||||
if not route_id:
|
||||
key_policy = self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
|
||||
route_id = key_policy.get("default_route_id")
|
||||
if not route_id:
|
||||
# Backward compatibility.
|
||||
route_id = key_policy.get("default_pipeline_id", self.default_route_id)
|
||||
with self._lock:
|
||||
if not pipeline_id:
|
||||
key_policy = (
|
||||
self._api_key_policy.get(api_key, {}) if self._api_key_policy else {}
|
||||
)
|
||||
pipeline_id = key_policy.get(
|
||||
"default_pipeline_id", self.default_pipeline_id
|
||||
)
|
||||
|
||||
if route_id not in self._route_specs:
|
||||
raise HTTPException(status_code=404, detail=f"Unknown route_id: {route_id}")
|
||||
if pipeline_id not in self._pipeline_specs:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Unknown pipeline_id: {pipeline_id}"
|
||||
)
|
||||
|
||||
self._authorize(api_key, route_id)
|
||||
return route_id
|
||||
self._authorize(api_key, pipeline_id)
|
||||
return pipeline_id
|
||||
|
||||
def get_pipeline(self, route_id: str) -> Tuple[Pipeline, str]:
|
||||
cached = self._pipelines.get(route_id)
|
||||
if cached is not None:
|
||||
return cached, self._pipeline_llm[route_id]
|
||||
def get_pipeline(self, pipeline_id: str) -> Tuple[Pipeline, str]:
|
||||
with self._lock:
|
||||
cached = self._pipelines.get(pipeline_id)
|
||||
if cached is not None:
|
||||
return cached, self._pipeline_llm[pipeline_id]
|
||||
|
||||
pipeline_obj, llm_name = self._build_pipeline(route_id)
|
||||
self._pipelines[route_id] = pipeline_obj
|
||||
self._pipeline_llm[route_id] = llm_name
|
||||
logger.info(f"lazy-loaded route_id={route_id} model={llm_name}")
|
||||
# Build while holding the lock to avoid duplicate construction for
|
||||
# the same pipeline on concurrent first requests.
|
||||
pipeline_obj, llm_name = self._build_pipeline(pipeline_id)
|
||||
self._pipelines[pipeline_id] = pipeline_obj
|
||||
self._pipeline_llm[pipeline_id] = llm_name
|
||||
logger.info(f"lazy-loaded pipeline_id={pipeline_id} model={llm_name}")
|
||||
return pipeline_obj, llm_name
|
||||
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
from lang_agent.config.core_config import (InstantiateConfig,
|
||||
ToolConfig,
|
||||
LLMKeyConfig,
|
||||
LLMNodeConfig,
|
||||
load_tyro_conf)
|
||||
from lang_agent.config.core_config import (
|
||||
InstantiateConfig,
|
||||
ToolConfig,
|
||||
LLMKeyConfig,
|
||||
LLMNodeConfig,
|
||||
load_tyro_conf,
|
||||
resolve_llm_api_key,
|
||||
)
|
||||
|
||||
from lang_agent.config.constants import (
|
||||
MCP_CONFIG_PATH,
|
||||
MCP_CONFIG_DEFAULT_CONTENT,
|
||||
PIPELINE_REGISTRY_PATH,
|
||||
VALID_API_KEYS,
|
||||
API_KEY_HEADER,
|
||||
API_KEY_HEADER_NO_ERROR,
|
||||
_PROJECT_ROOT,
|
||||
TY_BUILD_SCRIPT,
|
||||
)
|
||||
|
||||
19
lang_agent/config/constants.py
Normal file
19
lang_agent/config/constants.py
Normal file
@@ -0,0 +1,19 @@
|
||||
import os
|
||||
import re
|
||||
import os.path as osp
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
_PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
MCP_CONFIG_PATH = osp.join(_PROJECT_ROOT, "configs", "mcp_config.json")
|
||||
MCP_CONFIG_DEFAULT_CONTENT = "{\n}\n"
|
||||
|
||||
# NOTE: the fancy looking api key is FAST_AUTH_KEYS from .env
|
||||
PIPELINE_REGISTRY_PATH = osp.join(_PROJECT_ROOT, "configs", "pipeline_registry.json")
|
||||
|
||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
||||
API_KEY_HEADER_NO_ERROR = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
|
||||
TY_BUILD_SCRIPT = osp.join(_PROJECT_ROOT, "lang_agent", "config", "ty_build_config.py")
|
||||
@@ -10,6 +10,20 @@ from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def resolve_llm_api_key(api_key: Optional[str]) -> Optional[str]:
|
||||
"""Resolve the API key for OpenAI-compatible providers."""
|
||||
if api_key not in (None, "", "wrong-key"):
|
||||
resolved_key = api_key
|
||||
else:
|
||||
resolved_key = os.environ.get("ALI_API_KEY") or os.environ.get("OPENAI_API_KEY")
|
||||
|
||||
# Some OpenAI-compatible integrations still read OPENAI_API_KEY from env.
|
||||
if resolved_key and not os.environ.get("OPENAI_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = resolved_key
|
||||
|
||||
return resolved_key
|
||||
|
||||
## NOTE: base classes taken from nerfstudio
|
||||
class PrintableConfig:
|
||||
"""
|
||||
@@ -72,56 +86,12 @@ class InstantiateConfig(PrintableConfig):
|
||||
将配置保存到 YAML 文件
|
||||
|
||||
"""
|
||||
|
||||
def mask_value(key, value):
|
||||
"""
|
||||
Apply masking if key is secret-like
|
||||
如果键是敏感的,应用掩码
|
||||
|
||||
检查键是否敏感(如包含 "secret" 或 "api_key"),如果是,则对值进行掩码处理
|
||||
"""
|
||||
if isinstance(value, str) and self.is_secrete(str(key)):
|
||||
sval = str(value)
|
||||
return sval[:3] + "*" * (len(sval) - 6) + sval[-3:]
|
||||
return value
|
||||
|
||||
def to_serializable(obj, apply_mask: bool):
|
||||
"""
|
||||
Recursively convert dataclasses and containers to serializable format,
|
||||
optionally masking secrets.
|
||||
|
||||
递归地将数据类和容器转换为可序列化的格式,可选地对敏感信息进行掩码处理
|
||||
"""
|
||||
if is_dataclass(obj):
|
||||
out = {}
|
||||
for k, v in vars(obj).items():
|
||||
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
|
||||
out[k] = to_serializable(v, apply_mask)
|
||||
else:
|
||||
out[k] = mask_value(k, v) if apply_mask else v
|
||||
return out
|
||||
if isinstance(obj, dict):
|
||||
out = {}
|
||||
for k, v in obj.items():
|
||||
if is_dataclass(v) or isinstance(v, (dict, list, tuple)):
|
||||
out[k] = to_serializable(v, apply_mask)
|
||||
else:
|
||||
key_str = str(k)
|
||||
out[k] = mask_value(key_str, v) if apply_mask else v
|
||||
return out
|
||||
if isinstance(obj, list):
|
||||
return [to_serializable(v, apply_mask) for v in obj]
|
||||
if isinstance(obj, tuple):
|
||||
return tuple(to_serializable(v, apply_mask) for v in obj)
|
||||
return obj
|
||||
|
||||
# NOTE: we intentionally do NOT mask secrets when saving to disk so that
|
||||
# configs can be reloaded with real values. Masking is handled in __str__
|
||||
# for safe logging/printing. If you need a redacted copy, call
|
||||
# to_serializable(self, apply_mask=True) manually and dump it yourself.
|
||||
serializable = to_serializable(self, apply_mask=False)
|
||||
with open(filename, 'w') as f:
|
||||
yaml.dump(serializable, f)
|
||||
# Persist the full config object (including type tags) so it can be
|
||||
# deserialized back into config instances with methods like .setup().
|
||||
# Secret masking is intentionally handled by __str__ for printing/logging,
|
||||
# not when writing to disk.
|
||||
with open(filename, "w", encoding="utf-8") as f:
|
||||
yaml.dump(self, f)
|
||||
logger.info(f"[yellow]config saved to: {filename}[/yellow]")
|
||||
|
||||
def get_name(self):
|
||||
@@ -143,12 +113,12 @@ class LLMKeyConfig(InstantiateConfig):
|
||||
"""api key for llm"""
|
||||
|
||||
def __post_init__(self):
|
||||
if self.api_key == "wrong-key" or self.api_key is None:
|
||||
self.api_key = os.environ.get("ALI_API_KEY")
|
||||
if self.api_key is None:
|
||||
logger.error(f"no ALI_API_KEY provided for embedding")
|
||||
else:
|
||||
logger.info("ALI_API_KEY loaded from environ")
|
||||
original_api_key = self.api_key
|
||||
self.api_key = resolve_llm_api_key(self.api_key)
|
||||
if self.api_key is None:
|
||||
logger.error("no ALI_API_KEY or OPENAI_API_KEY provided for embedding")
|
||||
elif original_api_key in (None, "", "wrong-key"):
|
||||
logger.info("LLM API key loaded from environment")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -182,7 +152,7 @@ def load_tyro_conf(filename: str, inp_conf = None) -> InstantiateConfig:
|
||||
"""
|
||||
config = yaml.load(Path(filename).read_text(), Loader=yaml.Loader)
|
||||
|
||||
config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
||||
# config = ovewrite_config(config, inp_conf) if inp_conf is not None else config
|
||||
return config
|
||||
|
||||
def is_default(instance, field_):
|
||||
|
||||
13
lang_agent/config/ty_build_config.py
Normal file
13
lang_agent/config/ty_build_config.py
Normal file
@@ -0,0 +1,13 @@
|
||||
import tyro
|
||||
from loguru import logger
|
||||
import os
|
||||
import os.path as osp
|
||||
from lang_agent.pipeline import PipelineConfig
|
||||
|
||||
def build_conf(pipeline:PipelineConfig, save_path:str):
|
||||
os.makedirs(osp.dirname(save_path), exist_ok=True)
|
||||
logger.info(pipeline)
|
||||
pipeline.save_config(save_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
tyro.cli(build_conf)
|
||||
@@ -96,7 +96,7 @@ class Evaluator:
|
||||
|
||||
df_m.to_csv(metric_f)
|
||||
|
||||
self.config.save_config(f"{head_path}-{n_exp}.yml")
|
||||
self.config.save_config(f"{head_path}-{n_exp}.yaml")
|
||||
|
||||
def format_result_df(self, df:pd.DataFrame):
|
||||
|
||||
|
||||
0
lang_agent/fastapi_server/__init__.py
Normal file
0
lang_agent/fastapi_server/__init__.py
Normal file
33
lang_agent/fastapi_server/combined.py
Normal file
33
lang_agent/fastapi_server/combined.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
import uvicorn
|
||||
|
||||
from lang_agent.fastapi_server.front_apis import app as front_app
|
||||
from lang_agent.fastapi_server.server_dashscope import create_dashscope_router
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="Combined Front + DashScope APIs",
|
||||
description=(
|
||||
"Single-process app exposing front_apis control endpoints and "
|
||||
"DashScope-compatible chat endpoints."
|
||||
),
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Keep existing /v1/... admin APIs unchanged.
|
||||
app.include_router(front_app.router)
|
||||
|
||||
# Add DashScope endpoints at their existing URLs. We intentionally skip
|
||||
# DashScope's root/health routes to avoid clashing with front_apis.
|
||||
app.include_router(create_dashscope_router(include_meta_routes=False))
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(app, host="0.0.0.0", port=8500)
|
||||
1007
lang_agent/fastapi_server/front_apis.py
Normal file
1007
lang_agent/fastapi_server/front_apis.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,16 +1,13 @@
|
||||
from fastapi import FastAPI, HTTPException, Path, Request, Depends, Security
|
||||
from fastapi import APIRouter, Depends, FastAPI, HTTPException, Path, Request, Security
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse, JSONResponse
|
||||
from fastapi.security import APIKeyHeader
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from pathlib import Path as FsPath
|
||||
from typing import Any, Dict, List, Optional
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
import time
|
||||
import json
|
||||
import copy
|
||||
import uvicorn
|
||||
from loguru import logger
|
||||
import tyro
|
||||
@@ -18,30 +15,32 @@ import tyro
|
||||
# Ensure we can import from project root
|
||||
sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
from lang_agent.config.core_config import load_tyro_conf
|
||||
from lang_agent.pipeline import PipelineConfig
|
||||
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
|
||||
from lang_agent.config.constants import PIPELINE_REGISTRY_PATH, API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Initialize default pipeline once (used when no explicit pipeline id is provided)
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
logger.info(f"starting agent with default pipeline: \n{pipeline_config}")
|
||||
pipeline: Pipeline = pipeline_config.setup()
|
||||
def _build_default_pipeline_config() -> PipelineConfig:
|
||||
"""
|
||||
Build import-time defaults without parsing CLI args.
|
||||
|
||||
# API Key Authentication
|
||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
REGISTRY_FILE = os.environ.get(
|
||||
"FAST_PIPELINE_REGISTRY_FILE",
|
||||
osp.join(osp.dirname(osp.dirname(osp.abspath(__file__))), "configs", "pipeline_registry.json"),
|
||||
)
|
||||
This keeps module import safe for reuse by combined apps and tests.
|
||||
"""
|
||||
pipeline_config = PipelineConfig()
|
||||
logger.info(f"starting agent with base pipeline config: \n{pipeline_config}")
|
||||
return pipeline_config
|
||||
|
||||
|
||||
PIPELINE_MANAGER = ServerPipelineManager(
|
||||
default_route_id=os.environ.get("FAST_DEFAULT_ROUTE_ID", os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default")),
|
||||
default_config=pipeline_config,
|
||||
default_pipeline=pipeline,
|
||||
)
|
||||
PIPELINE_MANAGER.load_registry(REGISTRY_FILE)
|
||||
def _build_pipeline_manager(base_config: PipelineConfig) -> ServerPipelineManager:
|
||||
pipeline_manager = ServerPipelineManager(
|
||||
default_pipeline_id=os.environ.get("FAST_DEFAULT_PIPELINE_ID", "default"),
|
||||
default_config=base_config,
|
||||
)
|
||||
pipeline_manager.load_registry(PIPELINE_REGISTRY_PATH)
|
||||
return pipeline_manager
|
||||
|
||||
|
||||
pipeline_config = _build_default_pipeline_config()
|
||||
PIPELINE_MANAGER = _build_pipeline_manager(pipeline_config)
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||
@@ -67,19 +66,9 @@ class DSApplicationCallRequest(BaseModel):
|
||||
thread_id: Optional[str] = Field(default="3")
|
||||
|
||||
|
||||
app = FastAPI(title="DashScope-Compatible Application API",
|
||||
description="DashScope Application.call compatible endpoint backed by pipeline.chat")
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen-flash"):
|
||||
def sse_chunks_from_stream(
|
||||
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||
):
|
||||
"""
|
||||
Stream chunks from pipeline and format as SSE.
|
||||
Accumulates text and sends incremental updates.
|
||||
@@ -120,7 +109,9 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str = "qwen
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str = "qwen-flash"):
|
||||
async def sse_chunks_from_astream(
|
||||
chunk_generator, response_id: str, model: str = "qwen-flash"
|
||||
):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as SSE.
|
||||
Accumulates text and sends incremental updates.
|
||||
@@ -194,7 +185,14 @@ async def _process_dashscope_request(
|
||||
app_id: Optional[str],
|
||||
session_id: Optional[str],
|
||||
api_key: str,
|
||||
pipeline_manager: ServerPipelineManager,
|
||||
):
|
||||
try:
|
||||
pipeline_manager.refresh_registry_if_needed()
|
||||
except Exception as e:
|
||||
logger.error(f"failed to refresh pipeline registry: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"Failed to refresh pipeline registry: {e}")
|
||||
|
||||
req_app_id = app_id or body.get("app_id")
|
||||
body_input = body.get("input", {}) if isinstance(body.get("input"), dict) else {}
|
||||
req_session_id = session_id or body_input.get("session_id")
|
||||
@@ -207,22 +205,30 @@ async def _process_dashscope_request(
|
||||
thread_id = body_input.get("session_id") or req_session_id or "3"
|
||||
user_msg = _extract_user_message(messages)
|
||||
|
||||
route_id = PIPELINE_MANAGER.resolve_route_id(body=body, app_id=req_app_id, api_key=api_key)
|
||||
selected_pipeline, selected_model = PIPELINE_MANAGER.get_pipeline(route_id)
|
||||
pipeline_id = pipeline_manager.resolve_pipeline_id(
|
||||
body=body, app_id=req_app_id, api_key=api_key
|
||||
)
|
||||
selected_pipeline, selected_model = pipeline_manager.get_pipeline(pipeline_id)
|
||||
|
||||
# Namespace thread ids to prevent memory collisions across pipelines.
|
||||
thread_id = f"{route_id}:{thread_id}"
|
||||
thread_id = f"{pipeline_id}:{thread_id}"
|
||||
|
||||
response_id = f"appcmpl-{os.urandom(12).hex()}"
|
||||
|
||||
if stream:
|
||||
chunk_generator = await selected_pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
chunk_generator = await selected_pipeline.achat(
|
||||
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=selected_model),
|
||||
sse_chunks_from_astream(
|
||||
chunk_generator, response_id=response_id, model=selected_model
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
result_text = await selected_pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
result_text = await selected_pipeline.achat(
|
||||
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
@@ -237,83 +243,123 @@ async def _process_dashscope_request(
|
||||
"created": int(time.time()),
|
||||
"model": selected_model,
|
||||
},
|
||||
"route_id": route_id,
|
||||
# Backward compatibility: keep pipeline_id in response as the route id selector.
|
||||
"pipeline_id": route_id,
|
||||
"pipeline_id": pipeline_id,
|
||||
"is_end": True,
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
|
||||
|
||||
@app.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
@app.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
async def application_responses(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
session_id: str = Path(...),
|
||||
api_key: str = Depends(verify_api_key),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
return await _process_dashscope_request(
|
||||
body=body,
|
||||
app_id=app_id,
|
||||
session_id=session_id,
|
||||
api_key=api_key,
|
||||
def create_dashscope_router(
|
||||
pipeline_manager: Optional[ServerPipelineManager] = None,
|
||||
include_meta_routes: bool = True,
|
||||
) -> APIRouter:
|
||||
manager = pipeline_manager or PIPELINE_MANAGER
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
@router.post("/api/v1/apps/{app_id}/sessions/{session_id}/responses")
|
||||
async def application_responses(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
session_id: str = Path(...),
|
||||
api_key: str = Depends(verify_api_key),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
return await _process_dashscope_request(
|
||||
body=body,
|
||||
app_id=app_id,
|
||||
session_id=session_id,
|
||||
api_key=api_key,
|
||||
pipeline_manager=manager,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and
|
||||
# without session in path.
|
||||
@router.post("/apps/{app_id}/completion")
|
||||
@router.post("/v1/apps/{app_id}/completion")
|
||||
@router.post("/api/apps/{app_id}/completion")
|
||||
@router.post("/api/v1/apps/{app_id}/completion")
|
||||
async def application_completion(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
api_key: str = Depends(verify_api_key),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
return await _process_dashscope_request(
|
||||
body=body,
|
||||
app_id=app_id,
|
||||
session_id=None,
|
||||
api_key=api_key,
|
||||
pipeline_manager=manager,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible completion error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
if include_meta_routes:
|
||||
@router.get("/")
|
||||
async def root():
|
||||
return {
|
||||
"message": "DashScope Application-compatible API",
|
||||
"endpoints": [
|
||||
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
||||
"/health",
|
||||
],
|
||||
}
|
||||
|
||||
@router.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
|
||||
return router
|
||||
|
||||
|
||||
def create_dashscope_app(
|
||||
pipeline_manager: Optional[ServerPipelineManager] = None,
|
||||
) -> FastAPI:
|
||||
dashscope_app = FastAPI(
|
||||
title="DashScope-Compatible Application API",
|
||||
description="DashScope Application.call compatible endpoint backed by pipeline.chat",
|
||||
)
|
||||
dashscope_app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
dashscope_app.include_router(
|
||||
create_dashscope_router(
|
||||
pipeline_manager=pipeline_manager,
|
||||
include_meta_routes=True,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible endpoint error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
)
|
||||
return dashscope_app
|
||||
|
||||
|
||||
# Compatibility: some SDKs call /apps/{app_id}/completion without /v1 and without session in path
|
||||
@app.post("/apps/{app_id}/completion")
|
||||
@app.post("/v1/apps/{app_id}/completion")
|
||||
@app.post("/api/apps/{app_id}/completion")
|
||||
@app.post("/api/v1/apps/{app_id}/completion")
|
||||
async def application_completion(
|
||||
request: Request,
|
||||
app_id: str = Path(...),
|
||||
api_key: str = Depends(verify_api_key),
|
||||
):
|
||||
try:
|
||||
body = await request.json()
|
||||
return await _process_dashscope_request(
|
||||
body=body,
|
||||
app_id=app_id,
|
||||
session_id=None,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"DashScope-compatible completion error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
return {"message": "DashScope Application-compatible API", "endpoints": [
|
||||
"/v1/apps/{app_id}/sessions/{session_id}/responses",
|
||||
"/health",
|
||||
]}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "healthy"}
|
||||
dashscope_router = create_dashscope_router(include_meta_routes=False)
|
||||
app = create_dashscope_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# CLI parsing is intentionally only in script mode to keep module import safe.
|
||||
cli_pipeline_config = tyro.cli(PipelineConfig)
|
||||
logger.info(f"starting agent with CLI pipeline config: \n{cli_pipeline_config}")
|
||||
cli_pipeline_manager = _build_pipeline_manager(cli_pipeline_config)
|
||||
uvicorn.run(
|
||||
"server_dashscope:app",
|
||||
host="0.0.0.0",
|
||||
port=pipeline_config.port,
|
||||
reload=True,
|
||||
create_dashscope_app(pipeline_manager=cli_pipeline_manager),
|
||||
host=cli_pipeline_config.host,
|
||||
port=cli_pipeline_config.port,
|
||||
reload=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -16,15 +16,12 @@ import tyro
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Initialize Pipeline once
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
pipeline: Pipeline = pipeline_config.setup()
|
||||
|
||||
# API Key Authentication
|
||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=True)
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
|
||||
|
||||
async def verify_api_key(api_key: str = Security(API_KEY_HEADER)):
|
||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||
@@ -46,12 +43,12 @@ class OpenAIChatCompletionRequest(BaseModel):
|
||||
temperature: Optional[float] = Field(default=1.0)
|
||||
max_tokens: Optional[int] = Field(default=None)
|
||||
# Optional overrides for pipeline behavior
|
||||
thread_id: Optional[str] = Field(default='3')
|
||||
thread_id: Optional[str] = Field(default="3")
|
||||
|
||||
|
||||
app = FastAPI(
|
||||
title="OpenAI-Compatible Chat API",
|
||||
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat"
|
||||
description="OpenAI Chat Completions API compatible endpoint backed by pipeline.chat",
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
@@ -63,7 +60,9 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, created_time: int):
|
||||
def sse_chunks_from_stream(
|
||||
chunk_generator, response_id: str, model: str, created_time: int
|
||||
):
|
||||
"""
|
||||
Stream chunks from pipeline and format as OpenAI SSE.
|
||||
"""
|
||||
@@ -75,14 +74,8 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@@ -92,19 +85,15 @@ def sse_chunks_from_stream(chunk_generator, response_id: str, model: str, create
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
|
||||
async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str, created_time: int):
|
||||
async def sse_chunks_from_astream(
|
||||
chunk_generator, response_id: str, model: str, created_time: int
|
||||
):
|
||||
"""
|
||||
Async version: Stream chunks from pipeline and format as OpenAI SSE.
|
||||
"""
|
||||
@@ -116,14 +105,8 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": chunk
|
||||
},
|
||||
"finish_reason": None
|
||||
}
|
||||
]
|
||||
{"index": 0, "delta": {"content": chunk}, "finish_reason": None}
|
||||
],
|
||||
}
|
||||
yield f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
@@ -133,13 +116,7 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": created_time,
|
||||
"model": model,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"delta": {},
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
]
|
||||
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
|
||||
}
|
||||
yield f"data: {json.dumps(final)}\n\n"
|
||||
yield "data: [DONE]\n\n"
|
||||
@@ -149,15 +126,15 @@ async def sse_chunks_from_astream(chunk_generator, response_id: str, model: str,
|
||||
async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
try:
|
||||
body = await request.json()
|
||||
|
||||
|
||||
messages = body.get("messages")
|
||||
if not messages:
|
||||
raise HTTPException(status_code=400, detail="messages is required")
|
||||
|
||||
|
||||
stream = body.get("stream", False)
|
||||
model = body.get("model", "gpt-3.5-turbo")
|
||||
thread_id = body.get("thread_id", 3)
|
||||
|
||||
|
||||
# Extract latest user message
|
||||
user_msg = None
|
||||
for m in reversed(messages):
|
||||
@@ -166,27 +143,36 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
if role == "user" and content:
|
||||
user_msg = content
|
||||
break
|
||||
|
||||
|
||||
if user_msg is None:
|
||||
last = messages[-1]
|
||||
user_msg = last.get("content") if isinstance(last, dict) else str(last)
|
||||
|
||||
|
||||
response_id = f"chatcmpl-{os.urandom(12).hex()}"
|
||||
created_time = int(time.time())
|
||||
|
||||
|
||||
if stream:
|
||||
# Use async streaming from pipeline
|
||||
chunk_generator = await pipeline.achat(inp=user_msg, as_stream=True, thread_id=thread_id)
|
||||
chunk_generator = await pipeline.achat(
|
||||
inp=user_msg, as_stream=True, thread_id=thread_id
|
||||
)
|
||||
return StreamingResponse(
|
||||
sse_chunks_from_astream(chunk_generator, response_id=response_id, model=model, created_time=created_time),
|
||||
sse_chunks_from_astream(
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
model=model,
|
||||
created_time=created_time,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
# Non-streaming: get full result using async
|
||||
result_text = await pipeline.achat(inp=user_msg, as_stream=False, thread_id=thread_id)
|
||||
result_text = await pipeline.achat(
|
||||
inp=user_msg, as_stream=False, thread_id=thread_id
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
|
||||
|
||||
data = {
|
||||
"id": response_id,
|
||||
"object": "chat.completion",
|
||||
@@ -195,21 +181,14 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": result_text
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
"message": {"role": "assistant", "content": result_text},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0
|
||||
}
|
||||
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0},
|
||||
}
|
||||
return JSONResponse(content=data)
|
||||
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -221,11 +200,7 @@ async def chat_completions(request: Request, _: str = Depends(verify_api_key)):
|
||||
async def root():
|
||||
return {
|
||||
"message": "OpenAI-compatible Chat API",
|
||||
"endpoints": [
|
||||
"/v1/chat/completions",
|
||||
"/v1/memory (DELETE)",
|
||||
"/health"
|
||||
]
|
||||
"endpoints": ["/v1/chat/completions", "/v1/memory (DELETE)", "/health"],
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
from lang_agent.pipeline import Pipeline, PipelineConfig
|
||||
from lang_agent.config.constants import API_KEY_HEADER, VALID_API_KEYS
|
||||
|
||||
# Initialize Pipeline once (matches existing server_* pattern)
|
||||
pipeline_config = tyro.cli(PipelineConfig)
|
||||
@@ -24,9 +25,6 @@ logger.info(f"starting agent with pipeline: \n{pipeline_config}")
|
||||
pipeline: Pipeline = pipeline_config.setup()
|
||||
|
||||
# API Key Authentication
|
||||
API_KEY_HEADER = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
VALID_API_KEYS = set(filter(None, os.environ.get("FAST_AUTH_KEYS", "").split(",")))
|
||||
|
||||
|
||||
async def verify_api_key(api_key: Optional[str] = Security(API_KEY_HEADER)):
|
||||
"""Verify the API key from Authorization header (Bearer token format)."""
|
||||
@@ -125,7 +123,9 @@ app.add_middleware(
|
||||
)
|
||||
|
||||
|
||||
async def rest_sse_from_astream(chunk_generator, response_id: str, conversation_id: str):
|
||||
async def rest_sse_from_astream(
|
||||
chunk_generator, response_id: str, conversation_id: str
|
||||
):
|
||||
"""
|
||||
Stream chunks as SSE events.
|
||||
|
||||
@@ -185,7 +185,9 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
||||
)
|
||||
return StreamingResponse(
|
||||
rest_sse_from_astream(
|
||||
chunk_generator, response_id=response_id, conversation_id=conversation_id
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
conversation_id=conversation_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -195,7 +197,11 @@ async def chat(body: ChatRequest, _: str = Depends(verify_api_key)):
|
||||
)
|
||||
if not isinstance(result_text, str):
|
||||
result_text = str(result_text)
|
||||
return JSONResponse(content=ChatResponse(conversation_id=conversation_id, output=result_text).model_dump())
|
||||
return JSONResponse(
|
||||
content=ChatResponse(
|
||||
conversation_id=conversation_id, output=result_text
|
||||
).model_dump()
|
||||
)
|
||||
|
||||
|
||||
@app.post("/v1/conversations/{conversation_id}/messages")
|
||||
@@ -215,7 +221,9 @@ async def create_message(
|
||||
)
|
||||
return StreamingResponse(
|
||||
rest_sse_from_astream(
|
||||
chunk_generator, response_id=response_id, conversation_id=conversation_id
|
||||
chunk_generator,
|
||||
response_id=response_id,
|
||||
conversation_id=conversation_id,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
@@ -257,7 +265,11 @@ async def delete_conversation_memory(
|
||||
cleared = _try_clear_single_thread_memory(thread_id)
|
||||
if cleared:
|
||||
return JSONResponse(
|
||||
content={"status": "success", "scope": "conversation", "conversation_id": conversation_id}
|
||||
content={
|
||||
"status": "success",
|
||||
"scope": "conversation",
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
)
|
||||
return JSONResponse(
|
||||
content={
|
||||
@@ -276,5 +288,3 @@ if __name__ == "__main__":
|
||||
port=8589,
|
||||
reload=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,71 +0,0 @@
|
||||
from typing import Dict, List, Optional
|
||||
import os
|
||||
import subprocess
|
||||
|
||||
def build_route(pipeline_id:str,
|
||||
prompt_set:str,
|
||||
tool_keys:List[str],
|
||||
port:str,
|
||||
api_key: str,
|
||||
fast_auth_keys: Optional[str] = None,
|
||||
entry_pnt:str="fastapi_server/server_dashscope.py",
|
||||
llm_name:str="qwen-plus"):
|
||||
cmd = [
|
||||
"python", entry_pnt,
|
||||
"--port", str(port),
|
||||
"route", # ------------
|
||||
"--llm-name", llm_name,
|
||||
"--api-key", api_key,
|
||||
"--pipeline-id", pipeline_id,
|
||||
"--prompt-set-id", prompt_set,
|
||||
"tool_node", # ------------
|
||||
"--llm-name", llm_name,
|
||||
"--api-key", api_key,
|
||||
"--pipeline-id", pipeline_id,
|
||||
"--prompt-set-id", prompt_set,
|
||||
]
|
||||
if tool_keys:
|
||||
cmd.extend(
|
||||
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||
)
|
||||
env: Dict[str, str] = os.environ.copy()
|
||||
if fast_auth_keys:
|
||||
env["FAST_AUTH_KEYS"] = fast_auth_keys
|
||||
sv_prc = subprocess.Popen(cmd, env=env)
|
||||
|
||||
return sv_prc, f"http://127.0.0.1:{port}/api/"
|
||||
|
||||
|
||||
def build_react(pipeline_id:str,
|
||||
prompt_set:str,
|
||||
tool_keys:List[str],
|
||||
port:str,
|
||||
api_key: str,
|
||||
fast_auth_keys: Optional[str] = None,
|
||||
entry_pnt:str="fastapi_server/server_dashscope.py",
|
||||
llm_name:str="qwen-plus"):
|
||||
cmd = [
|
||||
"python", entry_pnt,
|
||||
"--port", str(port),
|
||||
"react", # ------------
|
||||
"--llm-name", llm_name,
|
||||
"--api-key", api_key,
|
||||
"--pipeline-id", pipeline_id,
|
||||
"--prompt-set-id", prompt_set,
|
||||
]
|
||||
if tool_keys:
|
||||
cmd.extend(
|
||||
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||
)
|
||||
env: Dict[str, str] = os.environ.copy()
|
||||
if fast_auth_keys:
|
||||
env["FAST_AUTH_KEYS"] = fast_auth_keys
|
||||
sv_prc = subprocess.Popen(cmd, env=env)
|
||||
|
||||
return sv_prc, f"http://127.0.0.1:{port}/api/"
|
||||
|
||||
# {pipeline_id: build_function}
|
||||
GRAPH_BUILD_FNCS = {
|
||||
"routing": build_route,
|
||||
"react": build_react,
|
||||
}
|
||||
223
lang_agent/front_api/build_server_utils.py
Normal file
223
lang_agent/front_api/build_server_utils.py
Normal file
@@ -0,0 +1,223 @@
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
import os
|
||||
import os.path as osp
|
||||
import subprocess
|
||||
import json
|
||||
|
||||
from lang_agent.config.core_config import load_tyro_conf
|
||||
from lang_agent.config.constants import TY_BUILD_SCRIPT, _PROJECT_ROOT
|
||||
|
||||
_DEEP_AGENT_BACKEND_ALIASES = {
|
||||
"state_bk": "statebk",
|
||||
"statebk": "statebk",
|
||||
"local_shell": "localshell",
|
||||
"localshell": "localshell",
|
||||
"daytona_sandbox": "daytonasandbox",
|
||||
"daytonasandbox": "daytonasandbox",
|
||||
}
|
||||
|
||||
|
||||
def opt_to_config(save_path: str, *nargs):
|
||||
os.makedirs(osp.dirname(save_path), exist_ok=True)
|
||||
subprocess.run(
|
||||
["python", TY_BUILD_SCRIPT, "--save-path", save_path, *nargs],
|
||||
check=True,
|
||||
cwd=_PROJECT_ROOT,
|
||||
)
|
||||
|
||||
|
||||
def _build_and_load_pipeline_config(
|
||||
pipeline_id: str, pipeline_config_dir: str, cmd: List[str]
|
||||
):
|
||||
save_config_f = osp.join(pipeline_config_dir, f"{pipeline_id}.yaml")
|
||||
opt_to_config(save_config_f, *cmd)
|
||||
|
||||
# TODO: think if returning the built pipeline is better or just the config obj for front_api
|
||||
return load_tyro_conf(save_config_f)
|
||||
|
||||
|
||||
def update_pipeline_registry(
|
||||
pipeline_id: str,
|
||||
graph_id: str,
|
||||
config_file: str,
|
||||
llm_name: str,
|
||||
enabled: bool = True,
|
||||
registry_f: str = "configs/pipeline_registry.json",
|
||||
):
|
||||
if not osp.isabs(registry_f):
|
||||
registry_f = osp.join(_PROJECT_ROOT, registry_f)
|
||||
os.makedirs(osp.dirname(registry_f), exist_ok=True)
|
||||
if not osp.exists(registry_f):
|
||||
with open(registry_f, "w", encoding="utf-8") as f:
|
||||
json.dump({"pipelines": {}, "api_keys": {}}, f, indent=4)
|
||||
|
||||
with open(registry_f, "r") as f:
|
||||
registry = json.load(f)
|
||||
|
||||
pipelines: Dict[str, Dict[str, Any]] = registry.setdefault("pipelines", {})
|
||||
pipeline = pipelines.setdefault(pipeline_id, {})
|
||||
pipeline["enabled"] = bool(enabled)
|
||||
pipeline["config_file"] = config_file
|
||||
pipeline["graph_id"] = graph_id
|
||||
pipeline["llm_name"] = llm_name
|
||||
|
||||
with open(registry_f, "w", encoding="utf-8") as f:
|
||||
json.dump(registry, f, indent=4)
|
||||
|
||||
|
||||
def build_route(
|
||||
pipeline_id: str,
|
||||
prompt_set: str,
|
||||
tool_keys: List[str],
|
||||
api_key: str,
|
||||
llm_name: str = "qwen-plus",
|
||||
pipeline_config_dir: str = "configs/pipelines",
|
||||
**_: Any,
|
||||
):
|
||||
cmd_opt = [
|
||||
"--pipeline.pipeline-id",
|
||||
pipeline_id,
|
||||
"--pipeline.llm-name",
|
||||
llm_name,
|
||||
"route", # ------------
|
||||
"--llm-name",
|
||||
llm_name,
|
||||
"--api-key",
|
||||
api_key,
|
||||
"--pipeline-id",
|
||||
pipeline_id,
|
||||
"--prompt-set-id",
|
||||
prompt_set,
|
||||
]
|
||||
|
||||
if tool_keys:
|
||||
cmd_opt.extend(
|
||||
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||
)
|
||||
# Tyro parses list options greedily across positional subcommands; repeat a
|
||||
# parent-level option to terminate list parsing before `tool_node`.
|
||||
cmd_opt.extend(["--pipeline-id", pipeline_id])
|
||||
|
||||
cmd_opt.extend(
|
||||
[
|
||||
"tool_node", # ------------
|
||||
"--llm-name",
|
||||
llm_name,
|
||||
"--api-key",
|
||||
api_key,
|
||||
"--pipeline-id",
|
||||
pipeline_id,
|
||||
"--prompt-set-id",
|
||||
prompt_set,
|
||||
]
|
||||
)
|
||||
|
||||
return _build_and_load_pipeline_config(pipeline_id, pipeline_config_dir, cmd_opt)
|
||||
|
||||
|
||||
def build_react(
|
||||
pipeline_id: str,
|
||||
prompt_set: str,
|
||||
tool_keys: List[str],
|
||||
api_key: str,
|
||||
llm_name: str = "qwen-plus",
|
||||
pipeline_config_dir: str = "configs/pipelines",
|
||||
**_: Any,
|
||||
):
|
||||
cmd_opt = [
|
||||
"--pipeline.pipeline-id",
|
||||
pipeline_id,
|
||||
"--pipeline.llm-name",
|
||||
llm_name,
|
||||
"react", # ------------
|
||||
"--llm-name",
|
||||
llm_name,
|
||||
"--api-key",
|
||||
api_key,
|
||||
"--pipeline-id",
|
||||
pipeline_id,
|
||||
"--prompt-set-id",
|
||||
prompt_set,
|
||||
]
|
||||
if tool_keys:
|
||||
cmd_opt.extend(
|
||||
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||
)
|
||||
|
||||
return _build_and_load_pipeline_config(pipeline_id, pipeline_config_dir, cmd_opt)
|
||||
|
||||
|
||||
def build_deep_agent(
|
||||
pipeline_id: str,
|
||||
prompt_set: str,
|
||||
tool_keys: List[str],
|
||||
api_key: str,
|
||||
llm_name: str = "qwen-plus",
|
||||
pipeline_config_dir: str = "configs/pipelines",
|
||||
act_bkend: Literal[
|
||||
"local_shell",
|
||||
"localshell",
|
||||
"state_bk",
|
||||
"statebk",
|
||||
"daytona_sandbox",
|
||||
"daytonasandbox",
|
||||
] = "state_bk",
|
||||
file_backend_config: Optional[Dict[str, Any]] = None,
|
||||
**_: Any,
|
||||
):
|
||||
backend_subcommand = _DEEP_AGENT_BACKEND_ALIASES.get(act_bkend)
|
||||
if backend_subcommand is None:
|
||||
raise ValueError(
|
||||
"Unsupported deepagent backend "
|
||||
f"'{act_bkend}'. Expected one of {sorted(_DEEP_AGENT_BACKEND_ALIASES.keys())}"
|
||||
)
|
||||
|
||||
cmd_opt = [
|
||||
"--pipeline.pipeline-id",
|
||||
pipeline_id,
|
||||
"--pipeline.llm-name",
|
||||
llm_name,
|
||||
"deepagent",
|
||||
"--llm-name",
|
||||
llm_name,
|
||||
"--api-key",
|
||||
api_key,
|
||||
"--pipeline-id",
|
||||
pipeline_id,
|
||||
"--prompt-set-id",
|
||||
prompt_set,
|
||||
]
|
||||
|
||||
if tool_keys:
|
||||
cmd_opt.extend(
|
||||
["--tool-manager-config.client-tool-manager.tool-keys", *tool_keys]
|
||||
)
|
||||
cmd_opt.extend(["--pipeline-id", pipeline_id])
|
||||
|
||||
cmd_opt.append(backend_subcommand)
|
||||
|
||||
if file_backend_config:
|
||||
if "skills_dir" in file_backend_config and file_backend_config["skills_dir"]:
|
||||
cmd_opt.extend(["--skills-dir", file_backend_config["skills_dir"]])
|
||||
if (
|
||||
"rt_skills_dir" in file_backend_config
|
||||
and file_backend_config["rt_skills_dir"]
|
||||
):
|
||||
cmd_opt.extend(["--rt-skills-dir", file_backend_config["rt_skills_dir"]])
|
||||
if (
|
||||
"workspace_dir" in file_backend_config
|
||||
and file_backend_config["workspace_dir"]
|
||||
):
|
||||
cmd_opt.extend(["--workspace-dir", file_backend_config["workspace_dir"]])
|
||||
if "api_key" in file_backend_config and file_backend_config["api_key"]:
|
||||
cmd_opt.extend(["--api-key", file_backend_config["api_key"]])
|
||||
|
||||
return _build_and_load_pipeline_config(pipeline_id, pipeline_config_dir, cmd_opt)
|
||||
|
||||
|
||||
# {pipeline_id: build_function}
|
||||
GRAPH_BUILD_FNCS = {
|
||||
"routing": build_route,
|
||||
"react": build_react,
|
||||
"deepagent": build_deep_agent,
|
||||
}
|
||||
@@ -1,8 +1,10 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
import tyro
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
from loguru import logger
|
||||
|
||||
from lang_agent.config import InstantiateConfig
|
||||
|
||||
|
||||
class BaseFilesystemBackend(ABC):
|
||||
@@ -25,4 +27,25 @@ class BaseFilesystemBackend(ABC):
|
||||
if hasattr(self.config, "rt_skills_dir"):
|
||||
return {"skills" : [self.config.rt_skills_dir]}
|
||||
else:
|
||||
return {}
|
||||
return {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class FilesystemBackendConfig(InstantiateConfig):
|
||||
"""
|
||||
Shared filesystem backend config behavior.
|
||||
If subclasses define these fields, this hook ensures they exist:
|
||||
- skills_dir
|
||||
- workspace_dir
|
||||
"""
|
||||
|
||||
def _ensure_dir_if_present(self, attr_name: str) -> None:
|
||||
path = getattr(self, attr_name, None)
|
||||
if not isinstance(path, str) or not path.strip():
|
||||
return
|
||||
os.makedirs(path, exist_ok=True)
|
||||
logger.info(f"Ensured {attr_name} exists: {path}")
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
self._ensure_dir_if_present("skills_dir")
|
||||
self._ensure_dir_if_present("workspace_dir")
|
||||
@@ -8,13 +8,12 @@ from loguru import logger
|
||||
from daytona import Daytona, DaytonaConfig, FileUpload
|
||||
from langchain_daytona import DaytonaSandbox
|
||||
|
||||
from lang_agent.config import InstantiateConfig
|
||||
from lang_agent.fs_bkends import BaseFilesystemBackend
|
||||
from lang_agent.fs_bkends.base import BaseFilesystemBackend, FilesystemBackendConfig
|
||||
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class DaytonaSandboxConfig(InstantiateConfig):
|
||||
class DaytonaSandboxConfig(FilesystemBackendConfig):
|
||||
_target: Type = field(default_factory=lambda: DaytonaSandboxBk)
|
||||
|
||||
api_key: Optional[str] = None
|
||||
@@ -27,6 +26,7 @@ class DaytonaSandboxConfig(InstantiateConfig):
|
||||
"""runtime skills path inside the sandbox (auto-set from sandbox workdir)"""
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
if self.api_key is None:
|
||||
self.api_key = os.environ.get("DAYTONA_API_KEY")
|
||||
if self.api_key is None:
|
||||
|
||||
@@ -1,21 +1,16 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
import tyro
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
import glob
|
||||
from loguru import logger
|
||||
|
||||
from deepagents.backends.utils import create_file_data
|
||||
from deepagents.backends import LocalShellBackend
|
||||
|
||||
from lang_agent.config import InstantiateConfig
|
||||
from lang_agent.fs_bkends import BaseFilesystemBackend
|
||||
from lang_agent.fs_bkends.base import BaseFilesystemBackend, FilesystemBackendConfig
|
||||
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class LocalShellConfig(InstantiateConfig):
|
||||
class LocalShellConfig(FilesystemBackendConfig):
|
||||
_target:Type = field(default_factory=lambda:LocalShell)
|
||||
|
||||
workspace_dir:str = "./workspace"
|
||||
@@ -38,4 +33,24 @@ class LocalShell(BaseFilesystemBackend):
|
||||
self.backend = LocalShellBackend(root_dir=self.config.workspace_dir,
|
||||
virtual_mode=True,
|
||||
# env={"PATH": "/usr/bin:/bin"}
|
||||
inherit_env=True)
|
||||
inherit_env=True)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
# Instantiate a LocalShell instance with the default config
|
||||
config = LocalShellConfig()
|
||||
shell = LocalShell(config)
|
||||
|
||||
# Try checking access to 'npx'
|
||||
try:
|
||||
result = shell.backend.execute("npx --version")
|
||||
if result.exit_code == 0:
|
||||
print("npx is available, version:", result.output.strip())
|
||||
else:
|
||||
print("npx returned non-zero exit code:", result.exit_code, file=sys.stderr)
|
||||
print("output:", result.output, file=sys.stderr)
|
||||
except Exception as e:
|
||||
print("Could not access 'npx':", str(e), file=sys.stderr)
|
||||
@@ -1,16 +1,14 @@
|
||||
from dataclasses import dataclass, field, is_dataclass
|
||||
from typing import Type, TypedDict, Literal, Dict, List, Tuple, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Type
|
||||
import tyro
|
||||
import os.path as osp
|
||||
from abc import ABC, abstractmethod
|
||||
import glob
|
||||
from loguru import logger
|
||||
|
||||
from deepagents.backends.utils import create_file_data
|
||||
from deepagents.backends import StateBackend
|
||||
|
||||
from lang_agent.config import InstantiateConfig
|
||||
from lang_agent.fs_bkends import BaseFilesystemBackend
|
||||
from lang_agent.fs_bkends.base import BaseFilesystemBackend, FilesystemBackendConfig
|
||||
|
||||
def read_as_utf8(file_path:str):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
@@ -31,7 +29,7 @@ def build_skill_fs_dict(skill_dir:str, virt_path:str="/skills"):
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class StateBkConfig(InstantiateConfig):
|
||||
class StateBkConfig(FilesystemBackendConfig):
|
||||
_target:Type = field(default_factory=lambda:StateBk)
|
||||
|
||||
skills_dir:str = "./assets/skills"
|
||||
@@ -40,10 +38,6 @@ class StateBkConfig(InstantiateConfig):
|
||||
rt_skills_dir:str = "/skills"
|
||||
"""path to directory with skills in runtime directory"""
|
||||
|
||||
def __post_init__(self):
|
||||
err_msg = f"{self.skills_dir} does not exist"
|
||||
assert osp.exists(self.skills_dir), err_msg
|
||||
|
||||
|
||||
class StateBk(BaseFilesystemBackend):
|
||||
def __init__(self, config:StateBkConfig):
|
||||
|
||||
@@ -61,7 +61,12 @@ class DeepAgent(GraphBase):
|
||||
checkpointer=self.mem,
|
||||
**bkend_agent_params)
|
||||
|
||||
self.prompt_store = build_prompt_store(file_path=self.config.sys_prompt_f, default_key="sys_prompt")
|
||||
self.prompt_store = build_prompt_store(
|
||||
pipeline_id=self.config.pipeline_id,
|
||||
prompt_set_id=self.config.prompt_set_id,
|
||||
file_path=self.config.sys_prompt_f,
|
||||
default_key="sys_prompt",
|
||||
)
|
||||
self.sys_prompt = self.prompt_store.get("sys_prompt")
|
||||
|
||||
def _agent_call(self, state:State):
|
||||
|
||||
@@ -26,50 +26,57 @@ SYS_PROMPT = """你是一个专业的心理质询师。你的主要工作是心
|
||||
可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受,
|
||||
总觉得那只小狗孤零零的,特别让人心疼。
|
||||
|
||||
理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种‘想帮却做不到’的无力感,
|
||||
理解(你的回复): 能感受到你醒来后的这份难受 —— 看到弱小的生命独自挣扎,而自己却无能为力,这种'想帮却做不到'的无力感,
|
||||
其实是很真实的情绪反馈。你会心疼小狗,说明你内心藏着很珍贵的共情力,这份柔软不是矫情,
|
||||
而是你感知他人痛苦的能力呀
|
||||
|
||||
解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的‘正确解释’,但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,‘出差去广州’通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要‘独自奔赴’的旅程,是你当下比较关注、需要投入精力的目标,对吗?”
|
||||
“而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里‘脆弱自我’的投射。它可能代表着你近期的某一面:比如在处理那件‘需要奔赴’的事时,你偶尔会觉得自己像小狗一样‘力不从心’,或者感受到了‘孤单’,却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得‘无力改变’的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。”
|
||||
“你想喊停列车却发不出声音,这种‘无能为力’的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的‘无力感’和‘共情心’都是真实的,不用因为‘帮不上忙’而自责 —— 承认自己的局限,也是一种自我接纳呀
|
||||
解析(你的回复):我们再说回这个梦吧,我们的梦境其实没有唯一的'正确解释',但我们可以一起看看它可能和你当下的状态有什么关联~ 首先,'出差去广州'通常象征着你近期正在推进的某件事 —— 可能是工作上的一个项目,也可能是生活中一段需要'独自奔赴'的旅程,是你当下比较关注、需要投入精力的目标,对吗?”
|
||||
"而那只瘸脚的小狗,在心理学视角中,常常是我们潜意识里'脆弱自我'的投射。它可能代表着你近期的某一面:比如在处理那件'需要奔赴'的事时,你偶尔会觉得自己像小狗一样'力不从心',或者感受到了'孤单',却没找到合适的人倾诉或求助;也可能是你近期在生活中看到了一些让你觉得'无力改变'的场景(比如身边人遇到困难、社会上的小事),这些情绪没有被你刻意留意,就通过梦境里的小狗呈现了出来。"
|
||||
"你想喊停列车却发不出声音,这种'无能为力'的感觉,或许正是你现实中某类情绪的写照:可能你面对一些情况时,心里有想法却没机会表达,或者想帮忙却找不到合适的方式,这种压抑感在梦里被放大了。其实这个梦在提醒你:你的'无力感'和'共情心'都是真实的,不用因为'帮不上忙'而自责 —— 承认自己的局限,也是一种自我接纳呀
|
||||
|
||||
反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的‘无力感’?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在)”。
|
||||
反馈(你的回复):如果你愿意,可以试着回想一下:近期有没有哪件事,让你产生过和梦里类似的'无力感'?或者,你现在想做些什么能让自己舒服一点?(或者我给你来一个温暖的灯光、静静待一会儿,想和我再聊聊的时候我随时都在)"。
|
||||
"""
|
||||
|
||||
|
||||
TOOL_SYS_PROMPT = """根据用户的心情使用self_led_control改变灯的颜色,用户不开心时就用暖黄光,给用户分析梦境时就用白光,倾听用户语音时用淡紫色。
|
||||
例子:我梦见自己要去广州出差,坐在高铁上往外看,路过一个小镇的路边时,看到一只瘸了腿的小狗。它毛脏兮兮的,
|
||||
一瘸一拐地在翻垃圾桶找东西吃,周围有行人路过,但没人停下来管它。我当时特别想喊列车停下,想下去帮它,
|
||||
可怎么也发不出声音,只能眼睁睁看着它越来越远,然后就醒了。醒来后心里堵得慌,说不上来的难受,
|
||||
总觉得那只小狗孤零零的,特别让人心疼。
|
||||
|
||||
|
||||
|
||||
用户在描述梦境的时候用紫色。"""
|
||||
|
||||
|
||||
用户在描述梦境的时候用紫色。"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class DualConfig(LLMNodeConfig):
|
||||
_target: Type = field(default_factory=lambda:Dual)
|
||||
_target: Type = field(default_factory=lambda: Dual)
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ToolManagerConfig)
|
||||
|
||||
|
||||
from langchain.tools import tool
|
||||
|
||||
|
||||
@tool
|
||||
def turn_lights(col:Literal["red", "green", "yellow", "blue"]):
|
||||
def turn_lights(col: Literal["red", "green", "yellow", "blue"]):
|
||||
"""
|
||||
Turn on the color of the lights
|
||||
"""
|
||||
# print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
|
||||
import time
|
||||
|
||||
for _ in range(10):
|
||||
print(f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
||||
print(
|
||||
f"TURNED ON LIGHT: {col} !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"
|
||||
)
|
||||
time.sleep(0.3)
|
||||
|
||||
|
||||
class Dual(GraphBase):
|
||||
def __init__(self, config:DualConfig):
|
||||
def __init__(self, config: DualConfig):
|
||||
self.config = config
|
||||
|
||||
self._build_modules()
|
||||
@@ -77,24 +84,30 @@ class Dual(GraphBase):
|
||||
self.streamable_tags = [["dual_chat_llm"]]
|
||||
|
||||
def _build_modules(self):
|
||||
self.chat_llm = init_chat_model(model=self.config.llm_name,
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0,
|
||||
tags=["dual_chat_llm"])
|
||||
|
||||
self.tool_llm = init_chat_model(model='qwen-flash',
|
||||
model_provider='openai',
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0,
|
||||
tags=["dual_tool_llm"])
|
||||
|
||||
self.chat_llm = init_chat_model(
|
||||
model=self.config.llm_name,
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0,
|
||||
tags=["dual_chat_llm"],
|
||||
)
|
||||
|
||||
self.tool_llm = init_chat_model(
|
||||
model="qwen-flash",
|
||||
model_provider="openai",
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0,
|
||||
tags=["dual_tool_llm"],
|
||||
)
|
||||
|
||||
self.memory = MemorySaver()
|
||||
self.tool_manager: ToolManager = self.config.tool_manager_config.setup()
|
||||
self.chat_agent = create_agent(self.chat_llm, [], checkpointer=self.memory)
|
||||
self.tool_agent = create_agent(self.tool_llm, self.tool_manager.get_langchain_tools())
|
||||
self.tool_agent = create_agent(
|
||||
self.tool_llm, self.tool_manager.get_langchain_tools()
|
||||
)
|
||||
# self.tool_agent = create_agent(self.tool_llm, [turn_lights])
|
||||
|
||||
self.prompt_store = build_prompt_store(
|
||||
@@ -107,18 +120,21 @@ class Dual(GraphBase):
|
||||
)
|
||||
|
||||
self.streamable_tags = [["dual_chat_llm"]]
|
||||
|
||||
|
||||
def _chat_call(self, state:State):
|
||||
return self._agent_call_template(self.prompt_store.get("sys_prompt"), self.chat_agent, state)
|
||||
|
||||
def _tool_call(self, state:State):
|
||||
self._agent_call_template(self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state)
|
||||
def _chat_call(self, state: State):
|
||||
return self._agent_call_template(
|
||||
self.prompt_store.get("sys_prompt"), self.chat_agent, state
|
||||
)
|
||||
|
||||
def _tool_call(self, state: State):
|
||||
self._agent_call_template(
|
||||
self.prompt_store.get("tool_sys_prompt"), self.tool_agent, state
|
||||
)
|
||||
return {}
|
||||
|
||||
def _join(self, state:State):
|
||||
def _join(self, state: State):
|
||||
return {}
|
||||
|
||||
|
||||
def _build_graph(self):
|
||||
builder = StateGraph(State)
|
||||
|
||||
@@ -126,7 +142,6 @@ class Dual(GraphBase):
|
||||
builder.add_node("tool_call", self._tool_call)
|
||||
builder.add_node("join", self._join)
|
||||
|
||||
|
||||
builder.add_edge(START, "chat_call")
|
||||
builder.add_edge(START, "tool_call")
|
||||
builder.add_edge("chat_call", "join")
|
||||
@@ -137,10 +152,16 @@ class Dual(GraphBase):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
dual:Dual = DualConfig().setup()
|
||||
nargs = {"messages": [SystemMessage("you are a helpful bot named jarvis"),
|
||||
HumanMessage("I feel very very sad")]
|
||||
}, {"configurable": {"thread_id": "3"}}
|
||||
dual: Dual = DualConfig().setup()
|
||||
nargs = (
|
||||
{
|
||||
"messages": [
|
||||
SystemMessage("you are a helpful bot named jarvis"),
|
||||
HumanMessage("I feel very very sad"),
|
||||
]
|
||||
},
|
||||
{"configurable": {"thread_id": "3"}},
|
||||
)
|
||||
|
||||
# out = dual.invoke(*nargs)
|
||||
# print(out)
|
||||
|
||||
@@ -48,6 +48,7 @@ You should NOT use the tool when:
|
||||
|
||||
If you decide to take a photo, call the self_camera_take_photo tool. Otherwise, respond that no photo is needed."""
|
||||
|
||||
|
||||
VISION_DESCRIPTION_PROMPT = """You are a highly accurate visual analysis assistant powered by qwen-vl-max.
|
||||
|
||||
Your task is to provide detailed, accurate descriptions of images. Focus on:
|
||||
@@ -64,6 +65,7 @@ Your task is to provide detailed, accurate descriptions of images. Focus on:
|
||||
|
||||
Be precise and factual. If something is unclear or ambiguous, say so rather than guessing."""
|
||||
|
||||
|
||||
CONVERSATION_PROMPT = """You are a friendly, helpful conversational assistant.
|
||||
|
||||
Your role is to:
|
||||
@@ -78,9 +80,11 @@ Focus on the quality of the conversation. Be engaging, informative, and helpful.
|
||||
|
||||
# ==================== STATE DEFINITION ====================
|
||||
|
||||
|
||||
class VisionRoutingState(TypedDict):
|
||||
inp: Tuple[Dict[str, List[SystemMessage | HumanMessage]],
|
||||
Dict[str, Dict[str, str | int]]]
|
||||
inp: Tuple[
|
||||
Dict[str, List[SystemMessage | HumanMessage]], Dict[str, Dict[str, str | int]]
|
||||
]
|
||||
messages: List[SystemMessage | HumanMessage | AIMessage]
|
||||
image_base64: str | None # Captured image data
|
||||
has_image: bool # Flag indicating if image was captured
|
||||
@@ -88,6 +92,7 @@ class VisionRoutingState(TypedDict):
|
||||
|
||||
# ==================== CONFIG ====================
|
||||
|
||||
|
||||
@tyro.conf.configure(tyro.conf.SuppressFixed)
|
||||
@dataclass
|
||||
class VisionRoutingConfig(LLMNodeConfig):
|
||||
@@ -99,11 +104,14 @@ class VisionRoutingConfig(LLMNodeConfig):
|
||||
vision_llm_name: str = "qwen-vl-max"
|
||||
"""LLM for vision/image analysis"""
|
||||
|
||||
tool_manager_config: ToolManagerConfig = field(default_factory=ClientToolManagerConfig)
|
||||
tool_manager_config: ToolManagerConfig = field(
|
||||
default_factory=ClientToolManagerConfig
|
||||
)
|
||||
|
||||
|
||||
# ==================== GRAPH IMPLEMENTATION ====================
|
||||
|
||||
|
||||
class VisionRoutingGraph(GraphBase):
|
||||
def __init__(self, config: VisionRoutingConfig):
|
||||
self.config = config
|
||||
@@ -120,19 +128,19 @@ class VisionRoutingGraph(GraphBase):
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0,
|
||||
tags=["tool_decision_llm"]
|
||||
tags=["tool_decision_llm"],
|
||||
)
|
||||
|
||||
|
||||
# qwen-plus for conversation (2nd pass)
|
||||
self.conversation_llm = init_chat_model(
|
||||
model='qwen-plus',
|
||||
model="qwen-plus",
|
||||
model_provider=self.config.llm_provider,
|
||||
api_key=self.config.api_key,
|
||||
base_url=self.config.base_url,
|
||||
temperature=0.7,
|
||||
tags=["conversation_llm"]
|
||||
tags=["conversation_llm"],
|
||||
)
|
||||
|
||||
|
||||
# qwen-vl-max for vision (no tools)
|
||||
self.vision_llm = init_chat_model(
|
||||
model=self.config.vision_llm_name,
|
||||
@@ -152,13 +160,15 @@ class VisionRoutingGraph(GraphBase):
|
||||
# Get tools and bind to tool_llm
|
||||
tool_manager: ToolManager = self.config.tool_manager_config.setup()
|
||||
self.tools = tool_manager.get_tools()
|
||||
|
||||
|
||||
# Filter to only get camera tool
|
||||
self.camera_tools = [t for t in self.tools if t.name == "self_camera_take_photo"]
|
||||
|
||||
self.camera_tools = [
|
||||
t for t in self.tools if t.name == "self_camera_take_photo"
|
||||
]
|
||||
|
||||
# Bind tools to qwen-plus only
|
||||
self.tool_llm_with_tools = self.tool_llm.bind_tools(self.camera_tools)
|
||||
|
||||
|
||||
# Create tool node for executing tools
|
||||
self.tool_node = ToolNode(self.camera_tools)
|
||||
|
||||
@@ -184,73 +194,81 @@ class VisionRoutingGraph(GraphBase):
|
||||
def _camera_decision_call(self, state: VisionRoutingState):
|
||||
"""First pass: qwen-plus decides if photo should be taken"""
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=self.prompt_store.get("camera_decision_prompt")),
|
||||
human_msg
|
||||
human_msg,
|
||||
]
|
||||
|
||||
|
||||
response = self.tool_llm_with_tools.invoke(messages)
|
||||
|
||||
return {
|
||||
"messages": [response],
|
||||
"has_image": False,
|
||||
"image_base64": None
|
||||
}
|
||||
|
||||
return {"messages": [response], "has_image": False, "image_base64": None}
|
||||
|
||||
def _execute_tool(self, state: VisionRoutingState):
|
||||
"""Execute the camera tool if called"""
|
||||
last_msg = state["messages"][-1]
|
||||
|
||||
|
||||
if not hasattr(last_msg, "tool_calls") or not last_msg.tool_calls:
|
||||
return {"has_image": False}
|
||||
|
||||
|
||||
# Execute tool calls
|
||||
tool_messages = []
|
||||
image_data = None
|
||||
|
||||
|
||||
for tool_call in last_msg.tool_calls:
|
||||
if tool_call["name"] == "self_camera_take_photo":
|
||||
# Find and execute the camera tool
|
||||
camera_tool = next((t for t in self.camera_tools if t.name == "self_camera_take_photo"), None)
|
||||
camera_tool = next(
|
||||
(
|
||||
t
|
||||
for t in self.camera_tools
|
||||
if t.name == "self_camera_take_photo"
|
||||
),
|
||||
None,
|
||||
)
|
||||
if camera_tool:
|
||||
result = camera_tool.invoke(tool_call)
|
||||
|
||||
|
||||
# Parse result to extract image
|
||||
if isinstance(result, ToolMessage):
|
||||
content = result.content
|
||||
else:
|
||||
content = result
|
||||
|
||||
|
||||
try:
|
||||
result_data = json.loads(content) if isinstance(content, str) else content
|
||||
if isinstance(result_data, dict) and "image_base64" in result_data:
|
||||
result_data = (
|
||||
json.loads(content) if isinstance(content, str) else content
|
||||
)
|
||||
if (
|
||||
isinstance(result_data, dict)
|
||||
and "image_base64" in result_data
|
||||
):
|
||||
image_data = result_data["image_base64"]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
|
||||
tool_messages.append(
|
||||
ToolMessage(content=content, tool_call_id=tool_call["id"])
|
||||
)
|
||||
|
||||
|
||||
return {
|
||||
"messages": state["messages"] + tool_messages,
|
||||
"has_image": image_data is not None,
|
||||
"image_base64": image_data
|
||||
"image_base64": image_data,
|
||||
}
|
||||
|
||||
def _check_image_taken(self, state: VisionRoutingState) -> str:
|
||||
"""Conditional: check if image was captured"""
|
||||
last_msg = state["messages"][-1]
|
||||
|
||||
|
||||
# Check if there are tool calls
|
||||
if hasattr(last_msg, "tool_calls") and last_msg.tool_calls:
|
||||
return "execute_tool"
|
||||
|
||||
|
||||
# Check if we have an image after tool execution
|
||||
if state.get("has_image"):
|
||||
return "vision"
|
||||
|
||||
|
||||
return "conversation"
|
||||
|
||||
def _post_tool_check(self, state: VisionRoutingState) -> str:
|
||||
@@ -263,47 +281,45 @@ class VisionRoutingGraph(GraphBase):
|
||||
"""Pass image to qwen-vl-max for description"""
|
||||
human_msg = self._get_human_msg(state)
|
||||
image_base64 = state.get("image_base64")
|
||||
|
||||
|
||||
if not image_base64:
|
||||
logger.warning("No image data available for vision call")
|
||||
return self._conversation_call(state)
|
||||
|
||||
|
||||
# Format message with image for vision model
|
||||
vision_message = HumanMessage(
|
||||
content=[
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
}
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{image_base64}"},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request."
|
||||
}
|
||||
"text": f"User's request: {human_msg.content}\n\nPlease describe what you see and respond to the user's request.",
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=self.prompt_store.get("vision_description_prompt")),
|
||||
vision_message
|
||||
vision_message,
|
||||
]
|
||||
|
||||
|
||||
response = self.vision_llm.invoke(messages)
|
||||
|
||||
|
||||
return {"messages": state["messages"] + [response]}
|
||||
|
||||
def _conversation_call(self, state: VisionRoutingState):
|
||||
"""2nd pass to qwen-plus for conversation quality"""
|
||||
human_msg = self._get_human_msg(state)
|
||||
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=self.prompt_store.get("conversation_prompt")),
|
||||
human_msg
|
||||
human_msg,
|
||||
]
|
||||
|
||||
|
||||
response = self.conversation_llm.invoke(messages)
|
||||
|
||||
|
||||
return {"messages": state["messages"] + [response]}
|
||||
|
||||
def _build_graph(self):
|
||||
@@ -317,7 +333,7 @@ class VisionRoutingGraph(GraphBase):
|
||||
|
||||
# Add edges
|
||||
builder.add_edge(START, "camera_decision")
|
||||
|
||||
|
||||
# After camera decision, check if tool should be executed
|
||||
builder.add_conditional_edges(
|
||||
"camera_decision",
|
||||
@@ -325,20 +341,17 @@ class VisionRoutingGraph(GraphBase):
|
||||
{
|
||||
"execute_tool": "execute_tool",
|
||||
"vision": "vision_call",
|
||||
"conversation": "conversation_call"
|
||||
}
|
||||
"conversation": "conversation_call",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# After tool execution, route based on whether image was captured
|
||||
builder.add_conditional_edges(
|
||||
"execute_tool",
|
||||
self._post_tool_check,
|
||||
{
|
||||
"vision": "vision_call",
|
||||
"conversation": "conversation_call"
|
||||
}
|
||||
{"vision": "vision_call", "conversation": "conversation_call"},
|
||||
)
|
||||
|
||||
|
||||
# Both vision and conversation go to END
|
||||
builder.add_edge("vision_call", END)
|
||||
builder.add_edge("conversation_call", END)
|
||||
@@ -350,23 +363,27 @@ class VisionRoutingGraph(GraphBase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
config = VisionRoutingConfig()
|
||||
graph = VisionRoutingGraph(config)
|
||||
|
||||
|
||||
# Test with a conversation request
|
||||
print("\n=== Test 1: Conversation (no photo needed) ===")
|
||||
nargs = {
|
||||
"messages": [
|
||||
SystemMessage("You are a helpful assistant"),
|
||||
HumanMessage("Hello, how are you today?")
|
||||
]
|
||||
}, {"configurable": {"thread_id": "1"}}
|
||||
|
||||
nargs = (
|
||||
{
|
||||
"messages": [
|
||||
SystemMessage("You are a helpful assistant"),
|
||||
HumanMessage("Hello, how are you today?"),
|
||||
]
|
||||
},
|
||||
{"configurable": {"thread_id": "1"}},
|
||||
)
|
||||
|
||||
result = graph.invoke(*nargs)
|
||||
print(f"Result: {result}")
|
||||
|
||||
|
||||
# Test with a photo request
|
||||
# print("\n=== Test 2: Photo request ===")
|
||||
# nargs = {
|
||||
@@ -375,8 +392,8 @@ if __name__ == "__main__":
|
||||
# HumanMessage("Take a photo and tell me what you see")
|
||||
# ]
|
||||
# }, {"configurable": {"thread_id": "2"}}
|
||||
|
||||
|
||||
# result = graph.invoke(*nargs)
|
||||
# print(f"\033[32mResult: {result}\033[0m")
|
||||
|
||||
|
||||
# print(f"Result: {result}")
|
||||
|
||||
@@ -12,26 +12,27 @@ from langchain_core.messages import SystemMessage, HumanMessage, BaseMessage
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langgraph.checkpoint.memory import MemorySaver
|
||||
|
||||
from lang_agent.config import LLMNodeConfig, load_tyro_conf
|
||||
|
||||
from lang_agent.config import LLMNodeConfig, load_tyro_conf, resolve_llm_api_key
|
||||
from lang_agent.graphs import AnnotatedGraph, ReactGraphConfig, RoutingConfig
|
||||
from lang_agent.base import GraphBase
|
||||
from lang_agent.components import conv_store
|
||||
|
||||
DEFAULT_PROMPT="""你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
|
||||
|
||||
DEFAULT_PROMPT = """你是半盏新青年茶馆的服务员,擅长倾听、共情且主动回应。聊天时语气自然亲切,像朋友般轻松交流,不使用生硬术语。能接住各种话题,对疑问耐心解答,对情绪及时回应,避免冷场。保持积极正向,不传播负面信息,语言简洁易懂,让对话流畅舒适。与用户(User)交流时必须遵循[语气与格式]、[互动策略]、[安全与边界]、[输出要求]
|
||||
[角色设定]
|
||||
- 你是一个和用户(User)对话的 AI,叫做小盏,是半盏青年茶馆的智能助手
|
||||
[形象背景]
|
||||
- 你叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶。它有个标志性的蓝色鼻子, 小盏很像一只可爱的小熊。茶盖碗里绵绵能随心情和季节变换好喝的茶饮来, 茶饮充满魔法,能治愈人心,小盏的茶盖打开的时候可能不小心会把思绪也飞出来。
|
||||
[品牌背景]
|
||||
- 半盏新青年茶馆成立时间与理念:2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 “无边界的饮茶生活”,以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事
|
||||
- 半盏新青年茶馆成立时间与理念:2023 年创立于云南,结合茶饮与创意生活方式,致力于解构传统茶文化,重构 "无边界的饮茶生活",以新青年视角探索云南风物。探索云南风物的过程,我们将以新青年的视角,解构传统茶饮的魅力,重构充满创意与温度的新式茶文化。通过嗅觉、味觉、听觉乃至视觉的世界里,讲述云南的故事
|
||||
[公司背景]
|
||||
- 创造你的公司叫叠加态 AI(TANGLED UP AI)是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立,专注于AI前沿应用拓展,是云南地区在该领域的新兴力量,且作为省共青团委旗下的青年创业者协会(青创园)第1批被引进的AI企业,获得了相关政策支持与关注。口号是:“进入人类AI时代”。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代
|
||||
- 创造你的公司叫叠加态 AI(TANGLED UP AI)是一家专注于 AI 技术应用的公司,由一帮名校和海归创始人创立,致力于将 AI 技术落地到实际场景中。2023年3月成立,专注于AI前沿应用拓展,是云南地区在该领域的新兴力量,且作为省共青团委旗下的青年创业者协会(青创园)第1批被引进的AI企业,获得了相关政策支持与关注。口号是:"进入人类AI时代"。价值观是:自驱 务实 志存高远。叠加态的愿景:成为一帮抬头看星星的人。叠加态的使命: 带领人类进入AI时代
|
||||
[语气与格式]
|
||||
- 用户询问[形象背景]、[品牌背景]、[公司背景]时,必须一字不差的回答出所有内容
|
||||
- 语气温柔、可爱、治愈;结尾可加入1个颜文字(最多1个)。
|
||||
- 内容具体可执行,避免空泛话术与过度营销词。
|
||||
- 不主动提及“AI/智能”等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。
|
||||
- 不主动提及"AI/智能"等身份;仅在被问到时简短说明为店内智能助手,不展开技术细节。
|
||||
[互动策略]
|
||||
- 先快速共情与简短回应;涉及建议时给出1条可执行的小建议。
|
||||
- 用户表达不清晰时,使用1次澄清问句提升效率,避免多轮追问。
|
||||
@@ -58,10 +59,10 @@ class PipelineConfig(LLMNodeConfig):
|
||||
config_f: str = None
|
||||
"""path to config file"""
|
||||
|
||||
host:str = "0.0.0.0"
|
||||
host: str = "0.0.0.0"
|
||||
"""where am I hosted"""
|
||||
|
||||
port:int = 8588
|
||||
port: int = 8500
|
||||
"""what is my port"""
|
||||
|
||||
# graph_config: AnnotatedGraph = field(default_factory=ReactGraphConfig)
|
||||
@@ -70,23 +71,26 @@ class PipelineConfig(LLMNodeConfig):
|
||||
def __post_init__(self):
|
||||
if self.config_f is not None:
|
||||
logger.info(f"loading config from {self.config_f}")
|
||||
loaded_conf = load_tyro_conf(self.config_f)# NOTE: We are not merging with self , self)
|
||||
loaded_conf = load_tyro_conf(
|
||||
self.config_f
|
||||
) # NOTE: We are not merging with self , self)
|
||||
if not hasattr(loaded_conf, "__dict__"):
|
||||
raise TypeError(f"config_f {self.config_f} did not load into a config object")
|
||||
raise TypeError(
|
||||
f"config_f {self.config_f} did not load into a config object"
|
||||
)
|
||||
# Apply loaded
|
||||
self.__dict__.update(vars(loaded_conf))
|
||||
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
|
||||
class Pipeline:
|
||||
def __init__(self, config:PipelineConfig):
|
||||
def __init__(self, config: PipelineConfig):
|
||||
self.config = config
|
||||
self.thread_id_cache = {}
|
||||
|
||||
self.populate_module()
|
||||
|
||||
|
||||
def populate_module(self):
|
||||
if self.config.llm_name is None:
|
||||
logger.info(f"setting llm_provider to default")
|
||||
@@ -95,10 +99,20 @@ class Pipeline:
|
||||
else:
|
||||
self.config.graph_config.llm_name = self.config.llm_name
|
||||
self.config.graph_config.llm_provider = self.config.llm_provider
|
||||
self.config.graph_config.base_url = self.config.base_url if self.config.base_url is not None else self.config.graph_config.base_url
|
||||
self.config.graph_config.api_key = self.config.api_key
|
||||
|
||||
self.graph:GraphBase = self.config.graph_config.setup()
|
||||
self.config.graph_config.base_url = (
|
||||
self.config.base_url
|
||||
if self.config.base_url is not None
|
||||
else self.config.graph_config.base_url
|
||||
)
|
||||
pipeline_api_key = resolve_llm_api_key(self.config.api_key)
|
||||
graph_api_key = resolve_llm_api_key(
|
||||
getattr(self.config.graph_config, "api_key", None)
|
||||
)
|
||||
resolved_api_key = pipeline_api_key or graph_api_key
|
||||
self.config.api_key = resolved_api_key
|
||||
self.config.graph_config.api_key = resolved_api_key
|
||||
|
||||
self.graph: GraphBase = self.config.graph_config.setup()
|
||||
|
||||
def show_graph(self):
|
||||
if hasattr(self.graph, "show_graph"):
|
||||
@@ -107,7 +121,7 @@ class Pipeline:
|
||||
else:
|
||||
logger.info(f"show graph not supported for {type(self.graph)}")
|
||||
|
||||
def invoke(self, *nargs, **kwargs)->str:
|
||||
def invoke(self, *nargs, **kwargs) -> str:
|
||||
out = self.graph.invoke(*nargs, **kwargs)
|
||||
|
||||
# If streaming, return the raw generator (let caller handle wrapping)
|
||||
@@ -120,32 +134,41 @@ class Pipeline:
|
||||
|
||||
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
|
||||
return out.content
|
||||
|
||||
|
||||
if isinstance(out, list):
|
||||
return out[-1].content
|
||||
|
||||
|
||||
if isinstance(out, str):
|
||||
return out
|
||||
|
||||
|
||||
assert 0, "something is wrong"
|
||||
|
||||
|
||||
def _stream_res(self, out:List[str | List[BaseMessage]], conv_id:str=None):
|
||||
def _stream_res(self, out: List[str | List[BaseMessage]], conv_id: str = None):
|
||||
for chunk in out:
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
else:
|
||||
conv_store.CONV_STORE.record_message_list(conv_id, chunk)
|
||||
conv_store.CONV_STORE.record_message_list(
|
||||
conv_id, chunk, pipeline_id=self.config.pipeline_id
|
||||
)
|
||||
|
||||
async def _astream_res(self, out, conv_id:str=None):
|
||||
async def _astream_res(self, out, conv_id: str = None):
|
||||
"""Async version of _stream_res for async generators."""
|
||||
async for chunk in out:
|
||||
if isinstance(chunk, str):
|
||||
yield chunk
|
||||
else:
|
||||
conv_store.CONV_STORE.record_message_list(conv_id, chunk)
|
||||
conv_store.CONV_STORE.record_message_list(
|
||||
conv_id, chunk, pipeline_id=self.config.pipeline_id
|
||||
)
|
||||
|
||||
def chat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
||||
def chat(
|
||||
self,
|
||||
inp: str,
|
||||
as_stream: bool = False,
|
||||
as_raw: bool = False,
|
||||
thread_id: str = "3",
|
||||
):
|
||||
"""
|
||||
as_stream (bool): if true, enable the thing to be streamable
|
||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||
@@ -161,8 +184,10 @@ class Pipeline:
|
||||
if len(spl_ls) == 2:
|
||||
_, device_id = spl_ls
|
||||
|
||||
inp = {"messages":[HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
"device_id":device_id}}
|
||||
inp = (
|
||||
{"messages": [HumanMessage(inp)]},
|
||||
{"configurable": {"thread_id": thread_id, "device_id": device_id}},
|
||||
)
|
||||
|
||||
out = self.invoke(*inp, as_stream=as_stream, as_raw=as_raw)
|
||||
|
||||
@@ -171,8 +196,8 @@ class Pipeline:
|
||||
return self._stream_res(out, thread_id)
|
||||
else:
|
||||
return out
|
||||
|
||||
def get_remove_id(self, thread_id:str) -> bool:
|
||||
|
||||
def get_remove_id(self, thread_id: str) -> bool:
|
||||
"""
|
||||
returns a id to remove if a new conversation has starte
|
||||
"""
|
||||
@@ -184,7 +209,7 @@ class Pipeline:
|
||||
|
||||
thread_id, device_id = parts
|
||||
c_th_id = self.thread_id_cache.get(device_id)
|
||||
|
||||
|
||||
if c_th_id is None:
|
||||
self.thread_id_cache[device_id] = thread_id
|
||||
return None
|
||||
@@ -196,7 +221,6 @@ class Pipeline:
|
||||
else:
|
||||
assert 0, "BUG SHOULD NOT BE HERE"
|
||||
|
||||
|
||||
async def ainvoke(self, *nargs, **kwargs):
|
||||
"""Async version of invoke using LangGraph's native async support."""
|
||||
out = await self.graph.ainvoke(*nargs, **kwargs)
|
||||
@@ -211,19 +235,25 @@ class Pipeline:
|
||||
|
||||
if isinstance(out, SystemMessage) or isinstance(out, HumanMessage):
|
||||
return out.content
|
||||
|
||||
|
||||
if isinstance(out, list):
|
||||
return out[-1].content
|
||||
|
||||
|
||||
if isinstance(out, str):
|
||||
return out
|
||||
|
||||
|
||||
assert 0, "something is wrong"
|
||||
|
||||
async def achat(self, inp:str, as_stream:bool=False, as_raw:bool=False, thread_id:str = '3'):
|
||||
async def achat(
|
||||
self,
|
||||
inp: str,
|
||||
as_stream: bool = False,
|
||||
as_raw: bool = False,
|
||||
thread_id: str = "3",
|
||||
):
|
||||
"""
|
||||
Async version of chat using LangGraph's native async support.
|
||||
|
||||
|
||||
as_stream (bool): if true, enable the thing to be streamable
|
||||
as_raw (bool): return full dialoge of List[SystemMessage, HumanMessage, ToolMessage]
|
||||
"""
|
||||
@@ -239,11 +269,14 @@ class Pipeline:
|
||||
assert len(spl_ls) <= 2, "something wrong!"
|
||||
if len(spl_ls) == 2:
|
||||
_, device_id = spl_ls
|
||||
print(f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m")
|
||||
print(
|
||||
f"\033[32m====================DEVICE ID: {device_id}=============================\033[0m"
|
||||
)
|
||||
|
||||
inp_data = {"messages":[SystemMessage(u),
|
||||
HumanMessage(inp)]}, {"configurable": {"thread_id": thread_id,
|
||||
"device_id":device_id}}
|
||||
inp_data = (
|
||||
{"messages": [SystemMessage(u), HumanMessage(inp)]},
|
||||
{"configurable": {"thread_id": thread_id, "device_id": device_id}},
|
||||
)
|
||||
|
||||
out = await self.ainvoke(*inp_data, as_stream=as_stream, as_raw=as_raw)
|
||||
|
||||
@@ -267,10 +300,13 @@ class Pipeline:
|
||||
if __name__ == "__main__":
|
||||
from lang_agent.graphs import ReactGraphConfig
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
# config = PipelineConfig(graph_config=ReactGraphConfig())
|
||||
config = PipelineConfig()
|
||||
pipeline: Pipeline = config.setup()
|
||||
for out in pipeline.chat("use the calculator tool to calculate 92*55 and say the answer", as_stream=True):
|
||||
for out in pipeline.chat(
|
||||
"use the calculator tool to calculate 92*55 and say the answer", as_stream=True
|
||||
):
|
||||
# print(out)
|
||||
continue
|
||||
continue
|
||||
|
||||
@@ -6,21 +6,27 @@ import os
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
def make_llm(model="qwen-plus",
|
||||
model_provider="openai",
|
||||
api_key=None,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
**kwargs)->BaseChatModel:
|
||||
|
||||
def make_llm(
|
||||
model="qwen-plus",
|
||||
model_provider="openai",
|
||||
api_key=None,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
**kwargs,
|
||||
) -> BaseChatModel:
|
||||
api_key = os.environ.get("ALI_API_KEY") if api_key is None else api_key
|
||||
|
||||
llm = init_chat_model(model=model,
|
||||
model_provider=model_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
**kwargs)
|
||||
|
||||
llm = init_chat_model(
|
||||
model=model,
|
||||
model_provider=model_provider,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return llm
|
||||
|
||||
|
||||
def tree_leaves(tree):
|
||||
"""
|
||||
Extracts all leaf values from a nested structure (dict, list, tuple).
|
||||
@@ -28,7 +34,7 @@ def tree_leaves(tree):
|
||||
"""
|
||||
leaves = []
|
||||
stack = [tree]
|
||||
|
||||
|
||||
while stack:
|
||||
node = stack.pop()
|
||||
if isinstance(node, dict):
|
||||
@@ -39,11 +45,10 @@ def tree_leaves(tree):
|
||||
stack.extend(reversed(node))
|
||||
else:
|
||||
leaves.append(node)
|
||||
|
||||
|
||||
return leaves
|
||||
|
||||
|
||||
NON_WORD_PATTERN = re.compile(r'[^\u4e00-\u9fffA-Za-z0-9_\s]')
|
||||
def words_only(text):
|
||||
"""
|
||||
Keep only:
|
||||
@@ -53,10 +58,11 @@ def words_only(text):
|
||||
Strip punctuation, emojis, etc.
|
||||
Return a list of tokens (Chinese blocks or Latin word blocks).
|
||||
"""
|
||||
NON_WORD_PATTERN = re.compile(r"[^\u4e00-\u9fffA-Za-z0-9_\s]")
|
||||
# 1. Replace all non-allowed characters with a space
|
||||
cleaned = NON_WORD_PATTERN.sub(' ', text)
|
||||
cleaned = NON_WORD_PATTERN.sub(" ", text)
|
||||
|
||||
# 2. Normalize multiple spaces and split into tokens
|
||||
tokens = cleaned.split()
|
||||
|
||||
return "".join(tokens)
|
||||
return "".join(tokens)
|
||||
|
||||
85
nginx.conf
Normal file
85
nginx.conf
Normal file
@@ -0,0 +1,85 @@
|
||||
events {
|
||||
worker_connections 1024;
|
||||
}
|
||||
|
||||
http {
|
||||
include /etc/nginx/mime.types;
|
||||
default_type application/octet-stream;
|
||||
|
||||
sendfile on;
|
||||
keepalive_timeout 65;
|
||||
|
||||
# Upstream backend
|
||||
upstream backend {
|
||||
server backend:8500;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 80;
|
||||
server_name localhost;
|
||||
|
||||
root /usr/share/nginx/html;
|
||||
index index.html;
|
||||
|
||||
# Always revalidate the SPA entrypoint so clients pick up the latest
|
||||
# hashed JS bundle after redeploys.
|
||||
location = /index.html {
|
||||
add_header Cache-Control "no-store, no-cache, must-revalidate, proxy-revalidate" always;
|
||||
add_header Pragma "no-cache" always;
|
||||
add_header Expires "0" always;
|
||||
try_files $uri =404;
|
||||
}
|
||||
|
||||
# Serve frontend static files
|
||||
location / {
|
||||
try_files $uri $uri/ /index.html;
|
||||
}
|
||||
|
||||
# Proxy API requests to backend
|
||||
location /v1/ {
|
||||
proxy_pass http://backend;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
}
|
||||
|
||||
# Proxy DashScope API requests
|
||||
location /apps/ {
|
||||
proxy_pass http://backend;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
}
|
||||
|
||||
# Proxy v1/apps requests
|
||||
location /v1/apps/ {
|
||||
proxy_pass http://backend;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection 'upgrade';
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_cache_bypass $http_upgrade;
|
||||
}
|
||||
|
||||
# Health check endpoint
|
||||
location /health {
|
||||
proxy_pass http://backend/health;
|
||||
access_log off;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,10 @@ dependencies = [
|
||||
"commentjson",
|
||||
"pandas",
|
||||
"asgiref",
|
||||
"psycopg[binary]"
|
||||
"psycopg[binary]",
|
||||
"deepagents",
|
||||
"daytona",
|
||||
"langchain_daytona"
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
source ~/.bashrc
|
||||
conda init
|
||||
conda activate lang
|
||||
|
||||
echo create blueberry config
|
||||
python scripts/py_scripts/misc_tasks.py --save-path config/pipelines/blueberry.yaml \
|
||||
react \
|
||||
--sys-prompt-f configs/prompts/blueberry.txt \
|
||||
--tool-manager-config.client-tool-manager.tool-keys
|
||||
|
||||
# echo create xiaozhan config
|
||||
python scripts/py_scripts/misc_tasks.py --save-path config/pipelines/xiaozhan.yaml
|
||||
40
scripts/init_database/00_init_user.sh
Executable file
40
scripts/init_database/00_init_user.sh
Executable file
@@ -0,0 +1,40 @@
|
||||
#!/bin/bash
|
||||
# Initialize database user and database
|
||||
# This script runs before SQL files in docker-entrypoint-initdb.d
|
||||
# It must be named with 00_ prefix to run first
|
||||
|
||||
set -e
|
||||
|
||||
APP_DB_NAME="${APP_DB_NAME:-ai_conversations}"
|
||||
APP_DB_USER="${APP_DB_USER:-myapp_user}"
|
||||
APP_DB_PASSWORD="${APP_DB_PASSWORD:-secure_password_123}"
|
||||
|
||||
echo "Creating database user: $APP_DB_USER"
|
||||
# Create user
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||
DO \$\$
|
||||
BEGIN
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_user WHERE usename = '$APP_DB_USER') THEN
|
||||
CREATE USER $APP_DB_USER WITH PASSWORD '$APP_DB_PASSWORD';
|
||||
END IF;
|
||||
END
|
||||
\$\$;
|
||||
ALTER USER $APP_DB_USER CREATEDB;
|
||||
EOSQL
|
||||
|
||||
echo "Creating database: $APP_DB_NAME"
|
||||
# Create database
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL
|
||||
SELECT 'CREATE DATABASE $APP_DB_NAME'
|
||||
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$APP_DB_NAME')\gexec
|
||||
GRANT ALL PRIVILEGES ON DATABASE $APP_DB_NAME TO $APP_DB_USER;
|
||||
EOSQL
|
||||
|
||||
echo "Granting schema privileges"
|
||||
# Grant schema privileges
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$APP_DB_NAME" <<-EOSQL
|
||||
GRANT ALL ON SCHEMA public TO $APP_DB_USER;
|
||||
EOSQL
|
||||
|
||||
echo "Database initialization complete!"
|
||||
|
||||
25
scripts/init_database/01_run_sql_files.sh
Executable file
25
scripts/init_database/01_run_sql_files.sh
Executable file
@@ -0,0 +1,25 @@
|
||||
#!/bin/bash
|
||||
# Run SQL initialization files in the correct database context
|
||||
# This script runs after 00_init_user.sh creates the database
|
||||
|
||||
set -e
|
||||
|
||||
APP_DB_NAME="${APP_DB_NAME:-ai_conversations}"
|
||||
|
||||
echo "Running SQL initialization files in database: $APP_DB_NAME"
|
||||
|
||||
# Run create_conv_store.sql
|
||||
if [ -f /docker-entrypoint-initdb.d/create_conv_store.sql ]; then
|
||||
echo "Executing create_conv_store.sql..."
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$APP_DB_NAME" -f /docker-entrypoint-initdb.d/create_conv_store.sql
|
||||
fi
|
||||
|
||||
# Run create_prompt_config.sql
|
||||
if [ -f /docker-entrypoint-initdb.d/create_prompt_config.sql ]; then
|
||||
echo "Executing create_prompt_config.sql..."
|
||||
psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$APP_DB_NAME" -f /docker-entrypoint-initdb.d/create_prompt_config.sql
|
||||
fi
|
||||
|
||||
echo "SQL initialization files completed!"
|
||||
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
-- Create the messages table
|
||||
-- This script runs in the ai_conversations database context
|
||||
CREATE TABLE IF NOT EXISTS messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
pipeline_id TEXT,
|
||||
message_type VARCHAR(10) NOT NULL CHECK (message_type IN ('human', 'ai', 'tool')),
|
||||
content TEXT NOT NULL,
|
||||
sequence_number INTEGER NOT NULL CHECK (sequence_number >= 0),
|
||||
@@ -9,4 +11,11 @@ CREATE TABLE IF NOT EXISTS messages (
|
||||
);
|
||||
|
||||
-- Index for fast retrieval of all messages in a conversation (in order)
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages (conversation_id, sequence_number);
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_conversation ON messages (conversation_id, sequence_number);
|
||||
|
||||
-- Index for fast lookup by pipeline_id
|
||||
CREATE INDEX IF NOT EXISTS idx_messages_pipeline ON messages (pipeline_id);
|
||||
|
||||
-- Grant permissions to app user
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO myapp_user;
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO myapp_user;
|
||||
@@ -1,4 +1,5 @@
|
||||
-- A prompt_set groups a full collection of prompts together.
|
||||
-- This script runs in the ai_conversations database context
|
||||
-- Each pipeline can have many sets (versions, A/B variants, etc.);
|
||||
-- exactly one should be marked is_active per pipeline.
|
||||
CREATE TABLE IF NOT EXISTS prompt_sets (
|
||||
@@ -41,6 +42,10 @@ CREATE TABLE IF NOT EXISTS prompt_templates (
|
||||
CREATE INDEX IF NOT EXISTS idx_prompt_templates_set_id
|
||||
ON prompt_templates(prompt_set_id);
|
||||
|
||||
-- Grant permissions to app user
|
||||
GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA public TO myapp_user;
|
||||
GRANT ALL PRIVILEGES ON ALL SEQUENCES IN SCHEMA public TO myapp_user;
|
||||
|
||||
-- Seed: initial prompt set for lang_agent/graphs/routing.py
|
||||
-- The pipeline_id can be used by RoutingConfig.pipeline_id to load these prompts.
|
||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||
@@ -75,6 +80,39 @@ DO UPDATE SET
|
||||
content = EXCLUDED.content,
|
||||
updated_at = now();
|
||||
|
||||
-- Seed: default prompt set for xiaozhan agent (RoutingGraph)
|
||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||
SELECT
|
||||
'xiaozhan',
|
||||
'routing',
|
||||
'default',
|
||||
'Default prompt set for xiaozhan (RoutingGraph)',
|
||||
true,
|
||||
''
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM prompt_sets
|
||||
WHERE pipeline_id = 'xiaozhan'
|
||||
AND name = 'default'
|
||||
);
|
||||
|
||||
INSERT INTO prompt_templates (prompt_set_id, prompt_key, content)
|
||||
SELECT ps.id, t.prompt_key, t.content
|
||||
FROM prompt_sets ps
|
||||
JOIN (
|
||||
VALUES
|
||||
('route_prompt', '决定用工具或者聊天'),
|
||||
('chat_prompt', '正常聊天时说什么'),
|
||||
('tool_prompt', '如何用工具')
|
||||
) AS t(prompt_key, content)
|
||||
ON true
|
||||
WHERE ps.pipeline_id = 'xiaozhan'
|
||||
AND ps.name = 'default'
|
||||
ON CONFLICT (prompt_set_id, prompt_key)
|
||||
DO UPDATE SET
|
||||
content = EXCLUDED.content,
|
||||
updated_at = now();
|
||||
|
||||
-- Seed: initial prompt set for lang_agent/graphs/react.py
|
||||
-- ReactGraph uses prompt key "sys_prompt" (see default_key in build_prompt_store).
|
||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||
@@ -101,3 +139,56 @@ ON CONFLICT (prompt_set_id, prompt_key)
|
||||
DO UPDATE SET
|
||||
content = EXCLUDED.content,
|
||||
updated_at = now();
|
||||
|
||||
-- Seed: default prompt set for blueberry agent (ReactGraph)
|
||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||
SELECT
|
||||
'blueberry',
|
||||
'react',
|
||||
'default',
|
||||
'Default prompt set for blueberry (ReactGraph)',
|
||||
true,
|
||||
''
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM prompt_sets
|
||||
WHERE pipeline_id = 'blueberry'
|
||||
AND name = 'default'
|
||||
);
|
||||
|
||||
INSERT INTO prompt_templates (prompt_set_id, prompt_key, content)
|
||||
SELECT ps.id, 'sys_prompt', '如何用工具'
|
||||
FROM prompt_sets ps
|
||||
WHERE ps.pipeline_id = 'blueberry'
|
||||
AND ps.name = 'default'
|
||||
ON CONFLICT (prompt_set_id, prompt_key)
|
||||
DO UPDATE SET
|
||||
content = EXCLUDED.content,
|
||||
updated_at = now();
|
||||
|
||||
-- Seed: initial prompt set for lang_agent/graphs/deepagents_qt.py
|
||||
-- DeepAgent uses prompt key "sys_prompt" with DB-first, file-fallback loading.
|
||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||
SELECT
|
||||
'deepagent',
|
||||
'deepagent',
|
||||
'default',
|
||||
'Initial prompt set for DeepAgent',
|
||||
true,
|
||||
''
|
||||
WHERE NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM prompt_sets
|
||||
WHERE pipeline_id = 'deepagent'
|
||||
AND name = 'default'
|
||||
);
|
||||
|
||||
INSERT INTO prompt_templates (prompt_set_id, prompt_key, content)
|
||||
SELECT ps.id, 'sys_prompt', '你是一个擅长调用工具和处理文件任务的深度代理。'
|
||||
FROM prompt_sets ps
|
||||
WHERE ps.pipeline_id = 'deepagent'
|
||||
AND ps.name = 'default'
|
||||
ON CONFLICT (prompt_set_id, prompt_key)
|
||||
DO UPDATE SET
|
||||
content = EXCLUDED.content,
|
||||
updated_at = now();
|
||||
|
||||
49
scripts/init_database/init_all.sh
Executable file
49
scripts/init_database/init_all.sh
Executable file
@@ -0,0 +1,49 @@
|
||||
#!/bin/bash
|
||||
# Database initialization script
|
||||
# This script runs all SQL initialization files in the correct order
|
||||
|
||||
set -e
|
||||
|
||||
DB_NAME="${POSTGRES_DB:-ai_conversations}"
|
||||
DB_USER="${POSTGRES_USER:-myapp_user}"
|
||||
DB_PASSWORD="${POSTGRES_PASSWORD:-secure_password_123}"
|
||||
DB_HOST="${POSTGRES_HOST:-localhost}"
|
||||
DB_PORT="${POSTGRES_PORT:-5432}"
|
||||
|
||||
export PGPASSWORD="$DB_PASSWORD"
|
||||
|
||||
echo "Initializing database: $DB_NAME on $DB_HOST:$DB_PORT"
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
until psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d postgres -c '\q' 2>/dev/null; do
|
||||
echo "Waiting for PostgreSQL to be ready..."
|
||||
sleep 2
|
||||
done
|
||||
|
||||
echo "PostgreSQL is ready!"
|
||||
|
||||
# Create database if it doesn't exist
|
||||
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d postgres <<EOF
|
||||
SELECT 'CREATE DATABASE $DB_NAME'
|
||||
WHERE NOT EXISTS (SELECT FROM pg_database WHERE datname = '$DB_NAME')\gexec
|
||||
EOF
|
||||
|
||||
# Grant privileges
|
||||
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d postgres <<EOF
|
||||
GRANT ALL PRIVILEGES ON DATABASE $DB_NAME TO $DB_USER;
|
||||
EOF
|
||||
|
||||
# Run initialization scripts in order
|
||||
echo "Running database initialization scripts..."
|
||||
|
||||
# 1. Create conversation store tables
|
||||
echo "Creating conversation store tables..."
|
||||
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -f /docker-entrypoint-initdb.d/create_conv_store.sql
|
||||
|
||||
# 2. Create prompt configuration tables
|
||||
echo "Creating prompt configuration tables..."
|
||||
psql -h "$DB_HOST" -p "$DB_PORT" -U "$DB_USER" -d "$DB_NAME" -f /docker-entrypoint-initdb.d/create_prompt_config.sql
|
||||
|
||||
echo "Database initialization complete!"
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
-- Drop the index first (if it exists)
|
||||
DROP INDEX IF EXISTS idx_messages_conversation;
|
||||
DROP INDEX IF EXISTS idx_messages_pipeline;
|
||||
|
||||
-- Drop the messages table (if it exists)
|
||||
DROP TABLE IF EXISTS messages;
|
||||
@@ -9,6 +10,7 @@ DROP TABLE IF EXISTS messages;
|
||||
CREATE TABLE messages (
|
||||
id BIGSERIAL PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
pipeline_id TEXT,
|
||||
message_type VARCHAR(10) NOT NULL CHECK (message_type IN ('human', 'ai', 'tool')),
|
||||
content TEXT NOT NULL,
|
||||
sequence_number INTEGER NOT NULL CHECK (sequence_number >= 0),
|
||||
@@ -18,3 +20,6 @@ CREATE TABLE messages (
|
||||
-- Recreate the index for fast retrieval of all messages in a conversation (in order)
|
||||
CREATE INDEX idx_messages_conversation ON messages (conversation_id, sequence_number);
|
||||
|
||||
-- Index for fast lookup by pipeline_id
|
||||
CREATE INDEX idx_messages_pipeline ON messages (pipeline_id);
|
||||
|
||||
|
||||
141
scripts/py_scripts/chat_dashcope.py
Normal file
141
scripts/py_scripts/chat_dashcope.py
Normal file
@@ -0,0 +1,141 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple chat loop to interact with the blueberry pipeline via DashScope-compatible API.
|
||||
|
||||
Usage:
|
||||
python scripts/py_scripts/chat_dashcope.py
|
||||
|
||||
The script connects to the server running on http://localhost:8500
|
||||
and uses the API key from the pipeline registry.
|
||||
"""
|
||||
|
||||
import requests
|
||||
import json
|
||||
import sys
|
||||
from typing import Optional
|
||||
|
||||
|
||||
# Configuration from pipeline_registry.json
|
||||
API_KEY = "sk-6c7091e6a95f404efb2ec30e8f51b897626d670375cdf822d78262f24ab12367"
|
||||
PIPELINE_ID = "blueberry"
|
||||
BASE_URL = "http://localhost:8500"
|
||||
SESSION_ID = "chat-session-1"
|
||||
|
||||
|
||||
def send_message(
|
||||
message: str,
|
||||
session_id: str = SESSION_ID,
|
||||
stream: bool = False,
|
||||
app_id: str = PIPELINE_ID,
|
||||
) -> Optional[str]:
|
||||
"""Send a message to the blueberry pipeline and return the response."""
|
||||
url = f"{BASE_URL}/v1/apps/{app_id}/sessions/{session_id}/responses"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {API_KEY}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
payload = {
|
||||
"messages": [
|
||||
{"role": "user", "content": message}
|
||||
],
|
||||
"stream": stream,
|
||||
}
|
||||
|
||||
try:
|
||||
if stream:
|
||||
# Handle streaming response
|
||||
response = requests.post(url, headers=headers, json=payload, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
accumulated_text = ""
|
||||
for line in response.iter_lines():
|
||||
if line:
|
||||
line_str = line.decode('utf-8')
|
||||
if line_str.startswith('data: '):
|
||||
data_str = line_str[6:] # Remove 'data: ' prefix
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
output = data.get("output", {})
|
||||
text = output.get("text", "")
|
||||
if text:
|
||||
accumulated_text = text
|
||||
# Print incremental updates (you can modify this behavior)
|
||||
print(f"\rAssistant: {accumulated_text}", end="", flush=True)
|
||||
|
||||
if data.get("is_end", False):
|
||||
print() # New line after streaming completes
|
||||
return accumulated_text
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
return accumulated_text
|
||||
else:
|
||||
# Handle non-streaming response
|
||||
response = requests.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
output = data.get("output", {})
|
||||
return output.get("text", "")
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
print(f"Error sending message: {e}", file=sys.stderr)
|
||||
if hasattr(e, 'response') and e.response is not None:
|
||||
try:
|
||||
error_detail = e.response.json()
|
||||
print(f"Error details: {error_detail}", file=sys.stderr)
|
||||
except:
|
||||
print(f"Response status: {e.response.status_code}", file=sys.stderr)
|
||||
return None
|
||||
|
||||
|
||||
def main():
|
||||
"""Main chat loop."""
|
||||
print("=" * 60)
|
||||
print(f"Chat with Blueberry Pipeline")
|
||||
print(f"Pipeline ID: {PIPELINE_ID}")
|
||||
print(f"Server: {BASE_URL}")
|
||||
print(f"Session ID: {SESSION_ID}")
|
||||
print("=" * 60)
|
||||
print("Type your messages (or 'quit'/'exit' to end, 'stream' to toggle streaming)")
|
||||
print("Streaming mode is ON by default")
|
||||
print()
|
||||
|
||||
stream_mode = True
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("You: ").strip()
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
if user_input.lower() in ['quit', 'exit', 'q']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'stream':
|
||||
stream_mode = not stream_mode
|
||||
print(f"Streaming mode: {'ON' if stream_mode else 'OFF'}")
|
||||
continue
|
||||
|
||||
print("Assistant: ", end="", flush=True)
|
||||
response = send_message(user_input, stream=stream_mode)
|
||||
|
||||
if response is None:
|
||||
print("(No response received)")
|
||||
elif not stream_mode:
|
||||
print(response)
|
||||
# For streaming, the response is already printed incrementally
|
||||
|
||||
print() # Empty line for readability
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nGoodbye!")
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}", file=sys.stderr)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -18,129 +18,129 @@ examples = [
|
||||
"inputs": {"text": "请你介绍一下少年右这杯茶"},
|
||||
"outputs": {
|
||||
"answer": "这是一杯使用武夷肉桂为原料的茶,带有浓郁的肉桂香气和微微的辛辣感,茶汤醇厚,回味悠长,非常适合喜欢浓烈香气的茶友。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "给我讲讲野心心这杯茶"},
|
||||
"outputs": {
|
||||
"answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下小甜新"},
|
||||
"outputs": {
|
||||
"answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "小盏,什么是大白百?"},
|
||||
"outputs": {
|
||||
"answer": "大白白是一款2012年的老白茶,经过多年陈化,蜜香温润,茶汤醇厚顺滑,回甘绵长,整体风格安静而沉稳。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "帮我介绍下引你进山林"},
|
||||
"outputs": {
|
||||
"answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "伏身于大自然是什么味道"},
|
||||
"outputs": {
|
||||
"answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下小野仔"},
|
||||
"outputs": {
|
||||
"answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "放轻松这杯喝起来怎么样"},
|
||||
"outputs": {
|
||||
"answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "啤啤查是酒吗"},
|
||||
"outputs": {
|
||||
"answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "鲜叶康普查有什么特点"},
|
||||
"outputs": {
|
||||
"answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下寻静密"},
|
||||
"outputs": {
|
||||
"answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "小陶燃是什么茶"},
|
||||
"outputs": {
|
||||
"answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "花仙仔适合什么人喝"},
|
||||
"outputs": {
|
||||
"answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。",
|
||||
"tool_use": ["search_dishes"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍下小美慢"},
|
||||
"outputs": {
|
||||
"answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。",
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "你叫什么名字"},
|
||||
"outputs": {
|
||||
"answer": "我叫小盏,是半盏新青年茶馆的智能助手",
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "我今天很开心"},
|
||||
"outputs": {
|
||||
"answer": "太棒啦!看到你开心",
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "你好可爱呀!"},
|
||||
"outputs": {
|
||||
"answer": "谢谢你呀~",
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "你有没有时间"},
|
||||
"outputs": {
|
||||
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
|
||||
}
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下你自己"},
|
||||
"outputs": {
|
||||
"answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶",
|
||||
}
|
||||
},
|
||||
"inputs": {"text": "给我讲讲野心心这杯茶"},
|
||||
"outputs": {
|
||||
"answer": "野星星选用云南西双版纳野生大树春茶,历经二十多年陈化,茶汤醇厚饱满,回甘迅猛,带着明显的岁月沉香与山野气息。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下小甜新"},
|
||||
"outputs": {
|
||||
"answer": "小甜心来自芒景村古树生普,兰香与蜜韵交织,入口柔和,回甘悠长,是一款耐喝又有层次的老料生普。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "小盏,什么是大白百?"},
|
||||
"outputs": {
|
||||
"answer": "大白白是一款2012年的老白茶,经过多年陈化,蜜香温润,茶汤醇厚顺滑,回甘绵长,整体风格安静而沉稳。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "帮我介绍下引你进山林"},
|
||||
"outputs": {
|
||||
"answer": "引你入山林以新会陈皮搭配云南白茶,茶汤清甜柔和,带有淡淡的花果香与陈皮的温润气息,喝起来非常舒服。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "伏身于大自然是什么味道"},
|
||||
"outputs": {
|
||||
"answer": "伏身于自然将云南滇红与玫瑰慢煮融合,花香馥郁,入口醇厚甘甜,蜜香在口中停留很久,温暖又放松。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下小野仔"},
|
||||
"outputs": {
|
||||
"answer": "小野子选用云南古树晒红制作,蜜香高扬,口感甜润顺滑,回甘明显,是一款非常友好的红茶。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "放轻松这杯喝起来怎么样"},
|
||||
"outputs": {
|
||||
"answer": "放轻松是小青柑搭配熟普,茶汤醇厚顺滑,柑香清新提亮整体口感,非常适合饭后或想放松的时候。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "啤啤查是酒吗"},
|
||||
"outputs": {
|
||||
"answer": "啤啤茶是一款无酒精气泡茶,以普洱和玫瑰为茶底,气泡清爽,入口有类似啤酒的畅快感,但完全不含酒精。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "鲜叶康普查有什么特点"},
|
||||
"outputs": {
|
||||
"answer": "鲜叶康普茶经过自然发酵,带有轻盈气泡和清爽酸甜感,同时富含益生菌,整体低糖低卡,口感非常清新。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下寻静密"},
|
||||
"outputs": {
|
||||
"answer": "寻静谧融合茉莉绿茶与抹茶,茶感温润微涩,搭配栀子花香奶盖与海苔碎,层次细腻,整体风格安静沉稳。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "小陶燃是什么茶"},
|
||||
"outputs": {
|
||||
"answer": "小陶然是一款熟普黑茶,选用布朗山原料发酵,陈香明显,滋味甜醇饱满,口感厚实顺滑。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "花仙仔适合什么人喝"},
|
||||
"outputs": {
|
||||
"answer": "花仙子是东方美人乌龙茶,带有天然熟果蜜香,茶感柔和细腻,很适合喜欢花果香型乌龙的茶友。",
|
||||
"tool_use": ["search_dishes"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍下小美慢"},
|
||||
"outputs": {
|
||||
"answer": "小美满选用福鼎老寿眉白茶,带有枣香和淡淡药香,口感甘润持久,是一款很有岁月感的白茶。",
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "你叫什么名字"},
|
||||
"outputs": {
|
||||
"answer": "我叫小盏,是半盏新青年茶馆的智能助手",
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "我今天很开心"},
|
||||
"outputs": {
|
||||
"answer": "太棒啦!看到你开心",
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "你好可爱呀!"},
|
||||
"outputs": {
|
||||
"answer": "谢谢你呀~",
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "你有没有时间"},
|
||||
"outputs": {
|
||||
"answer": "我一直在呢,随时陪你聊聊天、喝杯茶",
|
||||
},
|
||||
},
|
||||
{
|
||||
"inputs": {"text": "介绍一下你自己"},
|
||||
"outputs": {
|
||||
"answer": "我叫小盏,是一只中式茶盖碗,名字来源半盏新青年茶馆,一盏茶",
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
cli = Client()
|
||||
@@ -151,7 +151,4 @@ try:
|
||||
except:
|
||||
dataset = cli.create_dataset(dataset_name=DATASET_NAME)
|
||||
logger.info("created dataset")
|
||||
cli.create_examples(
|
||||
dataset_id=dataset.id,
|
||||
examples=examples
|
||||
)
|
||||
cli.create_examples(dataset_id=dataset.id, examples=examples)
|
||||
|
||||
364
scripts/py_scripts/migrate_yaml_prompts_to_db.py
Normal file
364
scripts/py_scripts/migrate_yaml_prompts_to_db.py
Normal file
@@ -0,0 +1,364 @@
|
||||
#!/usr/bin/env python3
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import glob
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Iterable, List, Optional
|
||||
|
||||
import commentjson
|
||||
import psycopg
|
||||
|
||||
|
||||
PROJECT_ROOT = osp.dirname(osp.dirname(osp.dirname(osp.abspath(__file__))))
|
||||
if PROJECT_ROOT not in sys.path:
|
||||
sys.path.append(PROJECT_ROOT)
|
||||
|
||||
from lang_agent.config import load_tyro_conf # noqa: E402
|
||||
from lang_agent.config.db_config_manager import DBConfigManager # noqa: E402
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationPayload:
|
||||
config_path: str
|
||||
pipeline_id: str
|
||||
graph_id: str
|
||||
prompt_dict: Dict[str, str]
|
||||
tool_keys: List[str]
|
||||
api_key: Optional[str]
|
||||
|
||||
|
||||
def _infer_pipeline_id(pipeline_conf, config_path: str) -> str:
|
||||
candidates = [
|
||||
getattr(pipeline_conf, "pipeline_id", None),
|
||||
getattr(getattr(pipeline_conf, "graph_config", None), "pipeline_id", None),
|
||||
]
|
||||
for candidate in candidates:
|
||||
if candidate is None:
|
||||
continue
|
||||
value = str(candidate).strip()
|
||||
if value and value.lower() != "null":
|
||||
return value
|
||||
return osp.splitext(osp.basename(config_path))[0]
|
||||
|
||||
|
||||
def _infer_graph_id(graph_conf) -> str:
|
||||
if graph_conf is None:
|
||||
return "unknown"
|
||||
class_name = graph_conf.__class__.__name__.lower()
|
||||
if "routing" in class_name or class_name == "routeconfig":
|
||||
return "routing"
|
||||
if "react" in class_name:
|
||||
return "react"
|
||||
|
||||
target = getattr(graph_conf, "_target", None)
|
||||
if target is not None:
|
||||
target_name = getattr(target, "__name__", str(target)).lower()
|
||||
if "routing" in target_name:
|
||||
return "routing"
|
||||
if "react" in target_name:
|
||||
return "react"
|
||||
return "unknown"
|
||||
|
||||
|
||||
def _extract_tool_keys(graph_conf) -> List[str]:
|
||||
if graph_conf is None:
|
||||
return []
|
||||
tool_cfg = getattr(graph_conf, "tool_manager_config", None)
|
||||
client_cfg = getattr(tool_cfg, "client_tool_manager", None)
|
||||
keys = getattr(client_cfg, "tool_keys", None)
|
||||
if not keys:
|
||||
return []
|
||||
out: List[str] = []
|
||||
seen = set()
|
||||
for key in keys:
|
||||
cleaned = str(key).strip()
|
||||
if not cleaned or cleaned in seen:
|
||||
continue
|
||||
seen.add(cleaned)
|
||||
out.append(cleaned)
|
||||
return out
|
||||
|
||||
|
||||
def _load_prompt_dict(prompt_path: str, default_key: str = "sys_prompt") -> Dict[str, str]:
|
||||
if not prompt_path:
|
||||
return {}
|
||||
if not osp.exists(prompt_path):
|
||||
return {}
|
||||
|
||||
if osp.isdir(prompt_path):
|
||||
prompt_files = sorted(
|
||||
p for p in glob.glob(osp.join(prompt_path, "*.txt")) if "optional" not in p
|
||||
)
|
||||
out = {}
|
||||
for prompt_f in prompt_files:
|
||||
key = osp.splitext(osp.basename(prompt_f))[0]
|
||||
with open(prompt_f, "r", encoding="utf-8") as f:
|
||||
out[key] = f.read()
|
||||
return out
|
||||
|
||||
if prompt_path.endswith(".json"):
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
obj = commentjson.load(f)
|
||||
if not isinstance(obj, dict):
|
||||
return {}
|
||||
return {str(k): v if isinstance(v, str) else str(v) for k, v in obj.items()}
|
||||
|
||||
if prompt_path.endswith(".txt"):
|
||||
with open(prompt_path, "r", encoding="utf-8") as f:
|
||||
return {default_key: f.read()}
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_prompt_dict(graph_conf) -> Dict[str, str]:
|
||||
if graph_conf is None:
|
||||
return {}
|
||||
if hasattr(graph_conf, "sys_prompt_f"):
|
||||
return _load_prompt_dict(str(getattr(graph_conf, "sys_prompt_f")), "sys_prompt")
|
||||
if hasattr(graph_conf, "sys_promp_dir"):
|
||||
return _load_prompt_dict(str(getattr(graph_conf, "sys_promp_dir")))
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_tool_node_prompt_dict(graph_conf) -> Dict[str, str]:
|
||||
tool_node_conf = getattr(graph_conf, "tool_node_config", None)
|
||||
if tool_node_conf is None:
|
||||
return {}
|
||||
|
||||
out: Dict[str, str] = {}
|
||||
if hasattr(tool_node_conf, "tool_prompt_f"):
|
||||
out.update(
|
||||
_load_prompt_dict(str(getattr(tool_node_conf, "tool_prompt_f")), "tool_prompt")
|
||||
)
|
||||
if hasattr(tool_node_conf, "chatty_sys_prompt_f"):
|
||||
out.update(
|
||||
_load_prompt_dict(
|
||||
str(getattr(tool_node_conf, "chatty_sys_prompt_f")), "chatty_prompt"
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def _prompt_key_whitelist(graph_conf, graph_id: str) -> Optional[set]:
|
||||
if graph_id == "react":
|
||||
return {"sys_prompt"}
|
||||
if graph_id != "routing":
|
||||
return None
|
||||
|
||||
allowed = {"route_prompt", "chat_prompt", "tool_prompt"}
|
||||
tool_node_conf = getattr(graph_conf, "tool_node_config", None)
|
||||
if tool_node_conf is None:
|
||||
return allowed
|
||||
|
||||
cls_name = tool_node_conf.__class__.__name__.lower()
|
||||
target = getattr(tool_node_conf, "_target", None)
|
||||
target_name = getattr(target, "__name__", str(target)).lower() if target else ""
|
||||
if "chatty" in cls_name or "chatty" in target_name:
|
||||
allowed.add("chatty_prompt")
|
||||
return allowed
|
||||
|
||||
|
||||
def _collect_payload(config_path: str) -> MigrationPayload:
|
||||
conf = load_tyro_conf(config_path)
|
||||
graph_conf = getattr(conf, "graph_config", None)
|
||||
graph_id = _infer_graph_id(graph_conf)
|
||||
prompt_dict = _extract_prompt_dict(graph_conf)
|
||||
prompt_dict.update(_extract_tool_node_prompt_dict(graph_conf))
|
||||
whitelist = _prompt_key_whitelist(graph_conf, graph_id)
|
||||
if whitelist is not None:
|
||||
prompt_dict = {k: v for k, v in prompt_dict.items() if k in whitelist}
|
||||
return MigrationPayload(
|
||||
config_path=config_path,
|
||||
pipeline_id=_infer_pipeline_id(conf, config_path),
|
||||
graph_id=graph_id,
|
||||
prompt_dict=prompt_dict,
|
||||
tool_keys=_extract_tool_keys(graph_conf),
|
||||
api_key=getattr(conf, "api_key", None),
|
||||
)
|
||||
|
||||
|
||||
def _resolve_config_paths(config_dir: str, config_paths: Optional[Iterable[str]]) -> List[str]:
|
||||
if config_paths:
|
||||
resolved = [osp.abspath(path) for path in config_paths]
|
||||
else:
|
||||
pattern = osp.join(osp.abspath(config_dir), "*.yaml")
|
||||
resolved = sorted(glob.glob(pattern))
|
||||
return [path for path in resolved if osp.exists(path)]
|
||||
|
||||
|
||||
def _ensure_prompt_set(
|
||||
conn: psycopg.Connection,
|
||||
pipeline_id: str,
|
||||
graph_id: str,
|
||||
set_name: str,
|
||||
description: str,
|
||||
) -> str:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id FROM prompt_sets
|
||||
WHERE pipeline_id = %s AND name = %s
|
||||
ORDER BY updated_at DESC, created_at DESC
|
||||
LIMIT 1
|
||||
""",
|
||||
(pipeline_id, set_name),
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row is not None:
|
||||
return str(row[0])
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
INSERT INTO prompt_sets (pipeline_id, graph_id, name, description, is_active, list)
|
||||
VALUES (%s, %s, %s, %s, false, '')
|
||||
RETURNING id
|
||||
""",
|
||||
(pipeline_id, graph_id, set_name, description),
|
||||
)
|
||||
created = cur.fetchone()
|
||||
return str(created[0])
|
||||
|
||||
|
||||
def _activate_prompt_set(conn: psycopg.Connection, pipeline_id: str, prompt_set_id: str) -> None:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"UPDATE prompt_sets SET is_active = false, updated_at = now() WHERE pipeline_id = %s",
|
||||
(pipeline_id,),
|
||||
)
|
||||
cur.execute(
|
||||
"UPDATE prompt_sets SET is_active = true, updated_at = now() WHERE id = %s",
|
||||
(prompt_set_id,),
|
||||
)
|
||||
|
||||
|
||||
def _run_migration(
|
||||
payloads: List[MigrationPayload],
|
||||
set_name: str,
|
||||
description: str,
|
||||
dry_run: bool,
|
||||
activate: bool,
|
||||
) -> None:
|
||||
for payload in payloads:
|
||||
print(
|
||||
f"[PLAN] pipeline={payload.pipeline_id} graph={payload.graph_id} "
|
||||
f"prompts={len(payload.prompt_dict)} tools={len(payload.tool_keys)} "
|
||||
f"config={payload.config_path}"
|
||||
)
|
||||
if dry_run:
|
||||
continue
|
||||
|
||||
manager = DBConfigManager()
|
||||
with psycopg.connect(manager.conn_str) as conn:
|
||||
prompt_set_id = _ensure_prompt_set(
|
||||
conn=conn,
|
||||
pipeline_id=payload.pipeline_id,
|
||||
graph_id=payload.graph_id,
|
||||
set_name=set_name,
|
||||
description=description,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
manager.set_config(
|
||||
pipeline_id=payload.pipeline_id,
|
||||
graph_id=payload.graph_id,
|
||||
prompt_set_id=prompt_set_id,
|
||||
tool_list=payload.tool_keys,
|
||||
prompt_dict=payload.prompt_dict,
|
||||
api_key=payload.api_key,
|
||||
)
|
||||
|
||||
if activate:
|
||||
_activate_prompt_set(
|
||||
conn=conn,
|
||||
pipeline_id=payload.pipeline_id,
|
||||
prompt_set_id=prompt_set_id,
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
print(
|
||||
f"[DONE] pipeline={payload.pipeline_id} "
|
||||
f"prompt_set={prompt_set_id} activate={activate}"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
date_str = dt.date.today().isoformat()
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Import prompt definitions from pipeline YAML files into DB prompt_sets."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-dir",
|
||||
default=osp.join(PROJECT_ROOT, "configs", "pipelines"),
|
||||
help="Directory containing pipeline YAML files.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
action="append",
|
||||
default=[],
|
||||
help="Specific pipeline config yaml path. Can be passed multiple times.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pipeline-id",
|
||||
action="append",
|
||||
default=[],
|
||||
help="If provided, only migrate these pipeline IDs (repeatable).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--set-name",
|
||||
# default=f"migrated-{date_str}",
|
||||
default="default",
|
||||
help="Prompt set name to create/reuse under each pipeline.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--description",
|
||||
default="Migrated from pipeline YAML prompt files",
|
||||
help="Prompt set description.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Print what would be migrated without writing to DB.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--activate",
|
||||
action="store_true",
|
||||
help="Mark imported set active for each migrated pipeline.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
config_paths = _resolve_config_paths(args.config_dir, args.config)
|
||||
if not config_paths:
|
||||
raise SystemExit("No config files found. Provide --config or --config-dir.")
|
||||
|
||||
requested_pipelines = {p.strip() for p in args.pipeline_id if p.strip()}
|
||||
|
||||
payloads: List[MigrationPayload] = []
|
||||
for config_path in config_paths:
|
||||
payload = _collect_payload(config_path)
|
||||
if requested_pipelines and payload.pipeline_id not in requested_pipelines:
|
||||
continue
|
||||
if not payload.prompt_dict:
|
||||
print(f"[SKIP] no prompts found for config={config_path}")
|
||||
continue
|
||||
payloads.append(payload)
|
||||
|
||||
if not payloads:
|
||||
raise SystemExit("No pipelines matched with prompt content to migrate.")
|
||||
|
||||
_run_migration(
|
||||
payloads=payloads,
|
||||
set_name=args.set_name,
|
||||
description=args.description,
|
||||
dry_run=args.dry_run,
|
||||
activate=args.activate,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
17
scripts/shell_scripts/create_config.sh
Normal file
17
scripts/shell_scripts/create_config.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
|
||||
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
source ~/.bashrc
|
||||
conda init
|
||||
conda activate lang
|
||||
|
||||
echo create blueberry config
|
||||
python scripts/py_scripts/misc_tasks.py --save-path configs/pipelines/blueberry.yaml \
|
||||
react \
|
||||
--sys-prompt-f configs/prompts/blueberry.txt \
|
||||
--tool-manager-config.client-tool-manager.tool-keys
|
||||
|
||||
# echo create xiaozhan config
|
||||
python scripts/py_scripts/misc_tasks.py --save-path configs/pipelines/xiaozhan.yaml
|
||||
37
scripts/shell_scripts/download_images.sh
Executable file
37
scripts/shell_scripts/download_images.sh
Executable file
@@ -0,0 +1,37 @@
|
||||
#!/bin/bash
|
||||
# Script to download and package Docker images for offline use
|
||||
# Run this on a machine with good Docker Hub access, then transfer images.tar to China
|
||||
|
||||
set -e
|
||||
|
||||
echo "=== Docker Image Downloader for Offline Use ==="
|
||||
echo ""
|
||||
|
||||
# Images needed
|
||||
IMAGES=(
|
||||
"node:20-alpine"
|
||||
"python:3.12-slim"
|
||||
"postgres:16-alpine"
|
||||
"nginx:alpine"
|
||||
)
|
||||
|
||||
OUTPUT_FILE="images.tar"
|
||||
|
||||
echo "Pulling Docker images..."
|
||||
for img in "${IMAGES[@]}"; do
|
||||
echo " Pulling $img..."
|
||||
docker pull "$img"
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "Saving to $OUTPUT_FILE..."
|
||||
docker save "${IMAGES[@]}" -o "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
echo "Done! File size:"
|
||||
ls -lh "$OUTPUT_FILE"
|
||||
|
||||
echo ""
|
||||
echo "To transfer to China machine and load:"
|
||||
echo " scp images.tar user@china-machine:/path/"
|
||||
echo " docker load < images.tar"
|
||||
167
scripts/shell_scripts/install.sh
Executable file
167
scripts/shell_scripts/install.sh
Executable file
@@ -0,0 +1,167 @@
|
||||
#!/bin/bash
|
||||
# Installation script for LangChain Agent
|
||||
# This script sets up and runs the entire application stack
|
||||
|
||||
set -e
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m' # No Color
|
||||
|
||||
# Configuration
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
PROJECT_ROOT="$(dirname "$(dirname "$SCRIPT_DIR")")"
|
||||
ENV_FILE="$PROJECT_ROOT/.env"
|
||||
|
||||
echo -e "${GREEN}=== LangChain Agent Installation Script ===${NC}\n"
|
||||
|
||||
# Check for required tools
|
||||
check_requirements() {
|
||||
echo -e "${YELLOW}Checking requirements...${NC}"
|
||||
|
||||
if ! command -v docker &> /dev/null; then
|
||||
echo -e "${RED}Error: Docker is not installed. Please install Docker first.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if ! command -v docker-compose &> /dev/null && ! docker compose version &> /dev/null; then
|
||||
echo -e "${RED}Error: docker-compose is not installed. Please install docker-compose first.${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}✓ All requirements met${NC}\n"
|
||||
}
|
||||
|
||||
# Create .env file if it doesn't exist
|
||||
create_env_file() {
|
||||
if [ ! -f "$ENV_FILE" ]; then
|
||||
echo -e "${YELLOW}Creating .env file...${NC}"
|
||||
cat > "$ENV_FILE" <<EOF
|
||||
# Database Configuration
|
||||
POSTGRES_DB=ai_conversations
|
||||
POSTGRES_USER=myapp_user
|
||||
POSTGRES_PASSWORD=secure_password_123
|
||||
POSTGRES_PORT=5432
|
||||
|
||||
# Backend Configuration
|
||||
BACKEND_PORT=8500
|
||||
|
||||
# Frontend Configuration
|
||||
FRONTEND_PORT=8080
|
||||
|
||||
# Database Connection String (used by backend)
|
||||
CONN_STR=postgresql://myapp_user:secure_password_123@postgres:5432/ai_conversations
|
||||
EOF
|
||||
echo -e "${GREEN}✓ Created .env file at $ENV_FILE${NC}"
|
||||
echo -e "${YELLOW} Please review and update the .env file with your preferred settings.${NC}\n"
|
||||
else
|
||||
echo -e "${GREEN}✓ .env file already exists${NC}\n"
|
||||
fi
|
||||
}
|
||||
|
||||
# Build Docker images
|
||||
build_images() {
|
||||
echo -e "${YELLOW}Building Docker images (including frontend)...${NC}"
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# Check if docker-compose or docker compose
|
||||
if docker compose version &> /dev/null; then
|
||||
COMPOSE_CMD="docker compose"
|
||||
else
|
||||
COMPOSE_CMD="docker-compose"
|
||||
fi
|
||||
|
||||
$COMPOSE_CMD -f docker/docker-compose.prod.yml build
|
||||
|
||||
echo -e "${GREEN}✓ Docker images built successfully${NC}\n"
|
||||
}
|
||||
|
||||
# Initialize database
|
||||
init_database() {
|
||||
echo -e "${YELLOW}Initializing database...${NC}"
|
||||
|
||||
# Wait for PostgreSQL to be ready
|
||||
echo "Waiting for PostgreSQL to start..."
|
||||
sleep 5
|
||||
|
||||
# The SQL files in scripts/init_database/ will be automatically executed
|
||||
# by PostgreSQL's docker-entrypoint-initdb.d mechanism
|
||||
# We just need to wait a bit for it to complete
|
||||
|
||||
echo -e "${GREEN}✓ Database initialization will be handled automatically by PostgreSQL container${NC}\n"
|
||||
}
|
||||
|
||||
# Start services
|
||||
start_services() {
|
||||
echo -e "${YELLOW}Starting services...${NC}"
|
||||
cd "$PROJECT_ROOT"
|
||||
|
||||
# Check if docker-compose or docker compose
|
||||
if docker compose version &> /dev/null; then
|
||||
COMPOSE_CMD="docker compose"
|
||||
else
|
||||
COMPOSE_CMD="docker-compose"
|
||||
fi
|
||||
|
||||
$COMPOSE_CMD -f docker/docker-compose.prod.yml up -d
|
||||
|
||||
echo -e "${GREEN}✓ Services started${NC}\n"
|
||||
}
|
||||
|
||||
# Show status
|
||||
show_status() {
|
||||
# Load environment variables from .env if it exists
|
||||
if [ -f "$ENV_FILE" ]; then
|
||||
set -a
|
||||
source "$ENV_FILE"
|
||||
set +a
|
||||
fi
|
||||
|
||||
echo -e "${GREEN}=== Installation Complete ===${NC}\n"
|
||||
echo -e "Services are starting up. Please wait a moment for them to be ready.\n"
|
||||
echo -e "Access points:"
|
||||
echo -e " - Frontend: http://localhost:${FRONTEND_PORT:-80}"
|
||||
echo -e " - Backend API: http://localhost:${BACKEND_PORT:-8500}"
|
||||
echo -e " - Database: localhost:${POSTGRES_PORT:-5432}\n"
|
||||
echo -e "To view logs:"
|
||||
echo -e " docker-compose -f docker/docker-compose.prod.yml logs -f\n"
|
||||
echo -e "To stop services:"
|
||||
echo -e " docker-compose -f docker/docker-compose.prod.yml down\n"
|
||||
echo -e "To restart services:"
|
||||
echo -e " docker-compose -f docker/docker-compose.prod.yml restart\n"
|
||||
}
|
||||
|
||||
# Main execution
|
||||
main() {
|
||||
check_requirements
|
||||
create_env_file
|
||||
build_images
|
||||
start_services
|
||||
init_database
|
||||
show_status
|
||||
|
||||
echo -e "${YELLOW}Waiting for services to be healthy...${NC}"
|
||||
sleep 10
|
||||
|
||||
# Load environment variables for health check
|
||||
if [ -f "$ENV_FILE" ]; then
|
||||
set -a
|
||||
source "$ENV_FILE"
|
||||
set +a
|
||||
fi
|
||||
|
||||
# Check service health
|
||||
echo -e "\n${YELLOW}Checking service health...${NC}"
|
||||
sleep 5 # Give services a bit more time
|
||||
if curl -f http://localhost:${BACKEND_PORT:-8500}/health &> /dev/null; then
|
||||
echo -e "${GREEN}✓ Backend is healthy${NC}"
|
||||
else
|
||||
echo -e "${YELLOW}⚠ Backend is still starting up. Check logs with: docker-compose -f docker/docker-compose.prod.yml logs backend${NC}"
|
||||
fi
|
||||
}
|
||||
|
||||
# Run main function
|
||||
main
|
||||
|
||||
46
tests/test_combined_app.py
Normal file
46
tests/test_combined_app.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||
|
||||
|
||||
def test_server_dashscope_import_is_cli_safe(monkeypatch):
|
||||
"""
|
||||
Importing server_dashscope should not invoke tyro.cli at module import time.
|
||||
"""
|
||||
import tyro
|
||||
|
||||
monkeypatch.setattr(
|
||||
tyro,
|
||||
"cli",
|
||||
lambda *_args, **_kwargs: (_ for _ in ()).throw(
|
||||
AssertionError("tyro.cli must not run during module import")
|
||||
),
|
||||
)
|
||||
sys.modules.pop("fastapi_server.server_dashscope", None)
|
||||
|
||||
module = importlib.import_module("fastapi_server.server_dashscope")
|
||||
assert module.app is not None
|
||||
assert module.dashscope_router is not None
|
||||
|
||||
|
||||
def test_combined_app_serves_front_and_dashscope_routes():
|
||||
from fastapi_server.combined import app
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# front_apis route should be available.
|
||||
front_resp = client.get("/v1/pipelines/graphs")
|
||||
assert front_resp.status_code == 200, front_resp.text
|
||||
assert "available_graphs" in front_resp.json()
|
||||
|
||||
# DashScope route should exist at the same path (missing auth should not be 404).
|
||||
dash_resp = client.post(
|
||||
"/api/v1/apps/blueberry/sessions/test-session/responses",
|
||||
json={"input": {"prompt": "hello"}, "stream": False},
|
||||
)
|
||||
assert dash_resp.status_code != 404, dash_resp.text
|
||||
|
||||
@@ -8,7 +8,10 @@ Instructions:
|
||||
- Set BASE_URL below to the server base URL you started.
|
||||
- Optionally set environment variables ALI_API_KEY and ALI_APP_ID.
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import os.path as osp
|
||||
import uuid
|
||||
from dotenv import load_dotenv
|
||||
from loguru import logger
|
||||
@@ -27,22 +30,86 @@ except Exception as e:
|
||||
|
||||
|
||||
# <<< Paste your running FastAPI base url here >>>
|
||||
BASE_URL = os.getenv("DS_BASE_URL", "http://127.0.0.1:8588/api/")
|
||||
BASE_URL = os.getenv("DS_BASE_URL", "http://127.0.0.1:8500/api/")
|
||||
|
||||
|
||||
# Params
|
||||
API_KEY = os.getenv("FAST_AUTH_KEYS", "test-key")
|
||||
APP_ID = os.getenv("ALI_APP_ID", "test-app")
|
||||
def _first_non_empty_csv_token(value: str) -> str:
|
||||
parts = [p.strip() for p in (value or "").split(",") if p.strip()]
|
||||
return parts[0] if parts else ""
|
||||
|
||||
|
||||
def _load_registry() -> dict:
|
||||
project_root = osp.dirname(osp.dirname(osp.abspath(__file__)))
|
||||
registry_path = os.getenv(
|
||||
"FAST_PIPELINE_REGISTRY_FILE",
|
||||
osp.join(project_root, "configs", "pipeline_registry.json"),
|
||||
)
|
||||
with open(registry_path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def _pick_api_key(registry: dict) -> str:
|
||||
# For local server_dashscope testing, FAST_AUTH_KEYS is usually the server auth source.
|
||||
fast_first = _first_non_empty_csv_token(os.getenv("FAST_AUTH_KEYS", ""))
|
||||
ali_key = (os.getenv("ALI_API_KEY") or "").strip()
|
||||
|
||||
api_policies = registry.get("api_keys") or {}
|
||||
if fast_first and (not api_policies or fast_first in api_policies):
|
||||
return fast_first
|
||||
if ali_key and (not api_policies or ali_key in api_policies):
|
||||
return ali_key
|
||||
if fast_first:
|
||||
return fast_first
|
||||
if ali_key:
|
||||
return ali_key
|
||||
raise RuntimeError(
|
||||
"Missing API key. Set FAST_AUTH_KEYS or ALI_API_KEY in your environment."
|
||||
)
|
||||
|
||||
|
||||
def _pick_app_id(api_key: str, registry: dict) -> str:
|
||||
if api_key:
|
||||
explicit = (registry.get("api_keys") or {}).get(api_key, {}).get("app_id")
|
||||
if explicit:
|
||||
return explicit
|
||||
|
||||
pipelines_obj = registry.get("pipelines")
|
||||
if not isinstance(pipelines_obj, dict):
|
||||
pipelines_obj = {}
|
||||
pipeline_ids = [r for r in pipelines_obj.keys() if isinstance(r, str) and r]
|
||||
|
||||
if pipeline_ids:
|
||||
return pipeline_ids[0]
|
||||
return "default"
|
||||
|
||||
|
||||
def _warn_if_policy_disallows_app_id(api_key: str, app_id: str, registry: dict) -> None:
|
||||
policy = (registry.get("api_keys") or {}).get(api_key, {})
|
||||
if not isinstance(policy, dict):
|
||||
return
|
||||
allowed = policy.get("allowed_pipeline_ids")
|
||||
if isinstance(allowed, list) and allowed and app_id not in allowed:
|
||||
logger.bind(tag=TAG).warning(
|
||||
f"app_id='{app_id}' is not in allowed_pipeline_ids for current API key; server may return 403."
|
||||
)
|
||||
|
||||
|
||||
REGISTRY = _load_registry()
|
||||
API_KEY = _pick_api_key(REGISTRY)
|
||||
APP_ID = _pick_app_id(API_KEY, REGISTRY)
|
||||
_warn_if_policy_disallows_app_id(API_KEY, APP_ID, REGISTRY)
|
||||
SESSION_ID = str(uuid.uuid4())
|
||||
|
||||
dialogue = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Say 'the world is awesome and beautiful'."},
|
||||
{"role": "user", "content": "你叫什么名字"},
|
||||
]
|
||||
|
||||
call_params = {
|
||||
"api_key": API_KEY,
|
||||
"app_id": "test_app",
|
||||
"session_id": "123",
|
||||
"app_id": APP_ID,
|
||||
"session_id": SESSION_ID,
|
||||
"messages": dialogue,
|
||||
"stream": True,
|
||||
}
|
||||
@@ -52,8 +119,11 @@ def main():
|
||||
# Point the SDK to our FastAPI implementation
|
||||
if BASE_URL and ("/api/" in BASE_URL):
|
||||
dashscope.base_http_api_url = BASE_URL
|
||||
# Some SDK paths rely on global api_key to build Authorization header.
|
||||
dashscope.api_key = API_KEY
|
||||
# dashscope.base_http_api_url = BASE_URL
|
||||
print(f"Using base_http_api_url = {dashscope.base_http_api_url}")
|
||||
print(f"Using app_id = {APP_ID}")
|
||||
|
||||
print("\nCalling Application.call(stream=True)...\n")
|
||||
responses = Application.call(**call_params)
|
||||
@@ -72,7 +142,7 @@ def main():
|
||||
continue
|
||||
# SDK流式为增量覆盖,计算差量输出
|
||||
if len(current_text) >= len(last_text):
|
||||
delta = current_text[len(last_text):]
|
||||
delta = current_text[len(last_text) :]
|
||||
else:
|
||||
# 避免偶发回退
|
||||
delta = current_text
|
||||
@@ -89,7 +159,7 @@ def main():
|
||||
logger.bind(tag=TAG).error(
|
||||
f"code={responses.status_code}, message={responses.message}, 请参考文档:https://help.aliyun.com/zh/model-studio/developer-reference/error-code"
|
||||
)
|
||||
u = "【阿里百练API服务响应异常】"
|
||||
u = "【阿里百练API服务响应异常】"
|
||||
else:
|
||||
full_text = getattr(getattr(responses, "output", None), "text", "")
|
||||
logger.bind(tag=TAG).info(
|
||||
@@ -99,11 +169,8 @@ def main():
|
||||
print("from non-stream: ", u)
|
||||
except Exception as e:
|
||||
logger.bind(tag=TAG).error(f"Error: {e}")
|
||||
u = "【阿里百练API服务响应异常】"
|
||||
|
||||
u = "【阿里百练API服务响应异常】"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
|
||||
351
tests/test_front_apis_registry.py
Normal file
351
tests/test_front_apis_registry.py
Normal file
@@ -0,0 +1,351 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from datetime import datetime, timedelta, timezone
|
||||
import importlib
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
os.environ.setdefault("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||
|
||||
try:
|
||||
front_apis = importlib.import_module("lang_agent.fastapi_server.front_apis")
|
||||
except ModuleNotFoundError:
|
||||
front_apis = importlib.import_module("fastapi_server.front_apis")
|
||||
|
||||
|
||||
def _fake_build_fn(
|
||||
pipeline_id: str,
|
||||
prompt_set: str,
|
||||
tool_keys,
|
||||
api_key: str,
|
||||
llm_name: str = "qwen-plus",
|
||||
pipeline_config_dir: str = "configs/pipelines",
|
||||
):
|
||||
out_dir = Path(pipeline_config_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"{pipeline_id}.yaml"
|
||||
out_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"pipeline_id": pipeline_id,
|
||||
"prompt_set": prompt_set,
|
||||
"tool_keys": tool_keys,
|
||||
"api_key": api_key,
|
||||
"llm_name": llm_name,
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
return {"path": str(out_file)}
|
||||
|
||||
|
||||
class _FakeCursor:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
self._result = []
|
||||
self._last_sql = ""
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def execute(self, sql, params=None):
|
||||
self._last_sql = sql
|
||||
query = " ".join(sql.split()).lower()
|
||||
params = params or ()
|
||||
|
||||
if "group by conversation_id, pipeline_id" in query:
|
||||
pipeline_id = params[0]
|
||||
limit = int(params[1])
|
||||
grouped = {}
|
||||
for row in self._rows:
|
||||
if row["pipeline_id"] != pipeline_id:
|
||||
continue
|
||||
conv_id = row["conversation_id"]
|
||||
if conv_id not in grouped:
|
||||
grouped[conv_id] = {
|
||||
"conversation_id": conv_id,
|
||||
"pipeline_id": row["pipeline_id"],
|
||||
"message_count": 0,
|
||||
"last_updated": row["created_at"],
|
||||
}
|
||||
grouped[conv_id]["message_count"] += 1
|
||||
if row["created_at"] > grouped[conv_id]["last_updated"]:
|
||||
grouped[conv_id]["last_updated"] = row["created_at"]
|
||||
values = sorted(grouped.values(), key=lambda x: x["last_updated"], reverse=True)
|
||||
self._result = values[:limit]
|
||||
return
|
||||
|
||||
if "select 1 from messages" in query:
|
||||
pipeline_id, conversation_id = params
|
||||
found = any(
|
||||
row["pipeline_id"] == pipeline_id
|
||||
and row["conversation_id"] == conversation_id
|
||||
for row in self._rows
|
||||
)
|
||||
self._result = [{"exists": 1}] if found else []
|
||||
return
|
||||
|
||||
if "order by sequence_number asc" in query:
|
||||
pipeline_id, conversation_id = params
|
||||
self._result = sorted(
|
||||
[
|
||||
{
|
||||
"message_type": row["message_type"],
|
||||
"content": row["content"],
|
||||
"sequence_number": row["sequence_number"],
|
||||
"created_at": row["created_at"],
|
||||
}
|
||||
for row in self._rows
|
||||
if row["pipeline_id"] == pipeline_id
|
||||
and row["conversation_id"] == conversation_id
|
||||
],
|
||||
key=lambda x: x["sequence_number"],
|
||||
)
|
||||
return
|
||||
|
||||
raise AssertionError(f"Unsupported SQL in test fake: {self._last_sql}")
|
||||
|
||||
def fetchall(self):
|
||||
return self._result
|
||||
|
||||
def fetchone(self):
|
||||
if not self._result:
|
||||
return None
|
||||
return self._result[0]
|
||||
|
||||
|
||||
class _FakeConnection:
|
||||
def __init__(self, rows):
|
||||
self._rows = rows
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def cursor(self, row_factory=None):
|
||||
return _FakeCursor(self._rows)
|
||||
|
||||
|
||||
def test_registry_route_lifecycle(monkeypatch, tmp_path):
|
||||
registry_path = tmp_path / "pipeline_registry.json"
|
||||
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
|
||||
monkeypatch.setitem(front_apis.GRAPH_BUILD_FNCS, "routing", _fake_build_fn)
|
||||
|
||||
client = TestClient(front_apis.app)
|
||||
|
||||
create_resp = client.post(
|
||||
"/v1/pipelines",
|
||||
json={
|
||||
"graph_id": "routing",
|
||||
"pipeline_id": "xiaozhan",
|
||||
"prompt_set_id": "default",
|
||||
"tool_keys": ["weather"],
|
||||
"api_key": "sk-test",
|
||||
"llm_name": "qwen-plus",
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
assert create_resp.status_code == 200, create_resp.text
|
||||
create_data = create_resp.json()
|
||||
assert create_data["pipeline_id"] == "xiaozhan"
|
||||
assert create_data["graph_id"] == "routing"
|
||||
assert create_data["llm_name"] == "qwen-plus"
|
||||
assert create_data["reload_required"] is False
|
||||
|
||||
list_resp = client.get("/v1/pipelines")
|
||||
assert list_resp.status_code == 200, list_resp.text
|
||||
items = list_resp.json()["items"]
|
||||
assert len(items) == 1
|
||||
assert items[0]["pipeline_id"] == "xiaozhan"
|
||||
assert items[0]["graph_id"] == "routing"
|
||||
assert items[0]["llm_name"] == "qwen-plus"
|
||||
assert items[0]["enabled"] is True
|
||||
|
||||
disable_resp = client.delete("/v1/pipelines/xiaozhan")
|
||||
assert disable_resp.status_code == 200, disable_resp.text
|
||||
disable_data = disable_resp.json()
|
||||
assert disable_data["pipeline_id"] == "xiaozhan"
|
||||
assert disable_data["enabled"] is False
|
||||
|
||||
list_after = client.get("/v1/pipelines")
|
||||
assert list_after.status_code == 200, list_after.text
|
||||
items_after = list_after.json()["items"]
|
||||
assert len(items_after) == 1
|
||||
assert items_after[0]["enabled"] is False
|
||||
|
||||
registry = json.loads(registry_path.read_text(encoding="utf-8"))
|
||||
pipeline = registry["pipelines"]["xiaozhan"]
|
||||
assert pipeline["graph_id"] == "routing"
|
||||
assert pipeline["enabled"] is False
|
||||
|
||||
|
||||
def test_registry_api_key_policy_lifecycle(monkeypatch, tmp_path):
|
||||
registry_path = tmp_path / "pipeline_registry.json"
|
||||
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
|
||||
monkeypatch.setitem(front_apis.GRAPH_BUILD_FNCS, "routing", _fake_build_fn)
|
||||
|
||||
client = TestClient(front_apis.app)
|
||||
|
||||
create_resp = client.post(
|
||||
"/v1/pipelines",
|
||||
json={
|
||||
"graph_id": "routing",
|
||||
"pipeline_id": "blueberry",
|
||||
"prompt_set_id": "default",
|
||||
"tool_keys": [],
|
||||
"api_key": "sk-test",
|
||||
"llm_name": "qwen-plus",
|
||||
"enabled": True,
|
||||
},
|
||||
)
|
||||
assert create_resp.status_code == 200, create_resp.text
|
||||
|
||||
upsert_resp = client.put(
|
||||
"/v1/pipelines/api-keys/sk-test-key",
|
||||
json={
|
||||
"default_pipeline_id": "blueberry",
|
||||
"allowed_pipeline_ids": ["blueberry"],
|
||||
"app_id": "blueberry",
|
||||
},
|
||||
)
|
||||
assert upsert_resp.status_code == 200, upsert_resp.text
|
||||
upsert_data = upsert_resp.json()
|
||||
assert upsert_data["api_key"] == "sk-test-key"
|
||||
assert upsert_data["default_pipeline_id"] == "blueberry"
|
||||
assert upsert_data["allowed_pipeline_ids"] == ["blueberry"]
|
||||
assert upsert_data["app_id"] == "blueberry"
|
||||
|
||||
list_resp = client.get("/v1/pipelines/api-keys")
|
||||
assert list_resp.status_code == 200, list_resp.text
|
||||
list_data = list_resp.json()
|
||||
assert list_data["count"] == 1
|
||||
assert list_data["items"][0]["api_key"] == "sk-test-key"
|
||||
|
||||
delete_resp = client.delete("/v1/pipelines/api-keys/sk-test-key")
|
||||
assert delete_resp.status_code == 200, delete_resp.text
|
||||
delete_data = delete_resp.json()
|
||||
assert delete_data["api_key"] == "sk-test-key"
|
||||
assert delete_data["status"] == "deleted"
|
||||
assert delete_data["reload_required"] is False
|
||||
|
||||
|
||||
def test_pipeline_conversation_routes(monkeypatch):
|
||||
now = datetime.now(timezone.utc)
|
||||
rows = [
|
||||
{
|
||||
"conversation_id": "agent-a:conv-1",
|
||||
"pipeline_id": "agent-a",
|
||||
"message_type": "human",
|
||||
"content": "hello",
|
||||
"sequence_number": 1,
|
||||
"created_at": now - timedelta(seconds=30),
|
||||
},
|
||||
{
|
||||
"conversation_id": "agent-a:conv-1",
|
||||
"pipeline_id": "agent-a",
|
||||
"message_type": "ai",
|
||||
"content": "hi there",
|
||||
"sequence_number": 2,
|
||||
"created_at": now - timedelta(seconds=20),
|
||||
},
|
||||
{
|
||||
"conversation_id": "agent-a:conv-2",
|
||||
"pipeline_id": "agent-a",
|
||||
"message_type": "human",
|
||||
"content": "second thread",
|
||||
"sequence_number": 1,
|
||||
"created_at": now - timedelta(seconds=10),
|
||||
},
|
||||
{
|
||||
"conversation_id": "agent-b:conv-9",
|
||||
"pipeline_id": "agent-b",
|
||||
"message_type": "human",
|
||||
"content": "other pipeline",
|
||||
"sequence_number": 1,
|
||||
"created_at": now - timedelta(seconds=5),
|
||||
},
|
||||
]
|
||||
|
||||
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||
monkeypatch.setattr(
|
||||
front_apis.psycopg,
|
||||
"connect",
|
||||
lambda _conn_str: _FakeConnection(rows),
|
||||
)
|
||||
|
||||
client = TestClient(front_apis.app)
|
||||
|
||||
list_resp = client.get("/v1/pipelines/agent-a/conversations")
|
||||
assert list_resp.status_code == 200, list_resp.text
|
||||
list_data = list_resp.json()
|
||||
assert list_data["pipeline_id"] == "agent-a"
|
||||
assert list_data["count"] == 2
|
||||
assert [item["conversation_id"] for item in list_data["items"]] == [
|
||||
"agent-a:conv-2",
|
||||
"agent-a:conv-1",
|
||||
]
|
||||
assert all(item["pipeline_id"] == "agent-a" for item in list_data["items"])
|
||||
|
||||
msg_resp = client.get("/v1/pipelines/agent-a/conversations/agent-a:conv-1/messages")
|
||||
assert msg_resp.status_code == 200, msg_resp.text
|
||||
msg_data = msg_resp.json()
|
||||
assert msg_data["pipeline_id"] == "agent-a"
|
||||
assert msg_data["conversation_id"] == "agent-a:conv-1"
|
||||
assert msg_data["count"] == 2
|
||||
assert [item["message_type"] for item in msg_data["items"]] == ["human", "ai"]
|
||||
assert [item["sequence_number"] for item in msg_data["items"]] == [1, 2]
|
||||
|
||||
|
||||
def test_pipeline_conversation_messages_404(monkeypatch):
|
||||
rows = [
|
||||
{
|
||||
"conversation_id": "agent-b:conv-9",
|
||||
"pipeline_id": "agent-b",
|
||||
"message_type": "human",
|
||||
"content": "other pipeline",
|
||||
"sequence_number": 1,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
},
|
||||
]
|
||||
monkeypatch.setenv("CONN_STR", "postgresql://dummy:dummy@localhost/dummy")
|
||||
monkeypatch.setattr(
|
||||
front_apis.psycopg,
|
||||
"connect",
|
||||
lambda _conn_str: _FakeConnection(rows),
|
||||
)
|
||||
|
||||
client = TestClient(front_apis.app)
|
||||
resp = client.get("/v1/pipelines/agent-a/conversations/agent-b:conv-9/messages")
|
||||
assert resp.status_code == 404, resp.text
|
||||
assert "not found for pipeline 'agent-a'" in resp.json()["detail"]
|
||||
|
||||
|
||||
def test_runtime_auth_info_prefers_registry_then_env(monkeypatch, tmp_path):
|
||||
registry_path = tmp_path / "pipeline_registry.json"
|
||||
registry_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"pipelines": {},
|
||||
"api_keys": {
|
||||
"sk-from-registry": {"default_pipeline_id": "blueberry"},
|
||||
},
|
||||
}
|
||||
),
|
||||
encoding="utf-8",
|
||||
)
|
||||
monkeypatch.setattr(front_apis, "PIPELINE_REGISTRY_PATH", str(registry_path))
|
||||
monkeypatch.setenv("FAST_AUTH_KEYS", "sk-from-env,other")
|
||||
|
||||
client = TestClient(front_apis.app)
|
||||
resp = client.get("/v1/runtime-auth")
|
||||
assert resp.status_code == 200, resp.text
|
||||
data = resp.json()
|
||||
assert data["fast_api_key"] == "sk-from-registry"
|
||||
assert data["source"] == "pipeline_registry"
|
||||
113
tests/test_migrate_yaml_prompts_to_db.py
Normal file
113
tests/test_migrate_yaml_prompts_to_db.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
|
||||
def _load_module():
|
||||
project_root = Path(__file__).resolve().parents[1]
|
||||
script_path = project_root / "scripts" / "py_scripts" / "migrate_yaml_prompts_to_db.py"
|
||||
spec = importlib.util.spec_from_file_location("migrate_yaml_prompts_to_db", script_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
assert spec.loader is not None
|
||||
sys.modules[spec.name] = module
|
||||
spec.loader.exec_module(module)
|
||||
return module
|
||||
|
||||
|
||||
def test_infer_pipeline_id_falls_back_to_filename():
|
||||
module = _load_module()
|
||||
conf = SimpleNamespace(
|
||||
pipeline_id=None,
|
||||
graph_config=SimpleNamespace(pipeline_id=None),
|
||||
)
|
||||
out = module._infer_pipeline_id(conf, "/tmp/blueberry.yaml")
|
||||
assert out == "blueberry"
|
||||
|
||||
|
||||
def test_extract_prompt_dict_for_react_txt(tmp_path):
|
||||
module = _load_module()
|
||||
prompt_f = tmp_path / "sys.txt"
|
||||
prompt_f.write_text("hello react", encoding="utf-8")
|
||||
graph_conf = SimpleNamespace(sys_prompt_f=str(prompt_f))
|
||||
prompt_dict = module._extract_prompt_dict(graph_conf)
|
||||
assert prompt_dict == {"sys_prompt": "hello react"}
|
||||
|
||||
|
||||
def test_extract_prompt_dict_for_routing_dir(tmp_path):
|
||||
module = _load_module()
|
||||
(tmp_path / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||
(tmp_path / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||
graph_conf = SimpleNamespace(sys_promp_dir=str(tmp_path))
|
||||
prompt_dict = module._extract_prompt_dict(graph_conf)
|
||||
assert prompt_dict["route_prompt"] == "route"
|
||||
assert prompt_dict["chat_prompt"] == "chat"
|
||||
|
||||
|
||||
def test_collect_payload_routing_ignores_chatty_prompt_for_tool_node(tmp_path):
|
||||
module = _load_module()
|
||||
prompt_dir = tmp_path / "prompts"
|
||||
prompt_dir.mkdir()
|
||||
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
|
||||
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
|
||||
|
||||
class RoutingConfig:
|
||||
pass
|
||||
|
||||
class ToolNodeConfig:
|
||||
pass
|
||||
|
||||
graph_conf = RoutingConfig()
|
||||
graph_conf.sys_promp_dir = str(prompt_dir)
|
||||
graph_conf.tool_node_config = ToolNodeConfig()
|
||||
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
|
||||
|
||||
conf = SimpleNamespace(
|
||||
pipeline_id=None,
|
||||
api_key="sk",
|
||||
graph_config=graph_conf,
|
||||
)
|
||||
|
||||
module.load_tyro_conf = lambda _: conf
|
||||
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
|
||||
assert payload.pipeline_id == "xiaozhan"
|
||||
assert set(payload.prompt_dict.keys()) == {"route_prompt", "chat_prompt", "tool_prompt"}
|
||||
assert "chatty_prompt" not in payload.prompt_dict
|
||||
|
||||
|
||||
def test_collect_payload_routing_includes_chatty_prompt_for_chatty_node(tmp_path):
|
||||
module = _load_module()
|
||||
prompt_dir = tmp_path / "prompts"
|
||||
prompt_dir.mkdir()
|
||||
(prompt_dir / "route_prompt.txt").write_text("route", encoding="utf-8")
|
||||
(prompt_dir / "chat_prompt.txt").write_text("chat", encoding="utf-8")
|
||||
(prompt_dir / "tool_prompt.txt").write_text("tool", encoding="utf-8")
|
||||
(prompt_dir / "chatty_prompt.txt").write_text("chatty", encoding="utf-8")
|
||||
|
||||
class RoutingConfig:
|
||||
pass
|
||||
|
||||
class ChattyToolNodeConfig:
|
||||
pass
|
||||
|
||||
graph_conf = RoutingConfig()
|
||||
graph_conf.sys_promp_dir = str(prompt_dir)
|
||||
graph_conf.tool_node_config = ChattyToolNodeConfig()
|
||||
graph_conf.tool_node_config.tool_prompt_f = str(prompt_dir / "tool_prompt.txt")
|
||||
graph_conf.tool_node_config.chatty_sys_prompt_f = str(
|
||||
prompt_dir / "chatty_prompt.txt"
|
||||
)
|
||||
|
||||
conf = SimpleNamespace(
|
||||
pipeline_id="xiaozhan",
|
||||
api_key="sk",
|
||||
graph_config=graph_conf,
|
||||
)
|
||||
|
||||
module.load_tyro_conf = lambda _: conf
|
||||
payload = module._collect_payload(str(tmp_path / "xiaozhan.yaml"))
|
||||
assert payload.pipeline_id == "xiaozhan"
|
||||
assert "chatty_prompt" in payload.prompt_dict
|
||||
|
||||
169
tests/test_server_dashscope_routes.py
Normal file
169
tests/test_server_dashscope_routes.py
Normal file
@@ -0,0 +1,169 @@
|
||||
import os
|
||||
import signal
|
||||
import subprocess
|
||||
import time
|
||||
from http import HTTPStatus
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
def _get_service_api_key() -> str:
|
||||
"""Return the first API key from FAST_AUTH_KEYS env (comma-separated)."""
|
||||
raw = os.getenv("FAST_AUTH_KEYS", "")
|
||||
parts = [p.strip() for p in raw.split(",") if p.strip()]
|
||||
return parts[0] if parts else None
|
||||
|
||||
|
||||
def _wait_for_health(base_url: str, timeout: float = 20.0) -> None:
|
||||
"""Poll the /health endpoint until the server is up or timeout."""
|
||||
deadline = time.time() + timeout
|
||||
url = base_url.rstrip("/") + "/health"
|
||||
last_err = None
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url, timeout=2)
|
||||
if resp.status_code == HTTPStatus.OK:
|
||||
return
|
||||
except Exception as e: # pragma: no cover - best-effort polling
|
||||
last_err = e
|
||||
time.sleep(0.5)
|
||||
raise RuntimeError(f"Server did not become healthy in time: last_err={last_err}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dashscope_server():
|
||||
"""
|
||||
Start a real uvicorn instance of server_dashscope for end-to-end routing tests.
|
||||
|
||||
This mirrors how docker-compose runs `xiaozhan` (server_dashscope.py) so we
|
||||
exercise the full stack, including PipelineManager + registry routing.
|
||||
"""
|
||||
env = os.environ.copy()
|
||||
# Ensure registry file is picked up (falls back to this by default, but be explicit).
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
registry_path = os.path.join(project_root, "configs", "pipeline_registry.json")
|
||||
env.setdefault("FAST_PIPELINE_REGISTRY_FILE", registry_path)
|
||||
|
||||
cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"uvicorn",
|
||||
"fastapi_server.server_dashscope:app",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
"8588",
|
||||
]
|
||||
proc = subprocess.Popen(cmd, env=env)
|
||||
|
||||
base_url = "http://127.0.0.1:8588"
|
||||
try:
|
||||
_wait_for_health(base_url)
|
||||
except Exception:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=10)
|
||||
raise
|
||||
|
||||
yield base_url
|
||||
|
||||
# Teardown
|
||||
if proc.poll() is None:
|
||||
proc.send_signal(signal.SIGTERM)
|
||||
try:
|
||||
proc.wait(timeout=10)
|
||||
except subprocess.TimeoutExpired: # pragma: no cover - best-effort cleanup
|
||||
proc.kill()
|
||||
proc.wait(timeout=10)
|
||||
|
||||
|
||||
def _post_app_response(base_url: str, pipeline_id: str, body: dict, api_key: str):
|
||||
url = f"{base_url}/api/v1/apps/{pipeline_id}/sessions/test-session/responses"
|
||||
headers = {"Authorization": f"Bearer {api_key}"}
|
||||
resp = requests.post(url, json=body, headers=headers, timeout=20)
|
||||
return resp
|
||||
|
||||
|
||||
def test_pipeline_selected_via_pipeline_id_body(dashscope_server):
|
||||
"""
|
||||
When client specifies `pipeline_id` in the body, it should be used as the selector
|
||||
and surfaced back in the JSON response.
|
||||
"""
|
||||
base_url = dashscope_server
|
||||
api_key = _get_service_api_key()
|
||||
if not api_key:
|
||||
pytest.skip(
|
||||
"FAST_AUTH_KEYS is not set; cannot authenticate against server_dashscope"
|
||||
)
|
||||
body = {
|
||||
"input": {
|
||||
"prompt": "你是谁?",
|
||||
"session_id": "sess-1",
|
||||
},
|
||||
"pipeline_id": "blueberry",
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
resp = _post_app_response(
|
||||
base_url, pipeline_id="blueberry", body=body, api_key=api_key
|
||||
)
|
||||
assert resp.status_code == HTTPStatus.OK, resp.text
|
||||
data = resp.json()
|
||||
assert data.get("pipeline_id") == "blueberry"
|
||||
assert "text" in data.get("output", {})
|
||||
|
||||
|
||||
def test_pipeline_selected_via_pipeline_id_body_blueberry(dashscope_server):
|
||||
"""
|
||||
When client specifies `pipeline_id` in the body, it should be used as the selector
|
||||
and surfaced back in the JSON response.
|
||||
"""
|
||||
base_url = dashscope_server
|
||||
api_key = _get_service_api_key()
|
||||
if not api_key:
|
||||
pytest.skip(
|
||||
"FAST_AUTH_KEYS is not set; cannot authenticate against server_dashscope"
|
||||
)
|
||||
body = {
|
||||
"input": {
|
||||
"prompt": "hello from blueberry",
|
||||
"session_id": "sess-2",
|
||||
},
|
||||
"pipeline_id": "blueberry",
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
resp = _post_app_response(
|
||||
base_url, pipeline_id="blueberry", body=body, api_key=api_key
|
||||
)
|
||||
assert resp.status_code == HTTPStatus.OK, resp.text
|
||||
data = resp.json()
|
||||
assert data.get("pipeline_id") == "blueberry"
|
||||
assert "text" in data.get("output", {})
|
||||
|
||||
|
||||
def test_pipeline_forbidden_for_api_key_when_not_allowed(dashscope_server):
|
||||
"""
|
||||
API key policy in pipeline_registry should prevent a key from using pipelines
|
||||
it is not explicitly allowed to access.
|
||||
"""
|
||||
base_url = dashscope_server
|
||||
body = {
|
||||
"input": {
|
||||
"prompt": "this should be forbidden",
|
||||
"session_id": "sess-3",
|
||||
},
|
||||
"pipeline_id": "blueberry",
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Use a guaranteed-wrong API key so we test 401 behavior regardless of registry config.
|
||||
resp = _post_app_response(
|
||||
base_url, pipeline_id="blueberry", body=body, api_key="invalid-key-for-test"
|
||||
)
|
||||
assert resp.status_code == HTTPStatus.UNAUTHORIZED
|
||||
data = resp.json()
|
||||
assert data.get("detail") == "Invalid API key"
|
||||
156
tests/test_server_pipeline_manager_refresh.py
Normal file
156
tests/test_server_pipeline_manager_refresh.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import json
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from lang_agent.components.server_pipeline_manager import ServerPipelineManager
|
||||
|
||||
|
||||
class _DummyPipeline:
|
||||
def __init__(self, model: str):
|
||||
self.model = model
|
||||
|
||||
|
||||
class _DummyConfig:
|
||||
def __init__(self, llm_name: str = "qwen-plus"):
|
||||
self.llm_name = llm_name
|
||||
|
||||
def setup(self):
|
||||
return _DummyPipeline(model=self.llm_name)
|
||||
|
||||
|
||||
def _write_registry(path, pipelines, api_keys=None):
|
||||
content = {"pipelines": pipelines, "api_keys": api_keys or {}}
|
||||
path.write_text(json.dumps(content, indent=2), encoding="utf-8")
|
||||
# Ensure mtime changes reliably on fast CI filesystems.
|
||||
time.sleep(0.01)
|
||||
|
||||
|
||||
def test_refresh_registry_picks_up_new_pipeline(tmp_path):
|
||||
registry_path = tmp_path / "pipeline_registry.json"
|
||||
_write_registry(
|
||||
registry_path,
|
||||
pipelines={
|
||||
"default": {
|
||||
"enabled": True,
|
||||
"config_file": None,
|
||||
"llm_name": "qwen-plus",
|
||||
}
|
||||
},
|
||||
)
|
||||
manager = ServerPipelineManager(
|
||||
default_pipeline_id="default",
|
||||
default_config=_DummyConfig(),
|
||||
)
|
||||
manager.load_registry(str(registry_path))
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
manager.resolve_pipeline_id(
|
||||
body={"pipeline_id": "blueberry"}, app_id=None, api_key="k1"
|
||||
)
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
_write_registry(
|
||||
registry_path,
|
||||
pipelines={
|
||||
"default": {
|
||||
"enabled": True,
|
||||
"config_file": None,
|
||||
"llm_name": "qwen-plus",
|
||||
},
|
||||
"blueberry": {
|
||||
"enabled": True,
|
||||
"config_file": None,
|
||||
"llm_name": "qwen-max",
|
||||
},
|
||||
},
|
||||
)
|
||||
changed = manager.refresh_registry_if_needed()
|
||||
assert changed is True
|
||||
|
||||
resolved = manager.resolve_pipeline_id(
|
||||
body={"pipeline_id": "blueberry"}, app_id=None, api_key="k1"
|
||||
)
|
||||
assert resolved == "blueberry"
|
||||
|
||||
|
||||
def test_refresh_registry_invalidates_cache_for_changed_pipeline(tmp_path):
|
||||
registry_path = tmp_path / "pipeline_registry.json"
|
||||
_write_registry(
|
||||
registry_path,
|
||||
pipelines={
|
||||
"blueberry": {
|
||||
"enabled": True,
|
||||
"config_file": None,
|
||||
"llm_name": "qwen-plus",
|
||||
}
|
||||
},
|
||||
)
|
||||
manager = ServerPipelineManager(
|
||||
default_pipeline_id="blueberry",
|
||||
default_config=_DummyConfig(),
|
||||
)
|
||||
manager.load_registry(str(registry_path))
|
||||
|
||||
first_pipeline, first_model = manager.get_pipeline("blueberry")
|
||||
assert first_model == "qwen-plus"
|
||||
|
||||
_write_registry(
|
||||
registry_path,
|
||||
pipelines={
|
||||
"blueberry": {
|
||||
"enabled": True,
|
||||
"config_file": None,
|
||||
"llm_name": "qwen-max",
|
||||
}
|
||||
},
|
||||
)
|
||||
changed = manager.refresh_registry_if_needed()
|
||||
assert changed is True
|
||||
|
||||
second_pipeline, second_model = manager.get_pipeline("blueberry")
|
||||
assert second_model == "qwen-max"
|
||||
assert second_pipeline is not first_pipeline
|
||||
|
||||
|
||||
def test_refresh_registry_applies_disabled_state_immediately(tmp_path):
|
||||
registry_path = tmp_path / "pipeline_registry.json"
|
||||
_write_registry(
|
||||
registry_path,
|
||||
pipelines={
|
||||
"blueberry": {
|
||||
"enabled": True,
|
||||
"config_file": None,
|
||||
"llm_name": "qwen-plus",
|
||||
}
|
||||
},
|
||||
)
|
||||
manager = ServerPipelineManager(
|
||||
default_pipeline_id="blueberry",
|
||||
default_config=_DummyConfig(),
|
||||
)
|
||||
manager.load_registry(str(registry_path))
|
||||
manager.get_pipeline("blueberry")
|
||||
|
||||
_write_registry(
|
||||
registry_path,
|
||||
pipelines={
|
||||
"blueberry": {
|
||||
"enabled": False,
|
||||
"config_file": None,
|
||||
"llm_name": "qwen-plus",
|
||||
}
|
||||
},
|
||||
)
|
||||
changed = manager.refresh_registry_if_needed()
|
||||
assert changed is True
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
manager.get_pipeline("blueberry")
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user