diff --git a/main.py b/main.py index 5b0c66d..3d0cabf 100644 --- a/main.py +++ b/main.py @@ -6,6 +6,9 @@ import network import st7789py as st7789 from config import CURRENT_CONFIG from audio import AudioPlayer, Microphone + +# Define colors that might be missing in st7789py +DARKGREY = 0x4208 from display import Display from websocket_client import WebSocketClient import uselect @@ -89,7 +92,7 @@ def draw_mic_icon(display, x, y, active=True): if not display or not display.tft: return - color = st7789.GREEN if active else st7789.DARKGREY + color = st7789.GREEN if active else DARKGREY display.tft.fill_rect(x + 5, y, 10, 5, color) display.tft.fill_rect(x + 3, y + 5, 14, 10, color) @@ -137,7 +140,7 @@ def draw_progress_bar(display, x, y, width, height, progress, color=st7789.CYAN) if not display or not display.tft: return - display.tft.fill_rect(x, y, width, height, st7789.DARKGREY) + display.tft.fill_rect(x, y, width, height, DARKGREY) if progress > 0: bar_width = int(width * min(progress, 1.0)) display.tft.fill_rect(x, y, bar_width, height, color) @@ -176,7 +179,7 @@ def render_confirm_screen(display, asr_text=""): display.tft.fill_rect(0, 0, 240, 30, st7789.CYAN) display.text("说完了吗?", 75, 8, st7789.BLACK) - display.tft.fill_rect(10, 50, 220, 80, st7789.DARKGREY) + display.tft.fill_rect(10, 50, 220, 80, DARKGREY) display.text(asr_text if asr_text else "未识别到文字", 20, 75, st7789.WHITE) display.tft.fill_rect(20, 150, 80, 30, st7789.GREEN) diff --git a/websocket_client.py b/websocket_client.py index cb6d08e..cdadab1 100644 --- a/websocket_client.py +++ b/websocket_client.py @@ -147,10 +147,13 @@ class WebSocketClient: # Read payload data = bytearray(length) - view = memoryview(data) + + # Use smaller chunks for readinto to avoid memory allocation issues in MicroPython pos = 0 while pos < length: - read_len = self.sock.readinto(view[pos:]) + chunk_size = min(length - pos, 512) + chunk_view = memoryview(data)[pos:pos + chunk_size] + read_len = self.sock.readinto(chunk_view) if read_len == 0: return None pos += read_len diff --git a/websocket_server/__pycache__/server.cpython-312.pyc b/websocket_server/__pycache__/server.cpython-312.pyc index 1c29fce..77709ee 100644 Binary files a/websocket_server/__pycache__/server.cpython-312.pyc and b/websocket_server/__pycache__/server.cpython-312.pyc differ diff --git a/websocket_server/generated_thumb.bin b/websocket_server/generated_thumb.bin new file mode 100644 index 0000000..cebd8cf Binary files /dev/null and b/websocket_server/generated_thumb.bin differ diff --git a/websocket_server/received_audio.mp3 b/websocket_server/received_audio.mp3 index bd5948a..05835ee 100644 Binary files a/websocket_server/received_audio.mp3 and b/websocket_server/received_audio.mp3 differ diff --git a/websocket_server/received_audio.raw b/websocket_server/received_audio.raw index ac6dcae..d864a90 100644 Binary files a/websocket_server/received_audio.raw and b/websocket_server/received_audio.raw differ diff --git a/websocket_server/server.py b/websocket_server/server.py index 4cab5c0..424e2ad 100644 --- a/websocket_server/server.py +++ b/websocket_server/server.py @@ -43,6 +43,49 @@ def calculate_md5(filepath): hash_md5.update(chunk) return hash_md5.hexdigest() + +def get_font_data(unicode_val): + """从字体文件获取单个字符数据(带缓存)""" + if unicode_val in font_cache: + return font_cache[unicode_val] + + try: + char = chr(unicode_val) + gb_bytes = char.encode('gb2312') + if len(gb_bytes) == 2: + code = struct.unpack('>H', gb_bytes)[0] + area = (code >> 8) - 0xA0 + index = (code & 0xFF) - 0xA0 + + if area >= 1 and index >= 1: + offset = ((area - 1) * 94 + (index - 1)) * 32 + + if font_data_buffer: + if offset + 32 <= len(font_data_buffer): + font_data = font_data_buffer[offset:offset+32] + font_cache[unicode_val] = font_data + return font_data + else: + # Fallback to file reading if buffer failed + script_dir = os.path.dirname(os.path.abspath(__file__)) + font_path = os.path.join(script_dir, FONT_FILE) + if not os.path.exists(font_path): + font_path = os.path.join(script_dir, "..", FONT_FILE) + if not os.path.exists(font_path): + font_path = FONT_FILE + + if os.path.exists(font_path): + with open(font_path, "rb") as f: + f.seek(offset) + font_data = f.read(32) + if len(font_data) == 32: + font_cache[unicode_val] = font_data + return font_data + except: + pass + return None + + def init_font_cache(): """初始化字体缓存和MD5""" global font_cache, font_md5, font_data_buffer @@ -93,7 +136,7 @@ def get_output_path(): timestamp = time.strftime("%Y%m%d_%H%M%S") return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png") -THUMB_SIZE = 245 +THUMB_SIZE = 240 # 字体请求队列(用于重试机制) font_request_queue = {} @@ -131,26 +174,49 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str): await websocket.send_text(f"TASK_ID:{task_id}") - def progress_callback(progress: int, message: str): - """进度回调函数""" + # 获取当前事件循环 + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # 使用队列在线程和主事件循环之间传递消息 + message_queue = asyncio.Queue() + + async def progress_callback_async(progress: int, message: str): + """异步进度回调""" task.progress = progress task.message = message try: - asyncio.run_coroutine_threadsafe( - websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}"), - asyncio.get_event_loop() - ) + await websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}") except Exception as e: print(f"Error sending progress: {e}") + def progress_callback(progress: int, message: str): + """进度回调函数(在线程中调用)""" + task.progress = progress + task.message = message + # 通过队列在主循环中发送消息 + asyncio.run_coroutine_threadsafe( + progress_callback_async(progress, message), + loop + ) + try: task.status = "optimizing" await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...") await asyncio.sleep(0.2) + # 同步调用优化函数 optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback) + # 确保返回有效的提示词 + if not optimized_prompt: + optimized_prompt = asr_text + print(f"Warning: optimize_prompt returned None, using original text: {asr_text}") + await websocket.send_text(f"PROMPT:{optimized_prompt}") task.optimized_prompt = optimized_prompt @@ -158,6 +224,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str): await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...") await asyncio.sleep(0.2) + # 同步调用图片生成函数 image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback) task.result = image_path @@ -178,8 +245,11 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str): task.status = "failed" task.error = str(e) print(f"Image generation task error: {e}") - await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}") - await websocket.send_text("STATUS:ERROR:图片生成出错") + try: + await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}") + await websocket.send_text("STATUS:ERROR:图片生成出错") + except Exception as ws_e: + print(f"Error sending error message: {ws_e}") finally: if task_id in active_tasks: del active_tasks[task_id] @@ -198,7 +268,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str): await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}") # Send binary data directly - chunk_size = 4096 # Increased chunk size for binary + chunk_size = 512 # Decreased chunk size for ESP32 memory stability for i in range(0, len(image_data), chunk_size): chunk = image_data[i:i+chunk_size] await websocket.send_bytes(chunk) @@ -208,49 +278,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str): print("Image sent to ESP32 (Binary)") -def get_font_data(unicode_val): - """从字体文件获取单个字符数据(带缓存)""" - if unicode_val in font_cache: - return font_cache[unicode_val] - - try: - char = chr(unicode_val) - gb_bytes = char.encode('gb2312') - if len(gb_bytes) == 2: - code = struct.unpack('>H', gb_bytes)[0] - area = (code >> 8) - 0xA0 - index = (code & 0xFF) - 0xA0 - - if area >= 1 and index >= 1: - offset = ((area - 1) * 94 + (index - 1)) * 32 - - if font_data_buffer: - if offset + 32 <= len(font_data_buffer): - font_data = font_data_buffer[offset:offset+32] - font_cache[unicode_val] = font_data - return font_data - else: - # Fallback to file reading if buffer failed - script_dir = os.path.dirname(os.path.abspath(__file__)) - font_path = os.path.join(script_dir, FONT_FILE) - if not os.path.exists(font_path): - font_path = os.path.join(script_dir, "..", FONT_FILE) - if not os.path.exists(font_path): - font_path = FONT_FILE - - if os.path.exists(font_path): - with open(font_path, "rb") as f: - f.seek(offset) - font_data = f.read(32) - if len(font_data) == 32: - font_cache[unicode_val] = font_data - return font_data - except: - pass - return None - - -def send_font_batch_with_retry(websocket, code_list, retry_count=0): +async def send_font_batch_with_retry(websocket, code_list, retry_count=0): """批量发送字体数据(带重试机制)""" global font_request_queue @@ -269,10 +297,7 @@ def send_font_batch_with_retry(websocket, code_list, retry_count=0): import binascii hex_data = binascii.hexlify(font_data).decode('utf-8') response = f"FONT_DATA:{code_str}:{hex_data}" - asyncio.run_coroutine_threadsafe( - websocket.send_text(response), - asyncio.get_event_loop() - ) + await websocket.send_text(response) success_codes.add(unicode_val) else: failed_codes.append(code_str) @@ -336,7 +361,7 @@ async def handle_font_request(websocket, message_type, data): code_list = codes_str.split(",") print(f"Batch Font Request for {len(code_list)} chars") - success_count, failed = send_font_batch_with_retry(websocket, code_list) + success_count, failed = await send_font_batch_with_retry(websocket, code_list) print(f"Font batch: {success_count} success, {len(failed)} failed") # 发送完成标记 @@ -345,7 +370,7 @@ async def handle_font_request(websocket, message_type, data): # 如果有失败的,进行重试 if failed: await asyncio.sleep(0.5) - send_font_batch_with_retry(websocket, failed, retry_count=1) + await send_font_batch_with_retry(websocket, failed, retry_count=1) except Exception as e: print(f"Error handling batch font request: {e}") @@ -402,24 +427,35 @@ class MyRecognitionCallback(RecognitionCallback): self.websocket = websocket self.loop = loop self.final_text = "" # 保存最终识别结果 + self.sentence_list = [] # 累积所有句子 def on_open(self) -> None: print("ASR Session started") + self.sentence_list = [] + self.final_text = "" def on_close(self) -> None: print("ASR Session closed") + # 关闭时将所有句子合并为完整文本 + if self.sentence_list: + self.final_text = "".join(self.sentence_list) + print(f"Final combined ASR text: {self.final_text}") def on_event(self, result: RecognitionResult) -> None: if result.get_sentence(): text = result.get_sentence()['text'] print(f"ASR Result: {text}") - self.final_text = text # 保存识别结果 + # 累积每一句识别结果 + self.sentence_list.append(text) + # 同时更新 final_text 以便 Stop 时获取 + self.final_text = "".join(self.sentence_list) # 将识别结果发送回客户端 try: - asyncio.run_coroutine_threadsafe( - self.websocket.send_text(f"ASR:{text}"), - self.loop - ) + if self.loop.is_running(): + asyncio.run_coroutine_threadsafe( + self.websocket.send_text(f"ASR:{text}"), + self.loop + ) except Exception as e: print(f"Failed to send ASR result to client: {e}") @@ -470,13 +506,31 @@ def optimize_prompt(asr_text, progress_callback=None): ) if response.status_code == 200: - optimized = response.output.choices[0].message.content.strip() - print(f"Optimized prompt: {optimized}") - - if progress_callback: - progress_callback(30, f"提示词优化完成: {optimized[:50]}...") - - return optimized + if hasattr(response, 'output') and response.output and \ + hasattr(response.output, 'choices') and response.output.choices and \ + len(response.output.choices) > 0: + + optimized = response.output.choices[0].message.content.strip() + print(f"Optimized prompt: {optimized}") + + if progress_callback: + progress_callback(30, f"提示词优化完成: {optimized[:50]}...") + + return optimized + elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'): + # Handle case where API returns text directly instead of choices + optimized = response.output.text.strip() + print(f"Optimized prompt (direct text): {optimized}") + + if progress_callback: + progress_callback(30, f"提示词优化完成: {optimized[:50]}...") + + return optimized + else: + print(f"Prompt optimization response format error: {response}") + if progress_callback: + progress_callback(0, "提示词优化响应格式错误") + return asr_text else: print(f"Prompt optimization failed: {response.code} - {response.message}") if progress_callback: @@ -748,7 +802,8 @@ async def websocket_endpoint(websocket: WebSocket): # 先发送 ASR 文字到 ESP32 显示 await websocket.send_text(f"ASR:{asr_text}") - await start_async_image_generation(websocket, asr_text) + # 使用 create_task 异步执行,避免阻塞主循环处理字体请求 + asyncio.create_task(start_async_image_generation(websocket, asr_text)) else: print("No ASR text, skipping image generation") @@ -777,53 +832,6 @@ async def websocket_endpoint(websocket: WebSocket): except Exception as e: print(f"Font request error: {e}") await websocket.send_text("FONT_BATCH_END:0:0") - - # 计算偏移量 - # GB2312 编码范围:0xA1A1 - 0xFEFE - # 区码:高字节 - 0xA0 - # 位码:低字节 - 0xA0 - area = (code >> 8) - 0xA0 - index = (code & 0xFF) - 0xA0 - - if area >= 1 and index >= 1: - offset = ((area - 1) * 94 + (index - 1)) * 32 - - # 读取字体文件 - # 注意:这里为了简单,每次都打开文件。如果并发高,应该缓存文件句柄或内容。 - # 假设字体文件在当前目录或上级目录 - # Prioritize finding the file in the script's directory - script_dir = os.path.dirname(os.path.abspath(__file__)) - font_path = os.path.join(script_dir, FONT_FILE) - - # Fallback: check one level up - if not os.path.exists(font_path): - font_path = os.path.join(script_dir, "..", FONT_FILE) - - # Fallback: check current working directory - if not os.path.exists(font_path): - font_path = FONT_FILE - - if os.path.exists(font_path): - print(f"Reading font from: {font_path} (Offset: {offset})") - with open(font_path, "rb") as f: - f.seek(offset) - font_data = f.read(32) - - if len(font_data) == 32: - import binascii - hex_data = binascii.hexlify(font_data).decode('utf-8') - # Return the original requested code (unicode or hex) so client can map it back - response = f"FONT_DATA:{target_code_str}:{hex_data}" - # print(f"Sending Font Response: {response[:30]}...") - await websocket.send_text(response) - else: - print(f"Error: Read {len(font_data)} bytes for font data (expected 32)") - else: - print(f"Font file not found: {font_path}") - else: - print(f"Invalid GB2312 code derived: {code:X} (Area: {area}, Index: {index})") - except Exception as e: - print(f"Error handling FONT request: {e}") elif "bytes" in message: # 接收音频数据并追加到缓冲区