1
This commit is contained in:
9
main.py
9
main.py
@@ -6,6 +6,9 @@ import network
|
||||
import st7789py as st7789
|
||||
from config import CURRENT_CONFIG
|
||||
from audio import AudioPlayer, Microphone
|
||||
|
||||
# Define colors that might be missing in st7789py
|
||||
DARKGREY = 0x4208
|
||||
from display import Display
|
||||
from websocket_client import WebSocketClient
|
||||
import uselect
|
||||
@@ -89,7 +92,7 @@ def draw_mic_icon(display, x, y, active=True):
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
color = st7789.GREEN if active else st7789.DARKGREY
|
||||
color = st7789.GREEN if active else DARKGREY
|
||||
|
||||
display.tft.fill_rect(x + 5, y, 10, 5, color)
|
||||
display.tft.fill_rect(x + 3, y + 5, 14, 10, color)
|
||||
@@ -137,7 +140,7 @@ def draw_progress_bar(display, x, y, width, height, progress, color=st7789.CYAN)
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
display.tft.fill_rect(x, y, width, height, st7789.DARKGREY)
|
||||
display.tft.fill_rect(x, y, width, height, DARKGREY)
|
||||
if progress > 0:
|
||||
bar_width = int(width * min(progress, 1.0))
|
||||
display.tft.fill_rect(x, y, bar_width, height, color)
|
||||
@@ -176,7 +179,7 @@ def render_confirm_screen(display, asr_text=""):
|
||||
display.tft.fill_rect(0, 0, 240, 30, st7789.CYAN)
|
||||
display.text("说完了吗?", 75, 8, st7789.BLACK)
|
||||
|
||||
display.tft.fill_rect(10, 50, 220, 80, st7789.DARKGREY)
|
||||
display.tft.fill_rect(10, 50, 220, 80, DARKGREY)
|
||||
display.text(asr_text if asr_text else "未识别到文字", 20, 75, st7789.WHITE)
|
||||
|
||||
display.tft.fill_rect(20, 150, 80, 30, st7789.GREEN)
|
||||
|
||||
@@ -147,10 +147,13 @@ class WebSocketClient:
|
||||
|
||||
# Read payload
|
||||
data = bytearray(length)
|
||||
view = memoryview(data)
|
||||
|
||||
# Use smaller chunks for readinto to avoid memory allocation issues in MicroPython
|
||||
pos = 0
|
||||
while pos < length:
|
||||
read_len = self.sock.readinto(view[pos:])
|
||||
chunk_size = min(length - pos, 512)
|
||||
chunk_view = memoryview(data)[pos:pos + chunk_size]
|
||||
read_len = self.sock.readinto(chunk_view)
|
||||
if read_len == 0:
|
||||
return None
|
||||
pos += read_len
|
||||
|
||||
Binary file not shown.
BIN
websocket_server/generated_thumb.bin
Normal file
BIN
websocket_server/generated_thumb.bin
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -43,6 +43,49 @@ def calculate_md5(filepath):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def get_font_data(unicode_val):
|
||||
"""从字体文件获取单个字符数据(带缓存)"""
|
||||
if unicode_val in font_cache:
|
||||
return font_cache[unicode_val]
|
||||
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
if font_data_buffer:
|
||||
if offset + 32 <= len(font_data_buffer):
|
||||
font_data = font_data_buffer[offset:offset+32]
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
else:
|
||||
# Fallback to file reading if buffer failed
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def init_font_cache():
|
||||
"""初始化字体缓存和MD5"""
|
||||
global font_cache, font_md5, font_data_buffer
|
||||
@@ -93,7 +136,7 @@ def get_output_path():
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
|
||||
|
||||
THUMB_SIZE = 245
|
||||
THUMB_SIZE = 240
|
||||
|
||||
# 字体请求队列(用于重试机制)
|
||||
font_request_queue = {}
|
||||
@@ -131,26 +174,49 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
|
||||
await websocket.send_text(f"TASK_ID:{task_id}")
|
||||
|
||||
def progress_callback(progress: int, message: str):
|
||||
"""进度回调函数"""
|
||||
# 获取当前事件循环
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# 使用队列在线程和主事件循环之间传递消息
|
||||
message_queue = asyncio.Queue()
|
||||
|
||||
async def progress_callback_async(progress: int, message: str):
|
||||
"""异步进度回调"""
|
||||
task.progress = progress
|
||||
task.message = message
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}"),
|
||||
asyncio.get_event_loop()
|
||||
)
|
||||
await websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}")
|
||||
except Exception as e:
|
||||
print(f"Error sending progress: {e}")
|
||||
|
||||
def progress_callback(progress: int, message: str):
|
||||
"""进度回调函数(在线程中调用)"""
|
||||
task.progress = progress
|
||||
task.message = message
|
||||
# 通过队列在主循环中发送消息
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
progress_callback_async(progress, message),
|
||||
loop
|
||||
)
|
||||
|
||||
try:
|
||||
task.status = "optimizing"
|
||||
|
||||
await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# 同步调用优化函数
|
||||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
|
||||
|
||||
# 确保返回有效的提示词
|
||||
if not optimized_prompt:
|
||||
optimized_prompt = asr_text
|
||||
print(f"Warning: optimize_prompt returned None, using original text: {asr_text}")
|
||||
|
||||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||||
task.optimized_prompt = optimized_prompt
|
||||
|
||||
@@ -158,6 +224,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
# 同步调用图片生成函数
|
||||
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
|
||||
|
||||
task.result = image_path
|
||||
@@ -178,8 +245,11 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
task.status = "failed"
|
||||
task.error = str(e)
|
||||
print(f"Image generation task error: {e}")
|
||||
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成出错")
|
||||
try:
|
||||
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成出错")
|
||||
except Exception as ws_e:
|
||||
print(f"Error sending error message: {ws_e}")
|
||||
finally:
|
||||
if task_id in active_tasks:
|
||||
del active_tasks[task_id]
|
||||
@@ -198,7 +268,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str):
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||||
|
||||
# Send binary data directly
|
||||
chunk_size = 4096 # Increased chunk size for binary
|
||||
chunk_size = 512 # Decreased chunk size for ESP32 memory stability
|
||||
for i in range(0, len(image_data), chunk_size):
|
||||
chunk = image_data[i:i+chunk_size]
|
||||
await websocket.send_bytes(chunk)
|
||||
@@ -208,49 +278,7 @@ async def send_image_to_client(websocket: WebSocket, image_path: str):
|
||||
print("Image sent to ESP32 (Binary)")
|
||||
|
||||
|
||||
def get_font_data(unicode_val):
|
||||
"""从字体文件获取单个字符数据(带缓存)"""
|
||||
if unicode_val in font_cache:
|
||||
return font_cache[unicode_val]
|
||||
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
if font_data_buffer:
|
||||
if offset + 32 <= len(font_data_buffer):
|
||||
font_data = font_data_buffer[offset:offset+32]
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
else:
|
||||
# Fallback to file reading if buffer failed
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||||
async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||||
"""批量发送字体数据(带重试机制)"""
|
||||
global font_request_queue
|
||||
|
||||
@@ -269,10 +297,7 @@ def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_text(response),
|
||||
asyncio.get_event_loop()
|
||||
)
|
||||
await websocket.send_text(response)
|
||||
success_codes.add(unicode_val)
|
||||
else:
|
||||
failed_codes.append(code_str)
|
||||
@@ -336,7 +361,7 @@ async def handle_font_request(websocket, message_type, data):
|
||||
code_list = codes_str.split(",")
|
||||
print(f"Batch Font Request for {len(code_list)} chars")
|
||||
|
||||
success_count, failed = send_font_batch_with_retry(websocket, code_list)
|
||||
success_count, failed = await send_font_batch_with_retry(websocket, code_list)
|
||||
print(f"Font batch: {success_count} success, {len(failed)} failed")
|
||||
|
||||
# 发送完成标记
|
||||
@@ -345,7 +370,7 @@ async def handle_font_request(websocket, message_type, data):
|
||||
# 如果有失败的,进行重试
|
||||
if failed:
|
||||
await asyncio.sleep(0.5)
|
||||
send_font_batch_with_retry(websocket, failed, retry_count=1)
|
||||
await send_font_batch_with_retry(websocket, failed, retry_count=1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling batch font request: {e}")
|
||||
@@ -402,24 +427,35 @@ class MyRecognitionCallback(RecognitionCallback):
|
||||
self.websocket = websocket
|
||||
self.loop = loop
|
||||
self.final_text = "" # 保存最终识别结果
|
||||
self.sentence_list = [] # 累积所有句子
|
||||
|
||||
def on_open(self) -> None:
|
||||
print("ASR Session started")
|
||||
self.sentence_list = []
|
||||
self.final_text = ""
|
||||
|
||||
def on_close(self) -> None:
|
||||
print("ASR Session closed")
|
||||
# 关闭时将所有句子合并为完整文本
|
||||
if self.sentence_list:
|
||||
self.final_text = "".join(self.sentence_list)
|
||||
print(f"Final combined ASR text: {self.final_text}")
|
||||
|
||||
def on_event(self, result: RecognitionResult) -> None:
|
||||
if result.get_sentence():
|
||||
text = result.get_sentence()['text']
|
||||
print(f"ASR Result: {text}")
|
||||
self.final_text = text # 保存识别结果
|
||||
# 累积每一句识别结果
|
||||
self.sentence_list.append(text)
|
||||
# 同时更新 final_text 以便 Stop 时获取
|
||||
self.final_text = "".join(self.sentence_list)
|
||||
# 将识别结果发送回客户端
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.websocket.send_text(f"ASR:{text}"),
|
||||
self.loop
|
||||
)
|
||||
if self.loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self.websocket.send_text(f"ASR:{text}"),
|
||||
self.loop
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to send ASR result to client: {e}")
|
||||
|
||||
@@ -470,13 +506,31 @@ def optimize_prompt(asr_text, progress_callback=None):
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
optimized = response.output.choices[0].message.content.strip()
|
||||
print(f"Optimized prompt: {optimized}")
|
||||
if hasattr(response, 'output') and response.output and \
|
||||
hasattr(response.output, 'choices') and response.output.choices and \
|
||||
len(response.output.choices) > 0:
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
optimized = response.output.choices[0].message.content.strip()
|
||||
print(f"Optimized prompt: {optimized}")
|
||||
|
||||
return optimized
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
|
||||
return optimized
|
||||
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
|
||||
# Handle case where API returns text directly instead of choices
|
||||
optimized = response.output.text.strip()
|
||||
print(f"Optimized prompt (direct text): {optimized}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
|
||||
return optimized
|
||||
else:
|
||||
print(f"Prompt optimization response format error: {response}")
|
||||
if progress_callback:
|
||||
progress_callback(0, "提示词优化响应格式错误")
|
||||
return asr_text
|
||||
else:
|
||||
print(f"Prompt optimization failed: {response.code} - {response.message}")
|
||||
if progress_callback:
|
||||
@@ -748,7 +802,8 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
# 先发送 ASR 文字到 ESP32 显示
|
||||
await websocket.send_text(f"ASR:{asr_text}")
|
||||
|
||||
await start_async_image_generation(websocket, asr_text)
|
||||
# 使用 create_task 异步执行,避免阻塞主循环处理字体请求
|
||||
asyncio.create_task(start_async_image_generation(websocket, asr_text))
|
||||
else:
|
||||
print("No ASR text, skipping image generation")
|
||||
|
||||
@@ -778,53 +833,6 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
print(f"Font request error: {e}")
|
||||
await websocket.send_text("FONT_BATCH_END:0:0")
|
||||
|
||||
# 计算偏移量
|
||||
# GB2312 编码范围:0xA1A1 - 0xFEFE
|
||||
# 区码:高字节 - 0xA0
|
||||
# 位码:低字节 - 0xA0
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
# 读取字体文件
|
||||
# 注意:这里为了简单,每次都打开文件。如果并发高,应该缓存文件句柄或内容。
|
||||
# 假设字体文件在当前目录或上级目录
|
||||
# Prioritize finding the file in the script's directory
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
|
||||
# Fallback: check one level up
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
|
||||
# Fallback: check current working directory
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
print(f"Reading font from: {font_path} (Offset: {offset})")
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
|
||||
if len(font_data) == 32:
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
# Return the original requested code (unicode or hex) so client can map it back
|
||||
response = f"FONT_DATA:{target_code_str}:{hex_data}"
|
||||
# print(f"Sending Font Response: {response[:30]}...")
|
||||
await websocket.send_text(response)
|
||||
else:
|
||||
print(f"Error: Read {len(font_data)} bytes for font data (expected 32)")
|
||||
else:
|
||||
print(f"Font file not found: {font_path}")
|
||||
else:
|
||||
print(f"Invalid GB2312 code derived: {code:X} (Area: {area}, Index: {index})")
|
||||
except Exception as e:
|
||||
print(f"Error handling FONT request: {e}")
|
||||
|
||||
elif "bytes" in message:
|
||||
# 接收音频数据并追加到缓冲区
|
||||
data = message["bytes"]
|
||||
|
||||
Reference in New Issue
Block a user