180 lines
6.1 KiB
Python
180 lines
6.1 KiB
Python
from fastapi import APIRouter, HTTPException, Depends
|
||
from typing import Dict, Any
|
||
import httpx
|
||
import json
|
||
import logging
|
||
from config import settings
|
||
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest
|
||
from api.prompts import get_prompt
|
||
|
||
router = APIRouter()
|
||
logger = logging.getLogger(__name__)
|
||
|
||
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
||
|
||
async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str):
|
||
if not settings.dashscope_api_key:
|
||
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {settings.dashscope_api_key}",
|
||
"X-DashScope-Async": "enable" # 确保异步任务提交
|
||
}
|
||
|
||
# 构建请求体
|
||
payload = {
|
||
"model": model,
|
||
"input": {
|
||
"messages": [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{
|
||
"text": prompt
|
||
}
|
||
]
|
||
}
|
||
]
|
||
},
|
||
"parameters": {
|
||
"prompt_extend": True,
|
||
"watermark": False,
|
||
"n": n,
|
||
"size": size
|
||
}
|
||
}
|
||
|
||
if negative_prompt:
|
||
payload["parameters"]["negative_prompt"] = negative_prompt
|
||
|
||
try:
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.post(
|
||
DASHSCOPE_API_URL,
|
||
headers=headers,
|
||
json=payload,
|
||
timeout=30.0
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"DashScope API error: {response.text}")
|
||
try:
|
||
error_detail = response.json()
|
||
except:
|
||
error_detail = {"message": response.text}
|
||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||
|
||
result = response.json()
|
||
|
||
if "output" in result and "task_id" in result["output"]:
|
||
task_id = result["output"]["task_id"]
|
||
elif "task_id" in result:
|
||
task_id = result["task_id"]
|
||
else:
|
||
if "output" in result and "task_id" in result["output"]:
|
||
task_id = result["output"]["task_id"]
|
||
else:
|
||
logger.warning(f"Unexpected response structure: {result}")
|
||
task_id = result.get("output", {}).get("task_id")
|
||
|
||
if not task_id:
|
||
raise HTTPException(status_code=500, detail="Failed to retrieve task_id from DashScope response")
|
||
|
||
return AITaskResponse(
|
||
task_id=task_id,
|
||
request_id=result.get("request_id")
|
||
)
|
||
|
||
except httpx.RequestError as e:
|
||
logger.error(f"Request error: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||
|
||
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
|
||
async def generate_image(request: AIGenerationRequest):
|
||
"""
|
||
提交AI图片生成任务,使用阿里云DashScope服务
|
||
"""
|
||
return await _submit_dashscope_task(
|
||
prompt=request.prompt,
|
||
negative_prompt=request.negative_prompt,
|
||
size=request.size,
|
||
n=request.n,
|
||
model=request.model
|
||
)
|
||
|
||
@router.post("/generate_from_template", response_model=AITaskResponse, summary="使用模板提交AI图片生成任务")
|
||
async def generate_image_from_template(request: AITemplateGenerationRequest):
|
||
"""
|
||
使用预设模板提交AI图片生成任务
|
||
"""
|
||
try:
|
||
prompt = get_prompt(request.template_id, **request.params)
|
||
if prompt is None:
|
||
raise HTTPException(status_code=404, detail=f"未找到模板ID: {request.template_id}")
|
||
except ValueError as e:
|
||
raise HTTPException(status_code=400, detail=str(e))
|
||
|
||
return await _submit_dashscope_task(
|
||
prompt=prompt,
|
||
negative_prompt=request.negative_prompt,
|
||
size=request.size,
|
||
n=request.n,
|
||
model=request.model
|
||
)
|
||
|
||
|
||
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
|
||
async def get_task_result(task_id: str):
|
||
"""
|
||
查询AI生成任务的结果
|
||
"""
|
||
if not settings.dashscope_api_key:
|
||
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {settings.dashscope_api_key}"
|
||
}
|
||
|
||
try:
|
||
async with httpx.AsyncClient() as client:
|
||
response = await client.get(
|
||
f"{DASHSCOPE_TASK_URL}/{task_id}",
|
||
headers=headers,
|
||
timeout=30.0
|
||
)
|
||
|
||
if response.status_code != 200:
|
||
logger.error(f"DashScope Task API error: {response.text}")
|
||
raise HTTPException(status_code=response.status_code, detail="Failed to fetch task status")
|
||
|
||
result = response.json()
|
||
|
||
# 构建返回结果
|
||
# DashScope 任务查询返回结构:
|
||
# {
|
||
# "request_id": "...",
|
||
# "output": {
|
||
# "task_id": "...",
|
||
# "task_status": "SUCCEEDED",
|
||
# "results": [...]
|
||
# },
|
||
# "usage": ...
|
||
# }
|
||
|
||
task_status = result.get("output", {}).get("task_status", "UNKNOWN")
|
||
|
||
return AITaskResult(
|
||
task_id=task_id,
|
||
status=task_status,
|
||
code=result.get("code"),
|
||
message=result.get("message"),
|
||
output=result.get("output"),
|
||
usage=result.get("usage")
|
||
)
|
||
|
||
except httpx.RequestError as e:
|
||
logger.error(f"Request error: {str(e)}")
|
||
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|