Compare commits

...

7 Commits

Author SHA1 Message Date
jeremygan2021
c6b2a378d1 test
Some checks are pending
Deploy WebSocket Server / deploy (push) Waiting to run
2026-03-20 18:09:53 +08:00
jeremygan2021
88bb27569a mode bug
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 4s
2026-03-20 18:04:44 +08:00
jeremygan2021
5b91e90d45 Update doubao model to seedream-4.0
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 4s
2026-03-20 17:53:23 +08:00
jeremygan2021
c9550f8a0d t
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 21s
2026-03-05 22:19:48 +08:00
jeremygan2021
e728cd1075 t
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 18s
2026-03-05 22:07:02 +08:00
jeremygan2021
0774ba5c9e t
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 20s
2026-03-05 22:03:02 +08:00
jeremygan2021
2392d0d705 t 2026-03-05 22:02:01 +08:00
5 changed files with 957 additions and 324 deletions

41
main.py
View File

@@ -119,22 +119,21 @@ def process_message(msg, display, image_state, image_data_list, printer_uart=Non
if isinstance(msg, (bytes, bytearray)):
if image_state == IMAGE_STATE_RECEIVING:
try:
if len(image_data_list) < 3:
if len(image_data_list) < 2:
# 异常情况,重置
return IMAGE_STATE_IDLE, None
width = image_data_list[0]
height = image_data_list[1]
current_offset = image_data_list[2]
img_size = image_data_list[0]
current_offset = image_data_list[1]
# Stream directly to display
if display and display.tft:
x = (240 - width) // 2
y = (240 - height) // 2
display.show_image_chunk(x, y, width, height, msg, current_offset)
x = (240 - img_size) // 2
y = (240 - img_size) // 2
display.show_image_chunk(x, y, img_size, img_size, msg, current_offset)
# Update offset
image_data_list[2] += len(msg)
image_data_list[1] += len(msg)
except Exception as e:
print(f"Stream image error: {e}")
@@ -196,28 +195,22 @@ def process_message(msg, display, image_state, image_data_list, printer_uart=Non
try:
parts = msg.split(":")
size = int(parts[1])
img_size = int(parts[2]) if len(parts) > 2 else 64
model_name = parts[3] if len(parts) > 3 else "Unknown Model"
width = 64
height = 64
print(f"Image start, size: {size}, img_size: {img_size}")
convert_img.print_model_info(model_name)
if len(parts) >= 4:
width = int(parts[2])
height = int(parts[3])
elif len(parts) == 3:
width = int(parts[2])
height = int(parts[2]) # assume square
print(f"Image start, size: {size}, dim: {width}x{height}")
image_data_list.clear()
image_data_list.append(width) # index 0
image_data_list.append(height) # index 1
image_data_list.append(0) # index 2: offset
image_data_list.append(img_size) # Store metadata at index 0
image_data_list.append(0) # Store current received bytes offset at index 1
# Prepare display for streaming
if display and display.tft:
# Clear screen area where image will be
# optional, but good practice if new image is smaller
pass
# Calculate position
x = (240 - img_size) // 2
y = (240 - img_size) // 2
# Pre-set window (this will be done in first chunk call)
return IMAGE_STATE_RECEIVING, None
except Exception as e:

View File

@@ -8,5 +8,6 @@ services:
- "8811:8000"
volumes:
- ./output_images:/app/output_images
- ./media:/app/media
- ./.env:/app/.env
restart: unless-stopped

View File

@@ -10,6 +10,7 @@ from dashscope import ImageSynthesis, Generation
load_dotenv()
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
class ImageGenerator:
def __init__(self, provider="dashscope", model=None):
self.provider = provider
@@ -19,7 +20,7 @@ class ImageGenerator:
if provider == "doubao":
self.api_key = os.getenv("volcengine_API_KEY")
if not self.model:
self.model = "doubao-seedream-5-0-260128" # Default model from user input
self.model = "doubao-seedream-4.0"
elif provider == "dashscope":
self.api_key = os.getenv("DASHSCOPE_API_KEY")
if not self.model:
@@ -35,14 +36,15 @@ class ImageGenerator:
system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。
关键要求:
1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。
2. 画面必须清晰、线条粗壮,适合低分辨率热敏打印机打印。
2. 画面必须清晰、线条粗壮,适合低分辨率热敏打印机打印,用来生成标签贴纸
3. 绝对不要有复杂的阴影、渐变、黑白线条描述。
4. 背景必须是纯白 (White background)。
5. 提示词内容请使用文描述,因为绘图模型对文生成要更准确。
5. 提示词内容请使用文描述,因为绘图模型对文生成要更准确。
6. 尺寸比例遵循宽48mm:高30mm (约 1.6:1)。
7. 直接输出优化后的提示词,不要包含任何解释。
如果用户要求输入文字,则用```把文字包裹起来,文字是中文
"房子的旁边有一个小孩,黑白线稿画作,卡通形象, 文字:```中国人```在下方。"
如果用户要求输入文字,则用```把文字包裹起来,如果用户有中文文字,则用中文包裹起来。所有文字都是中文,描述都是英文。
example:
"A house with a child on the side, black and white line art, cartoon style, text:```中国人``` below."
"""
try:
@@ -52,17 +54,20 @@ class ImageGenerator:
# Currently using Qwen-Turbo for all providers for prompt optimization
# You can also decouple this if needed
response = Generation.call(
model='qwen3.5-plus',
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
model="qwen-plus",
prompt=f"{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:",
max_tokens=200,
temperature=0.8
temperature=0.8,
)
if response.status_code == 200:
if hasattr(response, 'output') and response.output and \
hasattr(response.output, 'choices') and response.output.choices and \
len(response.output.choices) > 0:
if (
hasattr(response, "output")
and response.output
and hasattr(response.output, "choices")
and response.output.choices
and len(response.output.choices) > 0
):
optimized = response.output.choices[0].message.content.strip()
print(f"Optimized prompt: {optimized}")
@@ -70,7 +75,11 @@ class ImageGenerator:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
elif (
hasattr(response, "output")
and response.output
and hasattr(response.output, "text")
):
optimized = response.output.text.strip()
print(f"Optimized prompt (direct text): {optimized}")
if progress_callback:
@@ -82,7 +91,9 @@ class ImageGenerator:
progress_callback(0, "提示词优化响应格式错误")
return asr_text
else:
print(f"Prompt optimization failed: {response.code} - {response.message}")
print(
f"Prompt optimization failed: {response.code} - {response.message}"
)
if progress_callback:
progress_callback(0, f"提示词优化失败: {response.message}")
return asr_text
@@ -110,9 +121,7 @@ class ImageGenerator:
try:
response = ImageSynthesis.call(
model=self.model,
prompt=prompt,
size='1280*720'
model=self.model, prompt=prompt, size="1280*720"
)
if response.status_code == 200:
@@ -120,32 +129,36 @@ class ImageGenerator:
print("Error: response.output is None")
return None
task_status = response.output.get('task_status')
task_status = response.output.get("task_status")
if task_status == 'PENDING' or task_status == 'RUNNING':
if task_status == "PENDING" or task_status == "RUNNING":
print("Waiting for image generation to complete...")
if progress_callback:
progress_callback(45, "AI正在生成图片中...")
task_id = response.output.get('task_id')
task_id = response.output.get("task_id")
max_wait = 120
waited = 0
while waited < max_wait:
time.sleep(2)
waited += 2
task_result = ImageSynthesis.fetch(task_id)
if task_result.output.task_status == 'SUCCEEDED':
if task_result.output.task_status == "SUCCEEDED":
response.output = task_result.output
break
elif task_result.output.task_status == 'FAILED':
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
elif task_result.output.task_status == "FAILED":
error_msg = (
task_result.output.message
if hasattr(task_result.output, "message")
else "Unknown error"
)
print(f"Image generation failed: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
if response.output.get('task_status') == 'SUCCEEDED':
image_url = response.output['results'][0]['url']
if response.output.get("task_status") == "SUCCEEDED":
image_url = response.output["results"][0]["url"]
print(f"Image generated, url: {image_url}")
return image_url
else:
@@ -170,7 +183,7 @@ class ImageGenerator:
url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
"Authorization": f"Bearer {self.api_key}",
}
data = {
"model": self.model,
@@ -181,7 +194,7 @@ class ImageGenerator:
# User's curl says "2K". I will stick to it or maybe check docs.
# Actually for thermal printer, we don't need 2K. But let's follow user example.
"stream": False,
"watermark": True
"watermark": True,
}
try:
@@ -207,7 +220,9 @@ class ImageGenerator:
print(f"Unexpected response format: {result}")
return None
else:
print(f"Doubao API failed with status {response.status_code}: {response.text}")
print(
f"Doubao API failed with status {response.status_code}: {response.text}"
)
if progress_callback:
progress_callback(35, f"图片生成失败: {response.status_code}")
return None

View File

@@ -1,4 +1,7 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
import uvicorn
import asyncio
import os
@@ -15,6 +18,7 @@ from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionRes
# from dashscope import Generation
import sys
# import os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import convert_img
@@ -28,7 +32,169 @@ dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
# provider="doubao" or "dashscope"
image_generator = ImageGenerator(provider="doubao")
app = FastAPI()
@asynccontextmanager
async def lifespan(app: FastAPI):
cleanup_old_media()
print("Media cleanup completed on startup")
yield
app = FastAPI(lifespan=lifespan)
app.mount("/media", StaticFiles(directory="media"), name="media")
# Admin API endpoints
@app.get("/admin")
async def admin_page():
with open(
os.path.join(os.path.dirname(__file__), "templates", "admin.html"), "r"
) as f:
return HTMLResponse(content=f.read())
@app.get("/api/admin/status")
async def get_admin_status():
return {"provider": image_generator.provider, "model": image_generator.model}
@app.post("/api/admin/switch")
async def switch_provider(request: dict):
global image_generator
provider = request.get("provider")
if provider not in ["doubao", "dashscope"]:
return {"success": False, "message": "Invalid provider"}
old_provider = image_generator.provider
old_model = image_generator.model
image_generator = ImageGenerator(provider=provider)
return {
"success": True,
"message": f"Switched from {old_provider}/{old_model} to {provider}/{image_generator.model}",
}
@app.post("/api/admin/model")
async def set_model(request: dict):
global image_generator
provider = request.get("provider")
model = request.get("model")
if not provider or not model:
return {"success": False, "message": "Provider and model required"}
if provider not in ["doubao", "dashscope"]:
return {"success": False, "message": "Invalid provider"}
image_generator = ImageGenerator(provider=provider, model=model)
return {"success": True, "message": f"Model set to {provider}/{model}"}
@app.post("/api/admin/test-generate")
async def test_generate(request: dict):
prompt = request.get("prompt")
if not prompt:
return {"success": False, "message": "Prompt is required"}
def progress_callback(progress, message):
print(f"Test generation progress: {progress}% - {message}")
image_url = image_generator.generate_image(prompt, progress_callback)
if image_url:
local_path = save_to_media(image_url)
return {
"success": True,
"image_url": image_url,
"local_path": local_path,
"message": "Image generated successfully",
}
else:
return {"success": False, "message": "Image generation failed"}
def save_to_media(image_url):
import urllib.request
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.png"
filepath = os.path.join(MEDIA_FOLDER, filename)
try:
urllib.request.urlretrieve(image_url, filepath)
return filepath
except Exception as e:
print(f"Error saving to media: {e}")
return None
@app.get("/api/admin/images")
async def list_images():
images = []
if os.path.exists(MEDIA_FOLDER):
for f in sorted(os.listdir(MEDIA_FOLDER), reverse=True):
if f.endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")):
filepath = os.path.join(MEDIA_FOLDER, f)
stat = os.stat(filepath)
images.append(
{
"name": f,
"path": filepath,
"size": stat.st_size,
"created": stat.st_ctime,
"url": f"/media/{f}",
}
)
return {"images": images}
@app.delete("/api/admin/images/{filename}")
async def delete_image(filename: str):
safe_name = os.path.basename(filename)
filepath = os.path.join(MEDIA_FOLDER, safe_name)
if os.path.exists(filepath):
os.remove(filepath)
return {"success": True, "message": f"Deleted {safe_name}"}
return {"success": False, "message": "File not found"}
@app.post("/api/admin/auto-delete")
async def set_auto_delete(request: dict):
global auto_delete_hours, auto_delete_enabled
hours = request.get("hours")
enabled = request.get("enabled")
if hours is not None:
auto_delete_hours = int(hours)
if enabled is not None:
auto_delete_enabled = bool(enabled)
return {
"success": True,
"message": f"Auto-delete set to {auto_delete_hours}h, enabled: {auto_delete_enabled}",
}
@app.get("/api/admin/auto-delete")
async def get_auto_delete():
return {"hours": auto_delete_hours, "enabled": auto_delete_enabled}
def cleanup_old_media():
if not auto_delete_enabled:
return
if not os.path.exists(MEDIA_FOLDER):
return
now = time.time()
for f in os.listdir(MEDIA_FOLDER):
if f.endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")):
filepath = os.path.join(MEDIA_FOLDER, f)
age_hours = (now - os.stat(filepath).st_ctime) / 3600
if age_hours > auto_delete_hours:
print(f"Auto-deleting old image: {f}")
os.remove(filepath)
# 字体文件配置
FONT_FILE = "GB2312-16.bin"
@@ -43,6 +209,7 @@ font_cache = {}
font_md5 = {}
font_data_buffer = None
def calculate_md5(filepath):
"""计算文件的MD5哈希值"""
if not os.path.exists(filepath):
@@ -61,9 +228,9 @@ def get_font_data(unicode_val):
try:
char = chr(unicode_val)
gb_bytes = char.encode('gb2312')
gb_bytes = char.encode("gb2312")
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
code = struct.unpack(">H", gb_bytes)[0]
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
@@ -72,7 +239,7 @@ def get_font_data(unicode_val):
if font_data_buffer:
if offset + 32 <= len(font_data_buffer):
font_data = font_data_buffer[offset:offset+32]
font_data = font_data_buffer[offset : offset + 32]
font_cache[unicode_val] = font_data
return font_data
else:
@@ -123,6 +290,7 @@ def init_font_cache():
get_font_data(unicode_val)
print(f"Preloaded {len(font_cache)} high-frequency characters")
# 启动时初始化字体缓存
init_font_cache()
@@ -134,27 +302,37 @@ VOLUME_GAIN = 10.0
GENERATED_IMAGE_FILE = "generated_image.png"
GENERATED_THUMB_FILE = "generated_thumb.bin"
OUTPUT_DIR = "output_images"
MEDIA_FOLDER = "media"
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
if not os.path.exists(MEDIA_FOLDER):
os.makedirs(MEDIA_FOLDER)
image_counter = 0
auto_delete_hours = 24
auto_delete_enabled = True
def get_output_path():
global image_counter
image_counter += 1
timestamp = time.strftime("%Y%m%d_%H%M%S")
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
THUMB_SIZE = 240
# 字体请求队列(用于重试机制)
font_request_queue = {}
FONT_RETRY_MAX = 3
# 图片生成任务管理
class ImageGenerationTask:
"""图片生成任务管理类"""
def __init__(self, task_id: str, asr_text: str, websocket: WebSocket):
self.task_id = task_id
self.asr_text = asr_text
@@ -165,6 +343,7 @@ class ImageGenerationTask:
self.result = None
self.error = None
# 存储活跃的图片生成任务
active_tasks = {}
task_counter = 0
@@ -209,8 +388,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
task.message = message
# 通过队列在主循环中发送消息
asyncio.run_coroutine_threadsafe(
progress_callback_async(progress, message),
loop
progress_callback_async(progress, message), loop
)
try:
@@ -220,12 +398,16 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await asyncio.sleep(0.2)
# 同步调用优化函数
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
optimized_prompt = await asyncio.to_thread(
optimize_prompt, asr_text, progress_callback
)
# 确保返回有效的提示词
if not optimized_prompt:
optimized_prompt = asr_text
print(f"Warning: optimize_prompt returned None, using original text: {asr_text}")
print(
f"Warning: optimize_prompt returned None, using original text: {asr_text}"
)
await websocket.send_text(f"PROMPT:{optimized_prompt}")
task.optimized_prompt = optimized_prompt
@@ -235,17 +417,9 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await asyncio.sleep(0.2)
# 同步调用图片生成函数
gen_result = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
image_path = None
width = 0
height = 0
if gen_result:
if isinstance(gen_result, tuple):
image_path, width, height = gen_result
else:
image_path = gen_result
image_path = await asyncio.to_thread(
generate_image, optimized_prompt, progress_callback
)
task.result = image_path
@@ -254,7 +428,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await websocket.send_text("STATUS:COMPLETE:图片生成完成")
await asyncio.sleep(0.2)
await send_image_to_client(websocket, image_path, width, height)
await send_image_to_client(websocket, image_path)
else:
task.status = "failed"
task.error = "图片生成失败"
@@ -277,23 +451,25 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
return task
async def send_image_to_client(websocket: WebSocket, image_path: str, width: int = 0, height: int = 0):
async def send_image_to_client(websocket: WebSocket, image_path: str):
"""发送图片数据到客户端"""
with open(image_path, 'rb') as f:
with open(image_path, "rb") as f:
image_data = f.read()
print(f"Sending image to ESP32, size: {len(image_data)} bytes, dim: {width}x{height}")
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
# Send start marker
if width > 0 and height > 0:
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{width}:{height}")
else:
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
model_name = f"{image_generator.provider}"
if image_generator.model:
model_name += f" {image_generator.model}"
await websocket.send_text(
f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}:{model_name}"
)
# Send binary data directly
chunk_size = 512 # Decreased chunk size for ESP32 memory stability
for i in range(0, len(image_data), chunk_size):
chunk = image_data[i:i+chunk_size]
chunk = image_data[i : i + chunk_size]
await websocket.send_bytes(chunk)
# Send end marker
@@ -318,7 +494,8 @@ async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
if font_data:
import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
hex_data = binascii.hexlify(font_data).decode("utf-8")
response = f"FONT_DATA:{code_str}:{hex_data}"
await websocket.send_text(response)
success_codes.add(unicode_val)
@@ -332,9 +509,9 @@ async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
if failed_codes and retry_count < FONT_RETRY_MAX:
req_key = f"retry_{retry_count}_{time.time()}"
font_request_queue[req_key] = {
'codes': failed_codes,
'retry': retry_count + 1,
'timestamp': time.time()
"codes": failed_codes,
"retry": retry_count + 1,
"timestamp": time.time(),
}
return len(success_codes), failed_codes
@@ -351,11 +528,13 @@ async def send_font_with_fragment(websocket, unicode_val):
chunk_size = FONT_CHUNK_SIZE
for i in range(0, total_size, chunk_size):
chunk = font_data[i:i+chunk_size]
chunk = font_data[i : i + chunk_size]
seq_num = i // chunk_size
# 构造二进制消息头: 2字节序列号 + 2字节总片数 + 数据
header = struct.pack('<HH', seq_num, (total_size + chunk_size - 1) // chunk_size)
header = struct.pack(
"<HH", seq_num, (total_size + chunk_size - 1) // chunk_size
)
payload = header + chunk
await websocket.send_bytes(payload)
@@ -384,7 +563,9 @@ async def handle_font_request(websocket, message_type, data):
code_list = codes_str.split(",")
print(f"Batch Font Request for {len(code_list)} chars")
success_count, failed = await send_font_batch_with_retry(websocket, code_list)
success_count, failed = await send_font_batch_with_retry(
websocket, code_list
)
print(f"Font batch: {success_count} success, {len(failed)} failed")
# 发送完成标记
@@ -409,7 +590,9 @@ async def handle_font_request(websocket, message_type, data):
print(f"Error sending font fragment: {e}")
return
elif message_type.startswith("GET_FONT_UNICODE:") or message_type.startswith("GET_FONT:"):
elif message_type.startswith("GET_FONT_UNICODE:") or message_type.startswith(
"GET_FONT:"
):
# 单个字体请求(兼容旧版)
try:
is_unicode = message_type.startswith("GET_FONT_UNICODE:")
@@ -439,12 +622,14 @@ async def handle_font_request(websocket, message_type, data):
if font_data:
import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
hex_data = binascii.hexlify(font_data).decode("utf-8")
response = f"FONT_DATA:{code_str}:{hex_data}"
await websocket.send_text(response)
except Exception as e:
print(f"Error handling font request: {e}")
class MyRecognitionCallback(RecognitionCallback):
def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop):
self.websocket = websocket
@@ -469,8 +654,7 @@ class MyRecognitionCallback(RecognitionCallback):
try:
if self.loop.is_running():
asyncio.run_coroutine_threadsafe(
self.websocket.send_text(f"ASR:{self.final_text}"),
self.loop
self.websocket.send_text(f"ASR:{self.final_text}"), self.loop
)
except Exception as e:
print(f"Failed to send final ASR result: {e}")
@@ -478,10 +662,9 @@ class MyRecognitionCallback(RecognitionCallback):
def on_error(self, result: RecognitionResult) -> None:
print(f"ASR Error: {result}")
def on_event(self, result: RecognitionResult) -> None:
if result.get_sentence():
text = result.get_sentence()['text']
text = result.get_sentence()["text"]
# 获取当前句子的结束状态
# 注意DashScope Python SDK 的 Result 结构可能需要根据版本调整
@@ -491,8 +674,8 @@ class MyRecognitionCallback(RecognitionCallback):
if self.sentence_list:
last_sentence = self.sentence_list[-1]
# 去掉句尾标点进行比较,因为流式结果可能标点不稳定
last_clean = last_sentence.rstrip('。,?!')
text_clean = text.rstrip('。,?!')
last_clean = last_sentence.rstrip("。,?!")
text_clean = text.rstrip("。,?!")
if text_clean.startswith(last_clean):
# 更新当前句子
@@ -526,23 +709,26 @@ class MyRecognitionCallback(RecognitionCallback):
# except Exception as e:
# print(f"Failed to send ASR result to client: {e}")
def process_chunk_32_to_16(chunk_bytes, gain=1.0):
processed_chunk = bytearray()
# Iterate 4 bytes at a time
for i in range(0, len(chunk_bytes), 4):
if i+3 < len(chunk_bytes):
if i + 3 < len(chunk_bytes):
# 取 chunk[i+2] 和 chunk[i+3] 组成 16-bit signed int
sample = struct.unpack_from('<h', chunk_bytes, i+2)[0]
sample = struct.unpack_from("<h", chunk_bytes, i + 2)[0]
# 放大音量
sample = int(sample * gain)
# 限幅 (Clamping) 防止溢出爆音
if sample > 32767: sample = 32767
elif sample < -32768: sample = -32768
if sample > 32767:
sample = 32767
elif sample < -32768:
sample = -32768
# 重新打包为 16-bit little-endian
processed_chunk.extend(struct.pack('<h', sample))
processed_chunk.extend(struct.pack("<h", sample))
return processed_chunk
@@ -573,8 +759,12 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
if retry_count < max_retries:
print(f"Retrying... ({retry_count + 1}/{max_retries})")
if progress_callback:
progress_callback(35, f"生成失败,正在重试 ({retry_count + 1}/{max_retries})...")
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
progress_callback(
35, f"生成失败,正在重试 ({retry_count + 1}/{max_retries})..."
)
return generate_image(
prompt, progress_callback, retry_count + 1, max_retries
)
else:
return None
@@ -584,6 +774,7 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
progress_callback(70, "正在下载生成的图片...")
import urllib.request
try:
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
print(f"Image saved to {GENERATED_IMAGE_FILE}")
@@ -596,6 +787,7 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
# Save to output dir
output_path = get_output_path()
import shutil
shutil.copy(GENERATED_IMAGE_FILE, output_path)
print(f"Image also saved to {output_path}")
@@ -605,18 +797,33 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
# Resize and convert to RGB565 (Reuse existing logic)
try:
from PIL import Image
img = Image.open(GENERATED_IMAGE_FILE)
# 缩小到 fit THUMB_SIZE x THUMB_SIZE (保持比例)
img.thumbnail((THUMB_SIZE, THUMB_SIZE), Image.Resampling.LANCZOS)
width, height = img.size
# 保持比例缩放
# Calculate aspect ratio
ratio = min(THUMB_SIZE / img.width, THUMB_SIZE / img.height)
new_width = int(img.width * ratio)
new_height = int(img.height * ratio)
resized_img = img.resize((new_width, new_height), Image.LANCZOS)
# Create black background
final_img = Image.new("RGB", (THUMB_SIZE, THUMB_SIZE), (0, 0, 0))
# Paste centered
x_offset = (THUMB_SIZE - new_width) // 2
y_offset = (THUMB_SIZE - new_height) // 2
final_img.paste(resized_img, (x_offset, y_offset))
img = final_img
# 转换为RGB565格式的原始数据
# 每个像素2字节 (R5 G6 B5)
rgb565_data = bytearray()
for y in range(height):
for x in range(width):
for y in range(THUMB_SIZE):
for x in range(THUMB_SIZE):
r, g, b = img.getpixel((x, y))[:3]
# 转换为RGB565
@@ -627,36 +834,41 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
# Pack as Big Endian (>H) which is standard for SPI displays
# RGB565: Red(5) Green(6) Blue(5)
rgb565 = (r5 << 11) | (g6 << 5) | b5
rgb565_data.extend(struct.pack('>H', rgb565))
rgb565_data.extend(struct.pack(">H", rgb565))
# 保存为.bin文件
with open(GENERATED_THUMB_FILE, 'wb') as f:
with open(GENERATED_THUMB_FILE, "wb") as f:
f.write(rgb565_data)
print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes, dim: {width}x{height}")
print(
f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes"
)
if progress_callback:
progress_callback(100, "图片生成完成!")
return GENERATED_THUMB_FILE, width, height
return GENERATED_THUMB_FILE
except ImportError:
print("PIL not available, sending original image")
if progress_callback:
progress_callback(100, "图片生成完成!(原始格式)")
return GENERATED_IMAGE_FILE, 0, 0
return GENERATED_IMAGE_FILE
except Exception as e:
print(f"Error processing image: {e}")
if progress_callback:
progress_callback(80, f"图片处理出错: {str(e)}")
return GENERATED_IMAGE_FILE, 0, 0
return GENERATED_IMAGE_FILE
except Exception as e:
print(f"Error in generate_image: {e}")
if retry_count < max_retries:
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
return generate_image(
prompt, progress_callback, retry_count + 1, max_retries
)
return None
@app.websocket("/ws/audio")
async def websocket_endpoint(websocket: WebSocket):
global audio_buffer
@@ -674,7 +886,10 @@ async def websocket_endpoint(websocket: WebSocket):
try:
message = await websocket.receive()
except RuntimeError as e:
if "Cannot call \"receive\" once a disconnect message has been received" in str(e):
if (
'Cannot call "receive" once a disconnect message has been received'
in str(e)
):
print("Client disconnected (RuntimeError caught)")
break
raise e
@@ -692,10 +907,10 @@ async def websocket_endpoint(websocket: WebSocket):
try:
callback = MyRecognitionCallback(websocket, loop)
recognition = Recognition(
model='paraformer-realtime-v2',
format='pcm',
model="paraformer-realtime-v2",
format="pcm",
sample_rate=16000,
callback=callback
callback=callback,
)
recognition.start()
print("DashScope ASR started")
@@ -719,7 +934,9 @@ async def websocket_endpoint(websocket: WebSocket):
# 使用实时处理过的音频数据
processed_audio = processed_buffer
print(f"Processed audio size: {len(processed_audio)} bytes (Gain: {VOLUME_GAIN}x)")
print(
f"Processed audio size: {len(processed_audio)} bytes (Gain: {VOLUME_GAIN}x)"
)
# 获取ASR识别结果
asr_text = ""
@@ -737,11 +954,15 @@ async def websocket_endpoint(websocket: WebSocket):
cmd = [
"ffmpeg",
"-y", # 覆盖输出文件
"-f", "s16le", # 输入格式: signed 16-bit little endian
"-ar", "16000", # 输入采样率
"-ac", "1", # 输入声道数
"-i", RECORDING_RAW_FILE,
RECORDING_MP3_FILE
"-f",
"s16le", # 输入格式: signed 16-bit little endian
"-ar",
"16000", # 输入采样率
"-ac",
"1", # 输入声道数
"-i",
RECORDING_RAW_FILE,
RECORDING_MP3_FILE,
]
print(f"Running command: {' '.join(cmd)}")
@@ -749,18 +970,24 @@ async def websocket_endpoint(websocket: WebSocket):
process = await asyncio.create_subprocess_exec(
*cmd,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode, cmd, output=stdout, stderr=stderr)
raise subprocess.CalledProcessError(
process.returncode, cmd, output=stdout, stderr=stderr
)
print(f"Saved MP3 to {RECORDING_MP3_FILE}")
except subprocess.CalledProcessError as e:
print(f"Error converting to MP3: {e}")
# stderr might be bytes
error_msg = e.stderr.decode() if isinstance(e.stderr, bytes) else str(e.stderr)
error_msg = (
e.stderr.decode()
if isinstance(e.stderr, bytes)
else str(e.stderr)
)
print(f"FFmpeg stderr: {error_msg}")
except FileNotFoundError:
print("Error: ffmpeg not found. Please install ffmpeg.")
@@ -775,9 +1002,15 @@ async def websocket_endpoint(websocket: WebSocket):
try:
unique_chars = set(asr_text)
code_list = [str(ord(c)) for c in unique_chars]
print(f"Sending font data for {len(code_list)} characters...")
success_count, failed = await send_font_batch_with_retry(websocket, code_list)
print(f"Font data sent: {success_count} success, {len(failed)} failed")
print(
f"Sending font data for {len(code_list)} characters..."
)
success_count, failed = await send_font_batch_with_retry(
websocket, code_list
)
print(
f"Font data sent: {success_count} success, {len(failed)} failed"
)
except Exception as e:
print(f"Error sending font data: {e}")
@@ -798,7 +1031,9 @@ async def websocket_endpoint(websocket: WebSocket):
prompt_text = text.split(":", 1)[1]
print(f"Received GENERATE_IMAGE request: {prompt_text}")
if prompt_text:
asyncio.create_task(start_async_image_generation(websocket, prompt_text))
asyncio.create_task(
start_async_image_generation(websocket, prompt_text)
)
else:
await websocket.send_text("STATUS:ERROR:提示词为空")
@@ -807,15 +1042,19 @@ async def websocket_endpoint(websocket: WebSocket):
if os.path.exists(GENERATED_IMAGE_FILE):
try:
# Use convert_img logic to get TSPL commands
tspl_data = convert_img.image_to_tspl_commands(GENERATED_IMAGE_FILE)
tspl_data = convert_img.image_to_tspl_commands(
GENERATED_IMAGE_FILE
)
if tspl_data:
print(f"Sending printer data: {len(tspl_data)} bytes")
await websocket.send_text(f"PRINTER_DATA_START:{len(tspl_data)}")
await websocket.send_text(
f"PRINTER_DATA_START:{len(tspl_data)}"
)
# Send in chunks
chunk_size = 512
for i in range(0, len(tspl_data), chunk_size):
chunk = tspl_data[i:i+chunk_size]
chunk = tspl_data[i : i + chunk_size]
await websocket.send_bytes(chunk)
# Small delay to prevent overwhelming ESP32 buffer
await asyncio.sleep(0.01)
@@ -826,7 +1065,9 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_text("STATUS:ERROR:图片转换失败")
except Exception as e:
print(f"Error converting image for printer: {e}")
await websocket.send_text(f"STATUS:ERROR:打印出错: {str(e)}")
await websocket.send_text(
f"STATUS:ERROR:打印出错: {str(e)}"
)
else:
await websocket.send_text("STATUS:ERROR:没有可打印的图片")
@@ -834,20 +1075,37 @@ async def websocket_endpoint(websocket: WebSocket):
task_id = text.split(":", 1)[1].strip()
if task_id in active_tasks:
task = active_tasks[task_id]
await websocket.send_text(f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}")
await websocket.send_text(
f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}"
)
else:
await websocket.send_text(f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成")
await websocket.send_text(
f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成"
)
elif text.startswith("GET_FONTS_BATCH:") or text.startswith("GET_FONT") or text == "GET_FONT_MD5" or text == "GET_HIGH_FREQ":
elif (
text.startswith("GET_FONTS_BATCH:")
or text.startswith("GET_FONT")
or text == "GET_FONT_MD5"
or text == "GET_HIGH_FREQ"
):
# 使用新的统一字体处理函数
try:
if text.startswith("GET_FONTS_BATCH:"):
await handle_font_request(websocket, text, text.split(":", 1)[1])
await handle_font_request(
websocket, text, text.split(":", 1)[1]
)
elif text.startswith("GET_FONT_FRAGMENT:"):
await handle_font_request(websocket, text, text.split(":", 1)[1])
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
await handle_font_request(
websocket, text, text.split(":", 1)[1]
)
elif text.startswith("GET_FONT_UNICODE:") or text.startswith(
"GET_FONT:"
):
parts = text.split(":", 1)
await handle_font_request(websocket, parts[0], parts[1] if len(parts) > 1 else "")
await handle_font_request(
websocket, parts[0], parts[1] if len(parts) > 1 else ""
)
else:
await handle_font_request(websocket, text, "")
except Exception as e:
@@ -884,6 +1142,7 @@ async def websocket_endpoint(websocket: WebSocket):
except:
pass
if __name__ == "__main__":
# Check API Key
if not dashscope.api_key:
@@ -893,6 +1152,7 @@ if __name__ == "__main__":
# 获取本机IP方便ESP32连接
import socket
hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname)
print(f"Server running on ws://{local_ip}:8000/ws/audio")

View File

@@ -0,0 +1,364 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI Image Generator Admin</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: #1a1a2e; color: #eee; min-height: 100vh; padding: 20px; }
.container { max-width: 1100px; margin: 0 auto; }
h1 { text-align: center; margin-bottom: 30px; color: #00d4ff; }
.card { background: #16213e; border-radius: 12px; padding: 24px; margin-bottom: 20px; box-shadow: 0 4px 20px rgba(0,0,0,0.3); }
.card h2 { margin-bottom: 16px; color: #00d4ff; font-size: 18px; }
.current-status { background: #0f3460; padding: 16px; border-radius: 8px; margin-bottom: 16px; }
.current-status .label { color: #888; font-size: 14px; }
.current-status .value { color: #00d4ff; font-size: 20px; font-weight: bold; margin-top: 4px; }
.form-group { margin-bottom: 16px; }
.form-group label { display: block; margin-bottom: 8px; color: #ccc; }
select, input[type="number"] { width: 100%; padding: 12px; border-radius: 8px; border: 1px solid #333; background: #0f3460; color: #fff; font-size: 16px; }
select:focus, input:focus { outline: none; border-color: #00d4ff; }
input[type="checkbox"] { width: 20px; height: 20px; margin-right: 8px; }
.btn { display: inline-block; padding: 12px 24px; border-radius: 8px; border: none; cursor: pointer; font-size: 16px; font-weight: bold; transition: all 0.3s; }
.btn-primary { background: #00d4ff; color: #1a1a2e; }
.btn-primary:hover { background: #00b8e6; }
.btn-danger { background: #e74c3c; color: #fff; }
.btn-danger:hover { background: #c0392b; }
.btn-small { padding: 6px 12px; font-size: 14px; }
.model-list { margin-top: 16px; }
.model-item { background: #0f3460; padding: 12px 16px; border-radius: 8px; margin-bottom: 8px; display: flex; justify-content: space-between; align-items: center; }
.model-item.active { border: 2px solid #00d4ff; }
.model-item .name { font-weight: bold; }
.model-item .provider { color: #888; font-size: 14px; }
.test-section { margin-top: 20px; }
.test-input { width: 100%; padding: 12px; border-radius: 8px; border: 1px solid #333; background: #0f3460; color: #fff; font-size: 14px; resize: vertical; min-height: 80px; margin-bottom: 12px; }
.test-input:focus { outline: none; border-color: #00d4ff; }
.message { padding: 12px; border-radius: 8px; margin-top: 12px; display: none; }
.message.success { background: #27ae60; display: block; }
.message.error { background: #e74c3c; display: block; }
.loading { text-align: center; padding: 20px; display: none; }
.loading.show { display: block; }
.spinner { border: 3px solid #333; border-top: 3px solid #00d4ff; border-radius: 50%; width: 30px; height: 30px; animation: spin 1s linear infinite; margin: 0 auto; }
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
.auto-delete-settings { display: flex; gap: 16px; align-items: center; flex-wrap: wrap; }
.auto-delete-settings label { display: flex; align-items: center; color: #ccc; }
.auto-delete-settings input[type="number"] { width: 80px; }
.gallery { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 16px; margin-top: 16px; }
.gallery-item { background: #0f3460; border-radius: 8px; overflow: hidden; position: relative; }
.gallery-item img { width: 100%; height: 180px; object-fit: cover; display: block; }
.gallery-item .info { padding: 12px; }
.gallery-item .filename { font-size: 12px; color: #888; word-break: break-all; }
.gallery-item .size { font-size: 12px; color: #666; margin-top: 4px; }
.gallery-item .delete-btn { position: absolute; top: 8px; right: 8px; background: rgba(231,76,60,0.9); color: white; border: none; border-radius: 50%; width: 28px; height: 28px; cursor: pointer; font-size: 16px; line-height: 28px; }
.gallery-item .delete-btn:hover { background: #c0392b; }
.empty-gallery { text-align: center; padding: 40px; color: #666; }
.flex-row { display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }
.tab-nav { display: flex; gap: 4px; margin-bottom: 20px; background: #0f3460; border-radius: 8px; padding: 4px; }
.tab-nav button { flex: 1; padding: 12px; border: none; background: transparent; color: #888; cursor: pointer; border-radius: 6px; font-size: 16px; transition: all 0.3s; }
.tab-nav button.active { background: #00d4ff; color: #1a1a2e; font-weight: bold; }
.tab-content { display: none; }
.tab-content.active { display: block; }
</style>
</head>
<body>
<div class="container">
<h1>AI Image Generator Admin</h1>
<div class="tab-nav">
<button class="active" onclick="showTab('settings')">设置</button>
<button onclick="showTab('gallery')">图片库</button>
</div>
<div id="tab-settings" class="tab-content active">
<div class="card">
<h2>当前状态</h2>
<div class="current-status">
<div class="label">当前 Provider</div>
<div class="value" id="currentProvider">加载中...</div>
<div class="label" style="margin-top: 12px;">当前模型</div>
<div class="value" id="currentModel">加载中...</div>
</div>
</div>
<div class="card">
<h2>切换 Provider</h2>
<div class="form-group">
<select id="providerSelect">
<option value="doubao">豆包 (Doubao)</option>
<option value="dashscope">阿里云 (DashScope)</option>
</select>
</div>
<button class="btn btn-primary" onclick="switchProvider()">切换 Provider</button>
</div>
<div class="card">
<h2>豆包模型</h2>
<div class="model-list">
<div class="model-item" data-provider="doubao" data-model="doubao-seedream-4.0">
<div>
<div class="name">doubao-seedream-4.0</div>
<div class="provider">豆包</div>
</div>
<button class="btn btn-primary" onclick="setModel('doubao', 'doubao-seedream-4.0')">使用</button>
</div>
<div class="model-item" data-provider="doubao" data-model="doubao-seedream-5-0-260128">
<div>
<div class="name">doubao-seedream-5-0-260128</div>
<div class="provider">豆包</div>
</div>
<button class="btn btn-primary" onclick="setModel('doubao', 'doubao-seedream-5-0-260128')">使用</button>
</div>
</div>
</div>
<div class="card">
<h2>阿里云模型 (DashScope)</h2>
<div class="model-list">
<div class="model-item" data-provider="dashscope" data-model="wanx2.0-t2i-turbo">
<div>
<div class="name">wanx2.0-t2i-turbo</div>
<div class="provider">阿里云</div>
</div>
<button class="btn btn-primary" onclick="setModel('dashscope', 'wanx2.0-t2i-turbo')">使用</button>
</div>
<div class="model-item" data-provider="dashscope" data-model="qwen-image-plus">
<div>
<div class="name">qwen-image-plus</div>
<div class="provider">阿里云</div>
</div>
<button class="btn btn-primary" onclick="setModel('dashscope', 'qwen-image-plus')">使用</button>
</div>
<div class="model-item" data-provider="dashscope" data-model="qwen-image-v1">
<div>
<div class="name">qwen-image-v1</div>
<div class="provider">阿里云</div>
</div>
<button class="btn btn-primary" onclick="setModel('dashscope', 'qwen-image-v1')">使用</button>
</div>
</div>
</div>
<div class="card">
<h2>测试图片生成</h2>
<textarea class="test-input" id="testPrompt" placeholder="输入提示词...">A cute cat, black and white line art, cartoon style</textarea>
<button class="btn btn-primary" onclick="testGenerate()">生成图片</button>
<div class="loading" id="loading">
<div class="spinner"></div>
<p style="margin-top: 10px;">生成中...</p>
</div>
<div class="message" id="message"></div>
<div id="resultArea" style="margin-top: 16px; display: none;">
<img id="resultImage" style="max-width: 100%; max-height: 300px; border-radius: 8px;">
</div>
</div>
</div>
<div id="tab-gallery" class="tab-content">
<div class="card">
<h2>图片库</h2>
<div class="auto-delete-settings">
<label><input type="checkbox" id="autoDeleteEnabled" onchange="updateAutoDelete()"> 自动删除</label>
<label><input type="number" id="autoDeleteHours" min="1" max="168" value="24" onchange="updateAutoDelete()"> 小时后删除</label>
<button class="btn btn-primary btn-small" onclick="loadGallery()">刷新</button>
<button class="btn btn-danger btn-small" onclick="deleteAllImages()">删除全部</button>
</div>
<div class="gallery" id="gallery"></div>
</div>
</div>
</div>
<script>
function showTab(tab) {
document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
document.querySelectorAll('.tab-nav button').forEach(el => el.classList.remove('active'));
document.getElementById('tab-' + tab).classList.add('active');
event.target.classList.add('active');
if (tab === 'gallery') loadGallery();
}
async function loadStatus() {
try {
const res = await fetch('/api/admin/status');
const data = await res.json();
document.getElementById('currentProvider').textContent = data.provider;
document.getElementById('currentModel').textContent = data.model;
document.getElementById('providerSelect').value = data.provider;
updateActiveModel(data.provider, data.model);
} catch (e) {
console.error('Failed to load status:', e);
}
}
async function loadAutoDelete() {
try {
const res = await fetch('/api/admin/auto-delete');
const data = await res.json();
document.getElementById('autoDeleteEnabled').checked = data.enabled;
document.getElementById('autoDeleteHours').value = data.hours;
} catch (e) {
console.error('Failed to load auto-delete settings:', e);
}
}
async function updateAutoDelete() {
const enabled = document.getElementById('autoDeleteEnabled').checked;
const hours = document.getElementById('autoDeleteHours').value;
try {
await fetch('/api/admin/auto-delete', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ enabled, hours: parseInt(hours) })
});
} catch (e) {
console.error('Failed to update auto-delete:', e);
}
}
function updateActiveModel(provider, model) {
document.querySelectorAll('.model-item').forEach(item => {
item.classList.remove('active');
if (item.dataset.provider === provider && item.dataset.model === model) {
item.classList.add('active');
item.querySelector('button').textContent = '使用中';
item.querySelector('button').disabled = true;
} else {
item.querySelector('button').textContent = '使用';
item.querySelector('button').disabled = false;
}
});
}
async function switchProvider() {
const provider = document.getElementById('providerSelect').value;
try {
const res = await fetch('/api/admin/switch', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ provider })
});
const data = await res.json();
showMessage(data.message, data.success);
if (data.success) loadStatus();
} catch (e) {
showMessage('切换失败: ' + e.message, false);
}
}
async function setModel(provider, model) {
try {
const res = await fetch('/api/admin/model', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ provider, model })
});
const data = await res.json();
showMessage(data.message, data.success);
if (data.success) loadStatus();
} catch (e) {
showMessage('设置失败: ' + e.message, false);
}
}
async function testGenerate() {
const prompt = document.getElementById('testPrompt').value;
if (!prompt) return;
document.getElementById('loading').classList.add('show');
document.getElementById('message').style.display = 'none';
document.getElementById('resultArea').style.display = 'none';
try {
const res = await fetch('/api/admin/test-generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ prompt })
});
const data = await res.json();
if (data.success && data.image_url) {
document.getElementById('resultImage').src = data.image_url;
document.getElementById('resultArea').style.display = 'block';
showMessage('生成成功', true);
} else {
showMessage(data.message || '生成失败', false);
}
} catch (e) {
showMessage('生成失败: ' + e.message, false);
} finally {
document.getElementById('loading').classList.remove('show');
}
}
async function loadGallery() {
try {
const res = await fetch('/api/admin/images');
const data = await res.json();
const gallery = document.getElementById('gallery');
if (!data.images || data.images.length === 0) {
gallery.innerHTML = '<div class="empty-gallery">暂无图片</div>';
return;
}
gallery.innerHTML = data.images.map(img => `
<div class="gallery-item">
<button class="delete-btn" onclick="deleteImage('${img.name}')">×</button>
<img src="${img.url}" alt="${img.name}" onclick="window.open('${img.url}', '_blank')">
<div class="info">
<div class="filename">${img.name}</div>
<div class="size">${formatSize(img.size)}</div>
</div>
</div>
`).join('');
} catch (e) {
console.error('Failed to load gallery:', e);
}
}
async function deleteImage(filename) {
if (!confirm('确定要删除这张图片吗?')) return;
try {
const res = await fetch(`/api/admin/images/${encodeURIComponent(filename)}`, { method: 'DELETE' });
const data = await res.json();
if (data.success) {
loadGallery();
} else {
alert(data.message);
}
} catch (e) {
alert('删除失败: ' + e.message);
}
}
async function deleteAllImages() {
if (!confirm('确定要删除所有图片吗?此操作不可恢复!')) return;
try {
const res = await fetch('/api/admin/images');
const data = await res.json();
for (const img of data.images) {
await fetch(`/api/admin/images/${encodeURIComponent(img.name)}`, { method: 'DELETE' });
}
loadGallery();
} catch (e) {
alert('删除失败: ' + e.message);
}
}
function formatSize(bytes) {
if (bytes < 1024) return bytes + ' B';
if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB';
return (bytes / (1024 * 1024)).toFixed(1) + ' MB';
}
function showMessage(msg, success) {
const el = document.getElementById('message');
el.textContent = msg;
el.className = 'message ' + (success ? 'success' : 'error');
}
loadStatus();
loadAutoDelete();
</script>
</body>
</html>