AI图片
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
from fastapi import APIRouter
|
||||
from api import devices, contents, todos
|
||||
from api import devices, contents, todos, ai
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
# 注册所有路由,并添加全局安全要求
|
||||
api_router.include_router(devices.router, prefix="/devices")
|
||||
api_router.include_router(contents.router, prefix="/contents")
|
||||
api_router.include_router(todos.router, prefix="/todos")
|
||||
api_router.include_router(todos.router, prefix="/todos")
|
||||
api_router.include_router(ai.router, prefix="/ai", tags=["AI生成"])
|
||||
Binary file not shown.
Binary file not shown.
BIN
api/__pycache__/ai.cpython-312.pyc
Normal file
BIN
api/__pycache__/ai.cpython-312.pyc
Normal file
Binary file not shown.
BIN
api/__pycache__/ai_schemas.cpython-312.pyc
Normal file
BIN
api/__pycache__/ai_schemas.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
157
api/ai.py
Normal file
157
api/ai.py
Normal file
@@ -0,0 +1,157 @@
|
||||
from fastapi import APIRouter, HTTPException, Depends
|
||||
from typing import Dict, Any
|
||||
import httpx
|
||||
import json
|
||||
import logging
|
||||
from config import settings
|
||||
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult
|
||||
|
||||
router = APIRouter()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||||
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
||||
|
||||
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
|
||||
async def generate_image(request: AIGenerationRequest):
|
||||
"""
|
||||
提交AI图片生成任务,使用阿里云DashScope服务
|
||||
"""
|
||||
if not settings.dashscope_api_key:
|
||||
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {settings.dashscope_api_key}",
|
||||
"X-DashScope-Async": "enable" # 确保异步任务提交
|
||||
}
|
||||
|
||||
# 构建请求体
|
||||
payload = {
|
||||
"model": request.model,
|
||||
"input": {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"text": request.prompt
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"parameters": {
|
||||
"prompt_extend": True,
|
||||
"watermark": False,
|
||||
"n": request.n,
|
||||
"size": request.size
|
||||
}
|
||||
}
|
||||
|
||||
if request.negative_prompt:
|
||||
payload["parameters"]["negative_prompt"] = request.negative_prompt
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
DASHSCOPE_API_URL,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DashScope API error: {response.text}")
|
||||
try:
|
||||
error_detail = response.json()
|
||||
except:
|
||||
error_detail = {"message": response.text}
|
||||
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||
|
||||
result = response.json()
|
||||
# 检查是否有task_id,因为如果是同步返回可能没有task_id,但在这种模型通常是异步的
|
||||
# 这里的curl示例似乎是直接返回task_id
|
||||
|
||||
if "output" in result and "task_id" in result["output"]:
|
||||
task_id = result["output"]["task_id"]
|
||||
elif "task_id" in result: # 或者是这种结构
|
||||
task_id = result["task_id"]
|
||||
else:
|
||||
# 有些情况下直接返回 output.task_id
|
||||
# 根据文档 https://help.aliyun.com/zh/dashscope/developer-reference/api-details-10
|
||||
# 异步提交返回结构通常包含 output.task_id
|
||||
if "output" in result and "task_id" in result["output"]:
|
||||
task_id = result["output"]["task_id"]
|
||||
else:
|
||||
# 如果是同步返回,可能直接给结果,或者结构不同
|
||||
# 但 wan2.6-t2i 通常是异步任务
|
||||
logger.warning(f"Unexpected response structure: {result}")
|
||||
# 尝试直接取,或者抛错
|
||||
# 假设是标准异步结构
|
||||
task_id = result.get("output", {}).get("task_id")
|
||||
|
||||
if not task_id:
|
||||
raise HTTPException(status_code=500, detail="Failed to retrieve task_id from DashScope response")
|
||||
|
||||
return AITaskResponse(
|
||||
task_id=task_id,
|
||||
request_id=result.get("request_id")
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||||
|
||||
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
|
||||
async def get_task_result(task_id: str):
|
||||
"""
|
||||
查询AI生成任务的结果
|
||||
"""
|
||||
if not settings.dashscope_api_key:
|
||||
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {settings.dashscope_api_key}"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{DASHSCOPE_TASK_URL}/{task_id}",
|
||||
headers=headers,
|
||||
timeout=30.0
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"DashScope Task API error: {response.text}")
|
||||
raise HTTPException(status_code=response.status_code, detail="Failed to fetch task status")
|
||||
|
||||
result = response.json()
|
||||
|
||||
# 构建返回结果
|
||||
# DashScope 任务查询返回结构:
|
||||
# {
|
||||
# "request_id": "...",
|
||||
# "output": {
|
||||
# "task_id": "...",
|
||||
# "task_status": "SUCCEEDED",
|
||||
# "results": [...]
|
||||
# },
|
||||
# "usage": ...
|
||||
# }
|
||||
|
||||
task_status = result.get("output", {}).get("task_status", "UNKNOWN")
|
||||
|
||||
return AITaskResult(
|
||||
task_id=task_id,
|
||||
status=task_status,
|
||||
code=result.get("code"),
|
||||
message=result.get("message"),
|
||||
output=result.get("output"),
|
||||
usage=result.get("usage")
|
||||
)
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Request error: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||||
30
api/ai_schemas.py
Normal file
30
api/ai_schemas.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Optional, List, Dict, Any
|
||||
|
||||
# AI生成相关模型
|
||||
class AIGenerationRequest(BaseModel):
|
||||
prompt: str = Field(..., description="生成图片的提示词")
|
||||
negative_prompt: Optional[str] = Field(None, description="反向提示词")
|
||||
size: str = Field("1024*1024", description="图片尺寸")
|
||||
n: int = Field(1, description="生成数量", ge=1, le=4)
|
||||
model: str = Field("wan2.6-t2i", description="使用的模型")
|
||||
|
||||
class AITemplateGenerationRequest(BaseModel):
|
||||
template_id: str = Field(..., description="提示词模板ID")
|
||||
params: Dict[str, str] = Field(default_factory=dict, description="提示词参数")
|
||||
negative_prompt: Optional[str] = Field(None, description="反向提示词")
|
||||
size: str = Field("1024*1024", description="图片尺寸")
|
||||
n: int = Field(1, description="生成数量", ge=1, le=4)
|
||||
model: str = Field("wan2.6-t2i", description="使用的模型")
|
||||
|
||||
class AITaskResponse(BaseModel):
|
||||
task_id: str = Field(..., description="任务ID")
|
||||
request_id: Optional[str] = Field(None, description="请求ID")
|
||||
|
||||
class AITaskResult(BaseModel):
|
||||
task_id: str
|
||||
status: str
|
||||
code: Optional[str] = None
|
||||
message: Optional[str] = None
|
||||
output: Optional[Dict[str, Any]] = None
|
||||
usage: Optional[Dict[str, Any]] = None
|
||||
31
api/prompts.py
Normal file
31
api/prompts.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# 预设提示词模板库
|
||||
# 使用Python的format语法进行参数替换,例如 {keyword}
|
||||
|
||||
PROMPTS = {
|
||||
"flower_shop": "一间有着精致窗户的花店,漂亮的木质门,摆放着{flower_type},风格为{style}",
|
||||
"landscape": "宏伟的{season}自然风景,包含{element},光线为{lighting},高分辨率,写实风格",
|
||||
"cyberpunk_city": "赛博朋克风格的未来城市,霓虹灯闪烁,{weather}天气,街道上有{vehicle}",
|
||||
"portrait": "一张{gender}的肖像照,{expression}表情,背景是{background},专业摄影布光",
|
||||
"default": "一间有着精致窗户的花店,漂亮的木质门,摆放着花朵"
|
||||
}
|
||||
|
||||
def get_prompt(template_id: str, **kwargs) -> str:
|
||||
"""
|
||||
获取并格式化提示词
|
||||
:param template_id: 提示词模板ID
|
||||
:param kwargs: 替换参数
|
||||
:return: 格式化后的提示词
|
||||
"""
|
||||
template = PROMPTS.get(template_id)
|
||||
if not template:
|
||||
return None
|
||||
|
||||
try:
|
||||
# 使用 safe_substitute 避免参数缺失报错?
|
||||
# Python的format如果缺参数会报错,这里直接用format,让调用者负责提供完整参数
|
||||
# 或者我们可以在这里处理默认值
|
||||
return template.format(**kwargs)
|
||||
except KeyError as e:
|
||||
# 如果缺少参数,抛出异常或返回包含未替换占位符的字符串
|
||||
# 这里为了简单,如果出错,我们尽量保留原样或者报错
|
||||
raise ValueError(f"缺少提示词参数: {e}")
|
||||
Reference in New Issue
Block a user