AI图片
This commit is contained in:
2
.env
2
.env
@@ -4,6 +4,8 @@
|
|||||||
# DATABASE_URL=postgresql://luna:123luna@121.43.104.161:6432/luna
|
# DATABASE_URL=postgresql://luna:123luna@121.43.104.161:6432/luna
|
||||||
DATABASE_URL=postgresql://luna:123luna@6.6.6.66:5432/luna
|
DATABASE_URL=postgresql://luna:123luna@6.6.6.66:5432/luna
|
||||||
|
|
||||||
|
DASHSCOPE_API_KEY=sk-657968d48d0249099f3809f796f80a4f
|
||||||
|
|
||||||
# MQTT配置
|
# MQTT配置
|
||||||
MQTT_BROKER_HOST=luna-mqtt
|
MQTT_BROKER_HOST=luna-mqtt
|
||||||
MQTT_BROKER_PORT=1883
|
MQTT_BROKER_PORT=1883
|
||||||
|
|||||||
@@ -9,6 +9,8 @@ MQTT_BROKER_PORT=1883
|
|||||||
MQTT_USERNAME=luna2025
|
MQTT_USERNAME=luna2025
|
||||||
MQTT_PASSWORD=123luna2021
|
MQTT_PASSWORD=123luna2021
|
||||||
|
|
||||||
|
DASHSCOPE_API_KEY=sk-657968d48d0249099f3809f796f80a4f
|
||||||
|
|
||||||
# 应用配置
|
# 应用配置
|
||||||
APP_NAME=墨水屏桌面屏幕系统
|
APP_NAME=墨水屏桌面屏幕系统
|
||||||
DEBUG=false
|
DEBUG=false
|
||||||
|
|||||||
@@ -9,6 +9,10 @@ MQTT_BROKER_PORT=1883
|
|||||||
MQTT_USERNAME=luna2025
|
MQTT_USERNAME=luna2025
|
||||||
MQTT_PASSWORD=123luna2021
|
MQTT_PASSWORD=123luna2021
|
||||||
|
|
||||||
|
DASHSCOPE_API_KEY=sk-657968d48d0249099f3809f796f80a4f
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 应用配置
|
# 应用配置
|
||||||
APP_NAME=墨水屏桌面屏幕系统
|
APP_NAME=墨水屏桌面屏幕系统
|
||||||
DEBUG=false
|
DEBUG=false
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,9 +1,10 @@
|
|||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from api import devices, contents, todos
|
from api import devices, contents, todos, ai
|
||||||
|
|
||||||
api_router = APIRouter()
|
api_router = APIRouter()
|
||||||
|
|
||||||
# 注册所有路由,并添加全局安全要求
|
# 注册所有路由,并添加全局安全要求
|
||||||
api_router.include_router(devices.router, prefix="/devices")
|
api_router.include_router(devices.router, prefix="/devices")
|
||||||
api_router.include_router(contents.router, prefix="/contents")
|
api_router.include_router(contents.router, prefix="/contents")
|
||||||
api_router.include_router(todos.router, prefix="/todos")
|
api_router.include_router(todos.router, prefix="/todos")
|
||||||
|
api_router.include_router(ai.router, prefix="/ai", tags=["AI生成"])
|
||||||
Binary file not shown.
Binary file not shown.
BIN
api/__pycache__/ai.cpython-312.pyc
Normal file
BIN
api/__pycache__/ai.cpython-312.pyc
Normal file
Binary file not shown.
BIN
api/__pycache__/ai_schemas.cpython-312.pyc
Normal file
BIN
api/__pycache__/ai_schemas.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
157
api/ai.py
Normal file
157
api/ai.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
|
from typing import Dict, Any
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from config import settings
|
||||||
|
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||||||
|
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
||||||
|
|
||||||
|
@router.post("/generate", response_model=AITaskResponse, summary="提交AI图片生成任务")
|
||||||
|
async def generate_image(request: AIGenerationRequest):
|
||||||
|
"""
|
||||||
|
提交AI图片生成任务,使用阿里云DashScope服务
|
||||||
|
"""
|
||||||
|
if not settings.dashscope_api_key:
|
||||||
|
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {settings.dashscope_api_key}",
|
||||||
|
"X-DashScope-Async": "enable" # 确保异步任务提交
|
||||||
|
}
|
||||||
|
|
||||||
|
# 构建请求体
|
||||||
|
payload = {
|
||||||
|
"model": request.model,
|
||||||
|
"input": {
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"text": request.prompt
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"prompt_extend": True,
|
||||||
|
"watermark": False,
|
||||||
|
"n": request.n,
|
||||||
|
"size": request.size
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if request.negative_prompt:
|
||||||
|
payload["parameters"]["negative_prompt"] = request.negative_prompt
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
DASHSCOPE_API_URL,
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"DashScope API error: {response.text}")
|
||||||
|
try:
|
||||||
|
error_detail = response.json()
|
||||||
|
except:
|
||||||
|
error_detail = {"message": response.text}
|
||||||
|
raise HTTPException(status_code=response.status_code, detail=error_detail)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
# 检查是否有task_id,因为如果是同步返回可能没有task_id,但在这种模型通常是异步的
|
||||||
|
# 这里的curl示例似乎是直接返回task_id
|
||||||
|
|
||||||
|
if "output" in result and "task_id" in result["output"]:
|
||||||
|
task_id = result["output"]["task_id"]
|
||||||
|
elif "task_id" in result: # 或者是这种结构
|
||||||
|
task_id = result["task_id"]
|
||||||
|
else:
|
||||||
|
# 有些情况下直接返回 output.task_id
|
||||||
|
# 根据文档 https://help.aliyun.com/zh/dashscope/developer-reference/api-details-10
|
||||||
|
# 异步提交返回结构通常包含 output.task_id
|
||||||
|
if "output" in result and "task_id" in result["output"]:
|
||||||
|
task_id = result["output"]["task_id"]
|
||||||
|
else:
|
||||||
|
# 如果是同步返回,可能直接给结果,或者结构不同
|
||||||
|
# 但 wan2.6-t2i 通常是异步任务
|
||||||
|
logger.warning(f"Unexpected response structure: {result}")
|
||||||
|
# 尝试直接取,或者抛错
|
||||||
|
# 假设是标准异步结构
|
||||||
|
task_id = result.get("output", {}).get("task_id")
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to retrieve task_id from DashScope response")
|
||||||
|
|
||||||
|
return AITaskResponse(
|
||||||
|
task_id=task_id,
|
||||||
|
request_id=result.get("request_id")
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||||||
|
|
||||||
|
@router.get("/tasks/{task_id}", response_model=AITaskResult, summary="查询AI任务结果")
|
||||||
|
async def get_task_result(task_id: str):
|
||||||
|
"""
|
||||||
|
查询AI生成任务的结果
|
||||||
|
"""
|
||||||
|
if not settings.dashscope_api_key:
|
||||||
|
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings.dashscope_api_key}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
f"{DASHSCOPE_TASK_URL}/{task_id}",
|
||||||
|
headers=headers,
|
||||||
|
timeout=30.0
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"DashScope Task API error: {response.text}")
|
||||||
|
raise HTTPException(status_code=response.status_code, detail="Failed to fetch task status")
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
|
||||||
|
# 构建返回结果
|
||||||
|
# DashScope 任务查询返回结构:
|
||||||
|
# {
|
||||||
|
# "request_id": "...",
|
||||||
|
# "output": {
|
||||||
|
# "task_id": "...",
|
||||||
|
# "task_status": "SUCCEEDED",
|
||||||
|
# "results": [...]
|
||||||
|
# },
|
||||||
|
# "usage": ...
|
||||||
|
# }
|
||||||
|
|
||||||
|
task_status = result.get("output", {}).get("task_status", "UNKNOWN")
|
||||||
|
|
||||||
|
return AITaskResult(
|
||||||
|
task_id=task_id,
|
||||||
|
status=task_status,
|
||||||
|
code=result.get("code"),
|
||||||
|
message=result.get("message"),
|
||||||
|
output=result.get("output"),
|
||||||
|
usage=result.get("usage")
|
||||||
|
)
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error(f"Request error: {str(e)}")
|
||||||
|
raise HTTPException(status_code=500, detail=f"Request failed: {str(e)}")
|
||||||
30
api/ai_schemas.py
Normal file
30
api/ai_schemas.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
|
||||||
|
# AI生成相关模型
|
||||||
|
class AIGenerationRequest(BaseModel):
|
||||||
|
prompt: str = Field(..., description="生成图片的提示词")
|
||||||
|
negative_prompt: Optional[str] = Field(None, description="反向提示词")
|
||||||
|
size: str = Field("1024*1024", description="图片尺寸")
|
||||||
|
n: int = Field(1, description="生成数量", ge=1, le=4)
|
||||||
|
model: str = Field("wan2.6-t2i", description="使用的模型")
|
||||||
|
|
||||||
|
class AITemplateGenerationRequest(BaseModel):
|
||||||
|
template_id: str = Field(..., description="提示词模板ID")
|
||||||
|
params: Dict[str, str] = Field(default_factory=dict, description="提示词参数")
|
||||||
|
negative_prompt: Optional[str] = Field(None, description="反向提示词")
|
||||||
|
size: str = Field("1024*1024", description="图片尺寸")
|
||||||
|
n: int = Field(1, description="生成数量", ge=1, le=4)
|
||||||
|
model: str = Field("wan2.6-t2i", description="使用的模型")
|
||||||
|
|
||||||
|
class AITaskResponse(BaseModel):
|
||||||
|
task_id: str = Field(..., description="任务ID")
|
||||||
|
request_id: Optional[str] = Field(None, description="请求ID")
|
||||||
|
|
||||||
|
class AITaskResult(BaseModel):
|
||||||
|
task_id: str
|
||||||
|
status: str
|
||||||
|
code: Optional[str] = None
|
||||||
|
message: Optional[str] = None
|
||||||
|
output: Optional[Dict[str, Any]] = None
|
||||||
|
usage: Optional[Dict[str, Any]] = None
|
||||||
31
api/prompts.py
Normal file
31
api/prompts.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
# 预设提示词模板库
|
||||||
|
# 使用Python的format语法进行参数替换,例如 {keyword}
|
||||||
|
|
||||||
|
PROMPTS = {
|
||||||
|
"flower_shop": "一间有着精致窗户的花店,漂亮的木质门,摆放着{flower_type},风格为{style}",
|
||||||
|
"landscape": "宏伟的{season}自然风景,包含{element},光线为{lighting},高分辨率,写实风格",
|
||||||
|
"cyberpunk_city": "赛博朋克风格的未来城市,霓虹灯闪烁,{weather}天气,街道上有{vehicle}",
|
||||||
|
"portrait": "一张{gender}的肖像照,{expression}表情,背景是{background},专业摄影布光",
|
||||||
|
"default": "一间有着精致窗户的花店,漂亮的木质门,摆放着花朵"
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_prompt(template_id: str, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
获取并格式化提示词
|
||||||
|
:param template_id: 提示词模板ID
|
||||||
|
:param kwargs: 替换参数
|
||||||
|
:return: 格式化后的提示词
|
||||||
|
"""
|
||||||
|
template = PROMPTS.get(template_id)
|
||||||
|
if not template:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 使用 safe_substitute 避免参数缺失报错?
|
||||||
|
# Python的format如果缺参数会报错,这里直接用format,让调用者负责提供完整参数
|
||||||
|
# 或者我们可以在这里处理默认值
|
||||||
|
return template.format(**kwargs)
|
||||||
|
except KeyError as e:
|
||||||
|
# 如果缺少参数,抛出异常或返回包含未替换占位符的字符串
|
||||||
|
# 这里为了简单,如果出错,我们尽量保留原样或者报错
|
||||||
|
raise ValueError(f"缺少提示词参数: {e}")
|
||||||
@@ -34,6 +34,9 @@ class Settings(BaseSettings):
|
|||||||
# 管理员配置
|
# 管理员配置
|
||||||
admin_username: str = "admin"
|
admin_username: str = "admin"
|
||||||
admin_password: str = "123456"
|
admin_password: str = "123456"
|
||||||
|
|
||||||
|
# DashScope配置
|
||||||
|
dashscope_api_key: Optional[str] = None
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
env_file = ".env"
|
env_file = ".env"
|
||||||
|
|||||||
Reference in New Issue
Block a user