AI生图
This commit is contained in:
2
.env
2
.env
@@ -4,7 +4,7 @@
|
|||||||
# DATABASE_URL=postgresql://luna:123luna@121.43.104.161:6432/luna
|
# DATABASE_URL=postgresql://luna:123luna@121.43.104.161:6432/luna
|
||||||
DATABASE_URL=postgresql://luna:123luna@6.6.6.66:5432/luna
|
DATABASE_URL=postgresql://luna:123luna@6.6.6.66:5432/luna
|
||||||
|
|
||||||
DASHSCOPE_API_KEY=sk-657968d48d0249099f3809f796f80a4f
|
DASHSCOPE_API_KEY=sk-a294f382488d46a1aa0d7cd8e750729b
|
||||||
|
|
||||||
# MQTT配置
|
# MQTT配置
|
||||||
MQTT_BROKER_HOST=luna-mqtt
|
MQTT_BROKER_HOST=luna-mqtt
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ MQTT_BROKER_PORT=1883
|
|||||||
MQTT_USERNAME=luna2025
|
MQTT_USERNAME=luna2025
|
||||||
MQTT_PASSWORD=123luna2021
|
MQTT_PASSWORD=123luna2021
|
||||||
|
|
||||||
DASHSCOPE_API_KEY=sk-657968d48d0249099f3809f796f80a4f
|
DASHSCOPE_API_KEY=sk-a294f382488d46a1aa0d7cd8e750729b
|
||||||
|
|
||||||
# 应用配置
|
# 应用配置
|
||||||
APP_NAME=墨水屏桌面屏幕系统
|
APP_NAME=墨水屏桌面屏幕系统
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ MQTT_BROKER_PORT=1883
|
|||||||
MQTT_USERNAME=luna2025
|
MQTT_USERNAME=luna2025
|
||||||
MQTT_PASSWORD=123luna2021
|
MQTT_PASSWORD=123luna2021
|
||||||
|
|
||||||
DASHSCOPE_API_KEY=sk-657968d48d0249099f3809f796f80a4f
|
DASHSCOPE_API_KEY=sk-a294f382488d46a1aa0d7cd*******
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
api/__pycache__/prompts.cpython-312.pyc
Normal file
BIN
api/__pycache__/prompts.cpython-312.pyc
Normal file
Binary file not shown.
80
api/ai.py
80
api/ai.py
@@ -1,7 +1,8 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import httpx
|
||||||
from fastapi import APIRouter, HTTPException, Depends
|
from fastapi import APIRouter, HTTPException, Depends
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
import httpx
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
from config import settings
|
from config import settings
|
||||||
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest
|
from api.ai_schemas import AIGenerationRequest, AITaskResponse, AITaskResult, AITemplateGenerationRequest
|
||||||
@@ -13,16 +14,45 @@ logger = logging.getLogger(__name__)
|
|||||||
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
DASHSCOPE_API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation"
|
||||||
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
DASHSCOPE_TASK_URL = "https://dashscope.aliyuncs.com/api/v1/tasks"
|
||||||
|
|
||||||
|
async def _download_image(url: str) -> str:
|
||||||
|
"""
|
||||||
|
下载图片并保存到media目录
|
||||||
|
:param url: 图片URL
|
||||||
|
:return: 本地文件相对路径 (e.g., "/media/xxx.png")
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(url, timeout=30.0)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Failed to download image: {url}, status: {response.status_code}")
|
||||||
|
return url # 下载失败返回原URL
|
||||||
|
|
||||||
|
# 生成文件名
|
||||||
|
filename = f"{uuid.uuid4()}.png"
|
||||||
|
filepath = os.path.join(settings.media_dir, filename)
|
||||||
|
|
||||||
|
# 写入文件
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
f.write(response.content)
|
||||||
|
|
||||||
|
# 返回可访问的URL路径
|
||||||
|
return f"/media/{filename}"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error downloading image: {str(e)}")
|
||||||
|
return url
|
||||||
|
|
||||||
async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str):
|
async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n: int, model: str):
|
||||||
if not settings.dashscope_api_key:
|
if not settings.dashscope_api_key:
|
||||||
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
raise HTTPException(status_code=500, detail="DashScope API Key not configured")
|
||||||
|
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {settings.dashscope_api_key}",
|
"Authorization": f"Bearer {settings.dashscope_api_key}"
|
||||||
"X-DashScope-Async": "enable" # 确保异步任务提交
|
# "X-DashScope-Async": "enable" # 移除强制异步,因为某些账号/模型不支持
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
@@ -55,7 +85,7 @@ async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n
|
|||||||
DASHSCOPE_API_URL,
|
DASHSCOPE_API_URL,
|
||||||
headers=headers,
|
headers=headers,
|
||||||
json=payload,
|
json=payload,
|
||||||
timeout=30.0
|
timeout=60.0 # 增加超时时间,同步请求可能较慢
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
@@ -68,23 +98,47 @@ async def _submit_dashscope_task(prompt: str, negative_prompt: str, size: str, n
|
|||||||
|
|
||||||
result = response.json()
|
result = response.json()
|
||||||
|
|
||||||
|
# 检查是异步任务返回还是同步结果返回
|
||||||
|
task_id = None
|
||||||
if "output" in result and "task_id" in result["output"]:
|
if "output" in result and "task_id" in result["output"]:
|
||||||
task_id = result["output"]["task_id"]
|
task_id = result["output"]["task_id"]
|
||||||
elif "task_id" in result:
|
elif "task_id" in result:
|
||||||
task_id = result["task_id"]
|
task_id = result["task_id"]
|
||||||
else:
|
|
||||||
if "output" in result and "task_id" in result["output"]:
|
|
||||||
task_id = result["output"]["task_id"]
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unexpected response structure: {result}")
|
|
||||||
task_id = result.get("output", {}).get("task_id")
|
|
||||||
|
|
||||||
|
# 如果没有task_id,检查是否有直接结果(同步模式)
|
||||||
if not task_id:
|
if not task_id:
|
||||||
raise HTTPException(status_code=500, detail="Failed to retrieve task_id from DashScope response")
|
if "output" in result and "choices" in result["output"]:
|
||||||
|
# 同步返回成功
|
||||||
|
# 提取结果以便直接返回
|
||||||
|
choices = result["output"]["choices"]
|
||||||
|
results = []
|
||||||
|
for choice in choices:
|
||||||
|
msg_content = choice.get("message", {}).get("content", [])
|
||||||
|
for item in msg_content:
|
||||||
|
if "image" in item:
|
||||||
|
# 下载图片到本地
|
||||||
|
local_url = await _download_image(item["image"])
|
||||||
|
results.append({"url": local_url, "origin_url": item["image"]})
|
||||||
|
|
||||||
|
return AITaskResponse(
|
||||||
|
request_id=result.get("request_id"),
|
||||||
|
status="SUCCEEDED",
|
||||||
|
results=results
|
||||||
|
)
|
||||||
|
|
||||||
|
# 既没有task_id也没有choices,可能是其他结构或错误
|
||||||
|
logger.warning(f"Unexpected response structure: {result}")
|
||||||
|
# 尝试从output.task_id再找一次
|
||||||
|
task_id = result.get("output", {}).get("task_id")
|
||||||
|
|
||||||
|
if not task_id:
|
||||||
|
# 依然找不到,报错
|
||||||
|
raise HTTPException(status_code=500, detail="Failed to retrieve task_id or results from DashScope response")
|
||||||
|
|
||||||
return AITaskResponse(
|
return AITaskResponse(
|
||||||
task_id=task_id,
|
task_id=task_id,
|
||||||
request_id=result.get("request_id")
|
request_id=result.get("request_id"),
|
||||||
|
status="PENDING" # 异步任务初始状态
|
||||||
)
|
)
|
||||||
|
|
||||||
except httpx.RequestError as e:
|
except httpx.RequestError as e:
|
||||||
|
|||||||
@@ -18,8 +18,11 @@ class AITemplateGenerationRequest(BaseModel):
|
|||||||
model: str = Field("wan2.6-t2i", description="使用的模型")
|
model: str = Field("wan2.6-t2i", description="使用的模型")
|
||||||
|
|
||||||
class AITaskResponse(BaseModel):
|
class AITaskResponse(BaseModel):
|
||||||
task_id: str = Field(..., description="任务ID")
|
task_id: Optional[str] = Field(None, description="任务ID (异步任务时存在)")
|
||||||
request_id: Optional[str] = Field(None, description="请求ID")
|
request_id: Optional[str] = Field(None, description="请求ID")
|
||||||
|
status: Optional[str] = Field(None, description="任务状态")
|
||||||
|
results: Optional[List[Dict[str, Any]]] = Field(None, description="同步生成的直接结果")
|
||||||
|
# results结构: [{"url": "/media/xxx.png", "origin_url": "https://..."}]
|
||||||
|
|
||||||
class AITaskResult(BaseModel):
|
class AITaskResult(BaseModel):
|
||||||
task_id: str
|
task_id: str
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ class Settings(BaseSettings):
|
|||||||
static_dir: str = "static"
|
static_dir: str = "static"
|
||||||
upload_dir: str = "static/uploads"
|
upload_dir: str = "static/uploads"
|
||||||
processed_dir: str = "static/processed"
|
processed_dir: str = "static/processed"
|
||||||
|
media_dir: str = "media" # 新增媒体文件目录
|
||||||
|
|
||||||
# 墨水屏配置
|
# 墨水屏配置
|
||||||
ink_width: int = 400
|
ink_width: int = 400
|
||||||
|
|||||||
61
main.py
61
main.py
@@ -1,6 +1,7 @@
|
|||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.openapi.utils import get_openapi
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import logging
|
import logging
|
||||||
@@ -41,6 +42,7 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
# 确保静态文件目录存在
|
# 确保静态文件目录存在
|
||||||
os.makedirs(settings.static_dir, exist_ok=True)
|
os.makedirs(settings.static_dir, exist_ok=True)
|
||||||
|
os.makedirs(settings.media_dir, exist_ok=True)
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@@ -57,19 +59,55 @@ app = FastAPI(
|
|||||||
description="用于管理墨水屏设备、内容和待办事项的API",
|
description="用于管理墨水屏设备、内容和待办事项的API",
|
||||||
version="1.0.0",
|
version="1.0.0",
|
||||||
lifespan=lifespan,
|
lifespan=lifespan,
|
||||||
openapi_components={
|
|
||||||
"securitySchemes": {
|
|
||||||
"APIKeyHeader": {
|
|
||||||
"type": "apiKey",
|
|
||||||
"in": "header",
|
|
||||||
"name": "X-API-Key",
|
|
||||||
"description": "API Key鉴权,请在下方输入正确的API Key"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
security=[{"APIKeyHeader": []}]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# 自定义OpenAPI模式以显示API Key鉴权按钮
|
||||||
|
def custom_openapi():
|
||||||
|
if app.openapi_schema:
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
openapi_schema = get_openapi(
|
||||||
|
title=app.title,
|
||||||
|
version=app.version,
|
||||||
|
description=app.description,
|
||||||
|
routes=app.routes,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加安全方案
|
||||||
|
if "components" not in openapi_schema:
|
||||||
|
openapi_schema["components"] = {}
|
||||||
|
|
||||||
|
security_scheme = {
|
||||||
|
"type": "apiKey",
|
||||||
|
"in": "header",
|
||||||
|
"name": "X-API-Key",
|
||||||
|
"description": "API Key鉴权,请在下方输入正确的API Key"
|
||||||
|
}
|
||||||
|
|
||||||
|
if "securitySchemes" not in openapi_schema["components"]:
|
||||||
|
openapi_schema["components"]["securitySchemes"] = {}
|
||||||
|
|
||||||
|
openapi_schema["components"]["securitySchemes"]["APIKeyHeader"] = security_scheme
|
||||||
|
|
||||||
|
# 添加全局安全要求
|
||||||
|
if "security" not in openapi_schema:
|
||||||
|
openapi_schema["security"] = []
|
||||||
|
|
||||||
|
# 避免重复添加
|
||||||
|
has_apikey_security = False
|
||||||
|
for security_req in openapi_schema["security"]:
|
||||||
|
if "APIKeyHeader" in security_req:
|
||||||
|
has_apikey_security = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not has_apikey_security:
|
||||||
|
openapi_schema["security"].append({"APIKeyHeader": []})
|
||||||
|
|
||||||
|
app.openapi_schema = openapi_schema
|
||||||
|
return app.openapi_schema
|
||||||
|
|
||||||
|
app.openapi = custom_openapi
|
||||||
|
|
||||||
# 添加CORS中间件
|
# 添加CORS中间件
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@@ -90,6 +128,7 @@ app.add_middleware(SessionMiddleware, secret_key=settings.secret_key)
|
|||||||
|
|
||||||
# 挂载静态文件
|
# 挂载静态文件
|
||||||
app.mount("/static", StaticFiles(directory=settings.static_dir), name="static")
|
app.mount("/static", StaticFiles(directory=settings.static_dir), name="static")
|
||||||
|
app.mount("/media", StaticFiles(directory=settings.media_dir), name="media")
|
||||||
|
|
||||||
# 注册API路由
|
# 注册API路由
|
||||||
app.include_router(api_router, prefix="/api")
|
app.include_router(api_router, prefix="/api")
|
||||||
|
|||||||
BIN
media/d295c74a-84e7-4520-8cf1-62a763fde49e.png
Normal file
BIN
media/d295c74a-84e7-4520-8cf1-62a763fde49e.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 411 KiB |
BIN
test_image.png
BIN
test_image.png
Binary file not shown.
|
Before Width: | Height: | Size: 1.0 KiB |
Reference in New Issue
Block a user