mode bug
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 4s

This commit is contained in:
jeremygan2021
2026-03-20 18:04:44 +08:00
parent 5b91e90d45
commit 88bb27569a
2 changed files with 842 additions and 222 deletions

View File

@@ -1,4 +1,7 @@
from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from contextlib import asynccontextmanager
import uvicorn import uvicorn
import asyncio import asyncio
import os import os
@@ -15,6 +18,7 @@ from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionRes
# from dashscope import Generation # from dashscope import Generation
import sys import sys
# import os # import os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import convert_img import convert_img
@@ -28,7 +32,169 @@ dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
# provider="doubao" or "dashscope" # provider="doubao" or "dashscope"
image_generator = ImageGenerator(provider="doubao") image_generator = ImageGenerator(provider="doubao")
app = FastAPI()
@asynccontextmanager
async def lifespan(app: FastAPI):
cleanup_old_media()
print("Media cleanup completed on startup")
yield
app = FastAPI(lifespan=lifespan)
app.mount("/media", StaticFiles(directory="media"), name="media")
# Admin API endpoints
@app.get("/admin")
async def admin_page():
with open(
os.path.join(os.path.dirname(__file__), "templates", "admin.html"), "r"
) as f:
return HTMLResponse(content=f.read())
@app.get("/api/admin/status")
async def get_admin_status():
return {"provider": image_generator.provider, "model": image_generator.model}
@app.post("/api/admin/switch")
async def switch_provider(request: dict):
global image_generator
provider = request.get("provider")
if provider not in ["doubao", "dashscope"]:
return {"success": False, "message": "Invalid provider"}
old_provider = image_generator.provider
old_model = image_generator.model
image_generator = ImageGenerator(provider=provider)
return {
"success": True,
"message": f"Switched from {old_provider}/{old_model} to {provider}/{image_generator.model}",
}
@app.post("/api/admin/model")
async def set_model(request: dict):
global image_generator
provider = request.get("provider")
model = request.get("model")
if not provider or not model:
return {"success": False, "message": "Provider and model required"}
if provider not in ["doubao", "dashscope"]:
return {"success": False, "message": "Invalid provider"}
image_generator = ImageGenerator(provider=provider, model=model)
return {"success": True, "message": f"Model set to {provider}/{model}"}
@app.post("/api/admin/test-generate")
async def test_generate(request: dict):
prompt = request.get("prompt")
if not prompt:
return {"success": False, "message": "Prompt is required"}
def progress_callback(progress, message):
print(f"Test generation progress: {progress}% - {message}")
image_url = image_generator.generate_image(prompt, progress_callback)
if image_url:
local_path = save_to_media(image_url)
return {
"success": True,
"image_url": image_url,
"local_path": local_path,
"message": "Image generated successfully",
}
else:
return {"success": False, "message": "Image generation failed"}
def save_to_media(image_url):
import urllib.request
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"image_{timestamp}.png"
filepath = os.path.join(MEDIA_FOLDER, filename)
try:
urllib.request.urlretrieve(image_url, filepath)
return filepath
except Exception as e:
print(f"Error saving to media: {e}")
return None
@app.get("/api/admin/images")
async def list_images():
images = []
if os.path.exists(MEDIA_FOLDER):
for f in sorted(os.listdir(MEDIA_FOLDER), reverse=True):
if f.endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")):
filepath = os.path.join(MEDIA_FOLDER, f)
stat = os.stat(filepath)
images.append(
{
"name": f,
"path": filepath,
"size": stat.st_size,
"created": stat.st_ctime,
"url": f"/media/{f}",
}
)
return {"images": images}
@app.delete("/api/admin/images/{filename}")
async def delete_image(filename: str):
safe_name = os.path.basename(filename)
filepath = os.path.join(MEDIA_FOLDER, safe_name)
if os.path.exists(filepath):
os.remove(filepath)
return {"success": True, "message": f"Deleted {safe_name}"}
return {"success": False, "message": "File not found"}
@app.post("/api/admin/auto-delete")
async def set_auto_delete(request: dict):
global auto_delete_hours, auto_delete_enabled
hours = request.get("hours")
enabled = request.get("enabled")
if hours is not None:
auto_delete_hours = int(hours)
if enabled is not None:
auto_delete_enabled = bool(enabled)
return {
"success": True,
"message": f"Auto-delete set to {auto_delete_hours}h, enabled: {auto_delete_enabled}",
}
@app.get("/api/admin/auto-delete")
async def get_auto_delete():
return {"hours": auto_delete_hours, "enabled": auto_delete_enabled}
def cleanup_old_media():
if not auto_delete_enabled:
return
if not os.path.exists(MEDIA_FOLDER):
return
now = time.time()
for f in os.listdir(MEDIA_FOLDER):
if f.endswith((".png", ".jpg", ".jpeg", ".gif", ".webp")):
filepath = os.path.join(MEDIA_FOLDER, f)
age_hours = (now - os.stat(filepath).st_ctime) / 3600
if age_hours > auto_delete_hours:
print(f"Auto-deleting old image: {f}")
os.remove(filepath)
# 字体文件配置 # 字体文件配置
FONT_FILE = "GB2312-16.bin" FONT_FILE = "GB2312-16.bin"
@@ -43,6 +209,7 @@ font_cache = {}
font_md5 = {} font_md5 = {}
font_data_buffer = None font_data_buffer = None
def calculate_md5(filepath): def calculate_md5(filepath):
"""计算文件的MD5哈希值""" """计算文件的MD5哈希值"""
if not os.path.exists(filepath): if not os.path.exists(filepath):
@@ -61,9 +228,9 @@ def get_font_data(unicode_val):
try: try:
char = chr(unicode_val) char = chr(unicode_val)
gb_bytes = char.encode('gb2312') gb_bytes = char.encode("gb2312")
if len(gb_bytes) == 2: if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0] code = struct.unpack(">H", gb_bytes)[0]
area = (code >> 8) - 0xA0 area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0 index = (code & 0xFF) - 0xA0
@@ -123,6 +290,7 @@ def init_font_cache():
get_font_data(unicode_val) get_font_data(unicode_val)
print(f"Preloaded {len(font_cache)} high-frequency characters") print(f"Preloaded {len(font_cache)} high-frequency characters")
# 启动时初始化字体缓存 # 启动时初始化字体缓存
init_font_cache() init_font_cache()
@@ -134,27 +302,37 @@ VOLUME_GAIN = 10.0
GENERATED_IMAGE_FILE = "generated_image.png" GENERATED_IMAGE_FILE = "generated_image.png"
GENERATED_THUMB_FILE = "generated_thumb.bin" GENERATED_THUMB_FILE = "generated_thumb.bin"
OUTPUT_DIR = "output_images" OUTPUT_DIR = "output_images"
MEDIA_FOLDER = "media"
if not os.path.exists(OUTPUT_DIR): if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR) os.makedirs(OUTPUT_DIR)
if not os.path.exists(MEDIA_FOLDER):
os.makedirs(MEDIA_FOLDER)
image_counter = 0 image_counter = 0
auto_delete_hours = 24
auto_delete_enabled = True
def get_output_path(): def get_output_path():
global image_counter global image_counter
image_counter += 1 image_counter += 1
timestamp = time.strftime("%Y%m%d_%H%M%S") timestamp = time.strftime("%Y%m%d_%H%M%S")
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png") return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
THUMB_SIZE = 240 THUMB_SIZE = 240
# 字体请求队列(用于重试机制) # 字体请求队列(用于重试机制)
font_request_queue = {} font_request_queue = {}
FONT_RETRY_MAX = 3 FONT_RETRY_MAX = 3
# 图片生成任务管理 # 图片生成任务管理
class ImageGenerationTask: class ImageGenerationTask:
"""图片生成任务管理类""" """图片生成任务管理类"""
def __init__(self, task_id: str, asr_text: str, websocket: WebSocket): def __init__(self, task_id: str, asr_text: str, websocket: WebSocket):
self.task_id = task_id self.task_id = task_id
self.asr_text = asr_text self.asr_text = asr_text
@@ -165,6 +343,7 @@ class ImageGenerationTask:
self.result = None self.result = None
self.error = None self.error = None
# 存储活跃的图片生成任务 # 存储活跃的图片生成任务
active_tasks = {} active_tasks = {}
task_counter = 0 task_counter = 0
@@ -209,8 +388,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
task.message = message task.message = message
# 通过队列在主循环中发送消息 # 通过队列在主循环中发送消息
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
progress_callback_async(progress, message), progress_callback_async(progress, message), loop
loop
) )
try: try:
@@ -220,12 +398,16 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
# 同步调用优化函数 # 同步调用优化函数
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback) optimized_prompt = await asyncio.to_thread(
optimize_prompt, asr_text, progress_callback
)
# 确保返回有效的提示词 # 确保返回有效的提示词
if not optimized_prompt: if not optimized_prompt:
optimized_prompt = asr_text optimized_prompt = asr_text
print(f"Warning: optimize_prompt returned None, using original text: {asr_text}") print(
f"Warning: optimize_prompt returned None, using original text: {asr_text}"
)
await websocket.send_text(f"PROMPT:{optimized_prompt}") await websocket.send_text(f"PROMPT:{optimized_prompt}")
task.optimized_prompt = optimized_prompt task.optimized_prompt = optimized_prompt
@@ -235,7 +417,9 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
# 同步调用图片生成函数 # 同步调用图片生成函数
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback) image_path = await asyncio.to_thread(
generate_image, optimized_prompt, progress_callback
)
task.result = image_path task.result = image_path
@@ -269,7 +453,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
async def send_image_to_client(websocket: WebSocket, image_path: str): async def send_image_to_client(websocket: WebSocket, image_path: str):
"""发送图片数据到客户端""" """发送图片数据到客户端"""
with open(image_path, 'rb') as f: with open(image_path, "rb") as f:
image_data = f.read() image_data = f.read()
print(f"Sending image to ESP32, size: {len(image_data)} bytes") print(f"Sending image to ESP32, size: {len(image_data)} bytes")
@@ -278,7 +462,9 @@ async def send_image_to_client(websocket: WebSocket, image_path: str):
model_name = f"{image_generator.provider}" model_name = f"{image_generator.provider}"
if image_generator.model: if image_generator.model:
model_name += f" {image_generator.model}" model_name += f" {image_generator.model}"
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}:{model_name}") await websocket.send_text(
f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}:{model_name}"
)
# Send binary data directly # Send binary data directly
chunk_size = 512 # Decreased chunk size for ESP32 memory stability chunk_size = 512 # Decreased chunk size for ESP32 memory stability
@@ -308,7 +494,8 @@ async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
if font_data: if font_data:
import binascii import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
hex_data = binascii.hexlify(font_data).decode("utf-8")
response = f"FONT_DATA:{code_str}:{hex_data}" response = f"FONT_DATA:{code_str}:{hex_data}"
await websocket.send_text(response) await websocket.send_text(response)
success_codes.add(unicode_val) success_codes.add(unicode_val)
@@ -322,9 +509,9 @@ async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
if failed_codes and retry_count < FONT_RETRY_MAX: if failed_codes and retry_count < FONT_RETRY_MAX:
req_key = f"retry_{retry_count}_{time.time()}" req_key = f"retry_{retry_count}_{time.time()}"
font_request_queue[req_key] = { font_request_queue[req_key] = {
'codes': failed_codes, "codes": failed_codes,
'retry': retry_count + 1, "retry": retry_count + 1,
'timestamp': time.time() "timestamp": time.time(),
} }
return len(success_codes), failed_codes return len(success_codes), failed_codes
@@ -345,7 +532,9 @@ async def send_font_with_fragment(websocket, unicode_val):
seq_num = i // chunk_size seq_num = i // chunk_size
# 构造二进制消息头: 2字节序列号 + 2字节总片数 + 数据 # 构造二进制消息头: 2字节序列号 + 2字节总片数 + 数据
header = struct.pack('<HH', seq_num, (total_size + chunk_size - 1) // chunk_size) header = struct.pack(
"<HH", seq_num, (total_size + chunk_size - 1) // chunk_size
)
payload = header + chunk payload = header + chunk
await websocket.send_bytes(payload) await websocket.send_bytes(payload)
@@ -374,7 +563,9 @@ async def handle_font_request(websocket, message_type, data):
code_list = codes_str.split(",") code_list = codes_str.split(",")
print(f"Batch Font Request for {len(code_list)} chars") print(f"Batch Font Request for {len(code_list)} chars")
success_count, failed = await send_font_batch_with_retry(websocket, code_list) success_count, failed = await send_font_batch_with_retry(
websocket, code_list
)
print(f"Font batch: {success_count} success, {len(failed)} failed") print(f"Font batch: {success_count} success, {len(failed)} failed")
# 发送完成标记 # 发送完成标记
@@ -399,7 +590,9 @@ async def handle_font_request(websocket, message_type, data):
print(f"Error sending font fragment: {e}") print(f"Error sending font fragment: {e}")
return return
elif message_type.startswith("GET_FONT_UNICODE:") or message_type.startswith("GET_FONT:"): elif message_type.startswith("GET_FONT_UNICODE:") or message_type.startswith(
"GET_FONT:"
):
# 单个字体请求(兼容旧版) # 单个字体请求(兼容旧版)
try: try:
is_unicode = message_type.startswith("GET_FONT_UNICODE:") is_unicode = message_type.startswith("GET_FONT_UNICODE:")
@@ -429,12 +622,14 @@ async def handle_font_request(websocket, message_type, data):
if font_data: if font_data:
import binascii import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
hex_data = binascii.hexlify(font_data).decode("utf-8")
response = f"FONT_DATA:{code_str}:{hex_data}" response = f"FONT_DATA:{code_str}:{hex_data}"
await websocket.send_text(response) await websocket.send_text(response)
except Exception as e: except Exception as e:
print(f"Error handling font request: {e}") print(f"Error handling font request: {e}")
class MyRecognitionCallback(RecognitionCallback): class MyRecognitionCallback(RecognitionCallback):
def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop): def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop):
self.websocket = websocket self.websocket = websocket
@@ -459,8 +654,7 @@ class MyRecognitionCallback(RecognitionCallback):
try: try:
if self.loop.is_running(): if self.loop.is_running():
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
self.websocket.send_text(f"ASR:{self.final_text}"), self.websocket.send_text(f"ASR:{self.final_text}"), self.loop
self.loop
) )
except Exception as e: except Exception as e:
print(f"Failed to send final ASR result: {e}") print(f"Failed to send final ASR result: {e}")
@@ -468,10 +662,9 @@ class MyRecognitionCallback(RecognitionCallback):
def on_error(self, result: RecognitionResult) -> None: def on_error(self, result: RecognitionResult) -> None:
print(f"ASR Error: {result}") print(f"ASR Error: {result}")
def on_event(self, result: RecognitionResult) -> None: def on_event(self, result: RecognitionResult) -> None:
if result.get_sentence(): if result.get_sentence():
text = result.get_sentence()['text'] text = result.get_sentence()["text"]
# 获取当前句子的结束状态 # 获取当前句子的结束状态
# 注意DashScope Python SDK 的 Result 结构可能需要根据版本调整 # 注意DashScope Python SDK 的 Result 结构可能需要根据版本调整
@@ -481,8 +674,8 @@ class MyRecognitionCallback(RecognitionCallback):
if self.sentence_list: if self.sentence_list:
last_sentence = self.sentence_list[-1] last_sentence = self.sentence_list[-1]
# 去掉句尾标点进行比较,因为流式结果可能标点不稳定 # 去掉句尾标点进行比较,因为流式结果可能标点不稳定
last_clean = last_sentence.rstrip('。,?!') last_clean = last_sentence.rstrip("。,?!")
text_clean = text.rstrip('。,?!') text_clean = text.rstrip("。,?!")
if text_clean.startswith(last_clean): if text_clean.startswith(last_clean):
# 更新当前句子 # 更新当前句子
@@ -516,23 +709,26 @@ class MyRecognitionCallback(RecognitionCallback):
# except Exception as e: # except Exception as e:
# print(f"Failed to send ASR result to client: {e}") # print(f"Failed to send ASR result to client: {e}")
def process_chunk_32_to_16(chunk_bytes, gain=1.0): def process_chunk_32_to_16(chunk_bytes, gain=1.0):
processed_chunk = bytearray() processed_chunk = bytearray()
# Iterate 4 bytes at a time # Iterate 4 bytes at a time
for i in range(0, len(chunk_bytes), 4): for i in range(0, len(chunk_bytes), 4):
if i + 3 < len(chunk_bytes): if i + 3 < len(chunk_bytes):
# 取 chunk[i+2] 和 chunk[i+3] 组成 16-bit signed int # 取 chunk[i+2] 和 chunk[i+3] 组成 16-bit signed int
sample = struct.unpack_from('<h', chunk_bytes, i+2)[0] sample = struct.unpack_from("<h", chunk_bytes, i + 2)[0]
# 放大音量 # 放大音量
sample = int(sample * gain) sample = int(sample * gain)
# 限幅 (Clamping) 防止溢出爆音 # 限幅 (Clamping) 防止溢出爆音
if sample > 32767: sample = 32767 if sample > 32767:
elif sample < -32768: sample = -32768 sample = 32767
elif sample < -32768:
sample = -32768
# 重新打包为 16-bit little-endian # 重新打包为 16-bit little-endian
processed_chunk.extend(struct.pack('<h', sample)) processed_chunk.extend(struct.pack("<h", sample))
return processed_chunk return processed_chunk
@@ -563,8 +759,12 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
if retry_count < max_retries: if retry_count < max_retries:
print(f"Retrying... ({retry_count + 1}/{max_retries})") print(f"Retrying... ({retry_count + 1}/{max_retries})")
if progress_callback: if progress_callback:
progress_callback(35, f"生成失败,正在重试 ({retry_count + 1}/{max_retries})...") progress_callback(
return generate_image(prompt, progress_callback, retry_count + 1, max_retries) 35, f"生成失败,正在重试 ({retry_count + 1}/{max_retries})..."
)
return generate_image(
prompt, progress_callback, retry_count + 1, max_retries
)
else: else:
return None return None
@@ -574,6 +774,7 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
progress_callback(70, "正在下载生成的图片...") progress_callback(70, "正在下载生成的图片...")
import urllib.request import urllib.request
try: try:
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE) urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
print(f"Image saved to {GENERATED_IMAGE_FILE}") print(f"Image saved to {GENERATED_IMAGE_FILE}")
@@ -586,6 +787,7 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
# Save to output dir # Save to output dir
output_path = get_output_path() output_path = get_output_path()
import shutil import shutil
shutil.copy(GENERATED_IMAGE_FILE, output_path) shutil.copy(GENERATED_IMAGE_FILE, output_path)
print(f"Image also saved to {output_path}") print(f"Image also saved to {output_path}")
@@ -595,6 +797,7 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
# Resize and convert to RGB565 (Reuse existing logic) # Resize and convert to RGB565 (Reuse existing logic)
try: try:
from PIL import Image from PIL import Image
img = Image.open(GENERATED_IMAGE_FILE) img = Image.open(GENERATED_IMAGE_FILE)
# 保持比例缩放 # 保持比例缩放
@@ -631,13 +834,15 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
# Pack as Big Endian (>H) which is standard for SPI displays # Pack as Big Endian (>H) which is standard for SPI displays
# RGB565: Red(5) Green(6) Blue(5) # RGB565: Red(5) Green(6) Blue(5)
rgb565 = (r5 << 11) | (g6 << 5) | b5 rgb565 = (r5 << 11) | (g6 << 5) | b5
rgb565_data.extend(struct.pack('>H', rgb565)) rgb565_data.extend(struct.pack(">H", rgb565))
# 保存为.bin文件 # 保存为.bin文件
with open(GENERATED_THUMB_FILE, 'wb') as f: with open(GENERATED_THUMB_FILE, "wb") as f:
f.write(rgb565_data) f.write(rgb565_data)
print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes") print(
f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes"
)
if progress_callback: if progress_callback:
progress_callback(100, "图片生成完成!") progress_callback(100, "图片生成完成!")
@@ -658,9 +863,12 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
except Exception as e: except Exception as e:
print(f"Error in generate_image: {e}") print(f"Error in generate_image: {e}")
if retry_count < max_retries: if retry_count < max_retries:
return generate_image(prompt, progress_callback, retry_count + 1, max_retries) return generate_image(
prompt, progress_callback, retry_count + 1, max_retries
)
return None return None
@app.websocket("/ws/audio") @app.websocket("/ws/audio")
async def websocket_endpoint(websocket: WebSocket): async def websocket_endpoint(websocket: WebSocket):
global audio_buffer global audio_buffer
@@ -678,7 +886,10 @@ async def websocket_endpoint(websocket: WebSocket):
try: try:
message = await websocket.receive() message = await websocket.receive()
except RuntimeError as e: except RuntimeError as e:
if "Cannot call \"receive\" once a disconnect message has been received" in str(e): if (
'Cannot call "receive" once a disconnect message has been received'
in str(e)
):
print("Client disconnected (RuntimeError caught)") print("Client disconnected (RuntimeError caught)")
break break
raise e raise e
@@ -696,10 +907,10 @@ async def websocket_endpoint(websocket: WebSocket):
try: try:
callback = MyRecognitionCallback(websocket, loop) callback = MyRecognitionCallback(websocket, loop)
recognition = Recognition( recognition = Recognition(
model='paraformer-realtime-v2', model="paraformer-realtime-v2",
format='pcm', format="pcm",
sample_rate=16000, sample_rate=16000,
callback=callback callback=callback,
) )
recognition.start() recognition.start()
print("DashScope ASR started") print("DashScope ASR started")
@@ -723,7 +934,9 @@ async def websocket_endpoint(websocket: WebSocket):
# 使用实时处理过的音频数据 # 使用实时处理过的音频数据
processed_audio = processed_buffer processed_audio = processed_buffer
print(f"Processed audio size: {len(processed_audio)} bytes (Gain: {VOLUME_GAIN}x)") print(
f"Processed audio size: {len(processed_audio)} bytes (Gain: {VOLUME_GAIN}x)"
)
# 获取ASR识别结果 # 获取ASR识别结果
asr_text = "" asr_text = ""
@@ -741,11 +954,15 @@ async def websocket_endpoint(websocket: WebSocket):
cmd = [ cmd = [
"ffmpeg", "ffmpeg",
"-y", # 覆盖输出文件 "-y", # 覆盖输出文件
"-f", "s16le", # 输入格式: signed 16-bit little endian "-f",
"-ar", "16000", # 输入采样率 "s16le", # 输入格式: signed 16-bit little endian
"-ac", "1", # 输入声道数 "-ar",
"-i", RECORDING_RAW_FILE, "16000", # 输入采样率
RECORDING_MP3_FILE "-ac",
"1", # 输入声道数
"-i",
RECORDING_RAW_FILE,
RECORDING_MP3_FILE,
] ]
print(f"Running command: {' '.join(cmd)}") print(f"Running command: {' '.join(cmd)}")
@@ -753,18 +970,24 @@ async def websocket_endpoint(websocket: WebSocket):
process = await asyncio.create_subprocess_exec( process = await asyncio.create_subprocess_exec(
*cmd, *cmd,
stdout=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE stderr=asyncio.subprocess.PIPE,
) )
stdout, stderr = await process.communicate() stdout, stderr = await process.communicate()
if process.returncode != 0: if process.returncode != 0:
raise subprocess.CalledProcessError(process.returncode, cmd, output=stdout, stderr=stderr) raise subprocess.CalledProcessError(
process.returncode, cmd, output=stdout, stderr=stderr
)
print(f"Saved MP3 to {RECORDING_MP3_FILE}") print(f"Saved MP3 to {RECORDING_MP3_FILE}")
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
print(f"Error converting to MP3: {e}") print(f"Error converting to MP3: {e}")
# stderr might be bytes # stderr might be bytes
error_msg = e.stderr.decode() if isinstance(e.stderr, bytes) else str(e.stderr) error_msg = (
e.stderr.decode()
if isinstance(e.stderr, bytes)
else str(e.stderr)
)
print(f"FFmpeg stderr: {error_msg}") print(f"FFmpeg stderr: {error_msg}")
except FileNotFoundError: except FileNotFoundError:
print("Error: ffmpeg not found. Please install ffmpeg.") print("Error: ffmpeg not found. Please install ffmpeg.")
@@ -779,9 +1002,15 @@ async def websocket_endpoint(websocket: WebSocket):
try: try:
unique_chars = set(asr_text) unique_chars = set(asr_text)
code_list = [str(ord(c)) for c in unique_chars] code_list = [str(ord(c)) for c in unique_chars]
print(f"Sending font data for {len(code_list)} characters...") print(
success_count, failed = await send_font_batch_with_retry(websocket, code_list) f"Sending font data for {len(code_list)} characters..."
print(f"Font data sent: {success_count} success, {len(failed)} failed") )
success_count, failed = await send_font_batch_with_retry(
websocket, code_list
)
print(
f"Font data sent: {success_count} success, {len(failed)} failed"
)
except Exception as e: except Exception as e:
print(f"Error sending font data: {e}") print(f"Error sending font data: {e}")
@@ -802,7 +1031,9 @@ async def websocket_endpoint(websocket: WebSocket):
prompt_text = text.split(":", 1)[1] prompt_text = text.split(":", 1)[1]
print(f"Received GENERATE_IMAGE request: {prompt_text}") print(f"Received GENERATE_IMAGE request: {prompt_text}")
if prompt_text: if prompt_text:
asyncio.create_task(start_async_image_generation(websocket, prompt_text)) asyncio.create_task(
start_async_image_generation(websocket, prompt_text)
)
else: else:
await websocket.send_text("STATUS:ERROR:提示词为空") await websocket.send_text("STATUS:ERROR:提示词为空")
@@ -811,10 +1042,14 @@ async def websocket_endpoint(websocket: WebSocket):
if os.path.exists(GENERATED_IMAGE_FILE): if os.path.exists(GENERATED_IMAGE_FILE):
try: try:
# Use convert_img logic to get TSPL commands # Use convert_img logic to get TSPL commands
tspl_data = convert_img.image_to_tspl_commands(GENERATED_IMAGE_FILE) tspl_data = convert_img.image_to_tspl_commands(
GENERATED_IMAGE_FILE
)
if tspl_data: if tspl_data:
print(f"Sending printer data: {len(tspl_data)} bytes") print(f"Sending printer data: {len(tspl_data)} bytes")
await websocket.send_text(f"PRINTER_DATA_START:{len(tspl_data)}") await websocket.send_text(
f"PRINTER_DATA_START:{len(tspl_data)}"
)
# Send in chunks # Send in chunks
chunk_size = 512 chunk_size = 512
@@ -830,7 +1065,9 @@ async def websocket_endpoint(websocket: WebSocket):
await websocket.send_text("STATUS:ERROR:图片转换失败") await websocket.send_text("STATUS:ERROR:图片转换失败")
except Exception as e: except Exception as e:
print(f"Error converting image for printer: {e}") print(f"Error converting image for printer: {e}")
await websocket.send_text(f"STATUS:ERROR:打印出错: {str(e)}") await websocket.send_text(
f"STATUS:ERROR:打印出错: {str(e)}"
)
else: else:
await websocket.send_text("STATUS:ERROR:没有可打印的图片") await websocket.send_text("STATUS:ERROR:没有可打印的图片")
@@ -838,20 +1075,37 @@ async def websocket_endpoint(websocket: WebSocket):
task_id = text.split(":", 1)[1].strip() task_id = text.split(":", 1)[1].strip()
if task_id in active_tasks: if task_id in active_tasks:
task = active_tasks[task_id] task = active_tasks[task_id]
await websocket.send_text(f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}") await websocket.send_text(
f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}"
)
else: else:
await websocket.send_text(f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成") await websocket.send_text(
f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成"
)
elif text.startswith("GET_FONTS_BATCH:") or text.startswith("GET_FONT") or text == "GET_FONT_MD5" or text == "GET_HIGH_FREQ": elif (
text.startswith("GET_FONTS_BATCH:")
or text.startswith("GET_FONT")
or text == "GET_FONT_MD5"
or text == "GET_HIGH_FREQ"
):
# 使用新的统一字体处理函数 # 使用新的统一字体处理函数
try: try:
if text.startswith("GET_FONTS_BATCH:"): if text.startswith("GET_FONTS_BATCH:"):
await handle_font_request(websocket, text, text.split(":", 1)[1]) await handle_font_request(
websocket, text, text.split(":", 1)[1]
)
elif text.startswith("GET_FONT_FRAGMENT:"): elif text.startswith("GET_FONT_FRAGMENT:"):
await handle_font_request(websocket, text, text.split(":", 1)[1]) await handle_font_request(
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"): websocket, text, text.split(":", 1)[1]
)
elif text.startswith("GET_FONT_UNICODE:") or text.startswith(
"GET_FONT:"
):
parts = text.split(":", 1) parts = text.split(":", 1)
await handle_font_request(websocket, parts[0], parts[1] if len(parts) > 1 else "") await handle_font_request(
websocket, parts[0], parts[1] if len(parts) > 1 else ""
)
else: else:
await handle_font_request(websocket, text, "") await handle_font_request(websocket, text, "")
except Exception as e: except Exception as e:
@@ -888,6 +1142,7 @@ async def websocket_endpoint(websocket: WebSocket):
except: except:
pass pass
if __name__ == "__main__": if __name__ == "__main__":
# Check API Key # Check API Key
if not dashscope.api_key: if not dashscope.api_key:
@@ -897,6 +1152,7 @@ if __name__ == "__main__":
# 获取本机IP方便ESP32连接 # 获取本机IP方便ESP32连接
import socket import socket
hostname = socket.gethostname() hostname = socket.gethostname()
local_ip = socket.gethostbyname(hostname) local_ip = socket.gethostbyname(hostname)
print(f"Server running on ws://{local_ip}:8000/ws/audio") print(f"Server running on ws://{local_ip}:8000/ws/audio")

View File

@@ -0,0 +1,364 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>AI Image Generator Admin</title>
<style>
* { box-sizing: border-box; margin: 0; padding: 0; }
body { font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; background: #1a1a2e; color: #eee; min-height: 100vh; padding: 20px; }
.container { max-width: 1100px; margin: 0 auto; }
h1 { text-align: center; margin-bottom: 30px; color: #00d4ff; }
.card { background: #16213e; border-radius: 12px; padding: 24px; margin-bottom: 20px; box-shadow: 0 4px 20px rgba(0,0,0,0.3); }
.card h2 { margin-bottom: 16px; color: #00d4ff; font-size: 18px; }
.current-status { background: #0f3460; padding: 16px; border-radius: 8px; margin-bottom: 16px; }
.current-status .label { color: #888; font-size: 14px; }
.current-status .value { color: #00d4ff; font-size: 20px; font-weight: bold; margin-top: 4px; }
.form-group { margin-bottom: 16px; }
.form-group label { display: block; margin-bottom: 8px; color: #ccc; }
select, input[type="number"] { width: 100%; padding: 12px; border-radius: 8px; border: 1px solid #333; background: #0f3460; color: #fff; font-size: 16px; }
select:focus, input:focus { outline: none; border-color: #00d4ff; }
input[type="checkbox"] { width: 20px; height: 20px; margin-right: 8px; }
.btn { display: inline-block; padding: 12px 24px; border-radius: 8px; border: none; cursor: pointer; font-size: 16px; font-weight: bold; transition: all 0.3s; }
.btn-primary { background: #00d4ff; color: #1a1a2e; }
.btn-primary:hover { background: #00b8e6; }
.btn-danger { background: #e74c3c; color: #fff; }
.btn-danger:hover { background: #c0392b; }
.btn-small { padding: 6px 12px; font-size: 14px; }
.model-list { margin-top: 16px; }
.model-item { background: #0f3460; padding: 12px 16px; border-radius: 8px; margin-bottom: 8px; display: flex; justify-content: space-between; align-items: center; }
.model-item.active { border: 2px solid #00d4ff; }
.model-item .name { font-weight: bold; }
.model-item .provider { color: #888; font-size: 14px; }
.test-section { margin-top: 20px; }
.test-input { width: 100%; padding: 12px; border-radius: 8px; border: 1px solid #333; background: #0f3460; color: #fff; font-size: 14px; resize: vertical; min-height: 80px; margin-bottom: 12px; }
.test-input:focus { outline: none; border-color: #00d4ff; }
.message { padding: 12px; border-radius: 8px; margin-top: 12px; display: none; }
.message.success { background: #27ae60; display: block; }
.message.error { background: #e74c3c; display: block; }
.loading { text-align: center; padding: 20px; display: none; }
.loading.show { display: block; }
.spinner { border: 3px solid #333; border-top: 3px solid #00d4ff; border-radius: 50%; width: 30px; height: 30px; animation: spin 1s linear infinite; margin: 0 auto; }
@keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(360deg); } }
.auto-delete-settings { display: flex; gap: 16px; align-items: center; flex-wrap: wrap; }
.auto-delete-settings label { display: flex; align-items: center; color: #ccc; }
.auto-delete-settings input[type="number"] { width: 80px; }
.gallery { display: grid; grid-template-columns: repeat(auto-fill, minmax(200px, 1fr)); gap: 16px; margin-top: 16px; }
.gallery-item { background: #0f3460; border-radius: 8px; overflow: hidden; position: relative; }
.gallery-item img { width: 100%; height: 180px; object-fit: cover; display: block; }
.gallery-item .info { padding: 12px; }
.gallery-item .filename { font-size: 12px; color: #888; word-break: break-all; }
.gallery-item .size { font-size: 12px; color: #666; margin-top: 4px; }
.gallery-item .delete-btn { position: absolute; top: 8px; right: 8px; background: rgba(231,76,60,0.9); color: white; border: none; border-radius: 50%; width: 28px; height: 28px; cursor: pointer; font-size: 16px; line-height: 28px; }
.gallery-item .delete-btn:hover { background: #c0392b; }
.empty-gallery { text-align: center; padding: 40px; color: #666; }
.flex-row { display: flex; gap: 12px; align-items: center; flex-wrap: wrap; }
.tab-nav { display: flex; gap: 4px; margin-bottom: 20px; background: #0f3460; border-radius: 8px; padding: 4px; }
.tab-nav button { flex: 1; padding: 12px; border: none; background: transparent; color: #888; cursor: pointer; border-radius: 6px; font-size: 16px; transition: all 0.3s; }
.tab-nav button.active { background: #00d4ff; color: #1a1a2e; font-weight: bold; }
.tab-content { display: none; }
.tab-content.active { display: block; }
</style>
</head>
<body>
<div class="container">
<h1>AI Image Generator Admin</h1>
<div class="tab-nav">
<button class="active" onclick="showTab('settings')">设置</button>
<button onclick="showTab('gallery')">图片库</button>
</div>
<div id="tab-settings" class="tab-content active">
<div class="card">
<h2>当前状态</h2>
<div class="current-status">
<div class="label">当前 Provider</div>
<div class="value" id="currentProvider">加载中...</div>
<div class="label" style="margin-top: 12px;">当前模型</div>
<div class="value" id="currentModel">加载中...</div>
</div>
</div>
<div class="card">
<h2>切换 Provider</h2>
<div class="form-group">
<select id="providerSelect">
<option value="doubao">豆包 (Doubao)</option>
<option value="dashscope">阿里云 (DashScope)</option>
</select>
</div>
<button class="btn btn-primary" onclick="switchProvider()">切换 Provider</button>
</div>
<div class="card">
<h2>豆包模型</h2>
<div class="model-list">
<div class="model-item" data-provider="doubao" data-model="doubao-seedream-4.0">
<div>
<div class="name">doubao-seedream-4.0</div>
<div class="provider">豆包</div>
</div>
<button class="btn btn-primary" onclick="setModel('doubao', 'doubao-seedream-4.0')">使用</button>
</div>
<div class="model-item" data-provider="doubao" data-model="doubao-seedream-5-0-260128">
<div>
<div class="name">doubao-seedream-5-0-260128</div>
<div class="provider">豆包</div>
</div>
<button class="btn btn-primary" onclick="setModel('doubao', 'doubao-seedream-5-0-260128')">使用</button>
</div>
</div>
</div>
<div class="card">
<h2>阿里云模型 (DashScope)</h2>
<div class="model-list">
<div class="model-item" data-provider="dashscope" data-model="wanx2.0-t2i-turbo">
<div>
<div class="name">wanx2.0-t2i-turbo</div>
<div class="provider">阿里云</div>
</div>
<button class="btn btn-primary" onclick="setModel('dashscope', 'wanx2.0-t2i-turbo')">使用</button>
</div>
<div class="model-item" data-provider="dashscope" data-model="qwen-image-plus">
<div>
<div class="name">qwen-image-plus</div>
<div class="provider">阿里云</div>
</div>
<button class="btn btn-primary" onclick="setModel('dashscope', 'qwen-image-plus')">使用</button>
</div>
<div class="model-item" data-provider="dashscope" data-model="qwen-image-v1">
<div>
<div class="name">qwen-image-v1</div>
<div class="provider">阿里云</div>
</div>
<button class="btn btn-primary" onclick="setModel('dashscope', 'qwen-image-v1')">使用</button>
</div>
</div>
</div>
<div class="card">
<h2>测试图片生成</h2>
<textarea class="test-input" id="testPrompt" placeholder="输入提示词...">A cute cat, black and white line art, cartoon style</textarea>
<button class="btn btn-primary" onclick="testGenerate()">生成图片</button>
<div class="loading" id="loading">
<div class="spinner"></div>
<p style="margin-top: 10px;">生成中...</p>
</div>
<div class="message" id="message"></div>
<div id="resultArea" style="margin-top: 16px; display: none;">
<img id="resultImage" style="max-width: 100%; max-height: 300px; border-radius: 8px;">
</div>
</div>
</div>
<div id="tab-gallery" class="tab-content">
<div class="card">
<h2>图片库</h2>
<div class="auto-delete-settings">
<label><input type="checkbox" id="autoDeleteEnabled" onchange="updateAutoDelete()"> 自动删除</label>
<label><input type="number" id="autoDeleteHours" min="1" max="168" value="24" onchange="updateAutoDelete()"> 小时后删除</label>
<button class="btn btn-primary btn-small" onclick="loadGallery()">刷新</button>
<button class="btn btn-danger btn-small" onclick="deleteAllImages()">删除全部</button>
</div>
<div class="gallery" id="gallery"></div>
</div>
</div>
</div>
<script>
function showTab(tab) {
document.querySelectorAll('.tab-content').forEach(el => el.classList.remove('active'));
document.querySelectorAll('.tab-nav button').forEach(el => el.classList.remove('active'));
document.getElementById('tab-' + tab).classList.add('active');
event.target.classList.add('active');
if (tab === 'gallery') loadGallery();
}
async function loadStatus() {
try {
const res = await fetch('/api/admin/status');
const data = await res.json();
document.getElementById('currentProvider').textContent = data.provider;
document.getElementById('currentModel').textContent = data.model;
document.getElementById('providerSelect').value = data.provider;
updateActiveModel(data.provider, data.model);
} catch (e) {
console.error('Failed to load status:', e);
}
}
async function loadAutoDelete() {
try {
const res = await fetch('/api/admin/auto-delete');
const data = await res.json();
document.getElementById('autoDeleteEnabled').checked = data.enabled;
document.getElementById('autoDeleteHours').value = data.hours;
} catch (e) {
console.error('Failed to load auto-delete settings:', e);
}
}
async function updateAutoDelete() {
const enabled = document.getElementById('autoDeleteEnabled').checked;
const hours = document.getElementById('autoDeleteHours').value;
try {
await fetch('/api/admin/auto-delete', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ enabled, hours: parseInt(hours) })
});
} catch (e) {
console.error('Failed to update auto-delete:', e);
}
}
function updateActiveModel(provider, model) {
document.querySelectorAll('.model-item').forEach(item => {
item.classList.remove('active');
if (item.dataset.provider === provider && item.dataset.model === model) {
item.classList.add('active');
item.querySelector('button').textContent = '使用中';
item.querySelector('button').disabled = true;
} else {
item.querySelector('button').textContent = '使用';
item.querySelector('button').disabled = false;
}
});
}
async function switchProvider() {
const provider = document.getElementById('providerSelect').value;
try {
const res = await fetch('/api/admin/switch', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ provider })
});
const data = await res.json();
showMessage(data.message, data.success);
if (data.success) loadStatus();
} catch (e) {
showMessage('切换失败: ' + e.message, false);
}
}
async function setModel(provider, model) {
try {
const res = await fetch('/api/admin/model', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ provider, model })
});
const data = await res.json();
showMessage(data.message, data.success);
if (data.success) loadStatus();
} catch (e) {
showMessage('设置失败: ' + e.message, false);
}
}
async function testGenerate() {
const prompt = document.getElementById('testPrompt').value;
if (!prompt) return;
document.getElementById('loading').classList.add('show');
document.getElementById('message').style.display = 'none';
document.getElementById('resultArea').style.display = 'none';
try {
const res = await fetch('/api/admin/test-generate', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ prompt })
});
const data = await res.json();
if (data.success && data.image_url) {
document.getElementById('resultImage').src = data.image_url;
document.getElementById('resultArea').style.display = 'block';
showMessage('生成成功', true);
} else {
showMessage(data.message || '生成失败', false);
}
} catch (e) {
showMessage('生成失败: ' + e.message, false);
} finally {
document.getElementById('loading').classList.remove('show');
}
}
async function loadGallery() {
try {
const res = await fetch('/api/admin/images');
const data = await res.json();
const gallery = document.getElementById('gallery');
if (!data.images || data.images.length === 0) {
gallery.innerHTML = '<div class="empty-gallery">暂无图片</div>';
return;
}
gallery.innerHTML = data.images.map(img => `
<div class="gallery-item">
<button class="delete-btn" onclick="deleteImage('${img.name}')">×</button>
<img src="${img.url}" alt="${img.name}" onclick="window.open('${img.url}', '_blank')">
<div class="info">
<div class="filename">${img.name}</div>
<div class="size">${formatSize(img.size)}</div>
</div>
</div>
`).join('');
} catch (e) {
console.error('Failed to load gallery:', e);
}
}
async function deleteImage(filename) {
if (!confirm('确定要删除这张图片吗?')) return;
try {
const res = await fetch(`/api/admin/images/${encodeURIComponent(filename)}`, { method: 'DELETE' });
const data = await res.json();
if (data.success) {
loadGallery();
} else {
alert(data.message);
}
} catch (e) {
alert('删除失败: ' + e.message);
}
}
async function deleteAllImages() {
if (!confirm('确定要删除所有图片吗?此操作不可恢复!')) return;
try {
const res = await fetch('/api/admin/images');
const data = await res.json();
for (const img of data.images) {
await fetch(`/api/admin/images/${encodeURIComponent(img.name)}`, { method: 'DELETE' });
}
loadGallery();
} catch (e) {
alert('删除失败: ' + e.message);
}
}
function formatSize(bytes) {
if (bytes < 1024) return bytes + ' B';
if (bytes < 1024 * 1024) return (bytes / 1024).toFixed(1) + ' KB';
return (bytes / (1024 * 1024)).toFixed(1) + ' MB';
}
function showMessage(msg, success) {
const el = document.getElementById('message');
el.textContent = msg;
el.className = 'message ' + (success ? 'success' : 'error');
}
loadStatus();
loadAutoDelete();
</script>
</body>
</html>