AI生图
This commit is contained in:
80
api/ai.py
80
api/ai.py
@@ -1,7 +1,8 @@
|
||||
import os
|
||||
import uuid
|
||||
import httpx
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, Any
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
from config import settings
|
||||
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest
|
||||
@@ -13,16 +14,45 @@ logger = logging.getLogger(__name__)
|
||||
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||||
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
||||
|
||||
async def _download_image(url: str) -> str:
|
||||
"""
|
||||
下载图片并保存到media目录
|
||||
:param url: 图片URL
|
||||
:return: 本地文件相对路径 (e.g., "/media/xxx.png")
|
||||
"""
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, timeout=30.0)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to download image: {url}, status: {response.status_code}")
|
||||
return url # 下载失败返回原URL
|
||||
|
||||
# 生成文件名
|
||||
filename = f"{uuid.uuid4()}.png"
|
||||
filepath = os.path.join(settings.media_dir, filename)
|
||||
|
||||
# 写入文件
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
# 返回可访问的URL路径
|
||||
return f"/media/{filename}"
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error downloading image: {str(e)}")
|
||||
return url
|
||||
|
||||
async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str):
|
||||
if not settings.dashscope_api_key:
|
||||
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {settings.dashscope_api_key}",
|
||||
"X-DashScope-Async": "enable" # 确保异步任务提交
|
||||
"Authorization": f"Bearer {settings.dashscope_api_key}"
|
||||
# "X-DashScope-Async": "enable" # 移除强制异步,因为某些账号/模型不支持
|
||||
}
|
||||
|
||||
|
||||
# 构建请求体
|
||||
payload = {
|
||||
"model": model,
|
||||
@@ -55,7 +85,7 @@ async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n
|
||||
DASHSCOPE_API_URL,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30.0
|
||||
timeout=60.0 # 增加超时时间,同步请求可能较慢
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
@@ -68,23 +98,47 @@ async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n
|
||||
|
||||
result = response.json()
|
||||
|
||||
# 检查是异步任务返回还是同步结果返回
|
||||
task_id = None
|
||||
if "output" in result and "task_id" in result["output"]:
|
||||
task_id = result["output"]["task_id"]
|
||||
elif "task_id" in result:
|
||||
task_id = result["task_id"]
|
||||
else:
|
||||
if "output" in result and "task_id" in result["output"]:
|
||||
task_id = result["output"]["task_id"]
|
||||
else:
|
||||
logger.warning(f"Unexpected response structure: {result}")
|
||||
task_id = result.get("output", {}).get("task_id")
|
||||
|
||||
# 如果没有task_id,检查是否有直接结果(同步模式)
|
||||
if not task_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve task_id from DashScope response")
|
||||
if "output" in result and "choices" in result["output"]:
|
||||
# 同步返回成功
|
||||
# 提取结果以便直接返回
|
||||
choices = result["output"]["choices"]
|
||||
results = []
|
||||
for choice in choices:
|
||||
msg_content = choice.get("message", {}).get("content", [])
|
||||
for item in msg_content:
|
||||
if "image" in item:
|
||||
# 下载图片到本地
|
||||
local_url = await _download_image(item["image"])
|
||||
results.append({"url": local_url, "origin_url": item["image"]})
|
||||
|
||||
return AITaskResponse(
|
||||
request_id=result.get("request_id"),
|
||||
status="SUCCEEDED",
|
||||
results=results
|
||||
)
|
||||
|
||||
# 既没有task_id也没有choices,可能是其他结构或错误
|
||||
logger.warning(f"Unexpected response structure: {result}")
|
||||
# 尝试从output.task_id再找一次
|
||||
task_id = result.get("output", {}).get("task_id")
|
||||
|
||||
if not task_id:
|
||||
# 依然找不到,报错
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve task_id or results from DashScope response")
|
||||
|
||||
return AITaskResponse(
|
||||
task_id=task_id,
|
||||
request_id=result.get("request_id")
|
||||
request_id=result.get("request_id"),
|
||||
status="PENDING" # 异步任务初始状态
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
|
||||
Reference in New Issue
Block a user