1169 lines
44 KiB
Python
1169 lines
44 KiB
Python
# 修正拼写错误,将 qyditic 改为 pydantic
|
||
from fastapi import FastAPI, HTTPException, Depends, Request, UploadFile, File, Form
|
||
from pydantic import BaseModel, validator
|
||
import json
|
||
import psycopg2.pool
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from fastapi.security import APIKeyHeader
|
||
import uuid
|
||
import os
|
||
from typing import Optional, List
|
||
from pydantic import BaseModel
|
||
from sms import SMS
|
||
from oss_service import oss_service
|
||
from config import settings
|
||
|
||
class SMSRequest(BaseModel):
|
||
phone_number: str
|
||
code: Optional[str] = None
|
||
template_code: Optional[str] = settings.sms_template_code
|
||
sign_name: Optional[str] = settings.sms_sign_name
|
||
|
||
class Config:
|
||
extra = "allow"
|
||
|
||
|
||
|
||
import os
|
||
API_KEY = "123tangledup-ai"
|
||
api_key_header = APIKeyHeader(name="x-api-key")
|
||
|
||
def get_current_user(api_key: str = Depends(api_key_header)):
|
||
if api_key != API_KEY:
|
||
raise HTTPException(status_code=401, detail="Invalid API Key")
|
||
return {"api_key": api_key}
|
||
|
||
app = FastAPI()
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
# 原代码存在语法错误,* 不能单独作为列表元素,这里假设要允许所有来源,使用 ['*']
|
||
allow_origins=['http://localhost:5173', 'https://data.tangledup-ai.com', 'https://sms.tangledup-ai.com' , '*'],
|
||
allow_credentials=True,
|
||
allow_methods=['*'],
|
||
allow_headers=['*'],
|
||
)
|
||
|
||
# 数据库连接池配置
|
||
pool = psycopg2.pool.SimpleConnectionPool(
|
||
minconn=1,
|
||
maxconn=10,
|
||
host="47.101.218.42",
|
||
port="5432",
|
||
database="jeremygan2021",
|
||
user="jeremygan2021",
|
||
password="qweasdzxc1",
|
||
sslmode="disable"
|
||
)
|
||
|
||
class User(BaseModel):
|
||
phone: str
|
||
name: str
|
||
user_details: str
|
||
avatar: str
|
||
email: str
|
||
points: int
|
||
|
||
class UserUpdate(BaseModel):
|
||
name: str | None = None
|
||
user_details: str | None = None
|
||
avatar: str | None = None
|
||
email: str | None = None
|
||
points: int | None = None
|
||
|
||
|
||
@app.post('/register/', tags=["用户管理"])
|
||
def register(user: User, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM users WHERE phone_number = %s', (user.phone,))
|
||
if cur.fetchone():
|
||
raise HTTPException(status_code=400, detail='User already exists')
|
||
cur.execute('INSERT INTO users (phone_number, user_name, user_details, avatar, email, points, create_date) VALUES (%s, %s, %s, %s, %s, %s, NOW())', (user.phone, user.name, user.user_details, user.avatar, user.email, user.points))
|
||
conn.commit()
|
||
return {'message': 'User registered successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get('/user/{phone}', tags=["用户管理"])
|
||
def get_user(phone: str, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM users WHERE phone_number = %s', (phone,))
|
||
user = cur.fetchone()
|
||
if not user:
|
||
raise HTTPException(status_code=404, detail='User not found')
|
||
# 直接返回完整的用户数据作为 JSON
|
||
columns = [desc[0] for desc in cur.description]
|
||
user_dict = dict(zip(columns, user))
|
||
return user_dict
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
@app.get('/agents/{phone}', tags=["用户管理"])
|
||
def get_user(phone: str, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s', (phone,))
|
||
users = cur.fetchall()
|
||
if not users:
|
||
raise HTTPException(status_code=404, detail='No agents found for this user')
|
||
# 返回所有agent数据作为JSON数组
|
||
columns = [desc[0] for desc in cur.description]
|
||
return [dict(zip(columns, user)) for user in users]
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
|
||
|
||
class AgentCard(BaseModel):
|
||
phone_number: str
|
||
card_info: str
|
||
agent_avatar_url: str
|
||
agent_prompt: str
|
||
agent_name: str
|
||
is_publish: bool
|
||
create_date: str
|
||
voice_type: str
|
||
temperature: float
|
||
type: int = 0
|
||
|
||
|
||
class AgentCardUpdate(BaseModel):
|
||
card_info: str | None = None
|
||
agent_avatar_url: str | None = None
|
||
agent_prompt: str | None = None
|
||
agent_name: str | None = None
|
||
is_publish: bool | None = None
|
||
create_date: str | None = None
|
||
voice_type: str | None = None
|
||
temperature: float | None = None
|
||
type: int | None = None
|
||
|
||
|
||
class VideoItem(BaseModel):
|
||
url: str
|
||
emotion: str
|
||
|
||
class DynamicAgentCreate(AgentCard):
|
||
videos: List[VideoItem]
|
||
kb_id: str
|
||
kb_config: dict
|
||
|
||
# Force type to 1 for dynamic agents
|
||
type: int = 1
|
||
|
||
@validator('videos')
|
||
def validate_videos(cls, v):
|
||
if not (1 <= len(v) <= 7):
|
||
raise ValueError('Must provide between 1 and 7 videos')
|
||
return v
|
||
|
||
class DynamicAgentUpdate(BaseModel):
|
||
card_info: str | None = None
|
||
agent_avatar_url: str | None = None
|
||
agent_prompt: str | None = None
|
||
agent_name: str | None = None
|
||
is_publish: bool | None = None
|
||
voice_type: str | None = None
|
||
temperature: float | None = None
|
||
videos: List[VideoItem] | None = None
|
||
kb_id: str | None = None
|
||
kb_config: dict | None = None
|
||
|
||
@validator('videos')
|
||
def validate_videos(cls, v):
|
||
if v is not None and not (1 <= len(v) <= 7):
|
||
raise ValueError('Must provide between 1 and 7 videos')
|
||
return v
|
||
|
||
@app.post('/new_agent/', tags=["Agent管理"])
|
||
def create_agent_card(agent_card: AgentCard, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 原代码存在拼写错误,修正为正确的变量名和 SQL 语句
|
||
cur.execute('SELECT * FROM users WHERE phone_number = %s', (agent_card.phone_number,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=400, detail='Phone number not registered')
|
||
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (agent_card.phone_number, agent_card.agent_name))
|
||
if cur.fetchone():
|
||
raise HTTPException(status_code=400, detail='Agent card already exists')
|
||
cur.execute('INSERT INTO agent_cards (phone_number, card_info, agent_avatar_url, agent_prompt, agent_name, is_publish, create_date, voice_type, temperature, type) VALUES (%s, %s, %s, %s, %s, %s, NOW(), %s, %s, %s)', (agent_card.phone_number, agent_card.card_info, agent_card.agent_avatar_url, agent_card.agent_prompt, agent_card.agent_name, agent_card.is_publish, agent_card.voice_type, agent_card.temperature, agent_card.type))
|
||
conn.commit()
|
||
return {'message': 'Agent card created successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
class ToolCreate(BaseModel):
|
||
name: str
|
||
description: str | None = None
|
||
parameters: dict
|
||
|
||
class ToolUpdate(BaseModel):
|
||
description: str | None = None
|
||
parameters: dict | None = None
|
||
|
||
class ToolResponse(ToolCreate):
|
||
tool_id: int
|
||
created_at: str | None = None
|
||
|
||
class AgentToolAssignment(BaseModel):
|
||
tool_ids: List[int]
|
||
|
||
@app.post('/tools/', tags=["工具管理"])
|
||
def create_tool(tool: ToolCreate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# Check if tool exists
|
||
cur.execute('SELECT * FROM tools WHERE name = %s', (tool.name,))
|
||
if cur.fetchone():
|
||
raise HTTPException(status_code=400, detail='Tool with this name already exists')
|
||
|
||
cur.execute('''
|
||
INSERT INTO tools (name, description, parameters)
|
||
VALUES (%s, %s, %s)
|
||
RETURNING tool_id
|
||
''', (tool.name, tool.description, json.dumps(tool.parameters)))
|
||
|
||
tool_id = cur.fetchone()[0]
|
||
conn.commit()
|
||
return {'message': 'Tool created successfully', 'tool_id': tool_id}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get('/tools/', tags=["工具管理"])
|
||
def list_tools(current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM tools ORDER BY tool_id')
|
||
tools = cur.fetchall()
|
||
|
||
columns = [desc[0] for desc in cur.description]
|
||
result = []
|
||
for t in tools:
|
||
d = dict(zip(columns, t))
|
||
# Ensure parameters is a dict (psycopg2 might return it as dict if jsonb, but let's be safe)
|
||
if isinstance(d['parameters'], str):
|
||
d['parameters'] = json.loads(d['parameters'])
|
||
# Convert datetime to str
|
||
d['created_at'] = str(d['created_at'])
|
||
result.append(d)
|
||
return result
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get('/tools/{tool_id}', tags=["工具管理"])
|
||
def get_tool(tool_id: int, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM tools WHERE tool_id = %s', (tool_id,))
|
||
tool = cur.fetchone()
|
||
if not tool:
|
||
raise HTTPException(status_code=404, detail='Tool not found')
|
||
|
||
columns = [desc[0] for desc in cur.description]
|
||
d = dict(zip(columns, tool))
|
||
if isinstance(d['parameters'], str):
|
||
d['parameters'] = json.loads(d['parameters'])
|
||
d['created_at'] = str(d['created_at'])
|
||
return d
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.put('/tools/{tool_id}', tags=["工具管理"])
|
||
def update_tool(tool_id: int, tool_update: ToolUpdate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM tools WHERE tool_id = %s', (tool_id,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Tool not found')
|
||
|
||
update_fields = []
|
||
params = []
|
||
if tool_update.description is not None:
|
||
update_fields.append('description = %s')
|
||
params.append(tool_update.description)
|
||
if tool_update.parameters is not None:
|
||
update_fields.append('parameters = %s')
|
||
params.append(json.dumps(tool_update.parameters))
|
||
|
||
if not update_fields:
|
||
raise HTTPException(status_code=400, detail='No fields to update')
|
||
|
||
params.append(tool_id)
|
||
query = f"UPDATE tools SET {', '.join(update_fields)} WHERE tool_id = %s"
|
||
|
||
cur.execute(query, params)
|
||
conn.commit()
|
||
return {'message': 'Tool updated successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.delete('/tools/{tool_id}', tags=["工具管理"])
|
||
def delete_tool(tool_id: int, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('DELETE FROM tools WHERE tool_id = %s', (tool_id,))
|
||
if cur.rowcount == 0:
|
||
raise HTTPException(status_code=404, detail='Tool not found')
|
||
conn.commit()
|
||
return {'message': 'Tool deleted successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.post('/agent/{card_id}/tools', tags=["Agent工具关联"])
|
||
def assign_tools_to_agent(card_id: int, assignment: AgentToolAssignment, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# Check agent exists
|
||
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (card_id,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Agent not found')
|
||
|
||
# Verify tools exist
|
||
if assignment.tool_ids:
|
||
cur.execute('SELECT count(*) FROM tools WHERE tool_id = ANY(%s)', (assignment.tool_ids,))
|
||
count = cur.fetchone()[0]
|
||
if count != len(set(assignment.tool_ids)):
|
||
raise HTTPException(status_code=400, detail='One or more tool IDs are invalid')
|
||
|
||
# We can choose to replace all or add. Let's implement "replace all" for simplicity and consistency with a "save configuration" model.
|
||
# First, delete existing associations
|
||
cur.execute('DELETE FROM agent_tools WHERE agent_card_id = %s', (card_id,))
|
||
|
||
# Insert new ones
|
||
if assignment.tool_ids:
|
||
values = [(card_id, tid) for tid in assignment.tool_ids]
|
||
args_str = ','.join(cur.mogrify("(%s,%s)", x).decode('utf-8') for x in values)
|
||
cur.execute("INSERT INTO agent_tools (agent_card_id, tool_id) VALUES " + args_str)
|
||
|
||
conn.commit()
|
||
return {'message': 'Tools assigned successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get('/agent/{card_id}/tools', tags=["Agent工具关联"])
|
||
def get_agent_tools(card_id: int, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
query = """
|
||
SELECT t.*
|
||
FROM tools t
|
||
JOIN agent_tools at ON t.tool_id = at.tool_id
|
||
WHERE at.agent_card_id = %s
|
||
"""
|
||
cur.execute(query, (card_id,))
|
||
tools = cur.fetchall()
|
||
|
||
columns = [desc[0] for desc in cur.description]
|
||
result = []
|
||
for t in tools:
|
||
d = dict(zip(columns, t))
|
||
if isinstance(d['parameters'], str):
|
||
d['parameters'] = json.loads(d['parameters'])
|
||
d['created_at'] = str(d['created_at'])
|
||
result.append(d)
|
||
return result
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.post('/dynamic_agent/', tags=["Dynamic agent 管理"])
|
||
def create_dynamic_agent(agent_data: DynamicAgentCreate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 1. Check user exists
|
||
cur.execute('SELECT * FROM users WHERE phone_number = %s', (agent_data.phone_number,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=400, detail='Phone number not registered')
|
||
|
||
# 2. Check agent exists
|
||
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (agent_data.phone_number, agent_data.agent_name))
|
||
if cur.fetchone():
|
||
raise HTTPException(status_code=400, detail='Agent card already exists')
|
||
|
||
# 3. Insert into agent_cards (Force type=1)
|
||
# Note: We use RETURNING card_id to get the new ID
|
||
cur.execute('''
|
||
INSERT INTO agent_cards
|
||
(phone_number, card_info, agent_avatar_url, agent_prompt, agent_name, is_publish, create_date, voice_type, temperature, type)
|
||
VALUES (%s, %s, %s, %s, %s, %s, NOW(), %s, %s, 1)
|
||
RETURNING card_id
|
||
''', (agent_data.phone_number, agent_data.card_info, agent_data.agent_avatar_url, agent_data.agent_prompt, agent_data.agent_name, agent_data.is_publish, agent_data.voice_type, agent_data.temperature))
|
||
|
||
card_id = cur.fetchone()[0]
|
||
|
||
# 4. Insert into dynamic_agent_details
|
||
# Convert videos list to JSON string/object
|
||
videos_json = json.dumps([v.dict() for v in agent_data.videos])
|
||
kb_config_json = json.dumps(agent_data.kb_config)
|
||
|
||
cur.execute('''
|
||
INSERT INTO dynamic_agent_details (card_id, videos, kb_id, kb_config)
|
||
VALUES (%s, %s, %s, %s)
|
||
''', (card_id, videos_json, agent_data.kb_id, kb_config_json))
|
||
|
||
conn.commit()
|
||
return {'message': 'Dynamic agent created successfully', 'card_id': card_id}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get('/dynamic_agent/{card_id}', tags=["Dynamic agent 管理"])
|
||
def get_dynamic_agent(card_id: int, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# Join with dynamic_agent_details
|
||
cur.execute('''
|
||
SELECT a.*, d.videos, d.kb_id, d.kb_config
|
||
FROM agent_cards a
|
||
LEFT JOIN dynamic_agent_details d ON a.card_id = d.card_id
|
||
WHERE a.card_id = %s
|
||
''', (card_id,))
|
||
row = cur.fetchone()
|
||
if not row:
|
||
raise HTTPException(status_code=404, detail='Agent not found')
|
||
|
||
columns = [desc[0] for desc in cur.description]
|
||
result = dict(zip(columns, row))
|
||
|
||
# Fetch associated tools
|
||
cur.execute('''
|
||
SELECT t.*
|
||
FROM tools t
|
||
JOIN agent_tools at ON t.tool_id = at.tool_id
|
||
WHERE at.agent_card_id = %s
|
||
''', (card_id,))
|
||
tools = cur.fetchall()
|
||
|
||
tool_list = []
|
||
tool_columns = [desc[0] for desc in cur.description]
|
||
for t in tools:
|
||
td = dict(zip(tool_columns, t))
|
||
if isinstance(td['parameters'], str):
|
||
td['parameters'] = json.loads(td['parameters'])
|
||
td['created_at'] = str(td['created_at'])
|
||
tool_list.append(td)
|
||
|
||
result['tools'] = tool_list
|
||
|
||
return result
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.delete('/agent/{phone_number}/{agent_name}', tags=["Agent管理"])
|
||
def delete_agent_card(phone_number: str, agent_name: str, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('DELETE FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (phone_number, agent_name))
|
||
if cur.rowcount == 0:
|
||
raise HTTPException(status_code=404, detail='Agent card not found')
|
||
conn.commit()
|
||
return {'message': 'Agent card deleted successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.put('/agent/{phone_number}/{agent_name}', tags=["Agent管理"])
|
||
def update_agent_card(phone_number: str, agent_name: str, agent_card: AgentCardUpdate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM agent_cards WHERE phone_number = %s AND agent_name = %s', (phone_number, agent_name))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Agent card not found')
|
||
|
||
# 构建动态更新的SQL语句和参数
|
||
update_fields = []
|
||
params = []
|
||
|
||
if agent_card.card_info is not None:
|
||
update_fields.append('card_info = %s')
|
||
params.append(agent_card.card_info)
|
||
if agent_card.agent_avatar_url is not None:
|
||
update_fields.append('agent_avatar_url = %s')
|
||
params.append(agent_card.agent_avatar_url)
|
||
if agent_card.agent_prompt is not None:
|
||
update_fields.append('agent_prompt = %s')
|
||
params.append(agent_card.agent_prompt)
|
||
if agent_card.agent_name is not None:
|
||
update_fields.append('agent_name = %s')
|
||
params.append(agent_card.agent_name)
|
||
if agent_card.is_publish is not None:
|
||
update_fields.append('is_publish = %s')
|
||
params.append(agent_card.is_publish)
|
||
if agent_card.voice_type is not None:
|
||
update_fields.append('voice_type = %s')
|
||
params.append(agent_card.voice_type)
|
||
if agent_card.temperature is not None:
|
||
update_fields.append('temperature = %s')
|
||
params.append(agent_card.temperature)
|
||
if agent_card.type is not None:
|
||
update_fields.append('type = %s')
|
||
params.append(agent_card.type)
|
||
|
||
if not update_fields:
|
||
raise HTTPException(status_code=400, detail='No fields to update')
|
||
|
||
params.extend([phone_number, agent_name])
|
||
|
||
update_query = f'''
|
||
UPDATE agent_cards SET
|
||
{', '.join(update_fields)}
|
||
WHERE phone_number = %s AND agent_name = %s
|
||
'''
|
||
|
||
cur.execute(update_query, params)
|
||
conn.commit()
|
||
return {'message': 'Agent card updated successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
|
||
|
||
@app.put('/dynamic_agent/{card_id}', tags=["Dynamic agent 管理"])
|
||
def update_dynamic_agent(card_id: int, agent_update: DynamicAgentUpdate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 1. Verify existence
|
||
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (card_id,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Agent not found')
|
||
|
||
# 2. Update agent_cards table (base fields)
|
||
update_fields = []
|
||
params = []
|
||
|
||
if agent_update.card_info is not None:
|
||
update_fields.append('card_info = %s')
|
||
params.append(agent_update.card_info)
|
||
if agent_update.agent_avatar_url is not None:
|
||
update_fields.append('agent_avatar_url = %s')
|
||
params.append(agent_update.agent_avatar_url)
|
||
if agent_update.agent_prompt is not None:
|
||
update_fields.append('agent_prompt = %s')
|
||
params.append(agent_update.agent_prompt)
|
||
if agent_update.agent_name is not None:
|
||
update_fields.append('agent_name = %s')
|
||
params.append(agent_update.agent_name)
|
||
if agent_update.is_publish is not None:
|
||
update_fields.append('is_publish = %s')
|
||
params.append(agent_update.is_publish)
|
||
if agent_update.voice_type is not None:
|
||
update_fields.append('voice_type = %s')
|
||
params.append(agent_update.voice_type)
|
||
if agent_update.temperature is not None:
|
||
update_fields.append('temperature = %s')
|
||
params.append(agent_update.temperature)
|
||
|
||
if update_fields:
|
||
params.append(card_id)
|
||
cur.execute(f"UPDATE agent_cards SET {', '.join(update_fields)} WHERE card_id = %s", params)
|
||
|
||
# 3. Update dynamic_agent_details table
|
||
# Check if details record exists (it should for dynamic agents, but create if missing to be safe?)
|
||
# Assuming it exists.
|
||
|
||
dynamic_fields = []
|
||
dynamic_params = []
|
||
|
||
if agent_update.videos is not None:
|
||
dynamic_fields.append('videos = %s')
|
||
dynamic_params.append(json.dumps([v.dict() for v in agent_update.videos]))
|
||
if agent_update.kb_id is not None:
|
||
dynamic_fields.append('kb_id = %s')
|
||
dynamic_params.append(agent_update.kb_id)
|
||
if agent_update.kb_config is not None:
|
||
dynamic_fields.append('kb_config = %s')
|
||
dynamic_params.append(json.dumps(agent_update.kb_config))
|
||
|
||
if dynamic_fields:
|
||
dynamic_params.append(card_id)
|
||
cur.execute(f"UPDATE dynamic_agent_details SET {', '.join(dynamic_fields)} WHERE card_id = %s", dynamic_params)
|
||
|
||
if not update_fields and not dynamic_fields:
|
||
raise HTTPException(status_code=400, detail='No fields to update')
|
||
|
||
conn.commit()
|
||
return {'message': 'Dynamic agent updated successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.delete('/dynamic_agent/{card_id}', tags=["Dynamic agent 管理"])
|
||
def delete_dynamic_agent(card_id: int, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('DELETE FROM agent_cards WHERE card_id = %s', (card_id,))
|
||
if cur.rowcount == 0:
|
||
raise HTTPException(status_code=404, detail='Agent not found')
|
||
# Cascading delete will handle dynamic_agent_details and agent_tools
|
||
conn.commit()
|
||
return {'message': 'Dynamic agent deleted successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get('/dynamic_agents/{phone_number}', tags=["Dynamic agent 管理"])
|
||
def list_dynamic_agents(phone_number: str, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# Join tables to get full details for all agents of type 1 (or all if desired, but user asked for dynamic agent management)
|
||
# Let's filter by type=1 for "Dynamic Agents" list, or just return all with left join.
|
||
# Usually specific management implies specific type.
|
||
query = '''
|
||
SELECT a.*, d.videos, d.kb_id, d.kb_config
|
||
FROM agent_cards a
|
||
LEFT JOIN dynamic_agent_details d ON a.card_id = d.card_id
|
||
WHERE a.phone_number = %s AND a.type = 1
|
||
ORDER BY a.card_id DESC
|
||
'''
|
||
cur.execute(query, (phone_number,))
|
||
agents = cur.fetchall()
|
||
|
||
result = []
|
||
columns = [desc[0] for desc in cur.description]
|
||
|
||
for row in agents:
|
||
agent_dict = dict(zip(columns, row))
|
||
|
||
# Fetch tools for each agent
|
||
cur.execute('''
|
||
SELECT t.*
|
||
FROM tools t
|
||
JOIN agent_tools at ON t.tool_id = at.tool_id
|
||
WHERE at.agent_card_id = %s
|
||
''', (agent_dict['card_id'],))
|
||
tools = cur.fetchall()
|
||
|
||
tool_list = []
|
||
tool_columns = [desc[0] for desc in cur.description]
|
||
for t in tools:
|
||
td = dict(zip(tool_columns, t))
|
||
if isinstance(td['parameters'], str):
|
||
td['parameters'] = json.loads(td['parameters'])
|
||
td['created_at'] = str(td['created_at'])
|
||
tool_list.append(td)
|
||
|
||
agent_dict['tools'] = tool_list
|
||
result.append(agent_dict)
|
||
|
||
return result
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get('/agent_id/{agent_id}', tags=["Agent管理"])
|
||
def get_agent(agent_id: str, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (agent_id,))
|
||
agent = cur.fetchone()
|
||
if not agent:
|
||
raise HTTPException(status_code=404, detail='Agent card not found')
|
||
|
||
columns = [desc[0] for desc in cur.description]
|
||
return dict(zip(columns, agent))
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
@app.put('/agent_id/{agent_id}', tags=["Agent管理"])
|
||
def update_agent_card_by_id(agent_id: str, agent_card: AgentCardUpdate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (agent_id,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Agent card not found')
|
||
|
||
# 构建动态更新的SQL语句和参数
|
||
update_fields = []
|
||
params = []
|
||
|
||
if agent_card.card_info is not None:
|
||
update_fields.append('card_info = %s')
|
||
params.append(agent_card.card_info)
|
||
if agent_card.agent_avatar_url is not None:
|
||
update_fields.append('agent_avatar_url = %s')
|
||
params.append(agent_card.agent_avatar_url)
|
||
if agent_card.agent_prompt is not None:
|
||
update_fields.append('agent_prompt = %s')
|
||
params.append(agent_card.agent_prompt)
|
||
if agent_card.agent_name is not None:
|
||
update_fields.append('agent_name = %s')
|
||
params.append(agent_card.agent_name)
|
||
if agent_card.is_publish is not None:
|
||
update_fields.append('is_publish = %s')
|
||
params.append(agent_card.is_publish)
|
||
if agent_card.voice_type is not None:
|
||
update_fields.append('voice_type = %s')
|
||
params.append(agent_card.voice_type)
|
||
if agent_card.temperature is not None:
|
||
update_fields.append('temperature = %s')
|
||
params.append(agent_card.temperature)
|
||
if agent_card.type is not None:
|
||
update_fields.append('type = %s')
|
||
params.append(agent_card.type)
|
||
|
||
if not update_fields:
|
||
raise HTTPException(status_code=400, detail='No fields to update')
|
||
|
||
params.append(agent_id)
|
||
|
||
update_query = f'''
|
||
UPDATE agent_cards SET
|
||
{', '.join(update_fields)}
|
||
WHERE card_id = %s
|
||
'''
|
||
|
||
cur.execute(update_query, params)
|
||
conn.commit()
|
||
return {'message': 'Agent card updated successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
|
||
|
||
@app.delete('/agent_id/{agent_id}', tags=["Agent管理"])
|
||
def delete_agent_card_by_id(agent_id: str, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (agent_id,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Agent card not found')
|
||
|
||
cur.execute('DELETE FROM agent_cards WHERE card_id = %s', (agent_id,))
|
||
if cur.rowcount == 0:
|
||
raise HTTPException(status_code=404, detail='Agent card not found')
|
||
conn.commit()
|
||
return {'message': 'Agent card deleted successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
# 图片上传目录配置
|
||
#UPLOAD_DIR = "/mnt/server/userImage/call_avator"
|
||
UPLOAD_DIR = "./"
|
||
if not os.path.exists(UPLOAD_DIR):
|
||
os.makedirs(UPLOAD_DIR)
|
||
|
||
# 允许的图片类型
|
||
ALLOWED_IMAGE_TYPES = ["image/jpeg", "image/png", "image/gif"]
|
||
|
||
@app.post('/upload_image/', tags=["文件上传"])
|
||
async def upload_image(file: UploadFile = File(...), current_user: dict = Depends(get_current_user)):
|
||
# 验证文件类型
|
||
if file.content_type not in ALLOWED_IMAGE_TYPES:
|
||
raise HTTPException(status_code=400, detail="Only image files are allowed")
|
||
|
||
try:
|
||
# 生成唯一文件名
|
||
file_ext = os.path.splitext(file.filename)[1]
|
||
unique_id = str(uuid.uuid4())
|
||
filename = f"{unique_id}{file_ext}"
|
||
file_path = os.path.join(UPLOAD_DIR, filename)
|
||
|
||
# 保存文件
|
||
with open(file_path, "wb") as buffer:
|
||
buffer.write(await file.read())
|
||
|
||
return {"message": "Image uploaded successfully", "file_id": unique_id, "file_path": file_path}
|
||
except Exception as e:
|
||
raise HTTPException(status_code=500, detail=f"Error uploading file: {str(e)}")
|
||
|
||
@app.put('/user/{phone}', tags=["用户管理"])
|
||
def update_user(phone: str, user_update: UserUpdate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM users WHERE phone_number = %s', (phone,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='User not found')
|
||
|
||
# 构建动态更新的SQL语句和参数
|
||
update_fields = []
|
||
params = []
|
||
|
||
if user_update.name is not None:
|
||
update_fields.append('user_name = %s')
|
||
params.append(user_update.name)
|
||
if user_update.user_details is not None:
|
||
update_fields.append('user_details = %s')
|
||
params.append(user_update.user_details)
|
||
if user_update.avatar is not None:
|
||
update_fields.append('avatar = %s')
|
||
params.append(user_update.avatar)
|
||
if user_update.email is not None:
|
||
update_fields.append('email = %s')
|
||
params.append(user_update.email)
|
||
if user_update.points is not None:
|
||
update_fields.append('points = %s')
|
||
params.append(user_update.points)
|
||
|
||
if not update_fields:
|
||
raise HTTPException(status_code=400, detail='No fields to update')
|
||
|
||
params.append(phone)
|
||
|
||
update_query = f'''
|
||
UPDATE users SET
|
||
{', '.join(update_fields)}
|
||
WHERE phone_number = %s
|
||
'''
|
||
|
||
cur.execute(update_query, params)
|
||
conn.commit()
|
||
return {'message': 'User updated successfully'}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
@app.post("/api/send-sms", tags=["短信服务"])
|
||
async def send_sms(sms_request: SMSRequest):
|
||
"""发送短信API"""
|
||
sms = SMS()
|
||
|
||
# 获取所有参数
|
||
request_data = sms_request.dict()
|
||
|
||
# 提取固定参数
|
||
phone_number = request_data.pop('phone_number')
|
||
template_code = request_data.pop('template_code')
|
||
sign_name = request_data.pop('sign_name')
|
||
|
||
# 过滤掉值为 None 的参数,剩余的作为动态参数
|
||
dynamic_params = {k: v for k, v in request_data.items() if v is not None}
|
||
|
||
print(f"Sending SMS with template_code: {template_code}, sign_name: {sign_name}, params: {dynamic_params}")
|
||
|
||
result = sms.main(
|
||
self=sms,
|
||
phone_number=phone_number,
|
||
template_code=template_code,
|
||
sign_name=sign_name,
|
||
**dynamic_params
|
||
)
|
||
|
||
# 保存发送记录到数据库
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
status = 'success' if result and result.get('success') else 'failed'
|
||
biz_id = result.get('biz_id') if result else None
|
||
error_message = result.get('error_message') if result else 'Unknown error'
|
||
|
||
# 将参数转换为字符串存储
|
||
if dynamic_params:
|
||
# 如果只有 code 参数,为了保持兼容性,优先尝试只存 code 值(如果这是之前的习惯)
|
||
# 但考虑到现在支持多参数,存 JSON 更加通用。
|
||
# 这里为了直观,如果有多个参数或没有 code,存 JSON。
|
||
# 如果只有 code,也可以存 JSON,因为 sms.py 内部最终是转 JSON 发送的。
|
||
saved_param = json.dumps(dynamic_params, ensure_ascii=False)
|
||
else:
|
||
saved_param = ""
|
||
|
||
cur.execute("""
|
||
INSERT INTO sms_records (phone_number, template_code, template_param, sign_name, status, biz_id, error_message)
|
||
VALUES (%s, %s, %s, %s, %s, %s, %s)
|
||
""", (phone_number, template_code, saved_param, sign_name, status, biz_id, error_message))
|
||
conn.commit()
|
||
except Exception as e:
|
||
print(f"Error saving SMS record: {e}")
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
if not result or not result.get('success'):
|
||
return {"status": "failed", "message": result.get('error_message') if result else "发送失败"}
|
||
|
||
return {"status": "success", "message": "短信发送请求已处理", "data": result}
|
||
|
||
class ConversationCreate(BaseModel):
|
||
user_phone: str | None = None
|
||
visitor_key: str | None = None
|
||
agent_card_id: int
|
||
|
||
class MessageCreate(BaseModel):
|
||
conversation_id: int
|
||
sender: str
|
||
content: str
|
||
order: int
|
||
|
||
@app.post('/conversations/', tags=["对话管理"])
|
||
def create_conversation(conversation: ConversationCreate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 验证用户或访客只能存在一个
|
||
if conversation.user_phone and conversation.visitor_key:
|
||
raise HTTPException(status_code=400, detail='Cannot specify both user_phone and visitor_key')
|
||
if not conversation.user_phone and not conversation.visitor_key:
|
||
raise HTTPException(status_code=400, detail='Must specify either user_phone or visitor_key')
|
||
|
||
# 验证agent_card_id是否存在
|
||
cur.execute('SELECT * FROM agent_cards WHERE card_id = %s', (conversation.agent_card_id,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Agent card not found')
|
||
|
||
# 验证visitor_key格式是否为有效的UUID
|
||
if conversation.visitor_key:
|
||
try:
|
||
uuid.UUID(conversation.visitor_key)
|
||
except ValueError:
|
||
raise HTTPException(status_code=400, detail='visitor_key must be a valid UUID format')
|
||
|
||
# 创建新对话
|
||
cur.execute('''
|
||
INSERT INTO conversations (user_phone, visitor_key, agent_card_id)
|
||
VALUES (%s, %s, %s)
|
||
RETURNING conversation_id
|
||
''', (conversation.user_phone, conversation.visitor_key, conversation.agent_card_id))
|
||
|
||
conversation_id = cur.fetchone()[0]
|
||
conn.commit()
|
||
return {'conversation_id': conversation_id}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.post('/messages/', tags=["对话管理"])
|
||
def create_message(message: MessageCreate, current_user: dict = Depends(get_current_user)):
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
# 验证对话是否存在
|
||
cur.execute('SELECT * FROM conversations WHERE conversation_id = %s', (message.conversation_id,))
|
||
if not cur.fetchone():
|
||
raise HTTPException(status_code=404, detail='Conversation not found')
|
||
|
||
# 验证sender只能是user或agent
|
||
if message.sender not in ['user', 'agent']:
|
||
raise HTTPException(status_code=400, detail="Sender must be either 'user' or 'agent'")
|
||
|
||
# 保存消息
|
||
cur.execute('''
|
||
INSERT INTO messages (conversation_id, sender, content, "order")
|
||
VALUES (%s, %s, %s, %s)
|
||
RETURNING message_id
|
||
''',(message.conversation_id, message.sender, message.content, message.order))
|
||
|
||
message_id = cur.fetchone()[0]
|
||
conn.commit()
|
||
return {'message_id': message_id}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
@app.get("/api/sms-records", tags=["短信服务"])
|
||
async def get_all_sms_records():
|
||
"""获取所有短信发送记录"""
|
||
conn = pool.getconn()
|
||
try:
|
||
with conn.cursor() as cur:
|
||
cur.execute('SELECT * FROM sms_records ORDER BY created_at DESC')
|
||
records = cur.fetchall()
|
||
|
||
columns = [desc[0] for desc in cur.description]
|
||
result = []
|
||
for row in records:
|
||
d = dict(zip(columns, row))
|
||
# Convert datetime to str
|
||
if 'created_at' in d:
|
||
d['created_at'] = str(d['created_at'])
|
||
result.append(d)
|
||
return {"status": "success", "data": result}
|
||
finally:
|
||
pool.putconn(conn)
|
||
|
||
|
||
# 阿里云OSS相关API
|
||
@app.post("/test/", tags=["阿里云OSS"])
|
||
async def test_oss():
|
||
"""测试OSS连接"""
|
||
return "string"
|
||
|
||
|
||
@app.post("/upload", tags=["阿里云OSS"])
|
||
async def upload_file(
|
||
file: UploadFile = File(...),
|
||
folder: str = "uploads"
|
||
):
|
||
"""
|
||
上传单个文件到OSS
|
||
|
||
Args:
|
||
file: 上传的文件
|
||
folder: 存储文件夹(可选)
|
||
|
||
Returns:
|
||
文件上传结果
|
||
"""
|
||
try:
|
||
# 读取文件内容
|
||
file_content = await file.read()
|
||
|
||
# 生成对象键
|
||
object_key = oss_service.generate_object_key(file.filename, folder)
|
||
|
||
# 上传文件
|
||
result = oss_service.upload_file(
|
||
file_content=file_content,
|
||
object_key=object_key,
|
||
content_type=file.content_type
|
||
)
|
||
|
||
return result
|
||
except Exception as e:
|
||
# 返回错误信息
|
||
return {
|
||
"success": False,
|
||
"message": f"上传失败: {str(e)}",
|
||
"object_key": None,
|
||
"file_url": None,
|
||
"etag": None,
|
||
"request_id": None,
|
||
"error_code": "InternalError"
|
||
}
|
||
|
||
|
||
@app.post("/upload/multiple", tags=["阿里云OSS"])
|
||
async def upload_multiple_files(
|
||
files: List[UploadFile] = File(...),
|
||
folder: str = "uploads"
|
||
):
|
||
"""
|
||
上传多个文件到OSS
|
||
|
||
Args:
|
||
files: 上传的文件列表
|
||
folder: 存储文件夹(可选)
|
||
|
||
Returns:
|
||
文件上传结果列表
|
||
"""
|
||
results = []
|
||
|
||
for file in files:
|
||
try:
|
||
# 读取文件内容
|
||
file_content = await file.read()
|
||
|
||
# 生成对象键
|
||
object_key = oss_service.generate_object_key(file.filename, folder)
|
||
|
||
# 上传文件
|
||
result = oss_service.upload_file(
|
||
file_content=file_content,
|
||
object_key=object_key,
|
||
content_type=file.content_type
|
||
)
|
||
|
||
results.append(result)
|
||
except Exception as e:
|
||
# 添加错误信息
|
||
results.append({
|
||
"success": False,
|
||
"message": f"上传失败: {str(e)}",
|
||
"object_key": None,
|
||
"file_url": None,
|
||
"etag": None,
|
||
"request_id": None,
|
||
"error_code": "InternalError"
|
||
})
|
||
|
||
return results
|
||
|
||
|
||
@app.get("/files", tags=["阿里云OSS"])
|
||
async def list_files(
|
||
prefix: str = "recordings/",
|
||
max_keys: int = 100
|
||
):
|
||
"""
|
||
列出OSS中的文件
|
||
|
||
Args:
|
||
prefix: 对象键前缀(可选)
|
||
max_keys: 最大返回数量
|
||
|
||
Returns:
|
||
文件列表
|
||
"""
|
||
# 验证max_keys参数
|
||
if max_keys > 1000:
|
||
return {
|
||
"success": False,
|
||
"message": "max_keys不能超过1000",
|
||
"files": [],
|
||
"count": 0
|
||
}
|
||
|
||
# 获取文件列表
|
||
result = oss_service.list_files(prefix=prefix, max_keys=max_keys)
|
||
|
||
# 确保返回格式包含message字段
|
||
if "message" not in result:
|
||
result["message"] = "获取文件列表成功"
|
||
|
||
return result
|
||
|
||
|
||
@app.get("/files/{object_key:path}/info", tags=["阿里云OSS"])
|
||
async def get_file_info(object_key: str):
|
||
"""
|
||
获取文件信息
|
||
|
||
Args:
|
||
object_key: OSS对象键
|
||
|
||
Returns:
|
||
文件信息
|
||
"""
|
||
# 获取文件信息
|
||
result = oss_service.get_file_info(object_key)
|
||
|
||
# 确保返回格式包含message字段
|
||
if "message" not in result:
|
||
if result["success"]:
|
||
result["message"] = "获取文件信息成功"
|
||
else:
|
||
result["message"] = "获取文件信息失败"
|
||
|
||
return result
|
||
|
||
|
||
@app.delete("/files/{object_key:path}", tags=["阿里云OSS"])
|
||
async def delete_file(object_key: str):
|
||
"""
|
||
删除OSS文件
|
||
|
||
Args:
|
||
object_key: OSS对象键
|
||
|
||
Returns:
|
||
删除结果
|
||
"""
|
||
# 删除文件
|
||
result = oss_service.delete_file(object_key)
|
||
|
||
# 确保返回格式包含所有必要字段
|
||
if "error_code" not in result:
|
||
result["error_code"] = None
|
||
|
||
return result |