Files
ESP32_GDEY042T81_server/api/ai.py
jeremygan2021 82bba110ee AI图片
2026-03-02 12:32:45 +08:00

158 lines
5.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any
import httpx
import json
import logging
from config import settings
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult
router = APIRouter()
logger = logging.getLogger(__name__)
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
async def generate_image(request: AIGenerationRequest):
"""
提交AI图片生成任务使用阿里云DashScope服务
"""
if not settings.dashscope_api_key:
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.dashscope_api_key}",
"X-DashScope-Async": "enable" # 确保异步任务提交
}
# 构建请求体
payload = {
"model": request.model,
"input": {
"messages": [
{
"role": "user",
"content": [
{
"text": request.prompt
}
]
}
]
},
"parameters": {
"prompt_extend": True,
"watermark": False,
"n": request.n,
"size": request.size
}
}
if request.negative_prompt:
payload["parameters"]["negative_prompt"] = request.negative_prompt
try:
async with httpx.AsyncClient() as client:
response = await client.post(
DASHSCOPE_API_URL,
headers=headers,
json=payload,
timeout=30.0
)
if response.status_code != 200:
logger.error(f"DashScope API error: {response.text}")
try:
error_detail = response.json()
except:
error_detail = {"message": response.text}
raise HTTPException(status_code=response.status_code, detail=error_detail)
result = response.json()
# 检查是否有task_id因为如果是同步返回可能没有task_id但在这种模型通常是异步的
# 这里的curl示例似乎是直接返回task_id
if "output" in result and "task_id" in result["output"]:
task_id = result["output"]["task_id"]
elif "task_id" in result: # 或者是这种结构
task_id = result["task_id"]
else:
# 有些情况下直接返回 output.task_id
# 根据文档 https://help.aliyun.com/zh/dashscope/developer-reference/api-details-10
# 异步提交返回结构通常包含 output.task_id
if "output" in result and "task_id" in result["output"]:
task_id = result["output"]["task_id"]
else:
# 如果是同步返回,可能直接给结果,或者结构不同
# 但 wan2.6-t2i 通常是异步任务
logger.warning(f"Unexpected response structure: {result}")
# 尝试直接取,或者抛错
# 假设是标准异步结构
task_id = result.get("output", {}).get("task_id")
if not task_id:
raise HTTPException(status_code=500, detail="Failed to retrieve task_id from DashScope response")
return AITaskResponse(
task_id=task_id,
request_id=result.get("request_id")
)
except httpx.RequestError as e:
logger.error(f"Request error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
async def get_task_result(task_id: str):
"""
查询AI生成任务的结果
"""
if not settings.dashscope_api_key:
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
headers = {
"Authorization": f"Bearer {settings.dashscope_api_key}"
}
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{DASHSCOPE_TASK_URL}/{task_id}",
headers=headers,
timeout=30.0
)
if response.status_code != 200:
logger.error(f"DashScope Task API error: {response.text}")
raise HTTPException(status_code=response.status_code, detail="Failed to fetch task status")
result = response.json()
# 构建返回结果
# DashScope 任务查询返回结构:
# {
# "request_id": "...",
# "output": {
# "task_id": "...",
# "task_status": "SUCCEEDED",
# "results": [...]
# },
# "usage": ...
# }
task_status = result.get("output", {}).get("task_status", "UNKNOWN")
return AITaskResult(
task_id=task_id,
status=task_status,
code=result.get("code"),
message=result.get("message"),
output=result.get("output"),
usage=result.get("usage")
)
except httpx.RequestError as e:
logger.error(f"Request error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")