Files
ESP32_GDEY042T81_server/api/ai.py
jeremygan2021 37b2cf6ba6 AI图片
2026-03-02 12:32:53 +08:00

180 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any
import httpx
import json
import logging
from config import settings
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest
from api.prompts import get_prompt
router = APIRouter()
logger = logging.getLogger(__name__)
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str):
if not settings.dashscope_api_key:
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.dashscope_api_key}",
"X-DashScope-Async": "enable" # 确保异步任务提交
}
# 构建请求体
payload = {
"model": model,
"input": {
"messages": [
{
"role": "user",
"content": [
{
"text": prompt
}
]
}
]
},
"parameters": {
"prompt_extend": True,
"watermark": False,
"n": n,
"size": size
}
}
if negative_prompt:
payload["parameters"]["negative_prompt"] = negative_prompt
try:
async with httpx.AsyncClient() as client:
response = await client.post(
DASHSCOPE_API_URL,
headers=headers,
json=payload,
timeout=30.0
)
if response.status_code != 200:
logger.error(f"DashScope API error: {response.text}")
try:
error_detail = response.json()
except:
error_detail = {"message": response.text}
raise HTTPException(status_code=response.status_code, detail=error_detail)
result = response.json()
if "output" in result and "task_id" in result["output"]:
task_id = result["output"]["task_id"]
elif "task_id" in result:
task_id = result["task_id"]
else:
if "output" in result and "task_id" in result["output"]:
task_id = result["output"]["task_id"]
else:
logger.warning(f"Unexpected response structure: {result}")
task_id = result.get("output", {}).get("task_id")
if not task_id:
raise HTTPException(status_code=500, detail="Failed to retrieve task_id from DashScope response")
return AITaskResponse(
task_id=task_id,
request_id=result.get("request_id")
)
except httpx.RequestError as e:
logger.error(f"Request error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
async def generate_image(request: AIGenerationRequest):
"""
提交AI图片生成任务使用阿里云DashScope服务
"""
return await _submit_dashscope_task(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
size=request.size,
n=request.n,
model=request.model
)
@router.post("/generate_from_template", response_model=AITaskResponse, summary="使用模板提交AI图片生成任务")
async def generate_image_from_template(request: AITemplateGenerationRequest):
"""
使用预设模板提交AI图片生成任务
"""
try:
prompt = get_prompt(request.template_id, **request.params)
if prompt is None:
raise HTTPException(status_code=404, detail=f"未找到模板ID: {request.template_id}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return await _submit_dashscope_task(
prompt=prompt,
negative_prompt=request.negative_prompt,
size=request.size,
n=request.n,
model=request.model
)
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
async def get_task_result(task_id: str):
"""
查询AI生成任务的结果
"""
if not settings.dashscope_api_key:
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
headers = {
"Authorization": f"Bearer {settings.dashscope_api_key}"
}
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{DASHSCOPE_TASK_URL}/{task_id}",
headers=headers,
timeout=30.0
)
if response.status_code != 200:
logger.error(f"DashScope Task API error: {response.text}")
raise HTTPException(status_code=response.status_code, detail="Failed to fetch task status")
result = response.json()
# 构建返回结果
# DashScope 任务查询返回结构:
# {
# "request_id": "...",
# "output": {
# "task_id": "...",
# "task_status": "SUCCEEDED",
# "results": [...]
# },
# "usage": ...
# }
task_status = result.get("output", {}).get("task_status", "UNKNOWN")
return AITaskResult(
task_id=task_id,
status=task_status,
code=result.get("code"),
message=result.get("message"),
output=result.get("output"),
usage=result.get("usage")
)
except httpx.RequestError as e:
logger.error(f"Request error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")