This commit is contained in:
jeremygan2021
2026-03-03 21:59:57 +08:00
parent c87d5deedf
commit fc92a5feaf
7 changed files with 963 additions and 389 deletions

View File

@@ -31,6 +31,7 @@ HIGH_FREQ_UNICODE = [ord(c) for c in HIGH_FREQ_CHARS]
# 字体缓存
font_cache = {}
font_md5 = {}
font_data_buffer = None
def calculate_md5(filepath):
"""计算文件的MD5哈希值"""
@@ -44,7 +45,7 @@ def calculate_md5(filepath):
def init_font_cache():
"""初始化字体缓存和MD5"""
global font_cache, font_md5
global font_cache, font_md5, font_data_buffer
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
@@ -55,24 +56,18 @@ def init_font_cache():
font_md5 = calculate_md5(font_path)
print(f"Font MD5: {font_md5}")
# 加载高频字到缓
# 加载整个字体文件到内
try:
with open(font_path, "rb") as f:
font_data_buffer = f.read()
print(f"Loaded font file into memory: {len(font_data_buffer)} bytes")
except Exception as e:
print(f"Error loading font file: {e}")
font_data_buffer = None
# 预加载高频字到缓存 (仍然保留以便快速访问)
for unicode_val in HIGH_FREQ_UNICODE:
try:
char = chr(unicode_val)
gb_bytes = char.encode('gb2312')
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
font_cache[unicode_val] = font_data
except:
pass
get_font_data(unicode_val)
print(f"Preloaded {len(font_cache)} high-frequency characters")
# 启动时初始化字体缓存
@@ -104,6 +99,114 @@ THUMB_SIZE = 245
font_request_queue = {}
FONT_RETRY_MAX = 3
# 图片生成任务管理
class ImageGenerationTask:
"""图片生成任务管理类"""
def __init__(self, task_id: str, asr_text: str, websocket: WebSocket):
self.task_id = task_id
self.asr_text = asr_text
self.websocket = websocket
self.status = "pending" # pending, optimizing, generating, completed, failed
self.progress = 0
self.message = ""
self.result = None
self.error = None
# 存储活跃的图片生成任务
active_tasks = {}
task_counter = 0
async def start_async_image_generation(websocket: WebSocket, asr_text: str):
"""异步启动图片生成任务不阻塞WebSocket连接"""
global task_counter, active_tasks
task_id = f"task_{task_counter}_{int(time.time() * 1000)}"
task_counter += 1
task = ImageGenerationTask(task_id, asr_text, websocket)
active_tasks[task_id] = task
print(f"Starting async image generation task: {task_id}")
await websocket.send_text(f"TASK_ID:{task_id}")
def progress_callback(progress: int, message: str):
"""进度回调函数"""
task.progress = progress
task.message = message
try:
asyncio.run_coroutine_threadsafe(
websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}"),
asyncio.get_event_loop()
)
except Exception as e:
print(f"Error sending progress: {e}")
try:
task.status = "optimizing"
await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...")
await asyncio.sleep(0.2)
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
await websocket.send_text(f"PROMPT:{optimized_prompt}")
task.optimized_prompt = optimized_prompt
task.status = "generating"
await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...")
await asyncio.sleep(0.2)
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
task.result = image_path
if image_path and os.path.exists(image_path):
task.status = "completed"
await websocket.send_text("STATUS:COMPLETE:图片生成完成")
await asyncio.sleep(0.2)
await send_image_to_client(websocket, image_path)
else:
task.status = "failed"
task.error = "图片生成失败"
await websocket.send_text("IMAGE_ERROR:图片生成失败")
await websocket.send_text("STATUS:ERROR:图片生成失败")
except Exception as e:
task.status = "failed"
task.error = str(e)
print(f"Image generation task error: {e}")
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
await websocket.send_text("STATUS:ERROR:图片生成出错")
finally:
if task_id in active_tasks:
del active_tasks[task_id]
return task
async def send_image_to_client(websocket: WebSocket, image_path: str):
"""发送图片数据到客户端"""
with open(image_path, 'rb') as f:
image_data = f.read()
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
# Send start marker
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
# Send binary data directly
chunk_size = 4096 # Increased chunk size for binary
for i in range(0, len(image_data), chunk_size):
chunk = image_data[i:i+chunk_size]
await websocket.send_bytes(chunk)
# Send end marker
await websocket.send_text("IMAGE_END")
print("Image sent to ESP32 (Binary)")
def get_font_data(unicode_val):
"""从字体文件获取单个字符数据(带缓存)"""
@@ -121,20 +224,27 @@ def get_font_data(unicode_val):
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if not os.path.exists(font_path):
font_path = FONT_FILE
if os.path.exists(font_path):
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
font_cache[unicode_val] = font_data
return font_data
if font_data_buffer:
if offset + 32 <= len(font_data_buffer):
font_data = font_data_buffer[offset:offset+32]
font_cache[unicode_val] = font_data
return font_data
else:
# Fallback to file reading if buffer failed
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if not os.path.exists(font_path):
font_path = FONT_FILE
if os.path.exists(font_path):
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
font_cache[unicode_val] = font_data
return font_data
except:
pass
return None
@@ -333,10 +443,13 @@ def process_chunk_32_to_16(chunk_bytes, gain=1.0):
return processed_chunk
def optimize_prompt(asr_text):
def optimize_prompt(asr_text, progress_callback=None):
"""使用大模型优化提示词"""
print(f"Optimizing prompt for: {asr_text}")
if progress_callback:
progress_callback(0, "正在准备优化提示词...")
system_prompt = """你是一个AI图像提示词优化专家。将用户简短的语音识别结果转化为详细的、适合AI图像生成的英文提示词。
要求:
1. 保留核心内容和主要元素
@@ -346,6 +459,9 @@ def optimize_prompt(asr_text):
5. 不要添加多余解释,直接输出优化后的提示词"""
try:
if progress_callback:
progress_callback(10, "正在调用AI优化提示词...")
response = Generation.call(
model='qwen-turbo',
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
@@ -356,31 +472,76 @@ def optimize_prompt(asr_text):
if response.status_code == 200:
optimized = response.output.choices[0].message.content.strip()
print(f"Optimized prompt: {optimized}")
if progress_callback:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
else:
print(f"Prompt optimization failed: {response.code} - {response.message}")
if progress_callback:
progress_callback(0, f"提示词优化失败: {response.message}")
return asr_text
except Exception as e:
print(f"Error optimizing prompt: {e}")
if progress_callback:
progress_callback(0, f"提示词优化出错: {str(e)}")
return asr_text
def generate_image(prompt, websocket=None):
"""调用万相文生图API生成图片"""
def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2):
"""调用万相文生图API生成图片
Args:
prompt: 图像生成提示词
progress_callback: 进度回调函数 (progress, message)
retry_count: 当前重试次数
max_retries: 最大重试次数
"""
print(f"Generating image for prompt: {prompt}")
if progress_callback:
progress_callback(35, "正在请求AI生成图片...")
try:
response = ImageSynthesis.call(
model='wan2.6-t2i',
prompt=prompt,
size='512x512',
n=1
model='wanx2.0-t2i-turbo',
prompt=prompt
)
if response.status_code == 200:
image_url = response.output['results'][0]['url']
print(f"Image generated, downloading from: {image_url}")
task_status = response.output.get('task_status')
if task_status == 'PENDING' or task_status == 'RUNNING':
print("Waiting for image generation to complete...")
if progress_callback:
progress_callback(45, "AI正在生成图片中...")
import time
task_id = response.output.get('task_id')
max_wait = 120
waited = 0
while waited < max_wait:
time.sleep(2)
waited += 2
task_result = ImageSynthesis.fetch(task_id)
if task_result.output.task_status == 'SUCCEEDED':
response.output = task_result.output
break
elif task_result.output.task_status == 'FAILED':
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
print(f"Image generation failed: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
if response.output.get('task_status') == 'SUCCEEDED':
image_url = response.output['results'][0]['url']
print(f"Image generated, downloading from: {image_url}")
if progress_callback:
progress_callback(70, "正在下载生成的图片...")
import urllib.request
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
@@ -392,6 +553,9 @@ def generate_image(prompt, websocket=None):
shutil.copy(GENERATED_IMAGE_FILE, output_path)
print(f"Image also saved to {output_path}")
if progress_callback:
progress_callback(80, "正在处理图片...")
# 缩放图片并转换为RGB565格式
try:
from PIL import Image
@@ -422,21 +586,50 @@ def generate_image(prompt, websocket=None):
f.write(rgb565_data)
print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes")
if progress_callback:
progress_callback(100, "图片生成完成!")
return GENERATED_THUMB_FILE
except ImportError:
print("PIL not available, sending original image")
if progress_callback:
progress_callback(100, "图片生成完成!(原始格式)")
return GENERATED_IMAGE_FILE
except Exception as e:
print(f"Error processing image: {e}")
if progress_callback:
progress_callback(80, f"图片处理出错: {str(e)}")
return GENERATED_IMAGE_FILE
else:
print(f"Image generation failed: {response.code} - {response.message}")
return None
error_msg = f"{response.code} - {response.message}"
print(f"Image generation failed: {error_msg}")
# 重试机制
if retry_count < max_retries:
print(f"Retrying... ({retry_count + 1}/{max_retries})")
if progress_callback:
progress_callback(35, f"图片生成失败,正在重试 ({retry_count + 1}/{max_retries})...")
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
else:
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
except Exception as e:
print(f"Error generating image: {e}")
return None
# 重试机制
if retry_count < max_retries:
print(f"Retrying after error... ({retry_count + 1}/{max_retries})")
if progress_callback:
progress_callback(35, f"生成出错,正在重试 ({retry_count + 1}/{max_retries})...")
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
else:
if progress_callback:
progress_callback(35, f"图片生成出错: {str(e)}")
return None
@app.websocket("/ws/audio")
async def websocket_endpoint(websocket: WebSocket):
@@ -554,132 +747,36 @@ async def websocket_endpoint(websocket: WebSocket):
# 先发送 ASR 文字到 ESP32 显示
await websocket.send_text(f"ASR:{asr_text}")
await websocket.send_text("GENERATING_IMAGE:正在优化提示词...")
# 等待一会让 ESP32 显示文字
await asyncio.sleep(0.5)
# 优化提示词
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text)
await websocket.send_text(f"PROMPT:{optimized_prompt}")
await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...")
# 调用文生图API
image_path = await asyncio.to_thread(generate_image, optimized_prompt)
if image_path and os.path.exists(image_path):
# 读取图片并发送回ESP32
with open(image_path, 'rb') as f:
image_data = f.read()
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
# 使用hex编码发送每个字节2个字符
image_hex = image_data.hex()
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
# 分片发送图片数据
chunk_size = 1024
for i in range(0, len(image_hex), chunk_size):
chunk = image_hex[i:i+chunk_size]
await websocket.send_text(f"IMAGE_DATA:{chunk}")
await websocket.send_text("IMAGE_END")
print("Image sent to ESP32")
else:
await websocket.send_text("IMAGE_ERROR:图片生成失败")
await start_async_image_generation(websocket, asr_text)
else:
print("No ASR text, skipping image generation")
print("Server processing finished.")
elif text.startswith("GET_FONTS_BATCH:"):
# Format: GET_FONTS_BATCH:code1,code2,code3 (decimal unicode)
elif text.startswith("GET_TASK_STATUS:"):
task_id = text.split(":", 1)[1].strip()
if task_id in active_tasks:
task = active_tasks[task_id]
await websocket.send_text(f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}")
else:
await websocket.send_text(f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成")
elif text.startswith("GET_FONTS_BATCH:") or text.startswith("GET_FONT") or text == "GET_FONT_MD5" or text == "GET_HIGH_FREQ":
# 使用新的统一字体处理函数
try:
codes_str = text.split(":")[1]
code_list = codes_str.split(",")
print(f"Batch Font Request for {len(code_list)} chars: {code_list}")
for code_str in code_list:
if not code_str: continue
try:
unicode_val = int(code_str)
char = chr(unicode_val)
gb_bytes = char.encode('gb2312')
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
else:
print(f"Character {char} is not a valid 2-byte GB2312 char")
# Send empty/dummy? Or just skip.
# Better to send something so client doesn't wait forever if it counts responses.
# But client probably uses a set of missing chars.
continue
# Calc offset
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
# Read font file
# Optimization: Open file once outside loop?
# For simplicity, keep it here, OS caching helps.
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if not os.path.exists(font_path):
font_path = FONT_FILE
if os.path.exists(font_path):
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
response = f"FONT_DATA:{code_str}:{hex_data}"
await websocket.send_text(response)
# Small yield to let network flush?
# await asyncio.sleep(0.001)
except Exception as e:
print(f"Error processing batch item {code_str}: {e}")
# Send a completion marker
await websocket.send_text("FONT_BATCH_END")
except Exception as e:
print(f"Error handling BATCH FONT request: {e}")
await websocket.send_text("FONT_BATCH_END") # Ensure we unblock client
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
# 格式: GET_FONT_UNICODE:12345 (decimal) or GET_FONT:0xA1A1 (hex)
try:
is_unicode = text.startswith("GET_FONT_UNICODE:")
code_str = text.split(":")[1]
target_code_str = code_str # Used for response
if is_unicode:
unicode_val = int(code_str)
char = chr(unicode_val)
try:
gb_bytes = char.encode('gb2312')
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
else:
print(f"Character {char} is not a valid 2-byte GB2312 char")
continue
except Exception as e:
print(f"Failed to encode {char} to gb2312: {e}")
continue
if text.startswith("GET_FONTS_BATCH:"):
await handle_font_request(websocket, text, text.split(":", 1)[1])
elif text.startswith("GET_FONT_FRAGMENT:"):
await handle_font_request(websocket, text, text.split(":", 1)[1])
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
parts = text.split(":", 1)
await handle_font_request(websocket, parts[0], parts[1] if len(parts) > 1 else "")
else:
code = int(code_str, 16)
await handle_font_request(websocket, text, "")
except Exception as e:
print(f"Font request error: {e}")
await websocket.send_text("FONT_BATCH_END:0:0")
# 计算偏移量
# GB2312 编码范围0xA1A1 - 0xFEFE