This commit is contained in:
jeremygan2021
2026-03-03 21:59:57 +08:00
parent c87d5deedf
commit fc92a5feaf
7 changed files with 963 additions and 389 deletions

585
main.py
View File

@@ -17,15 +17,27 @@ SERVER_IP = "6.6.6.88"
SERVER_PORT = 8000
SERVER_URL = f"ws://{SERVER_IP}:{SERVER_PORT}/ws/audio"
# 图片接收状态
IMAGE_STATE_IDLE = 0
IMAGE_STATE_RECEIVING = 1
UI_SCREEN_RECORDING = 1
UI_SCREEN_CONFIRM = 2
UI_SCREEN_RESULT = 3
BOOT_SHORT_MS = 500
BOOT_LONG_MS = 2000
BOOT_EXTRA_LONG_MS = 5000
IMG_WIDTH = 120
IMG_HEIGHT = 120
_last_btn_state = None
_btn_release_time = 0
_btn_press_time = 0
def connect_wifi(max_retries=5):
"""连接WiFi网络"""
wlan = network.WLAN(network.STA_IF)
try:
@@ -72,38 +84,183 @@ def connect_wifi(max_retries=5):
return False
def print_asr(text, display=None):
print(f"ASR: {text}")
if display and display.tft:
display.fill_rect(0, 40, 240, 160, st7789.BLACK)
display.text(text, 0, 40, st7789.WHITE)
def draw_mic_icon(display, x, y, active=True):
"""绘制麦克风图标"""
if not display or not display.tft:
return
color = st7789.GREEN if active else st7789.DARKGREY
display.tft.fill_rect(x + 5, y, 10, 5, color)
display.tft.fill_rect(x + 3, y + 5, 14, 10, color)
display.tft.fill_rect(x + 8, y + 15, 4, 8, color)
display.tft.fill_rect(x + 6, y + 23, 8, 2, color)
display.tft.fill_rect(x + 8, y + 25, 4, 3, color)
def draw_loading_spinner(display, x, y, angle, color=st7789.WHITE):
"""绘制旋转加载图标"""
if not display or not display.tft:
return
import math
rad = math.radians(angle)
# Clear previous (simple erase)
# This is tricky without a buffer, so we just draw over.
# For better performance we should remember previous pos.
center_x = x + 10
center_y = y + 10
radius = 8
for i in range(8):
theta = math.radians(i * 45) + rad
px = int(center_x + radius * math.cos(theta))
py = int(center_y + radius * math.sin(theta))
# Brightness based on angle (simulated by color or size)
# Here we just draw dots
display.tft.pixel(px, py, color)
def draw_check_icon(display, x, y):
"""绘制勾选图标"""
if not display or not display.tft:
return
display.tft.line(x, y + 5, x + 3, y + 8, st7789.GREEN)
display.tft.line(x + 3, y + 8, x + 10, y, st7789.GREEN)
def draw_progress_bar(display, x, y, width, height, progress, color=st7789.CYAN):
"""绘制进度条"""
if not display or not display.tft:
return
display.tft.fill_rect(x, y, width, height, st7789.DARKGREY)
if progress > 0:
bar_width = int(width * min(progress, 1.0))
display.tft.fill_rect(x, y, bar_width, height, color)
def render_recording_screen(display, asr_text="", audio_level=0):
"""渲染录音界面"""
if not display or not display.tft:
return
display.tft.fill(st7789.BLACK)
display.tft.fill_rect(0, 0, 240, 30, st7789.WHITE)
display.text("语音识别", 80, 8, st7789.BLACK)
draw_mic_icon(display, 105, 50, True)
if audio_level > 0:
bar_width = min(int(audio_level * 2), 200)
display.tft.fill_rect(20, 100, bar_width, 10, st7789.GREEN)
if asr_text:
display.text(asr_text[:20], 20, 130, st7789.WHITE)
display.tft.fill_rect(60, 200, 120, 25, st7789.RED)
display.text("松开停止", 85, 205, st7789.WHITE)
def render_confirm_screen(display, asr_text=""):
"""渲染确认界面"""
if not display or not display.tft:
return
display.tft.fill(st7789.BLACK)
display.tft.fill_rect(0, 0, 240, 30, st7789.CYAN)
display.text("说完了吗?", 75, 8, st7789.BLACK)
display.tft.fill_rect(10, 50, 220, 80, st7789.DARKGREY)
display.text(asr_text if asr_text else "未识别到文字", 20, 75, st7789.WHITE)
display.tft.fill_rect(20, 150, 80, 30, st7789.GREEN)
display.text("短按确认", 30, 158, st7789.BLACK)
display.tft.fill_rect(140, 150, 80, 30, st7789.RED)
display.text("长按重录", 155, 158, st7789.WHITE)
def render_result_screen(display, status="", prompt="", image_received=False):
"""渲染结果界面"""
if not display or not display.tft:
return
# Only clear if we are starting a new state or it's the first render
# But for simplicity we clear all for now. Optimizing this requires state tracking.
display.tft.fill(st7789.BLACK)
# Header
display.tft.fill_rect(0, 0, 240, 30, st7789.WHITE)
display.text("AI 生成中", 80, 8, st7789.BLACK)
if status == "OPTIMIZING":
display.text("正在思考...", 80, 60, st7789.CYAN)
display.text("优化提示词中", 70, 80, st7789.CYAN)
draw_progress_bar(display, 40, 110, 160, 6, 0.3, st7789.CYAN)
# Spinner will be drawn by main loop
elif status == "RENDERING":
display.text("正在绘画...", 80, 60, st7789.YELLOW)
display.text("AI作画中", 85, 80, st7789.YELLOW)
draw_progress_bar(display, 40, 110, 160, 6, 0.7, st7789.YELLOW)
# Spinner will be drawn by main loop
elif status == "COMPLETE" or image_received:
display.text("生成完成!", 80, 50, st7789.GREEN)
draw_check_icon(display, 110, 80)
elif status == "ERROR":
display.text("生成失败", 80, 50, st7789.RED)
if prompt:
display.tft.fill_rect(10, 140, 220, 50, 0x2124) # Dark Grey
display.text("提示词:", 15, 145, st7789.CYAN)
display.text(prompt[:25] + "..." if len(prompt) > 25 else prompt, 15, 165, st7789.WHITE)
display.tft.fill_rect(60, 210, 120, 25, st7789.BLUE)
display.text("返回录音", 90, 215, st7789.WHITE)
def process_message(msg, display, image_state, image_data_list):
"""处理WebSocket消息返回新的image_state"""
"""处理WebSocket消息"""
# Handle binary image data
if isinstance(msg, (bytes, bytearray)):
if image_state == IMAGE_STATE_RECEIVING:
image_data_list.append(msg)
# Optional: Update progress bar or indicator
return image_state, None
return image_state, None
if not isinstance(msg, str):
return image_state
return image_state, None
status_info = None
# 处理ASR消息
if msg.startswith("ASR:"):
print_asr(msg[4:], display)
return image_state, ("asr", msg[4:])
elif msg.startswith("STATUS:"):
parts = msg[7:].split(":", 1)
status_type = parts[0]
status_text = parts[1] if len(parts) > 1 else ""
print(f"Status: {status_type} - {status_text}")
return image_state, ("status", status_type, status_text)
# 处理图片生成状态消息
elif msg.startswith("GENERATING_IMAGE:"):
print(msg)
if display and display.tft:
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
display.text("正在生成图片...", 0, 40, st7789.YELLOW)
# Deprecated by STATUS:RENDERING but kept for compatibility
return image_state, None
# 处理提示词优化消息
elif msg.startswith("PROMPT:"):
prompt = msg[7:]
print(f"Optimized prompt: {prompt}")
if display and display.tft:
display.fill_rect(0, 60, 240, 40, st7789.BLACK)
display.text("提示词: " + prompt[:20], 0, 60, st7789.CYAN)
return image_state, ("prompt", prompt)
# 处理图片开始消息
elif msg.startswith("IMAGE_START:"):
try:
parts = msg.split(":")
@@ -111,64 +268,120 @@ def process_message(msg, display, image_state, image_data_list):
img_size = int(parts[2]) if len(parts) > 2 else 64
print(f"Image start, size: {size}, img_size: {img_size}")
image_data_list.clear()
image_data_list.append(img_size) # 保存图片尺寸
return IMAGE_STATE_RECEIVING
image_data_list.append(img_size) # Store metadata at index 0
return IMAGE_STATE_RECEIVING, None
except Exception as e:
print(f"IMAGE_START parse error: {e}")
return image_state
return image_state, None
# 处理图片数据消息
# Deprecated text-based IMAGE_DATA handling
elif msg.startswith("IMAGE_DATA:") and image_state == IMAGE_STATE_RECEIVING:
try:
data = msg.split(":", 1)[1]
image_data_list.append(data)
# Convert hex to bytes immediately if using old protocol, but we switched to binary
# Keep this just in case server rolls back? No, let's assume binary.
pass
except:
pass
# 处理图片结束消息
elif msg == "IMAGE_END" and image_state == IMAGE_STATE_RECEIVING:
try:
print("Image received, processing...")
# 获取图片尺寸
img_size = image_data_list[0] if image_data_list else 64
hex_data = "".join(image_data_list[1:])
# Combine all binary chunks (skipping metadata at index 0)
img_data = b"".join(image_data_list[1:])
image_data_list.clear()
# 将hex字符串转换为字节数据
img_data = bytes.fromhex(hex_data)
print(f"Image data len: {len(img_data)}")
# 在屏幕中心显示图片
if display and display.tft:
# 计算居中位置
x = (240 - img_size) // 2
y = (240 - img_size) // 2
# 显示图片
display.show_image(x, y, img_size, img_size, img_data)
display.fill_rect(0, 0, 240, 30, st7789.WHITE)
display.text("图片已生成!", 0, 5, st7789.BLACK)
# Overlay success message slightly
display.tft.fill_rect(0, 0, 240, 30, st7789.WHITE)
display.text("图片已生成!", 70, 5, st7789.BLACK)
gc.collect()
print("Image displayed")
return IMAGE_STATE_IDLE, ("image_done",)
except Exception as e:
print(f"Image process error: {e}")
import sys
sys.print_exception(e)
return IMAGE_STATE_IDLE
return IMAGE_STATE_IDLE, None
# 处理图片错误消息
elif msg.startswith("IMAGE_ERROR:"):
print(msg)
if display and display.tft:
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
display.text("图片生成失败", 0, 40, st7789.RED)
return IMAGE_STATE_IDLE
return IMAGE_STATE_IDLE, ("error", msg[12:])
return image_state
return image_state, None
def print_asr(text, display=None):
"""打印ASR结果"""
print(f"ASR: {text}")
if display and display.tft:
display.fill_rect(0, 40, 240, 160, st7789.BLACK)
display.text(text, 0, 40, st7789.WHITE)
def get_boot_button_action(boot_btn):
"""获取Boot按键动作类型
返回:
0: 无动作
1: 短按 (<500ms)
2: 长按 (2-5秒)
3: 超长按 (>5秒)
"""
global _last_btn_state, _btn_release_time, _btn_press_time
current_value = boot_btn.value()
current_time = time.ticks_ms()
if current_value == 0:
if _last_btn_state != 0:
_last_btn_state = 0
_btn_press_time = current_time
return 0
if current_value == 1 and _last_btn_state == 0:
press_duration = time.ticks_diff(current_time, _btn_press_time)
_last_btn_state = 1
if press_duration < BOOT_SHORT_MS:
return 0
elif press_duration < BOOT_LONG_MS:
return 1
elif press_duration < BOOT_EXTRA_LONG_MS:
return 2
else:
return 3
if _last_btn_state is None:
_last_btn_state = current_value
_btn_release_time = current_time
return 0
def check_memory(silent=False):
"""检查内存使用情况
Args:
silent: 是否静默模式(不打印日志)
"""
free = gc.mem_free()
total = gc.mem_alloc() + free
usage = (gc.mem_alloc() / total) * 100 if total > 0 else 0
if not silent:
print(f"Memory: {free} free, {usage:.1f}% used")
return usage
def main():
@@ -191,12 +404,18 @@ def main():
if display.tft:
display.init_ui()
ui_screen = UI_SCREEN_RECORDING
is_recording = False
ws = None
image_state = IMAGE_STATE_IDLE
image_data_list = []
current_asr_text = ""
current_prompt = ""
current_status = ""
image_generation_done = False
confirm_waiting = False
def connect_ws():
def connect_ws(force=False):
nonlocal ws
try:
if ws:
@@ -205,16 +424,24 @@ def main():
pass
ws = None
try:
print(f"Connecting to {SERVER_URL}")
ws = WebSocketClient(SERVER_URL)
print("WebSocket connected!")
if display:
display.set_ws(ws)
return True
except Exception as e:
print(f"WS connection failed: {e}")
return False
retry_count = 0
max_retries = 3
while retry_count < max_retries:
try:
print(f"Connecting to {SERVER_URL} (attempt {retry_count + 1})")
ws = WebSocketClient(SERVER_URL)
print("WebSocket connected!")
if display:
display.set_ws(ws)
return True
except Exception as e:
print(f"WS connection failed: {e}")
retry_count += 1
time.sleep(1)
return False
if connect_wifi():
connect_ws()
@@ -222,27 +449,162 @@ def main():
print("Running in offline mode")
read_buf = bytearray(4096)
last_audio_level = 0
memory_check_counter = 0
spinner_angle = 0
last_spinner_time = 0
while True:
try:
btn_val = boot_btn.value()
memory_check_counter += 1
if btn_val == 0:
if not is_recording:
print(">>> Recording...")
is_recording = True
if memory_check_counter >= 300:
memory_check_counter = 0
if check_memory(silent=True) > 80:
gc.collect()
print("Memory high, cleaned")
# Spinner Animation
if ui_screen == UI_SCREEN_RESULT and not image_generation_done and current_status in ["OPTIMIZING", "RENDERING"]:
now = time.ticks_ms()
if time.ticks_diff(now, last_spinner_time) > 100:
if display.tft:
display.fill(st7789.WHITE)
# Clear previous spinner (draw in BLACK)
draw_loading_spinner(display, 110, 80, spinner_angle, st7789.BLACK)
spinner_angle = (spinner_angle + 45) % 360
# Draw new spinner
color = st7789.CYAN if current_status == "OPTIMIZING" else st7789.YELLOW
draw_loading_spinner(display, 110, 80, spinner_angle, color)
if ws is None or not ws.is_connected():
connect_ws()
last_spinner_time = now
btn_action = get_boot_button_action(boot_btn)
if btn_action == 1:
if is_recording:
print(">>> Stop recording")
if ws and ws.is_connected():
try:
ws.send("STOP_RECORDING")
except:
ws = None
is_recording = False
ui_screen = UI_SCREEN_RESULT
image_generation_done = False
if display.tft:
render_result_screen(display, "OPTIMIZING", current_asr_text, False)
time.sleep(0.5)
elif ui_screen == UI_SCREEN_RECORDING:
if not is_recording:
print(">>> Recording...")
is_recording = True
confirm_waiting = False
current_asr_text = ""
current_prompt = ""
current_status = ""
image_generation_done = False
if display.tft:
render_recording_screen(display, "", 0)
if ws is None or not ws.is_connected():
connect_ws()
if ws and ws.is_connected():
try:
ws.send("START_RECORDING")
except:
ws = None
elif ui_screen == UI_SCREEN_CONFIRM:
print(">>> Confirm and generate")
if ws and ws.is_connected():
try:
ws.send("STOP_RECORDING")
except:
ws = None
is_recording = False
ui_screen = UI_SCREEN_RESULT
image_generation_done = False
if display.tft:
render_result_screen(display, "OPTIMIZING", current_asr_text, False)
time.sleep(0.5)
elif ui_screen == UI_SCREEN_RESULT:
print(">>> Back to recording")
ui_screen = UI_SCREEN_RECORDING
is_recording = False
current_asr_text = ""
current_prompt = ""
current_status = ""
image_generation_done = False
confirm_waiting = False
if display.tft:
render_recording_screen(display, "", 0)
elif btn_action == 2:
if is_recording:
print(">>> Stop recording (long press)")
if ws and ws.is_connected():
try:
ws.send("STOP_RECORDING")
except:
ws = None
is_recording = False
if ui_screen == UI_SCREEN_RECORDING or is_recording == False:
if current_asr_text:
print(">>> Generate image with ASR text")
ui_screen = UI_SCREEN_RESULT
image_generation_done = False
if display.tft:
render_result_screen(display, "OPTIMIZING", current_asr_text, False)
time.sleep(0.5)
else:
print(">>> Re-record")
current_asr_text = ""
confirm_waiting = False
ui_screen = UI_SCREEN_RECORDING
if display.tft:
render_recording_screen(display, "", 0)
elif ui_screen == UI_SCREEN_CONFIRM:
print(">>> Re-record")
current_asr_text = ""
confirm_waiting = False
ui_screen = UI_SCREEN_RECORDING
if display.tft:
render_recording_screen(display, "", 0)
elif ui_screen == UI_SCREEN_RESULT:
print(">>> Generate image (manual)")
if ws and ws.is_connected():
try:
ws.send("START_RECORDING")
is_recording = True
ui_screen = UI_SCREEN_RECORDING
except:
ws = None
elif btn_action == 3:
print(">>> Config mode")
if is_recording and btn_action == 0:
if mic.i2s:
num_read = mic.readinto(read_buf)
if num_read > 0:
@@ -255,48 +617,73 @@ def main():
events = poller.poll(0)
if events:
msg = ws.recv()
image_state = process_message(msg, display, image_state, image_data_list)
image_state, event_data = process_message(msg, display, image_state, image_data_list)
if event_data:
if event_data[0] == "asr":
current_asr_text = event_data[1]
if display.tft:
render_recording_screen(display, current_asr_text, last_audio_level)
elif event_data[0] == "status":
current_status = event_data[1]
status_text = event_data[2] if len(event_data) > 2 else ""
if display.tft:
render_result_screen(display, current_status, current_prompt, image_generation_done)
elif event_data[0] == "prompt":
current_prompt = event_data[1]
elif event_data[0] == "image_done":
image_generation_done = True
if display.tft:
render_result_screen(display, "COMPLETE", current_prompt, True)
elif event_data[0] == "error":
if display.tft:
render_result_screen(display, "ERROR", current_prompt, False)
except:
ws = None
if ui_screen == UI_SCREEN_RESULT and ws and ws.is_connected():
try:
poller = uselect.poll()
poller.register(ws.sock, uselect.POLLIN)
events = poller.poll(100)
if events:
msg = ws.recv()
if msg:
image_state, event_data = process_message(msg, display, image_state, image_data_list)
if event_data:
if event_data[0] == "asr":
current_asr_text = event_data[1]
elif event_data[0] == "status":
current_status = event_data[1]
status_text = event_data[2] if len(event_data) > 2 else ""
if display.tft:
render_result_screen(display, current_status, current_prompt, image_generation_done)
elif event_data[0] == "prompt":
current_prompt = event_data[1]
if display.tft:
render_result_screen(display, current_status, current_prompt, image_generation_done)
elif event_data[0] == "image_done":
image_generation_done = True
if display.tft:
render_result_screen(display, "COMPLETE", current_prompt, True)
elif event_data[0] == "error":
if display.tft:
render_result_screen(display, "ERROR", current_prompt, False)
except:
pass
continue
elif is_recording:
print(">>> Stop")
is_recording = False
if display.tft:
display.init_ui()
if ws:
try:
ws.send("STOP_RECORDING")
# 等待更长时间以接收图片生成结果
t_wait = time.ticks_add(time.ticks_ms(), 30000) # 等待30秒
prev_image_state = image_state
while time.ticks_diff(t_wait, time.ticks_ms()) > 0:
poller = uselect.poll()
poller.register(ws.sock, uselect.POLLIN)
events = poller.poll(500)
if events:
msg = ws.recv()
prev_image_state = image_state
image_state = process_message(msg, display, image_state, image_data_list)
# 如果之前在接收图片,现在停止了,说明图片接收完成
if prev_image_state == IMAGE_STATE_RECEIVING and image_state == IMAGE_STATE_IDLE:
break
else:
# 检查是否还在接收图片
if image_state == IMAGE_STATE_IDLE:
break
except Exception as e:
print(f"Stop recording error: {e}")
ws = None
gc.collect()
time.sleep(0.01)
except Exception as e: