import os import uuid import httpx from fastapi import APIRouter, HTTPException, Depends from typing import Dict, Any import logging from config import settings from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest from api.prompts import get_prompt router = APIRouter() logger = logging.getLogger(__name__) DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation" DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks" async def _download_image(url: str) -> str: """ 下载图片并保存到media目录 :param url: 图片URL :return: 本地文件相对路径 (e.g., "/media/xxx.png") """ try: async with httpx.AsyncClient() as client: response = await client.get(url, timeout=30.0) if response.status_code != 200: logger.error(f"Failed to download image: {url}, status: {response.status_code}") return url # 下载失败返回原URL # 生成文件名 filename = f"{uuid.uuid4()}.png" filepath = os.path.join(settings.media_dir, filename) # 写入文件 with open(filepath, "wb") as f: f.write(response.content) # 返回可访问的URL路径 return f"/media/{filename}" except Exception as e: logger.error(f"Error downloading image: {str(e)}") return url async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str): if not settings.dashscope_api_key: raise HTTPException(status_code=500, detail="DashScope API Key not configured") headers = { "Content-Type": "application/json", "Authorization": f"Bearer {settings.dashscope_api_key}" # "X-DashScope-Async": "enable" # 移除强制异步,因为某些账号/模型不支持 } # 构建请求体 payload = { "model": model, "input": { "messages": [ { "role": "user", "content": [ { "text": prompt } ] } ] }, "parameters": { "prompt_extend": True, "watermark": False, "n": n, "size": size } } if negative_prompt: payload["parameters"]["negative_prompt"] = negative_prompt try: async with httpx.AsyncClient() as client: response = await client.post( DASHSCOPE_API_URL, headers=headers, json=payload, timeout=60.0 # 增加超时时间,同步请求可能较慢 ) if response.status_code != 200: logger.error(f"DashScope API error: {response.text}") try: error_detail = response.json() except: error_detail = {"message": response.text} raise HTTPException(status_code=response.status_code, detail=error_detail) result = response.json() # 检查是异步任务返回还是同步结果返回 task_id = None if "output" in result and "task_id" in result["output"]: task_id = result["output"]["task_id"] elif "task_id" in result: task_id = result["task_id"] # 如果没有task_id,检查是否有直接结果(同步模式) if not task_id: if "output" in result and "choices" in result["output"]: # 同步返回成功 # 提取结果以便直接返回 choices = result["output"]["choices"] results = [] for choice in choices: msg_content = choice.get("message", {}).get("content", []) for item in msg_content: if "image" in item: # 下载图片到本地 local_url = await _download_image(item["image"]) results.append({"url": local_url, "origin_url": item["image"]}) return AITaskResponse( request_id=result.get("request_id"), status="SUCCEEDED", results=results ) # 既没有task_id也没有choices,可能是其他结构或错误 logger.warning(f"Unexpected response structure: {result}") # 尝试从output.task_id再找一次 task_id = result.get("output", {}).get("task_id") if not task_id: # 依然找不到,报错 raise HTTPException(status_code=500, detail="Failed to retrieve task_id or results from DashScope response") return AITaskResponse( task_id=task_id, request_id=result.get("request_id"), status="PENDING" # 异步任务初始状态 ) except httpx.RequestError as e: logger.error(f"Request error: {str(e)}") raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}") @router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务") async def generate_image(request: AIGenerationRequest): """ 提交AI图片生成任务,使用阿里云DashScope服务 """ return await _submit_dashscope_task( prompt=request.prompt, negative_prompt=request.negative_prompt, size=request.size, n=request.n, model=request.model ) @router.post("/generate_from_template", response_model=AITaskResponse, summary="使用模板提交AI图片生成任务") async def generate_image_from_template(request: AITemplateGenerationRequest): """ 使用预设模板提交AI图片生成任务 """ try: prompt = get_prompt(request.template_id, **request.params) if prompt is None: raise HTTPException(status_code=404, detail=f"未找到模板ID: {request.template_id}") except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return await _submit_dashscope_task( prompt=prompt, negative_prompt=request.negative_prompt, size=request.size, n=request.n, model=request.model ) @router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果") async def get_task_result(task_id: str): """ 查询AI生成任务的结果 """ if not settings.dashscope_api_key: raise HTTPException(status_code=500, detail="DashScope API Key not configured") headers = { "Authorization": f"Bearer {settings.dashscope_api_key}" } try: async with httpx.AsyncClient() as client: response = await client.get( f"{DASHSCOPE_TASK_URL}/{task_id}", headers=headers, timeout=30.0 ) if response.status_code != 200: logger.error(f"DashScope Task API error: {response.text}") raise HTTPException(status_code=response.status_code, detail="Failed to fetch task status") result = response.json() # 构建返回结果 # DashScope 任务查询返回结构: # { # "request_id": "...", # "output": { # "task_id": "...", # "task_status": "SUCCEEDED", # "results": [...] # }, # "usage": ... # } task_status = result.get("output", {}).get("task_status", "UNKNOWN") return AITaskResult( task_id=task_id, status=task_status, code=result.get("code"), message=result.get("message"), output=result.get("output"), usage=result.get("usage") ) except httpx.RequestError as e: logger.error(f"Request error: {str(e)}") raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")