1137 lines
43 KiB
Python
1137 lines
43 KiB
Python
# 修正拼写错误,将 qyditic 改为 pydantic
|
|
from fastapi import FastAPI, HTTPException, Depends, Request, UploadFile, File, Form
|
|
from pydantic import BaseModel, validator
|
|
import json
|
|
import psycopg2.pool
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from fastapi.security import APIKeyHeader
|
|
import uuid
|
|
import os
|
|
from typing import Optional, List
|
|
from pydantic import BaseModel
|
|
from sms import SMS
|
|
from oss_service import oss_service
|
|
from config import settings
|
|
|
|
class SMSRequest(BaseModel):
|
|
phone_number: str
|
|
code: str
|
|
template_code: Optional[str] = settings.sms_template_code
|
|
sign_name: Optional[str] = settings.sms_sign_name
|
|
|
|
|
|
|
|
import os
|
|
API_KEY = "123tangledup-ai"
|
|
api_key_header = APIKeyHeader(name="x-api-key")
|
|
|
|
def get_current_user(api_key: str = Depends(api_key_header)):
|
|
if api_key != API_KEY:
|
|
raise HTTPException(status_code=401, detail="Invalid API Key")
|
|
return {"api_key": api_key}
|
|
|
|
app = FastAPI()
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
# 原代码存在语法错误,* 不能单独作为列表元素,这里假设要允许所有来源,使用 ['*']
|
|
allow_origins=['http://localhost:5173', 'https://data.tangledup-ai.com', 'https://sms.tangledup-ai.com' , '*'],
|
|
allow_credentials=True,
|
|
allow_methods=['*'],
|
|
allow_headers=['*'],
|
|
)
|
|
|
|
# 数据库连接池配置
|
|
pool = psycopg2.pool.SimpleConnectionPool(
|
|
minconn=1,
|
|
maxconn=10,
|
|
host="47.101.218.42",
|
|
port="5432",
|
|
database="jeremygan2021",
|
|
user="jeremygan2021",
|
|
password="qweasdzxc1",
|
|
sslmode="disable"
|
|
)
|
|
|
|
class User(BaseModel):
|
|
phone: str
|
|
name: str
|
|
user_details: str
|
|
avatar: str
|
|
email: str
|
|
points: int
|
|
|
|
class UserUpdate(BaseModel):
|
|
name: str | None = None
|
|
user_details: str | None = None
|
|
avatar: str | None = None
|
|
email: str | None = None
|
|
points: int | None = None
|
|
|
|
|
|
@app.post('/register/', tags=["用户管理"])
|
|
def register(user: User, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM users WHERE phone_number = %s', (user.phone,))
|
|
if cur.fetchone():
|
|
raise HTTPException(status_code=400, detail='User already exists')
|
|
cur.execute('INSERT INTO users (phone_number, user_name, user_details, avatar, email, points, create_date) VALUES (%s, %s, %s, %s, %s, %s, NOW())', (user.phone, user.name, user.user_details, user.avatar, user.email, user.points))
|
|
conn.commit()
|
|
return {'message': 'User registered successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get('/user/{phone}', tags=["用户管理"])
|
|
def get_user(phone: str, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM users WHERE phone_number = %s', (phone,))
|
|
user = cur.fetchone()
|
|
if not user:
|
|
raise HTTPException(status_code=404, detail='User not found')
|
|
# 直接返回完整的用户数据作为 JSON
|
|
columns = [desc[0] for desc in cur.description]
|
|
user_dict = dict(zip(columns, user))
|
|
return user_dict
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
@app.get('/agents/{phone}', tags=["用户管理"])
|
|
def get_user(phone: str, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s', (phone,))
|
|
users = cur.fetchall()
|
|
if not users:
|
|
raise HTTPException(status_code=404, detail='No agents found for this user')
|
|
# 返回所有agent数据作为JSON数组
|
|
columns = [desc[0] for desc in cur.description]
|
|
return [dict(zip(columns, user)) for user in users]
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
|
|
|
|
class AgentCard(BaseModel):
|
|
phone_number: str
|
|
card_info: str
|
|
agent_avatar_url: str
|
|
agent_prompt: str
|
|
agent_name: str
|
|
is_publish: bool
|
|
create_date: str
|
|
voice_type: str
|
|
temperature: float
|
|
type: int = 0
|
|
|
|
|
|
class AgentCardUpdate(BaseModel):
|
|
card_info: str | None = None
|
|
agent_avatar_url: str | None = None
|
|
agent_prompt: str | None = None
|
|
agent_name: str | None = None
|
|
is_publish: bool | None = None
|
|
create_date: str | None = None
|
|
voice_type: str | None = None
|
|
temperature: float | None = None
|
|
type: int | None = None
|
|
|
|
|
|
class VideoItem(BaseModel):
|
|
url: str
|
|
emotion: str
|
|
|
|
class DynamicAgentCreate(AgentCard):
|
|
videos: List[VideoItem]
|
|
kb_id: str
|
|
kb_config: dict
|
|
|
|
# Force type to 1 for dynamic agents
|
|
type: int = 1
|
|
|
|
@validator('videos')
|
|
def validate_videos(cls, v):
|
|
if not (1 <= len(v) <= 7):
|
|
raise ValueError('Must provide between 1 and 7 videos')
|
|
return v
|
|
|
|
class DynamicAgentUpdate(BaseModel):
|
|
card_info: str | None = None
|
|
agent_avatar_url: str | None = None
|
|
agent_prompt: str | None = None
|
|
agent_name: str | None = None
|
|
is_publish: bool | None = None
|
|
voice_type: str | None = None
|
|
temperature: float | None = None
|
|
videos: List[VideoItem] | None = None
|
|
kb_id: str | None = None
|
|
kb_config: dict | None = None
|
|
|
|
@validator('videos')
|
|
def validate_videos(cls, v):
|
|
if v is not None and not (1 <= len(v) <= 7):
|
|
raise ValueError('Must provide between 1 and 7 videos')
|
|
return v
|
|
|
|
@app.post('/new_agent/', tags=["Agent管理"])
|
|
def create_agent_card(agent_card: AgentCard, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 原代码存在拼写错误,修正为正确的变量名和 SQL 语句
|
|
cur.execute('SELECT * FROM users WHERE phone_number = %s', (agent_card.phone_number,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=400, detail='Phone number not registered')
|
|
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (agent_card.phone_number, agent_card.agent_name))
|
|
if cur.fetchone():
|
|
raise HTTPException(status_code=400, detail='Agent card already exists')
|
|
cur.execute('INSERT INTO agent_cards (phone_number, card_info, agent_avatar_url, agent_prompt, agent_name, is_publish, create_date, voice_type, temperature, type) VALUES (%s, %s, %s, %s, %s, %s, NOW(), %s, %s, %s)', (agent_card.phone_number, agent_card.card_info, agent_card.agent_avatar_url, agent_card.agent_prompt, agent_card.agent_name, agent_card.is_publish, agent_card.voice_type, agent_card.temperature, agent_card.type))
|
|
conn.commit()
|
|
return {'message': 'Agent card created successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
class ToolCreate(BaseModel):
|
|
name: str
|
|
description: str | None = None
|
|
parameters: dict
|
|
|
|
class ToolUpdate(BaseModel):
|
|
description: str | None = None
|
|
parameters: dict | None = None
|
|
|
|
class ToolResponse(ToolCreate):
|
|
tool_id: int
|
|
created_at: str | None = None
|
|
|
|
class AgentToolAssignment(BaseModel):
|
|
tool_ids: List[int]
|
|
|
|
@app.post('/tools/', tags=["工具管理"])
|
|
def create_tool(tool: ToolCreate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Check if tool exists
|
|
cur.execute('SELECT * FROM tools WHERE name = %s', (tool.name,))
|
|
if cur.fetchone():
|
|
raise HTTPException(status_code=400, detail='Tool with this name already exists')
|
|
|
|
cur.execute('''
|
|
INSERT INTO tools (name, description, parameters)
|
|
VALUES (%s, %s, %s)
|
|
RETURNING tool_id
|
|
''', (tool.name, tool.description, json.dumps(tool.parameters)))
|
|
|
|
tool_id = cur.fetchone()[0]
|
|
conn.commit()
|
|
return {'message': 'Tool created successfully', 'tool_id': tool_id}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get('/tools/', tags=["工具管理"])
|
|
def list_tools(current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM tools ORDER BY tool_id')
|
|
tools = cur.fetchall()
|
|
|
|
columns = [desc[0] for desc in cur.description]
|
|
result = []
|
|
for t in tools:
|
|
d = dict(zip(columns, t))
|
|
# Ensure parameters is a dict (psycopg2 might return it as dict if jsonb, but let's be safe)
|
|
if isinstance(d['parameters'], str):
|
|
d['parameters'] = json.loads(d['parameters'])
|
|
# Convert datetime to str
|
|
d['created_at'] = str(d['created_at'])
|
|
result.append(d)
|
|
return result
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get('/tools/{tool_id}', tags=["工具管理"])
|
|
def get_tool(tool_id: int, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM tools WHERE tool_id = %s', (tool_id,))
|
|
tool = cur.fetchone()
|
|
if not tool:
|
|
raise HTTPException(status_code=404, detail='Tool not found')
|
|
|
|
columns = [desc[0] for desc in cur.description]
|
|
d = dict(zip(columns, tool))
|
|
if isinstance(d['parameters'], str):
|
|
d['parameters'] = json.loads(d['parameters'])
|
|
d['created_at'] = str(d['created_at'])
|
|
return d
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.put('/tools/{tool_id}', tags=["工具管理"])
|
|
def update_tool(tool_id: int, tool_update: ToolUpdate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM tools WHERE tool_id = %s', (tool_id,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Tool not found')
|
|
|
|
update_fields = []
|
|
params = []
|
|
if tool_update.description is not None:
|
|
update_fields.append('description = %s')
|
|
params.append(tool_update.description)
|
|
if tool_update.parameters is not None:
|
|
update_fields.append('parameters = %s')
|
|
params.append(json.dumps(tool_update.parameters))
|
|
|
|
if not update_fields:
|
|
raise HTTPException(status_code=400, detail='No fields to update')
|
|
|
|
params.append(tool_id)
|
|
query = f"UPDATE tools SET {', '.join(update_fields)} WHERE tool_id = %s"
|
|
|
|
cur.execute(query, params)
|
|
conn.commit()
|
|
return {'message': 'Tool updated successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.delete('/tools/{tool_id}', tags=["工具管理"])
|
|
def delete_tool(tool_id: int, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('DELETE FROM tools WHERE tool_id = %s', (tool_id,))
|
|
if cur.rowcount == 0:
|
|
raise HTTPException(status_code=404, detail='Tool not found')
|
|
conn.commit()
|
|
return {'message': 'Tool deleted successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.post('/agent/{card_id}/tools', tags=["Agent工具关联"])
|
|
def assign_tools_to_agent(card_id: int, assignment: AgentToolAssignment, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Check agent exists
|
|
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (card_id,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Agent not found')
|
|
|
|
# Verify tools exist
|
|
if assignment.tool_ids:
|
|
cur.execute('SELECT count(*) FROM tools WHERE tool_id = ANY(%s)', (assignment.tool_ids,))
|
|
count = cur.fetchone()[0]
|
|
if count != len(set(assignment.tool_ids)):
|
|
raise HTTPException(status_code=400, detail='One or more tool IDs are invalid')
|
|
|
|
# We can choose to replace all or add. Let's implement "replace all" for simplicity and consistency with a "save configuration" model.
|
|
# First, delete existing associations
|
|
cur.execute('DELETE FROM agent_tools WHERE agent_card_id = %s', (card_id,))
|
|
|
|
# Insert new ones
|
|
if assignment.tool_ids:
|
|
values = [(card_id, tid) for tid in assignment.tool_ids]
|
|
args_str = ','.join(cur.mogrify("(%s,%s)", x).decode('utf-8') for x in values)
|
|
cur.execute("INSERT INTO agent_tools (agent_card_id, tool_id) VALUES " + args_str)
|
|
|
|
conn.commit()
|
|
return {'message': 'Tools assigned successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get('/agent/{card_id}/tools', tags=["Agent工具关联"])
|
|
def get_agent_tools(card_id: int, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
query = """
|
|
SELECT t.*
|
|
FROM tools t
|
|
JOIN agent_tools at ON t.tool_id = at.tool_id
|
|
WHERE at.agent_card_id = %s
|
|
"""
|
|
cur.execute(query, (card_id,))
|
|
tools = cur.fetchall()
|
|
|
|
columns = [desc[0] for desc in cur.description]
|
|
result = []
|
|
for t in tools:
|
|
d = dict(zip(columns, t))
|
|
if isinstance(d['parameters'], str):
|
|
d['parameters'] = json.loads(d['parameters'])
|
|
d['created_at'] = str(d['created_at'])
|
|
result.append(d)
|
|
return result
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.post('/dynamic_agent/', tags=["Dynamic agent 管理"])
|
|
def create_dynamic_agent(agent_data: DynamicAgentCreate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 1. Check user exists
|
|
cur.execute('SELECT * FROM users WHERE phone_number = %s', (agent_data.phone_number,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=400, detail='Phone number not registered')
|
|
|
|
# 2. Check agent exists
|
|
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (agent_data.phone_number, agent_data.agent_name))
|
|
if cur.fetchone():
|
|
raise HTTPException(status_code=400, detail='Agent card already exists')
|
|
|
|
# 3. Insert into agent_cards (Force type=1)
|
|
# Note: We use RETURNING card_id to get the new ID
|
|
cur.execute('''
|
|
INSERT INTO agent_cards
|
|
(phone_number, card_info, agent_avatar_url, agent_prompt, agent_name, is_publish, create_date, voice_type, temperature, type)
|
|
VALUES (%s, %s, %s, %s, %s, %s, NOW(), %s, %s, 1)
|
|
RETURNING card_id
|
|
''', (agent_data.phone_number, agent_data.card_info, agent_data.agent_avatar_url, agent_data.agent_prompt, agent_data.agent_name, agent_data.is_publish, agent_data.voice_type, agent_data.temperature))
|
|
|
|
card_id = cur.fetchone()[0]
|
|
|
|
# 4. Insert into dynamic_agent_details
|
|
# Convert videos list to JSON string/object
|
|
videos_json = json.dumps([v.dict() for v in agent_data.videos])
|
|
kb_config_json = json.dumps(agent_data.kb_config)
|
|
|
|
cur.execute('''
|
|
INSERT INTO dynamic_agent_details (card_id, videos, kb_id, kb_config)
|
|
VALUES (%s, %s, %s, %s)
|
|
''', (card_id, videos_json, agent_data.kb_id, kb_config_json))
|
|
|
|
conn.commit()
|
|
return {'message': 'Dynamic agent created successfully', 'card_id': card_id}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get('/dynamic_agent/{card_id}', tags=["Dynamic agent 管理"])
|
|
def get_dynamic_agent(card_id: int, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Join with dynamic_agent_details
|
|
cur.execute('''
|
|
SELECT a.*, d.videos, d.kb_id, d.kb_config
|
|
FROM agent_cards a
|
|
LEFT JOIN dynamic_agent_details d ON a.card_id = d.card_id
|
|
WHERE a.card_id = %s
|
|
''', (card_id,))
|
|
row = cur.fetchone()
|
|
if not row:
|
|
raise HTTPException(status_code=404, detail='Agent not found')
|
|
|
|
columns = [desc[0] for desc in cur.description]
|
|
result = dict(zip(columns, row))
|
|
|
|
# Fetch associated tools
|
|
cur.execute('''
|
|
SELECT t.*
|
|
FROM tools t
|
|
JOIN agent_tools at ON t.tool_id = at.tool_id
|
|
WHERE at.agent_card_id = %s
|
|
''', (card_id,))
|
|
tools = cur.fetchall()
|
|
|
|
tool_list = []
|
|
tool_columns = [desc[0] for desc in cur.description]
|
|
for t in tools:
|
|
td = dict(zip(tool_columns, t))
|
|
if isinstance(td['parameters'], str):
|
|
td['parameters'] = json.loads(td['parameters'])
|
|
td['created_at'] = str(td['created_at'])
|
|
tool_list.append(td)
|
|
|
|
result['tools'] = tool_list
|
|
|
|
return result
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.delete('/agent/{phone_number}/{agent_name}', tags=["Agent管理"])
|
|
def delete_agent_card(phone_number: str, agent_name: str, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('DELETE FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (phone_number, agent_name))
|
|
if cur.rowcount == 0:
|
|
raise HTTPException(status_code=404, detail='Agent card not found')
|
|
conn.commit()
|
|
return {'message': 'Agent card deleted successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.put('/agent/{phone_number}/{agent_name}', tags=["Agent管理"])
|
|
def update_agent_card(phone_number: str, agent_name: str, agent_card: AgentCardUpdate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (phone_number, agent_name))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Agent card not found')
|
|
|
|
# 构建动态更新的SQL语句和参数
|
|
update_fields = []
|
|
params = []
|
|
|
|
if agent_card.card_info is not None:
|
|
update_fields.append('card_info = %s')
|
|
params.append(agent_card.card_info)
|
|
if agent_card.agent_avatar_url is not None:
|
|
update_fields.append('agent_avatar_url = %s')
|
|
params.append(agent_card.agent_avatar_url)
|
|
if agent_card.agent_prompt is not None:
|
|
update_fields.append('agent_prompt = %s')
|
|
params.append(agent_card.agent_prompt)
|
|
if agent_card.agent_name is not None:
|
|
update_fields.append('agent_name = %s')
|
|
params.append(agent_card.agent_name)
|
|
if agent_card.is_publish is not None:
|
|
update_fields.append('is_publish = %s')
|
|
params.append(agent_card.is_publish)
|
|
if agent_card.voice_type is not None:
|
|
update_fields.append('voice_type = %s')
|
|
params.append(agent_card.voice_type)
|
|
if agent_card.temperature is not None:
|
|
update_fields.append('temperature = %s')
|
|
params.append(agent_card.temperature)
|
|
if agent_card.type is not None:
|
|
update_fields.append('type = %s')
|
|
params.append(agent_card.type)
|
|
|
|
if not update_fields:
|
|
raise HTTPException(status_code=400, detail='No fields to update')
|
|
|
|
params.extend([phone_number, agent_name])
|
|
|
|
update_query = f'''
|
|
UPDATE agent_cards SET
|
|
{', '.join(update_fields)}
|
|
WHERE phone_number = %s AND agent_name = %s
|
|
'''
|
|
|
|
cur.execute(update_query, params)
|
|
conn.commit()
|
|
return {'message': 'Agent card updated successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
|
|
|
|
@app.put('/dynamic_agent/{card_id}', tags=["Dynamic agent 管理"])
|
|
def update_dynamic_agent(card_id: int, agent_update: DynamicAgentUpdate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 1. Verify existence
|
|
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (card_id,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Agent not found')
|
|
|
|
# 2. Update agent_cards table (base fields)
|
|
update_fields = []
|
|
params = []
|
|
|
|
if agent_update.card_info is not None:
|
|
update_fields.append('card_info = %s')
|
|
params.append(agent_update.card_info)
|
|
if agent_update.agent_avatar_url is not None:
|
|
update_fields.append('agent_avatar_url = %s')
|
|
params.append(agent_update.agent_avatar_url)
|
|
if agent_update.agent_prompt is not None:
|
|
update_fields.append('agent_prompt = %s')
|
|
params.append(agent_update.agent_prompt)
|
|
if agent_update.agent_name is not None:
|
|
update_fields.append('agent_name = %s')
|
|
params.append(agent_update.agent_name)
|
|
if agent_update.is_publish is not None:
|
|
update_fields.append('is_publish = %s')
|
|
params.append(agent_update.is_publish)
|
|
if agent_update.voice_type is not None:
|
|
update_fields.append('voice_type = %s')
|
|
params.append(agent_update.voice_type)
|
|
if agent_update.temperature is not None:
|
|
update_fields.append('temperature = %s')
|
|
params.append(agent_update.temperature)
|
|
|
|
if update_fields:
|
|
params.append(card_id)
|
|
cur.execute(f"UPDATE agent_cards SET {', '.join(update_fields)} WHERE card_id = %s", params)
|
|
|
|
# 3. Update dynamic_agent_details table
|
|
# Check if details record exists (it should for dynamic agents, but create if missing to be safe?)
|
|
# Assuming it exists.
|
|
|
|
dynamic_fields = []
|
|
dynamic_params = []
|
|
|
|
if agent_update.videos is not None:
|
|
dynamic_fields.append('videos = %s')
|
|
dynamic_params.append(json.dumps([v.dict() for v in agent_update.videos]))
|
|
if agent_update.kb_id is not None:
|
|
dynamic_fields.append('kb_id = %s')
|
|
dynamic_params.append(agent_update.kb_id)
|
|
if agent_update.kb_config is not None:
|
|
dynamic_fields.append('kb_config = %s')
|
|
dynamic_params.append(json.dumps(agent_update.kb_config))
|
|
|
|
if dynamic_fields:
|
|
dynamic_params.append(card_id)
|
|
cur.execute(f"UPDATE dynamic_agent_details SET {', '.join(dynamic_fields)} WHERE card_id = %s", dynamic_params)
|
|
|
|
if not update_fields and not dynamic_fields:
|
|
raise HTTPException(status_code=400, detail='No fields to update')
|
|
|
|
conn.commit()
|
|
return {'message': 'Dynamic agent updated successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.delete('/dynamic_agent/{card_id}', tags=["Dynamic agent 管理"])
|
|
def delete_dynamic_agent(card_id: int, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('DELETE FROM agent_cards WHERE card_id = %s', (card_id,))
|
|
if cur.rowcount == 0:
|
|
raise HTTPException(status_code=404, detail='Agent not found')
|
|
# Cascading delete will handle dynamic_agent_details and agent_tools
|
|
conn.commit()
|
|
return {'message': 'Dynamic agent deleted successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get('/dynamic_agents/{phone_number}', tags=["Dynamic agent 管理"])
|
|
def list_dynamic_agents(phone_number: str, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# Join tables to get full details for all agents of type 1 (or all if desired, but user asked for dynamic agent management)
|
|
# Let's filter by type=1 for "Dynamic Agents" list, or just return all with left join.
|
|
# Usually specific management implies specific type.
|
|
query = '''
|
|
SELECT a.*, d.videos, d.kb_id, d.kb_config
|
|
FROM agent_cards a
|
|
LEFT JOIN dynamic_agent_details d ON a.card_id = d.card_id
|
|
WHERE a.phone_number = %s AND a.type = 1
|
|
ORDER BY a.card_id DESC
|
|
'''
|
|
cur.execute(query, (phone_number,))
|
|
agents = cur.fetchall()
|
|
|
|
result = []
|
|
columns = [desc[0] for desc in cur.description]
|
|
|
|
for row in agents:
|
|
agent_dict = dict(zip(columns, row))
|
|
|
|
# Fetch tools for each agent
|
|
cur.execute('''
|
|
SELECT t.*
|
|
FROM tools t
|
|
JOIN agent_tools at ON t.tool_id = at.tool_id
|
|
WHERE at.agent_card_id = %s
|
|
''', (agent_dict['card_id'],))
|
|
tools = cur.fetchall()
|
|
|
|
tool_list = []
|
|
tool_columns = [desc[0] for desc in cur.description]
|
|
for t in tools:
|
|
td = dict(zip(tool_columns, t))
|
|
if isinstance(td['parameters'], str):
|
|
td['parameters'] = json.loads(td['parameters'])
|
|
td['created_at'] = str(td['created_at'])
|
|
tool_list.append(td)
|
|
|
|
agent_dict['tools'] = tool_list
|
|
result.append(agent_dict)
|
|
|
|
return result
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get('/agent_id/{agent_id}', tags=["Agent管理"])
|
|
def get_agent(agent_id: str, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (agent_id,))
|
|
agent = cur.fetchone()
|
|
if not agent:
|
|
raise HTTPException(status_code=404, detail='Agent card not found')
|
|
|
|
columns = [desc[0] for desc in cur.description]
|
|
return dict(zip(columns, agent))
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
@app.put('/agent_id/{agent_id}', tags=["Agent管理"])
|
|
def update_agent_card_by_id(agent_id: str, agent_card: AgentCardUpdate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (agent_id,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Agent card not found')
|
|
|
|
# 构建动态更新的SQL语句和参数
|
|
update_fields = []
|
|
params = []
|
|
|
|
if agent_card.card_info is not None:
|
|
update_fields.append('card_info = %s')
|
|
params.append(agent_card.card_info)
|
|
if agent_card.agent_avatar_url is not None:
|
|
update_fields.append('agent_avatar_url = %s')
|
|
params.append(agent_card.agent_avatar_url)
|
|
if agent_card.agent_prompt is not None:
|
|
update_fields.append('agent_prompt = %s')
|
|
params.append(agent_card.agent_prompt)
|
|
if agent_card.agent_name is not None:
|
|
update_fields.append('agent_name = %s')
|
|
params.append(agent_card.agent_name)
|
|
if agent_card.is_publish is not None:
|
|
update_fields.append('is_publish = %s')
|
|
params.append(agent_card.is_publish)
|
|
if agent_card.voice_type is not None:
|
|
update_fields.append('voice_type = %s')
|
|
params.append(agent_card.voice_type)
|
|
if agent_card.temperature is not None:
|
|
update_fields.append('temperature = %s')
|
|
params.append(agent_card.temperature)
|
|
if agent_card.type is not None:
|
|
update_fields.append('type = %s')
|
|
params.append(agent_card.type)
|
|
|
|
if not update_fields:
|
|
raise HTTPException(status_code=400, detail='No fields to update')
|
|
|
|
params.append(agent_id)
|
|
|
|
update_query = f'''
|
|
UPDATE agent_cards SET
|
|
{', '.join(update_fields)}
|
|
WHERE card_id = %s
|
|
'''
|
|
|
|
cur.execute(update_query, params)
|
|
conn.commit()
|
|
return {'message': 'Agent card updated successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
|
|
|
|
@app.delete('/agent_id/{agent_id}', tags=["Agent管理"])
|
|
def delete_agent_card_by_id(agent_id: str, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (agent_id,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Agent card not found')
|
|
|
|
cur.execute('DELETE FROM agent_cards WHERE card_id = %s', (agent_id,))
|
|
if cur.rowcount == 0:
|
|
raise HTTPException(status_code=404, detail='Agent card not found')
|
|
conn.commit()
|
|
return {'message': 'Agent card deleted successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
# 图片上传目录配置
|
|
#UPLOAD_DIR = "/mnt/server/userImage/call_avator"
|
|
UPLOAD_DIR = "./"
|
|
if not os.path.exists(UPLOAD_DIR):
|
|
os.makedirs(UPLOAD_DIR)
|
|
|
|
# 允许的图片类型
|
|
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/gif"]
|
|
|
|
@app.post('/upload_image/', tags=["文件上传"])
|
|
async def upload_image(file: UploadFile = File(...), current_user: dict = Depends(get_current_user)):
|
|
# 验证文件类型
|
|
if file.content_type not in ALLOWED_IMAGE_TYPES:
|
|
raise HTTPException(status_code=400, detail="Only image files are allowed")
|
|
|
|
try:
|
|
# 生成唯一文件名
|
|
file_ext = os.path.splitext(file.filename)[1]
|
|
unique_id = str(uuid.uuid4())
|
|
filename = f"{unique_id}{file_ext}"
|
|
file_path = os.path.join(UPLOAD_DIR, filename)
|
|
|
|
# 保存文件
|
|
with open(file_path, "wb") as buffer:
|
|
buffer.write(await file.read())
|
|
|
|
return {"message": "Image uploaded successfully", "file_id": unique_id, "file_path": file_path}
|
|
except Exception as e:
|
|
raise HTTPException(status_code=500, detail=f"Error uploading file: {str(e)}")
|
|
|
|
@app.put('/user/{phone}', tags=["用户管理"])
|
|
def update_user(phone: str, user_update: UserUpdate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM users WHERE phone_number = %s', (phone,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='User not found')
|
|
|
|
# 构建动态更新的SQL语句和参数
|
|
update_fields = []
|
|
params = []
|
|
|
|
if user_update.name is not None:
|
|
update_fields.append('user_name = %s')
|
|
params.append(user_update.name)
|
|
if user_update.user_details is not None:
|
|
update_fields.append('user_details = %s')
|
|
params.append(user_update.user_details)
|
|
if user_update.avatar is not None:
|
|
update_fields.append('avatar = %s')
|
|
params.append(user_update.avatar)
|
|
if user_update.email is not None:
|
|
update_fields.append('email = %s')
|
|
params.append(user_update.email)
|
|
if user_update.points is not None:
|
|
update_fields.append('points = %s')
|
|
params.append(user_update.points)
|
|
|
|
if not update_fields:
|
|
raise HTTPException(status_code=400, detail='No fields to update')
|
|
|
|
params.append(phone)
|
|
|
|
update_query = f'''
|
|
UPDATE users SET
|
|
{', '.join(update_fields)}
|
|
WHERE phone_number = %s
|
|
'''
|
|
|
|
cur.execute(update_query, params)
|
|
conn.commit()
|
|
return {'message': 'User updated successfully'}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
@app.post("/api/send-sms", tags=["短信服务"])
|
|
async def send_sms(sms_request: SMSRequest):
|
|
"""发送短信API"""
|
|
sms = SMS()
|
|
print(f"Sending SMS with template_code: {sms_request.template_code}, sign_name: {sms_request.sign_name}")
|
|
result = sms.main(self=sms, phone_number=sms_request.phone_number, template_param=sms_request.code, template_code=sms_request.template_code, sign_name=sms_request.sign_name)
|
|
|
|
# 保存发送记录到数据库
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
status = 'success' if result and result.get('success') else 'failed'
|
|
biz_id = result.get('biz_id') if result else None
|
|
error_message = result.get('error_message') if result else 'Unknown error'
|
|
|
|
cur.execute("""
|
|
INSERT INTO sms_records (phone_number, template_code, template_param, sign_name, status, biz_id, error_message)
|
|
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
|
""", (sms_request.phone_number, sms_request.template_code, sms_request.code, sms_request.sign_name, status, biz_id, error_message))
|
|
conn.commit()
|
|
except Exception as e:
|
|
print(f"Error saving SMS record: {e}")
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
if not result or not result.get('success'):
|
|
return {"status": "failed", "message": result.get('error_message') if result else "发送失败"}
|
|
|
|
return {"status": "success", "message": "短信发送请求已处理", "data": result}
|
|
|
|
class ConversationCreate(BaseModel):
|
|
user_phone: str | None = None
|
|
visitor_key: str | None = None
|
|
agent_card_id: int
|
|
|
|
class MessageCreate(BaseModel):
|
|
conversation_id: int
|
|
sender: str
|
|
content: str
|
|
order: int
|
|
|
|
@app.post('/conversations/', tags=["对话管理"])
|
|
def create_conversation(conversation: ConversationCreate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 验证用户或访客只能存在一个
|
|
if conversation.user_phone and conversation.visitor_key:
|
|
raise HTTPException(status_code=400, detail='Cannot specify both user_phone and visitor_key')
|
|
if not conversation.user_phone and not conversation.visitor_key:
|
|
raise HTTPException(status_code=400, detail='Must specify either user_phone or visitor_key')
|
|
|
|
# 验证agent_card_id是否存在
|
|
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (conversation.agent_card_id,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Agent card not found')
|
|
|
|
# 验证visitor_key格式是否为有效的UUID
|
|
if conversation.visitor_key:
|
|
try:
|
|
uuid.UUID(conversation.visitor_key)
|
|
except ValueError:
|
|
raise HTTPException(status_code=400, detail='visitor_key must be a valid UUID format')
|
|
|
|
# 创建新对话
|
|
cur.execute('''
|
|
INSERT INTO conversations (user_phone, visitor_key, agent_card_id)
|
|
VALUES (%s, %s, %s)
|
|
RETURNING conversation_id
|
|
''', (conversation.user_phone, conversation.visitor_key, conversation.agent_card_id))
|
|
|
|
conversation_id = cur.fetchone()[0]
|
|
conn.commit()
|
|
return {'conversation_id': conversation_id}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.post('/messages/', tags=["对话管理"])
|
|
def create_message(message: MessageCreate, current_user: dict = Depends(get_current_user)):
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
# 验证对话是否存在
|
|
cur.execute('SELECT * FROM conversations WHERE conversation_id = %s', (message.conversation_id,))
|
|
if not cur.fetchone():
|
|
raise HTTPException(status_code=404, detail='Conversation not found')
|
|
|
|
# 验证sender只能是user或agent
|
|
if message.sender not in ['user', 'agent']:
|
|
raise HTTPException(status_code=400, detail="Sender must be either 'user' or 'agent'")
|
|
|
|
# 保存消息
|
|
cur.execute('''
|
|
INSERT INTO messages (conversation_id, sender, content, "order")
|
|
VALUES (%s, %s, %s, %s)
|
|
RETURNING message_id
|
|
''',(message.conversation_id, message.sender, message.content, message.order))
|
|
|
|
message_id = cur.fetchone()[0]
|
|
conn.commit()
|
|
return {'message_id': message_id}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
@app.get("/api/sms-records", tags=["短信服务"])
|
|
async def get_all_sms_records():
|
|
"""获取所有短信发送记录"""
|
|
conn = pool.getconn()
|
|
try:
|
|
with conn.cursor() as cur:
|
|
cur.execute('SELECT * FROM sms_records ORDER BY created_at DESC')
|
|
records = cur.fetchall()
|
|
|
|
columns = [desc[0] for desc in cur.description]
|
|
result = []
|
|
for row in records:
|
|
d = dict(zip(columns, row))
|
|
# Convert datetime to str
|
|
if 'created_at' in d:
|
|
d['created_at'] = str(d['created_at'])
|
|
result.append(d)
|
|
return {"status": "success", "data": result}
|
|
finally:
|
|
pool.putconn(conn)
|
|
|
|
|
|
# 阿里云OSS相关API
|
|
@app.post("/test/", tags=["阿里云OSS"])
|
|
async def test_oss():
|
|
"""测试OSS连接"""
|
|
return "string"
|
|
|
|
|
|
@app.post("/upload", tags=["阿里云OSS"])
|
|
async def upload_file(
|
|
file: UploadFile = File(...),
|
|
folder: str = "uploads"
|
|
):
|
|
"""
|
|
上传单个文件到OSS
|
|
|
|
Args:
|
|
file: 上传的文件
|
|
folder: 存储文件夹(可选)
|
|
|
|
Returns:
|
|
文件上传结果
|
|
"""
|
|
try:
|
|
# 读取文件内容
|
|
file_content = await file.read()
|
|
|
|
# 生成对象键
|
|
object_key = oss_service.generate_object_key(file.filename, folder)
|
|
|
|
# 上传文件
|
|
result = oss_service.upload_file(
|
|
file_content=file_content,
|
|
object_key=object_key,
|
|
content_type=file.content_type
|
|
)
|
|
|
|
return result
|
|
except Exception as e:
|
|
# 返回错误信息
|
|
return {
|
|
"success": False,
|
|
"message": f"上传失败: {str(e)}",
|
|
"object_key": None,
|
|
"file_url": None,
|
|
"etag": None,
|
|
"request_id": None,
|
|
"error_code": "InternalError"
|
|
}
|
|
|
|
|
|
@app.post("/upload/multiple", tags=["阿里云OSS"])
|
|
async def upload_multiple_files(
|
|
files: List[UploadFile] = File(...),
|
|
folder: str = "uploads"
|
|
):
|
|
"""
|
|
上传多个文件到OSS
|
|
|
|
Args:
|
|
files: 上传的文件列表
|
|
folder: 存储文件夹(可选)
|
|
|
|
Returns:
|
|
文件上传结果列表
|
|
"""
|
|
results = []
|
|
|
|
for file in files:
|
|
try:
|
|
# 读取文件内容
|
|
file_content = await file.read()
|
|
|
|
# 生成对象键
|
|
object_key = oss_service.generate_object_key(file.filename, folder)
|
|
|
|
# 上传文件
|
|
result = oss_service.upload_file(
|
|
file_content=file_content,
|
|
object_key=object_key,
|
|
content_type=file.content_type
|
|
)
|
|
|
|
results.append(result)
|
|
except Exception as e:
|
|
# 添加错误信息
|
|
results.append({
|
|
"success": False,
|
|
"message": f"上传失败: {str(e)}",
|
|
"object_key": None,
|
|
"file_url": None,
|
|
"etag": None,
|
|
"request_id": None,
|
|
"error_code": "InternalError"
|
|
})
|
|
|
|
return results
|
|
|
|
|
|
@app.get("/files", tags=["阿里云OSS"])
|
|
async def list_files(
|
|
prefix: str = "recordings/",
|
|
max_keys: int = 100
|
|
):
|
|
"""
|
|
列出OSS中的文件
|
|
|
|
Args:
|
|
prefix: 对象键前缀(可选)
|
|
max_keys: 最大返回数量
|
|
|
|
Returns:
|
|
文件列表
|
|
"""
|
|
# 验证max_keys参数
|
|
if max_keys > 1000:
|
|
return {
|
|
"success": False,
|
|
"message": "max_keys不能超过1000",
|
|
"files": [],
|
|
"count": 0
|
|
}
|
|
|
|
# 获取文件列表
|
|
result = oss_service.list_files(prefix=prefix, max_keys=max_keys)
|
|
|
|
# 确保返回格式包含message字段
|
|
if "message" not in result:
|
|
result["message"] = "获取文件列表成功"
|
|
|
|
return result
|
|
|
|
|
|
@app.get("/files/{object_key:path}/info", tags=["阿里云OSS"])
|
|
async def get_file_info(object_key: str):
|
|
"""
|
|
获取文件信息
|
|
|
|
Args:
|
|
object_key: OSS对象键
|
|
|
|
Returns:
|
|
文件信息
|
|
"""
|
|
# 获取文件信息
|
|
result = oss_service.get_file_info(object_key)
|
|
|
|
# 确保返回格式包含message字段
|
|
if "message" not in result:
|
|
if result["success"]:
|
|
result["message"] = "获取文件信息成功"
|
|
else:
|
|
result["message"] = "获取文件信息失败"
|
|
|
|
return result
|
|
|
|
|
|
@app.delete("/files/{object_key:path}", tags=["阿里云OSS"])
|
|
async def delete_file(object_key: str):
|
|
"""
|
|
删除OSS文件
|
|
|
|
Args:
|
|
object_key: OSS对象键
|
|
|
|
Returns:
|
|
删除结果
|
|
"""
|
|
# 删除文件
|
|
result = oss_service.delete_file(object_key)
|
|
|
|
# 确保返回格式包含所有必要字段
|
|
if "error_code" not in result:
|
|
result["error_code"] = None
|
|
|
|
return result |