This commit is contained in:
jeremygan2021
2026-03-03 22:11:26 +08:00
parent fc92a5feaf
commit 700bc55657
7 changed files with 138 additions and 124 deletions

View File

@@ -6,6 +6,9 @@ import network
import st7789py as st7789
from config import CURRENT_CONFIG
from audio import AudioPlayer, Microphone
# Define colors that might be missing in st7789py
DARKGREY = 0x4208
from display import Display
from websocket_client import WebSocketClient
import uselect
@@ -89,7 +92,7 @@ def draw_mic_icon(display, x, y, active=True):
if not display or not display.tft:
return
color = st7789.GREEN if active else st7789.DARKGREY
color = st7789.GREEN if active else DARKGREY
display.tft.fill_rect(x + 5, y, 10, 5, color)
display.tft.fill_rect(x + 3, y + 5, 14, 10, color)
@@ -137,7 +140,7 @@ def draw_progress_bar(display, x, y, width, height, progress, color=st7789.CYAN)
if not display or not display.tft:
return
display.tft.fill_rect(x, y, width, height, st7789.DARKGREY)
display.tft.fill_rect(x, y, width, height, DARKGREY)
if progress > 0:
bar_width = int(width * min(progress, 1.0))
display.tft.fill_rect(x, y, bar_width, height, color)
@@ -176,7 +179,7 @@ def render_confirm_screen(display, asr_text=""):
display.tft.fill_rect(0, 0, 240, 30, st7789.CYAN)
display.text("说完了吗?", 75, 8, st7789.BLACK)
display.tft.fill_rect(10, 50, 220, 80, st7789.DARKGREY)
display.tft.fill_rect(10, 50, 220, 80, DARKGREY)
display.text(asr_text if asr_text else "未识别到文字", 20, 75, st7789.WHITE)
display.tft.fill_rect(20, 150, 80, 30, st7789.GREEN)

View File

@@ -147,10 +147,13 @@ class WebSocketClient:
# Read payload
data = bytearray(length)
view = memoryview(data)
# Use smaller chunks for readinto to avoid memory allocation issues in MicroPython
pos = 0
while pos < length:
read_len = self.sock.readinto(view[pos:])
chunk_size = min(length - pos, 512)
chunk_view = memoryview(data)[pos:pos + chunk_size]
read_len = self.sock.readinto(chunk_view)
if read_len == 0:
return None
pos += read_len

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -43,6 +43,49 @@ def calculate_md5(filepath):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def get_font_data(unicode_val):
"""从字体文件获取单个字符数据(带缓存)"""
if unicode_val in font_cache:
return font_cache[unicode_val]
try:
char = chr(unicode_val)
gb_bytes = char.encode('gb2312')
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
if font_data_buffer:
if offset + 32 <= len(font_data_buffer):
font_data = font_data_buffer[offset:offset+32]
font_cache[unicode_val] = font_data
return font_data
else:
# Fallback to file reading if buffer failed
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if not os.path.exists(font_path):
font_path = FONT_FILE
if os.path.exists(font_path):
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
font_cache[unicode_val] = font_data
return font_data
except:
pass
return None
def init_font_cache():
"""初始化字体缓存和MD5"""
global font_cache, font_md5, font_data_buffer
@@ -93,7 +136,7 @@ def get_output_path():
timestamp = time.strftime("%Y%m%d_%H%M%S")
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
THUMB_SIZE = 245
THUMB_SIZE = 240
# 字体请求队列(用于重试机制)
font_request_queue = {}
@@ -131,26 +174,49 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await websocket.send_text(f"TASK_ID:{task_id}")
def progress_callback(progress: int, message: str):
"""进度回调函数"""
# 获取当前事件循环
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 使用队列在线程和主事件循环之间传递消息
message_queue = asyncio.Queue()
async def progress_callback_async(progress: int, message: str):
"""异步进度回调"""
task.progress = progress
task.message = message
try:
asyncio.run_coroutine_threadsafe(
websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}"),
asyncio.get_event_loop()
)
await websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}")
except Exception as e:
print(f"Error sending progress: {e}")
def progress_callback(progress: int, message: str):
"""进度回调函数(在线程中调用)"""
task.progress = progress
task.message = message
# 通过队列在主循环中发送消息
asyncio.run_coroutine_threadsafe(
progress_callback_async(progress, message),
loop
)
try:
task.status = "optimizing"
await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...")
await asyncio.sleep(0.2)
# 同步调用优化函数
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
# 确保返回有效的提示词
if not optimized_prompt:
optimized_prompt = asr_text
print(f"Warning: optimize_prompt returned None, using original text: {asr_text}")
await websocket.send_text(f"PROMPT:{optimized_prompt}")
task.optimized_prompt = optimized_prompt
@@ -158,6 +224,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...")
await asyncio.sleep(0.2)
# 同步调用图片生成函数
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
task.result = image_path
@@ -178,8 +245,11 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
task.status = "failed"
task.error = str(e)
print(f"Image generation task error: {e}")
try:
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
await websocket.send_text("STATUS:ERROR:图片生成出错")
except Exception as ws_e:
print(f"Error sending error message: {ws_e}")
finally:
if task_id in active_tasks:
del active_tasks[task_id]
@@ -198,7 +268,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str):
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
# Send binary data directly
chunk_size = 4096 # Increased chunk size for binary
chunk_size = 512 # Decreased chunk size for ESP32 memory stability
for i in range(0, len(image_data), chunk_size):
chunk = image_data[i:i+chunk_size]
await websocket.send_bytes(chunk)
@@ -208,49 +278,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str):
print("Image sent to ESP32 (Binary)")
def get_font_data(unicode_val):
"""从字体文件获取单个字符数据(带缓存)"""
if unicode_val in font_cache:
return font_cache[unicode_val]
try:
char = chr(unicode_val)
gb_bytes = char.encode('gb2312')
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
if font_data_buffer:
if offset + 32 <= len(font_data_buffer):
font_data = font_data_buffer[offset:offset+32]
font_cache[unicode_val] = font_data
return font_data
else:
# Fallback to file reading if buffer failed
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if not os.path.exists(font_path):
font_path = FONT_FILE
if os.path.exists(font_path):
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
font_cache[unicode_val] = font_data
return font_data
except:
pass
return None
def send_font_batch_with_retry(websocket, code_list, retry_count=0):
async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
"""批量发送字体数据(带重试机制)"""
global font_request_queue
@@ -269,10 +297,7 @@ def send_font_batch_with_retry(websocket, code_list, retry_count=0):
import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
response = f"FONT_DATA:{code_str}:{hex_data}"
asyncio.run_coroutine_threadsafe(
websocket.send_text(response),
asyncio.get_event_loop()
)
await websocket.send_text(response)
success_codes.add(unicode_val)
else:
failed_codes.append(code_str)
@@ -336,7 +361,7 @@ async def handle_font_request(websocket, message_type, data):
code_list = codes_str.split(",")
print(f"Batch Font Request for {len(code_list)} chars")
success_count, failed = send_font_batch_with_retry(websocket, code_list)
success_count, failed = await send_font_batch_with_retry(websocket, code_list)
print(f"Font batch: {success_count} success, {len(failed)} failed")
# 发送完成标记
@@ -345,7 +370,7 @@ async def handle_font_request(websocket, message_type, data):
# 如果有失败的,进行重试
if failed:
await asyncio.sleep(0.5)
send_font_batch_with_retry(websocket, failed, retry_count=1)
await send_font_batch_with_retry(websocket, failed, retry_count=1)
except Exception as e:
print(f"Error handling batch font request: {e}")
@@ -402,20 +427,31 @@ class MyRecognitionCallback(RecognitionCallback):
self.websocket = websocket
self.loop = loop
self.final_text = "" # 保存最终识别结果
self.sentence_list = [] # 累积所有句子
def on_open(self) -> None:
print("ASR Session started")
self.sentence_list = []
self.final_text = ""
def on_close(self) -> None:
print("ASR Session closed")
# 关闭时将所有句子合并为完整文本
if self.sentence_list:
self.final_text = "".join(self.sentence_list)
print(f"Final combined ASR text: {self.final_text}")
def on_event(self, result: RecognitionResult) -> None:
if result.get_sentence():
text = result.get_sentence()['text']
print(f"ASR Result: {text}")
self.final_text = text # 保存识别结果
# 累积每一句识别结果
self.sentence_list.append(text)
# 同时更新 final_text 以便 Stop 时获取
self.final_text = "".join(self.sentence_list)
# 将识别结果发送回客户端
try:
if self.loop.is_running():
asyncio.run_coroutine_threadsafe(
self.websocket.send_text(f"ASR:{text}"),
self.loop
@@ -470,6 +506,10 @@ def optimize_prompt(asr_text, progress_callback=None):
)
if response.status_code == 200:
if hasattr(response, 'output') and response.output and \
hasattr(response.output, 'choices') and response.output.choices and \
len(response.output.choices) > 0:
optimized = response.output.choices[0].message.content.strip()
print(f"Optimized prompt: {optimized}")
@@ -477,6 +517,20 @@ def optimize_prompt(asr_text, progress_callback=None):
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
# Handle case where API returns text directly instead of choices
optimized = response.output.text.strip()
print(f"Optimized prompt (direct text): {optimized}")
if progress_callback:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
else:
print(f"Prompt optimization response format error: {response}")
if progress_callback:
progress_callback(0, "提示词优化响应格式错误")
return asr_text
else:
print(f"Prompt optimization failed: {response.code} - {response.message}")
if progress_callback:
@@ -748,7 +802,8 @@ async def websocket_endpoint(websocket: WebSocket):
# 先发送 ASR 文字到 ESP32 显示
await websocket.send_text(f"ASR:{asr_text}")
await start_async_image_generation(websocket, asr_text)
# 使用 create_task 异步执行,避免阻塞主循环处理字体请求
asyncio.create_task(start_async_image_generation(websocket, asr_text))
else:
print("No ASR text, skipping image generation")
@@ -778,53 +833,6 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"Font request error: {e}")
await websocket.send_text("FONT_BATCH_END:0:0")
# 计算偏移量
# GB2312 编码范围0xA1A1 - 0xFEFE
# 区码:高字节 - 0xA0
# 位码:低字节 - 0xA0
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
# 读取字体文件
# 注意:这里为了简单,每次都打开文件。如果并发高,应该缓存文件句柄或内容。
# 假设字体文件在当前目录或上级目录
# Prioritize finding the file in the script's directory
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
# Fallback: check one level up
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
# Fallback: check current working directory
if not os.path.exists(font_path):
font_path = FONT_FILE
if os.path.exists(font_path):
print(f"Reading font from: {font_path} (Offset: {offset})")
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
# Return the original requested code (unicode or hex) so client can map it back
response = f"FONT_DATA:{target_code_str}:{hex_data}"
# print(f"Sending Font Response: {response[:30]}...")
await websocket.send_text(response)
else:
print(f"Error: Read {len(font_data)} bytes for font data (expected 32)")
else:
print(f"Font file not found: {font_path}")
else:
print(f"Invalid GB2312 code derived: {code:X} (Area: {area}, Index: {index})")
except Exception as e:
print(f"Error handling FONT request: {e}")
elif "bytes" in message:
# 接收音频数据并追加到缓冲区
data = message["bytes"]