Files
ESP32_GDEY042T81_server/api/ai.py
jeremygan2021 9620a4138d AI生图
2026-03-02 12:53:36 +08:00

234 lines
8.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import httpx
from fastapi import APIRouter, HTTPException, Depends
from typing import Dict, Any
import logging
from config import settings
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest
from api.prompts import get_prompt
router = APIRouter()
logger = logging.getLogger(__name__)
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
async def _download_image(url: str) -> str:
"""
下载图片并保存到media目录
:param url: 图片URL
:return: 本地文件相对路径 (e.g., "/media/xxx.png")
"""
try:
async with httpx.AsyncClient() as client:
response = await client.get(url, timeout=30.0)
if response.status_code != 200:
logger.error(f"Failed to download image: {url}, status: {response.status_code}")
return url # 下载失败返回原URL
# 生成文件名
filename = f"{uuid.uuid4()}.png"
filepath = os.path.join(settings.media_dir, filename)
# 写入文件
with open(filepath, "wb") as f:
f.write(response.content)
# 返回可访问的URL路径
return f"/media/{filename}"
except Exception as e:
logger.error(f"Error downloading image: {str(e)}")
return url
async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str):
if not settings.dashscope_api_key:
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {settings.dashscope_api_key}"
# "X-DashScope-Async": "enable" # 移除强制异步,因为某些账号/模型不支持
}
# 构建请求体
payload = {
"model": model,
"input": {
"messages": [
{
"role": "user",
"content": [
{
"text": prompt
}
]
}
]
},
"parameters": {
"prompt_extend": True,
"watermark": False,
"n": n,
"size": size
}
}
if negative_prompt:
payload["parameters"]["negative_prompt"] = negative_prompt
try:
async with httpx.AsyncClient() as client:
response = await client.post(
DASHSCOPE_API_URL,
headers=headers,
json=payload,
timeout=60.0 # 增加超时时间,同步请求可能较慢
)
if response.status_code != 200:
logger.error(f"DashScope API error: {response.text}")
try:
error_detail = response.json()
except:
error_detail = {"message": response.text}
raise HTTPException(status_code=response.status_code, detail=error_detail)
result = response.json()
# 检查是异步任务返回还是同步结果返回
task_id = None
if "output" in result and "task_id" in result["output"]:
task_id = result["output"]["task_id"]
elif "task_id" in result:
task_id = result["task_id"]
# 如果没有task_id检查是否有直接结果同步模式
if not task_id:
if "output" in result and "choices" in result["output"]:
# 同步返回成功
# 提取结果以便直接返回
choices = result["output"]["choices"]
results = []
for choice in choices:
msg_content = choice.get("message", {}).get("content", [])
for item in msg_content:
if "image" in item:
# 下载图片到本地
local_url = await _download_image(item["image"])
results.append({"url": local_url, "origin_url": item["image"]})
return AITaskResponse(
request_id=result.get("request_id"),
status="SUCCEEDED",
results=results
)
# 既没有task_id也没有choices可能是其他结构或错误
logger.warning(f"Unexpected response structure: {result}")
# 尝试从output.task_id再找一次
task_id = result.get("output", {}).get("task_id")
if not task_id:
# 依然找不到,报错
raise HTTPException(status_code=500, detail="Failed to retrieve task_id or results from DashScope response")
return AITaskResponse(
task_id=task_id,
request_id=result.get("request_id"),
status="PENDING" # 异步任务初始状态
)
except httpx.RequestError as e:
logger.error(f"Request error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
async def generate_image(request: AIGenerationRequest):
"""
提交AI图片生成任务使用阿里云DashScope服务
"""
return await _submit_dashscope_task(
prompt=request.prompt,
negative_prompt=request.negative_prompt,
size=request.size,
n=request.n,
model=request.model
)
@router.post("/generate_from_template", response_model=AITaskResponse, summary="使用模板提交AI图片生成任务")
async def generate_image_from_template(request: AITemplateGenerationRequest):
"""
使用预设模板提交AI图片生成任务
"""
try:
prompt = get_prompt(request.template_id, **request.params)
if prompt is None:
raise HTTPException(status_code=404, detail=f"未找到模板ID: {request.template_id}")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return await _submit_dashscope_task(
prompt=prompt,
negative_prompt=request.negative_prompt,
size=request.size,
n=request.n,
model=request.model
)
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
async def get_task_result(task_id: str):
"""
查询AI生成任务的结果
"""
if not settings.dashscope_api_key:
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
headers = {
"Authorization": f"Bearer {settings.dashscope_api_key}"
}
try:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{DASHSCOPE_TASK_URL}/{task_id}",
headers=headers,
timeout=30.0
)
if response.status_code != 200:
logger.error(f"DashScope Task API error: {response.text}")
raise HTTPException(status_code=response.status_code, detail="Failed to fetch task status")
result = response.json()
# 构建返回结果
# DashScope 任务查询返回结构:
# {
# "request_id": "...",
# "output": {
# "task_id": "...",
# "task_status": "SUCCEEDED",
# "results": [...]
# },
# "usage": ...
# }
task_status = result.get("output", {}).get("task_status", "UNKNOWN")
return AITaskResult(
task_id=task_id,
status=task_status,
code=result.get("code"),
message=result.get("message"),
output=result.get("output"),
usage=result.get("usage")
)
except httpx.RequestError as e:
logger.error(f"Request error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")