AI图片
This commit is contained in:
66
api/ai.py
66
api/ai.py
@@ -4,7 +4,8 @@ import httpx
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from config import settings
|
from config import settings
|
||||||
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult
|
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest
|
||||||
|
from api.prompts import get_prompt
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -12,11 +13,7 @@ logger = logging.getLogger(__name__)
|
|||||||
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||||||
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
||||||
|
|
||||||
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
|
async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str):
|
||||||
async def generate_image(request: AIGenerationRequest):
|
|
||||||
"""
|
|
||||||
提交AI图片生成任务,使用阿里云DashScope服务
|
|
||||||
"""
|
|
||||||
if not settings.dashscope_api_key:
|
if not settings.dashscope_api_key:
|
||||||
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||||||
|
|
||||||
@@ -28,14 +25,14 @@ async def generate_image(request: AIGenerationRequest):
|
|||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
payload = {
|
payload = {
|
||||||
"model": request.model,
|
"model": model,
|
||||||
"input": {
|
"input": {
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{
|
||||||
"text": request.prompt
|
"text": prompt
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -44,13 +41,13 @@ async def generate_image(request: AIGenerationRequest):
|
|||||||
"parameters": {
|
"parameters": {
|
||||||
"prompt_extend": True,
|
"prompt_extend": True,
|
||||||
"watermark": False,
|
"watermark": False,
|
||||||
"n": request.n,
|
"n": n,
|
||||||
"size": request.size
|
"size": size
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if request.negative_prompt:
|
if negative_prompt:
|
||||||
payload["parameters"]["negative_prompt"] = request.negative_prompt
|
payload["parameters"]["negative_prompt"] = negative_prompt
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
@@ -70,25 +67,16 @@ async def generate_image(request: AIGenerationRequest):
|
|||||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
# 检查是否有task_id,因为如果是同步返回可能没有task_id,但在这种模型通常是异步的
|
|
||||||
# 这里的curl示例似乎是直接返回task_id
|
|
||||||
|
|
||||||
if "output" in result and "task_id" in result["output"]:
|
if "output" in result and "task_id" in result["output"]:
|
||||||
task_id = result["output"]["task_id"]
|
task_id = result["output"]["task_id"]
|
||||||
elif "task_id" in result: # 或者是这种结构
|
elif "task_id" in result:
|
||||||
task_id = result["task_id"]
|
task_id = result["task_id"]
|
||||||
else:
|
else:
|
||||||
# 有些情况下直接返回 output.task_id
|
|
||||||
# 根据文档 https://help.aliyun.com/zh/dashscope/developer-reference/api-details-10
|
|
||||||
# 异步提交返回结构通常包含 output.task_id
|
|
||||||
if "output" in result and "task_id" in result["output"]:
|
if "output" in result and "task_id" in result["output"]:
|
||||||
task_id = result["output"]["task_id"]
|
task_id = result["output"]["task_id"]
|
||||||
else:
|
else:
|
||||||
# 如果是同步返回,可能直接给结果,或者结构不同
|
|
||||||
# 但 wan2.6-t2i 通常是异步任务
|
|
||||||
logger.warning(f"Unexpected response structure: {result}")
|
logger.warning(f"Unexpected response structure: {result}")
|
||||||
# 尝试直接取,或者抛错
|
|
||||||
# 假设是标准异步结构
|
|
||||||
task_id = result.get("output", {}).get("task_id")
|
task_id = result.get("output", {}).get("task_id")
|
||||||
|
|
||||||
if not task_id:
|
if not task_id:
|
||||||
@@ -103,6 +91,40 @@ async def generate_image(request: AIGenerationRequest):
|
|||||||
logger.error(f"Request error: {str(e)}")
|
logger.error(f"Request error: {str(e)}")
|
||||||
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||||||
|
|
||||||
|
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
|
||||||
|
async def generate_image(request: AIGenerationRequest):
|
||||||
|
"""
|
||||||
|
提交AI图片生成任务,使用阿里云DashScope服务
|
||||||
|
"""
|
||||||
|
return await _submit_dashscope_task(
|
||||||
|
prompt=request.prompt,
|
||||||
|
negative_prompt=request.negative_prompt,
|
||||||
|
size=request.size,
|
||||||
|
n=request.n,
|
||||||
|
model=request.model
|
||||||
|
)
|
||||||
|
|
||||||
|
@router.post("/generate_from_template", response_model=AITaskResponse, summary="使用模板提交AI图片生成任务")
|
||||||
|
async def generate_image_from_template(request: AITemplateGenerationRequest):
|
||||||
|
"""
|
||||||
|
使用预设模板提交AI图片生成任务
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
prompt = get_prompt(request.template_id, **request.params)
|
||||||
|
if prompt is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"未找到模板ID: {request.template_id}")
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
return await _submit_dashscope_task(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=request.negative_prompt,
|
||||||
|
size=request.size,
|
||||||
|
n=request.n,
|
||||||
|
model=request.model
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
|
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
|
||||||
async def get_task_result(task_id: str):
|
async def get_task_result(task_id: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user