diff --git a/api/ai.py b/api/ai.py index 4da2ca6..485423d 100644 --- a/api/ai.py +++ b/api/ai.py @@ -4,7 +4,8 @@ import httpx import json import logging from config import settings -from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult +from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest +from api.prompts import get_prompt router = APIRouter() logger = logging.getLogger(__name__) @@ -12,11 +13,7 @@ logger = logging.getLogger(__name__) DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation" DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks" -@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务") -async def generate_image(request: AIGenerationRequest): - """ - 提交AI图片生成任务,使用阿里云DashScope服务 - """ +async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str): if not settings.dashscope_api_key: raise HTTPException(status_code=500, detail="DashScope API Key not configured") @@ -28,14 +25,14 @@ async def generate_image(request: AIGenerationRequest): # 构建请求体 payload = { - "model": request.model, + "model": model, "input": { "messages": [ { "role": "user", "content": [ { - "text": request.prompt + "text": prompt } ] } @@ -44,13 +41,13 @@ async def generate_image(request: AIGenerationRequest): "parameters": { "prompt_extend": True, "watermark": False, - "n": request.n, - "size": request.size + "n": n, + "size": size } } - if request.negative_prompt: - payload["parameters"]["negative_prompt"] = request.negative_prompt + if negative_prompt: + payload["parameters"]["negative_prompt"] = negative_prompt try: async with httpx.AsyncClient() as client: @@ -70,25 +67,16 @@ async def generate_image(request: AIGenerationRequest): raise HTTPException(status_code=response.status_code, detail=error_detail) result = response.json() - # 检查是否有task_id,因为如果是同步返回可能没有task_id,但在这种模型通常是异步的 - # 这里的curl示例似乎是直接返回task_id if "output" in result and "task_id" in result["output"]: task_id = result["output"]["task_id"] - elif "task_id" in result: # 或者是这种结构 + elif "task_id" in result: task_id = result["task_id"] else: - # 有些情况下直接返回 output.task_id - # 根据文档 https://help.aliyun.com/zh/dashscope/developer-reference/api-details-10 - # 异步提交返回结构通常包含 output.task_id if "output" in result and "task_id" in result["output"]: task_id = result["output"]["task_id"] else: - # 如果是同步返回,可能直接给结果,或者结构不同 - # 但 wan2.6-t2i 通常是异步任务 logger.warning(f"Unexpected response structure: {result}") - # 尝试直接取,或者抛错 - # 假设是标准异步结构 task_id = result.get("output", {}).get("task_id") if not task_id: @@ -103,6 +91,40 @@ async def generate_image(request: AIGenerationRequest): logger.error(f"Request error: {str(e)}") raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}") +@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务") +async def generate_image(request: AIGenerationRequest): + """ + 提交AI图片生成任务,使用阿里云DashScope服务 + """ + return await _submit_dashscope_task( + prompt=request.prompt, + negative_prompt=request.negative_prompt, + size=request.size, + n=request.n, + model=request.model + ) + +@router.post("/generate_from_template", response_model=AITaskResponse, summary="使用模板提交AI图片生成任务") +async def generate_image_from_template(request: AITemplateGenerationRequest): + """ + 使用预设模板提交AI图片生成任务 + """ + try: + prompt = get_prompt(request.template_id, **request.params) + if prompt is None: + raise HTTPException(status_code=404, detail=f"未找到模板ID: {request.template_id}") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + + return await _submit_dashscope_task( + prompt=prompt, + negative_prompt=request.negative_prompt, + size=request.size, + n=request.n, + model=request.model + ) + + @router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果") async def get_task_result(task_id: str): """