1
This commit is contained in:
@@ -31,6 +31,7 @@ HIGH_FREQ_UNICODE = [ord(c) for c in HIGH_FREQ_CHARS]
|
||||
# 字体缓存
|
||||
font_cache = {}
|
||||
font_md5 = {}
|
||||
font_data_buffer = None
|
||||
|
||||
def calculate_md5(filepath):
|
||||
"""计算文件的MD5哈希值"""
|
||||
@@ -44,7 +45,7 @@ def calculate_md5(filepath):
|
||||
|
||||
def init_font_cache():
|
||||
"""初始化字体缓存和MD5"""
|
||||
global font_cache, font_md5
|
||||
global font_cache, font_md5, font_data_buffer
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
|
||||
@@ -55,24 +56,18 @@ def init_font_cache():
|
||||
font_md5 = calculate_md5(font_path)
|
||||
print(f"Font MD5: {font_md5}")
|
||||
|
||||
# 预加载高频字到缓存
|
||||
# 加载整个字体文件到内存
|
||||
try:
|
||||
with open(font_path, "rb") as f:
|
||||
font_data_buffer = f.read()
|
||||
print(f"Loaded font file into memory: {len(font_data_buffer)} bytes")
|
||||
except Exception as e:
|
||||
print(f"Error loading font file: {e}")
|
||||
font_data_buffer = None
|
||||
|
||||
# 预加载高频字到缓存 (仍然保留以便快速访问)
|
||||
for unicode_val in HIGH_FREQ_UNICODE:
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
except:
|
||||
pass
|
||||
get_font_data(unicode_val)
|
||||
print(f"Preloaded {len(font_cache)} high-frequency characters")
|
||||
|
||||
# 启动时初始化字体缓存
|
||||
@@ -104,6 +99,114 @@ THUMB_SIZE = 245
|
||||
font_request_queue = {}
|
||||
FONT_RETRY_MAX = 3
|
||||
|
||||
# 图片生成任务管理
|
||||
class ImageGenerationTask:
|
||||
"""图片生成任务管理类"""
|
||||
def __init__(self, task_id: str, asr_text: str, websocket: WebSocket):
|
||||
self.task_id = task_id
|
||||
self.asr_text = asr_text
|
||||
self.websocket = websocket
|
||||
self.status = "pending" # pending, optimizing, generating, completed, failed
|
||||
self.progress = 0
|
||||
self.message = ""
|
||||
self.result = None
|
||||
self.error = None
|
||||
|
||||
# 存储活跃的图片生成任务
|
||||
active_tasks = {}
|
||||
task_counter = 0
|
||||
|
||||
|
||||
async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
"""异步启动图片生成任务,不阻塞WebSocket连接"""
|
||||
global task_counter, active_tasks
|
||||
|
||||
task_id = f"task_{task_counter}_{int(time.time() * 1000)}"
|
||||
task_counter += 1
|
||||
|
||||
task = ImageGenerationTask(task_id, asr_text, websocket)
|
||||
active_tasks[task_id] = task
|
||||
|
||||
print(f"Starting async image generation task: {task_id}")
|
||||
|
||||
await websocket.send_text(f"TASK_ID:{task_id}")
|
||||
|
||||
def progress_callback(progress: int, message: str):
|
||||
"""进度回调函数"""
|
||||
task.progress = progress
|
||||
task.message = message
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}"),
|
||||
asyncio.get_event_loop()
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error sending progress: {e}")
|
||||
|
||||
try:
|
||||
task.status = "optimizing"
|
||||
|
||||
await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
|
||||
|
||||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||||
task.optimized_prompt = optimized_prompt
|
||||
|
||||
task.status = "generating"
|
||||
await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
|
||||
|
||||
task.result = image_path
|
||||
|
||||
if image_path and os.path.exists(image_path):
|
||||
task.status = "completed"
|
||||
await websocket.send_text("STATUS:COMPLETE:图片生成完成")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
await send_image_to_client(websocket, image_path)
|
||||
else:
|
||||
task.status = "failed"
|
||||
task.error = "图片生成失败"
|
||||
await websocket.send_text("IMAGE_ERROR:图片生成失败")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成失败")
|
||||
|
||||
except Exception as e:
|
||||
task.status = "failed"
|
||||
task.error = str(e)
|
||||
print(f"Image generation task error: {e}")
|
||||
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成出错")
|
||||
finally:
|
||||
if task_id in active_tasks:
|
||||
del active_tasks[task_id]
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def send_image_to_client(websocket: WebSocket, image_path: str):
|
||||
"""发送图片数据到客户端"""
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
|
||||
|
||||
# Send start marker
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||||
|
||||
# Send binary data directly
|
||||
chunk_size = 4096 # Increased chunk size for binary
|
||||
for i in range(0, len(image_data), chunk_size):
|
||||
chunk = image_data[i:i+chunk_size]
|
||||
await websocket.send_bytes(chunk)
|
||||
|
||||
# Send end marker
|
||||
await websocket.send_text("IMAGE_END")
|
||||
print("Image sent to ESP32 (Binary)")
|
||||
|
||||
|
||||
def get_font_data(unicode_val):
|
||||
"""从字体文件获取单个字符数据(带缓存)"""
|
||||
@@ -121,20 +224,27 @@ def get_font_data(unicode_val):
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
if font_data_buffer:
|
||||
if offset + 32 <= len(font_data_buffer):
|
||||
font_data = font_data_buffer[offset:offset+32]
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
else:
|
||||
# Fallback to file reading if buffer failed
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
@@ -333,10 +443,13 @@ def process_chunk_32_to_16(chunk_bytes, gain=1.0):
|
||||
return processed_chunk
|
||||
|
||||
|
||||
def optimize_prompt(asr_text):
|
||||
def optimize_prompt(asr_text, progress_callback=None):
|
||||
"""使用大模型优化提示词"""
|
||||
print(f"Optimizing prompt for: {asr_text}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0, "正在准备优化提示词...")
|
||||
|
||||
system_prompt = """你是一个AI图像提示词优化专家。将用户简短的语音识别结果转化为详细的、适合AI图像生成的英文提示词。
|
||||
要求:
|
||||
1. 保留核心内容和主要元素
|
||||
@@ -346,6 +459,9 @@ def optimize_prompt(asr_text):
|
||||
5. 不要添加多余解释,直接输出优化后的提示词"""
|
||||
|
||||
try:
|
||||
if progress_callback:
|
||||
progress_callback(10, "正在调用AI优化提示词...")
|
||||
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
|
||||
@@ -356,31 +472,76 @@ def optimize_prompt(asr_text):
|
||||
if response.status_code == 200:
|
||||
optimized = response.output.choices[0].message.content.strip()
|
||||
print(f"Optimized prompt: {optimized}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
|
||||
return optimized
|
||||
else:
|
||||
print(f"Prompt optimization failed: {response.code} - {response.message}")
|
||||
if progress_callback:
|
||||
progress_callback(0, f"提示词优化失败: {response.message}")
|
||||
return asr_text
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error optimizing prompt: {e}")
|
||||
if progress_callback:
|
||||
progress_callback(0, f"提示词优化出错: {str(e)}")
|
||||
return asr_text
|
||||
|
||||
|
||||
def generate_image(prompt, websocket=None):
|
||||
"""调用万相文生图API生成图片"""
|
||||
def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2):
|
||||
"""调用万相文生图API生成图片
|
||||
|
||||
Args:
|
||||
prompt: 图像生成提示词
|
||||
progress_callback: 进度回调函数 (progress, message)
|
||||
retry_count: 当前重试次数
|
||||
max_retries: 最大重试次数
|
||||
"""
|
||||
print(f"Generating image for prompt: {prompt}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(35, "正在请求AI生成图片...")
|
||||
|
||||
try:
|
||||
response = ImageSynthesis.call(
|
||||
model='wan2.6-t2i',
|
||||
prompt=prompt,
|
||||
size='512x512',
|
||||
n=1
|
||||
model='wanx2.0-t2i-turbo',
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
image_url = response.output['results'][0]['url']
|
||||
print(f"Image generated, downloading from: {image_url}")
|
||||
task_status = response.output.get('task_status')
|
||||
|
||||
if task_status == 'PENDING' or task_status == 'RUNNING':
|
||||
print("Waiting for image generation to complete...")
|
||||
if progress_callback:
|
||||
progress_callback(45, "AI正在生成图片中...")
|
||||
|
||||
import time
|
||||
task_id = response.output.get('task_id')
|
||||
max_wait = 120
|
||||
waited = 0
|
||||
while waited < max_wait:
|
||||
time.sleep(2)
|
||||
waited += 2
|
||||
task_result = ImageSynthesis.fetch(task_id)
|
||||
if task_result.output.task_status == 'SUCCEEDED':
|
||||
response.output = task_result.output
|
||||
break
|
||||
elif task_result.output.task_status == 'FAILED':
|
||||
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
|
||||
print(f"Image generation failed: {error_msg}")
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成失败: {error_msg}")
|
||||
return None
|
||||
|
||||
if response.output.get('task_status') == 'SUCCEEDED':
|
||||
image_url = response.output['results'][0]['url']
|
||||
print(f"Image generated, downloading from: {image_url}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(70, "正在下载生成的图片...")
|
||||
|
||||
import urllib.request
|
||||
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
|
||||
@@ -392,6 +553,9 @@ def generate_image(prompt, websocket=None):
|
||||
shutil.copy(GENERATED_IMAGE_FILE, output_path)
|
||||
print(f"Image also saved to {output_path}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(80, "正在处理图片...")
|
||||
|
||||
# 缩放图片并转换为RGB565格式
|
||||
try:
|
||||
from PIL import Image
|
||||
@@ -422,21 +586,50 @@ def generate_image(prompt, websocket=None):
|
||||
f.write(rgb565_data)
|
||||
|
||||
print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(100, "图片生成完成!")
|
||||
|
||||
return GENERATED_THUMB_FILE
|
||||
|
||||
except ImportError:
|
||||
print("PIL not available, sending original image")
|
||||
if progress_callback:
|
||||
progress_callback(100, "图片生成完成!(原始格式)")
|
||||
return GENERATED_IMAGE_FILE
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {e}")
|
||||
if progress_callback:
|
||||
progress_callback(80, f"图片处理出错: {str(e)}")
|
||||
return GENERATED_IMAGE_FILE
|
||||
else:
|
||||
print(f"Image generation failed: {response.code} - {response.message}")
|
||||
return None
|
||||
error_msg = f"{response.code} - {response.message}"
|
||||
print(f"Image generation failed: {error_msg}")
|
||||
|
||||
# 重试机制
|
||||
if retry_count < max_retries:
|
||||
print(f"Retrying... ({retry_count + 1}/{max_retries})")
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成失败,正在重试 ({retry_count + 1}/{max_retries})...")
|
||||
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
|
||||
else:
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成失败: {error_msg}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating image: {e}")
|
||||
return None
|
||||
|
||||
# 重试机制
|
||||
if retry_count < max_retries:
|
||||
print(f"Retrying after error... ({retry_count + 1}/{max_retries})")
|
||||
if progress_callback:
|
||||
progress_callback(35, f"生成出错,正在重试 ({retry_count + 1}/{max_retries})...")
|
||||
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
|
||||
else:
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成出错: {str(e)}")
|
||||
return None
|
||||
|
||||
@app.websocket("/ws/audio")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
@@ -554,132 +747,36 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
# 先发送 ASR 文字到 ESP32 显示
|
||||
await websocket.send_text(f"ASR:{asr_text}")
|
||||
await websocket.send_text("GENERATING_IMAGE:正在优化提示词...")
|
||||
|
||||
# 等待一会让 ESP32 显示文字
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 优化提示词
|
||||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text)
|
||||
|
||||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||||
await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...")
|
||||
|
||||
# 调用文生图API
|
||||
image_path = await asyncio.to_thread(generate_image, optimized_prompt)
|
||||
|
||||
if image_path and os.path.exists(image_path):
|
||||
# 读取图片并发送回ESP32
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
|
||||
|
||||
# 使用hex编码发送(每个字节2个字符)
|
||||
image_hex = image_data.hex()
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||||
|
||||
# 分片发送图片数据
|
||||
chunk_size = 1024
|
||||
for i in range(0, len(image_hex), chunk_size):
|
||||
chunk = image_hex[i:i+chunk_size]
|
||||
await websocket.send_text(f"IMAGE_DATA:{chunk}")
|
||||
|
||||
await websocket.send_text("IMAGE_END")
|
||||
print("Image sent to ESP32")
|
||||
else:
|
||||
await websocket.send_text("IMAGE_ERROR:图片生成失败")
|
||||
await start_async_image_generation(websocket, asr_text)
|
||||
else:
|
||||
print("No ASR text, skipping image generation")
|
||||
|
||||
print("Server processing finished.")
|
||||
|
||||
elif text.startswith("GET_FONTS_BATCH:"):
|
||||
# Format: GET_FONTS_BATCH:code1,code2,code3 (decimal unicode)
|
||||
elif text.startswith("GET_TASK_STATUS:"):
|
||||
task_id = text.split(":", 1)[1].strip()
|
||||
if task_id in active_tasks:
|
||||
task = active_tasks[task_id]
|
||||
await websocket.send_text(f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}")
|
||||
else:
|
||||
await websocket.send_text(f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成")
|
||||
|
||||
elif text.startswith("GET_FONTS_BATCH:") or text.startswith("GET_FONT") or text == "GET_FONT_MD5" or text == "GET_HIGH_FREQ":
|
||||
# 使用新的统一字体处理函数
|
||||
try:
|
||||
codes_str = text.split(":")[1]
|
||||
code_list = codes_str.split(",")
|
||||
print(f"Batch Font Request for {len(code_list)} chars: {code_list}")
|
||||
|
||||
for code_str in code_list:
|
||||
if not code_str: continue
|
||||
|
||||
try:
|
||||
unicode_val = int(code_str)
|
||||
char = chr(unicode_val)
|
||||
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
else:
|
||||
print(f"Character {char} is not a valid 2-byte GB2312 char")
|
||||
# Send empty/dummy? Or just skip.
|
||||
# Better to send something so client doesn't wait forever if it counts responses.
|
||||
# But client probably uses a set of missing chars.
|
||||
continue
|
||||
|
||||
# Calc offset
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
# Read font file
|
||||
# Optimization: Open file once outside loop?
|
||||
# For simplicity, keep it here, OS caching helps.
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||||
await websocket.send_text(response)
|
||||
# Small yield to let network flush?
|
||||
# await asyncio.sleep(0.001)
|
||||
except Exception as e:
|
||||
print(f"Error processing batch item {code_str}: {e}")
|
||||
|
||||
# Send a completion marker
|
||||
await websocket.send_text("FONT_BATCH_END")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling BATCH FONT request: {e}")
|
||||
await websocket.send_text("FONT_BATCH_END") # Ensure we unblock client
|
||||
|
||||
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
|
||||
# 格式: GET_FONT_UNICODE:12345 (decimal) or GET_FONT:0xA1A1 (hex)
|
||||
try:
|
||||
is_unicode = text.startswith("GET_FONT_UNICODE:")
|
||||
code_str = text.split(":")[1]
|
||||
|
||||
target_code_str = code_str # Used for response
|
||||
|
||||
if is_unicode:
|
||||
unicode_val = int(code_str)
|
||||
char = chr(unicode_val)
|
||||
try:
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
else:
|
||||
print(f"Character {char} is not a valid 2-byte GB2312 char")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Failed to encode {char} to gb2312: {e}")
|
||||
continue
|
||||
if text.startswith("GET_FONTS_BATCH:"):
|
||||
await handle_font_request(websocket, text, text.split(":", 1)[1])
|
||||
elif text.startswith("GET_FONT_FRAGMENT:"):
|
||||
await handle_font_request(websocket, text, text.split(":", 1)[1])
|
||||
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
|
||||
parts = text.split(":", 1)
|
||||
await handle_font_request(websocket, parts[0], parts[1] if len(parts) > 1 else "")
|
||||
else:
|
||||
code = int(code_str, 16)
|
||||
await handle_font_request(websocket, text, "")
|
||||
except Exception as e:
|
||||
print(f"Font request error: {e}")
|
||||
await websocket.send_text("FONT_BATCH_END:0:0")
|
||||
|
||||
# 计算偏移量
|
||||
# GB2312 编码范围:0xA1A1 - 0xFEFE
|
||||
|
||||
Reference in New Issue
Block a user