1
This commit is contained in:
@@ -43,6 +43,49 @@ def calculate_md5(filepath):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def get_font_data(unicode_val):
|
||||
"""从字体文件获取单个字符数据(带缓存)"""
|
||||
if unicode_val in font_cache:
|
||||
return font_cache[unicode_val]
|
||||
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
if font_data_buffer:
|
||||
if offset + 32 <= len(font_data_buffer):
|
||||
font_data = font_data_buffer[offset:offset+32]
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
else:
|
||||
# Fallback to file reading if buffer failed
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def init_font_cache():
|
||||
"""初始化字体缓存和MD5"""
|
||||
global font_cache, font_md5, font_data_buffer
|
||||
@@ -93,7 +136,7 @@ def get_output_path():
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
|
||||
|
||||
THUMB_SIZE = 245
|
||||
THUMB_SIZE = 240
|
||||
|
||||
# 字体请求队列(用于重试机制)
|
||||
font_request_queue = {}
|
||||
@@ -131,26 +174,49 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
|
||||
await websocket.send_text(f"TASK_ID:{task_id}")
|
||||
|
||||
def progress_callback(progress: int, message: str):
|
||||
"""进度回调函数"""
|
||||
# 获取当前事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 使用队列在线程和主事件循环之间传递消息
|
||||
message_queue = asyncio.Queue()
|
||||
|
||||
async def progress_callback_async(progress: int, message: str):
|
||||
"""异步进度回调"""
|
||||
task.progress = progress
|
||||
task.message = message
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}"),
|
||||
asyncio.get_event_loop()
|
||||
)
|
||||
await websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}")
|
||||
except Exception as e:
|
||||
print(f"Error sending progress: {e}")
|
||||
|
||||
def progress_callback(progress: int, message: str):
|
||||
"""进度回调函数(在线程中调用)"""
|
||||
task.progress = progress
|
||||
task.message = message
|
||||
# 通过队列在主循环中发送消息
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
progress_callback_async(progress, message),
|
||||
loop
|
||||
)
|
||||
|
||||
try:
|
||||
task.status = "optimizing"
|
||||
|
||||
await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# 同步调用优化函数
|
||||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
|
||||
|
||||
# 确保返回有效的提示词
|
||||
if not optimized_prompt:
|
||||
optimized_prompt = asr_text
|
||||
print(f"Warning: optimize_prompt returned None, using original text: {asr_text}")
|
||||
|
||||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||||
task.optimized_prompt = optimized_prompt
|
||||
|
||||
@@ -158,6 +224,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# 同步调用图片生成函数
|
||||
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
|
||||
|
||||
task.result = image_path
|
||||
@@ -178,8 +245,11 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
task.status = "failed"
|
||||
task.error = str(e)
|
||||
print(f"Image generation task error: {e}")
|
||||
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成出错")
|
||||
try:
|
||||
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成出错")
|
||||
except Exception as ws_e:
|
||||
print(f"Error sending error message: {ws_e}")
|
||||
finally:
|
||||
if task_id in active_tasks:
|
||||
del active_tasks[task_id]
|
||||
@@ -198,7 +268,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str):
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||||
|
||||
# Send binary data directly
|
||||
chunk_size = 4096 # Increased chunk size for binary
|
||||
chunk_size = 512 # Decreased chunk size for ESP32 memory stability
|
||||
for i in range(0, len(image_data), chunk_size):
|
||||
chunk = image_data[i:i+chunk_size]
|
||||
await websocket.send_bytes(chunk)
|
||||
@@ -208,49 +278,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str):
|
||||
print("Image sent to ESP32 (Binary)")
|
||||
|
||||
|
||||
def get_font_data(unicode_val):
|
||||
"""从字体文件获取单个字符数据(带缓存)"""
|
||||
if unicode_val in font_cache:
|
||||
return font_cache[unicode_val]
|
||||
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
if font_data_buffer:
|
||||
if offset + 32 <= len(font_data_buffer):
|
||||
font_data = font_data_buffer[offset:offset+32]
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
else:
|
||||
# Fallback to file reading if buffer failed
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||||
async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||||
"""批量发送字体数据(带重试机制)"""
|
||||
global font_request_queue
|
||||
|
||||
@@ -269,10 +297,7 @@ def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_text(response),
|
||||
asyncio.get_event_loop()
|
||||
)
|
||||
await websocket.send_text(response)
|
||||
success_codes.add(unicode_val)
|
||||
else:
|
||||
failed_codes.append(code_str)
|
||||
@@ -336,7 +361,7 @@ async def handle_font_request(websocket, message_type, data):
|
||||
code_list = codes_str.split(",")
|
||||
print(f"Batch Font Request for {len(code_list)} chars")
|
||||
|
||||
success_count, failed = send_font_batch_with_retry(websocket, code_list)
|
||||
success_count, failed = await send_font_batch_with_retry(websocket, code_list)
|
||||
print(f"Font batch: {success_count} success, {len(failed)} failed")
|
||||
|
||||
# 发送完成标记
|
||||
@@ -345,7 +370,7 @@ async def handle_font_request(websocket, message_type, data):
|
||||
# 如果有失败的,进行重试
|
||||
if failed:
|
||||
await asyncio.sleep(0.5)
|
||||
send_font_batch_with_retry(websocket, failed, retry_count=1)
|
||||
await send_font_batch_with_retry(websocket, failed, retry_count=1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling batch font request: {e}")
|
||||
@@ -402,24 +427,35 @@ class MyRecognitionCallback(RecognitionCallback):
|
||||
self.websocket = websocket
|
||||
self.loop = loop
|
||||
self.final_text = "" # 保存最终识别结果
|
||||
self.sentence_list = [] # 累积所有句子
|
||||
|
||||
def on_open(self) -> None:
|
||||
print("ASR Session started")
|
||||
self.sentence_list = []
|
||||
self.final_text = ""
|
||||
|
||||
def on_close(self) -> None:
|
||||
print("ASR Session closed")
|
||||
# 关闭时将所有句子合并为完整文本
|
||||
if self.sentence_list:
|
||||
self.final_text = "".join(self.sentence_list)
|
||||
print(f"Final combined ASR text: {self.final_text}")
|
||||
|
||||
def on_event(self, result: RecognitionResult) -> None:
|
||||
if result.get_sentence():
|
||||
text = result.get_sentence()['text']
|
||||
print(f"ASR Result: {text}")
|
||||
self.final_text = text # 保存识别结果
|
||||
# 累积每一句识别结果
|
||||
self.sentence_list.append(text)
|
||||
# 同时更新 final_text 以便 Stop 时获取
|
||||
self.final_text = "".join(self.sentence_list)
|
||||
# 将识别结果发送回客户端
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.websocket.send_text(f"ASR:{text}"),
|
||||
self.loop
|
||||
)
|
||||
if self.loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.websocket.send_text(f"ASR:{text}"),
|
||||
self.loop
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to send ASR result to client: {e}")
|
||||
|
||||
@@ -470,13 +506,31 @@ def optimize_prompt(asr_text, progress_callback=None):
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
optimized = response.output.choices[0].message.content.strip()
|
||||
print(f"Optimized prompt: {optimized}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
|
||||
return optimized
|
||||
if hasattr(response, 'output') and response.output and \
|
||||
hasattr(response.output, 'choices') and response.output.choices and \
|
||||
len(response.output.choices) > 0:
|
||||
|
||||
optimized = response.output.choices[0].message.content.strip()
|
||||
print(f"Optimized prompt: {optimized}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
|
||||
return optimized
|
||||
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
|
||||
# Handle case where API returns text directly instead of choices
|
||||
optimized = response.output.text.strip()
|
||||
print(f"Optimized prompt (direct text): {optimized}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
|
||||
return optimized
|
||||
else:
|
||||
print(f"Prompt optimization response format error: {response}")
|
||||
if progress_callback:
|
||||
progress_callback(0, "提示词优化响应格式错误")
|
||||
return asr_text
|
||||
else:
|
||||
print(f"Prompt optimization failed: {response.code} - {response.message}")
|
||||
if progress_callback:
|
||||
@@ -748,7 +802,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
# 先发送 ASR 文字到 ESP32 显示
|
||||
await websocket.send_text(f"ASR:{asr_text}")
|
||||
|
||||
await start_async_image_generation(websocket, asr_text)
|
||||
# 使用 create_task 异步执行,避免阻塞主循环处理字体请求
|
||||
asyncio.create_task(start_async_image_generation(websocket, asr_text))
|
||||
else:
|
||||
print("No ASR text, skipping image generation")
|
||||
|
||||
@@ -777,53 +832,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
except Exception as e:
|
||||
print(f"Font request error: {e}")
|
||||
await websocket.send_text("FONT_BATCH_END:0:0")
|
||||
|
||||
# 计算偏移量
|
||||
# GB2312 编码范围:0xA1A1 - 0xFEFE
|
||||
# 区码:高字节 - 0xA0
|
||||
# 位码:低字节 - 0xA0
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
# 读取字体文件
|
||||
# 注意:这里为了简单,每次都打开文件。如果并发高,应该缓存文件句柄或内容。
|
||||
# 假设字体文件在当前目录或上级目录
|
||||
# Prioritize finding the file in the script's directory
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
|
||||
# Fallback: check one level up
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
|
||||
# Fallback: check current working directory
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
print(f"Reading font from: {font_path} (Offset: {offset})")
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
|
||||
if len(font_data) == 32:
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
# Return the original requested code (unicode or hex) so client can map it back
|
||||
response = f"FONT_DATA:{target_code_str}:{hex_data}"
|
||||
# print(f"Sending Font Response: {response[:30]}...")
|
||||
await websocket.send_text(response)
|
||||
else:
|
||||
print(f"Error: Read {len(font_data)} bytes for font data (expected 32)")
|
||||
else:
|
||||
print(f"Font file not found: {font_path}")
|
||||
else:
|
||||
print(f"Invalid GB2312 code derived: {code:X} (Area: {area}, Index: {index})")
|
||||
except Exception as e:
|
||||
print(f"Error handling FONT request: {e}")
|
||||
|
||||
elif "bytes" in message:
|
||||
# 接收音频数据并追加到缓冲区
|
||||
|
||||
Reference in New Issue
Block a user