603 lines
22 KiB
Python
603 lines
22 KiB
Python
import machine
|
||
import time
|
||
import struct
|
||
import gc
|
||
import network
|
||
import st7789py as st7789
|
||
from config import CURRENT_CONFIG, SERVER_URL, ttl_tx, ttl_rx
|
||
from audio import AudioPlayer, Microphone
|
||
import convert_img
|
||
|
||
# Define colors that might be missing in st7789py
|
||
DARKGREY = 0x4208
|
||
from display import Display
|
||
from websocket_client import WebSocketClient
|
||
import uselect
|
||
import ujson
|
||
|
||
WIFI_SSID = "Tangledup-AI"
|
||
WIFI_PASS = "djt12345678"
|
||
|
||
IMAGE_STATE_IDLE = 0
|
||
IMAGE_STATE_RECEIVING = 1
|
||
PRINTER_STATE_RECEIVING = 2
|
||
|
||
UI_SCREEN_HOME = 0
|
||
UI_SCREEN_RECORDING = 1
|
||
UI_SCREEN_CONFIRM = 2
|
||
UI_SCREEN_RESULT = 3
|
||
|
||
BOOT_SHORT_MS = 100
|
||
BOOT_LONG_MS = 2000
|
||
BOOT_EXTRA_LONG_MS = 5000
|
||
|
||
IMG_WIDTH = 120
|
||
IMG_HEIGHT = 120
|
||
|
||
_last_btn_state = None
|
||
_btn_release_time = 0
|
||
_btn_press_time = 0
|
||
|
||
|
||
def connect_wifi(display=None, max_retries=5):
|
||
"""连接WiFi网络"""
|
||
wlan = network.WLAN(network.STA_IF)
|
||
|
||
try:
|
||
wlan.active(False)
|
||
time.sleep(2)
|
||
wlan.active(True)
|
||
time.sleep(3)
|
||
except Exception as e:
|
||
print(f"WiFi init error: {e}")
|
||
if display and display.tft:
|
||
display.render_wifi_status(False)
|
||
return False
|
||
|
||
if display and display.tft:
|
||
display.render_wifi_connecting()
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
if wlan.isconnected():
|
||
print('WiFi connected')
|
||
if display and display.tft:
|
||
display.render_wifi_status(True)
|
||
time.sleep(1.5)
|
||
return True
|
||
|
||
print(f'Connecting to WiFi {WIFI_SSID}...')
|
||
wlan.connect(WIFI_SSID, WIFI_PASS)
|
||
|
||
start_time = time.ticks_ms()
|
||
spinner_angle = 0
|
||
while not wlan.isconnected():
|
||
if time.ticks_diff(time.ticks_ms(), start_time) > 30000:
|
||
print("WiFi timeout!")
|
||
break
|
||
time.sleep(0.1)
|
||
print(".", end="")
|
||
|
||
# 简单的加载动画
|
||
if display and display.tft:
|
||
if time.ticks_ms() % 200 < 50: # 节流刷新
|
||
display.draw_loading_spinner(120, 150, spinner_angle, st7789.CYAN)
|
||
spinner_angle = (spinner_angle + 45) % 360
|
||
|
||
if wlan.isconnected():
|
||
print('\nWiFi connected!')
|
||
if display and display.tft:
|
||
display.render_wifi_status(True)
|
||
time.sleep(1.5)
|
||
return True
|
||
|
||
if attempt < max_retries - 1:
|
||
print(f"\nRetry {attempt + 1}/{max_retries}...")
|
||
wlan.disconnect()
|
||
time.sleep(3)
|
||
if display and display.tft:
|
||
display.text(f"重试 {attempt + 1}/{max_retries}...", 80, 180, st7789.YELLOW, wait=False)
|
||
|
||
except Exception as e:
|
||
print(f"WiFi error: {e}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(5)
|
||
|
||
print("WiFi connection failed!")
|
||
if display and display.tft:
|
||
display.render_wifi_status(False)
|
||
time.sleep(3)
|
||
return False
|
||
|
||
|
||
|
||
|
||
|
||
def process_message(msg, display, image_state, image_data_list, printer_uart=None):
|
||
"""处理WebSocket消息"""
|
||
# Handle binary image data
|
||
if isinstance(msg, (bytes, bytearray)):
|
||
if image_state == IMAGE_STATE_RECEIVING:
|
||
try:
|
||
if len(image_data_list) < 3:
|
||
# 异常情况,重置
|
||
return IMAGE_STATE_IDLE, None
|
||
|
||
width = image_data_list[0]
|
||
height = image_data_list[1]
|
||
current_offset = image_data_list[2]
|
||
|
||
# Stream directly to display
|
||
if display and display.tft:
|
||
x = (240 - width) // 2
|
||
y = (240 - height) // 2
|
||
display.show_image_chunk(x, y, width, height, msg, current_offset)
|
||
|
||
# Update offset
|
||
image_data_list[2] += len(msg)
|
||
|
||
except Exception as e:
|
||
print(f"Stream image error: {e}")
|
||
|
||
return image_state, None
|
||
|
||
elif image_state == PRINTER_STATE_RECEIVING:
|
||
if printer_uart:
|
||
chunk_size = 128
|
||
for i in range(0, len(msg), chunk_size):
|
||
chunk = msg[i:i+chunk_size]
|
||
printer_uart.write(chunk)
|
||
time.sleep_ms(5)
|
||
return image_state, None
|
||
|
||
return image_state, None
|
||
|
||
if not isinstance(msg, str):
|
||
return image_state, None
|
||
|
||
# Check for font data first
|
||
if display and hasattr(display, 'font') and display.font.handle_message(msg):
|
||
return image_state, ("font_update",)
|
||
|
||
status_info = None
|
||
|
||
if msg.startswith("ASR:"):
|
||
print_asr(msg[4:], display)
|
||
return image_state, ("asr", msg[4:])
|
||
|
||
elif msg.startswith("PRINTER_DATA_START:"):
|
||
print(f"Start receiving printer data...")
|
||
return PRINTER_STATE_RECEIVING, ("printer_start",)
|
||
|
||
elif msg == "PRINTER_DATA_END":
|
||
print("Printer data received completely")
|
||
# 发送打印完成的回车
|
||
if printer_uart:
|
||
printer_uart.write(b'\r\n')
|
||
return IMAGE_STATE_IDLE, ("printer_done",)
|
||
|
||
elif msg.startswith("STATUS:"):
|
||
parts = msg[7:].split(":", 1)
|
||
status_type = parts[0]
|
||
status_text = parts[1] if len(parts) > 1 else ""
|
||
print(f"Status: {status_type} - {status_text}")
|
||
return image_state, ("status", status_type, status_text)
|
||
|
||
elif msg.startswith("GENERATING_IMAGE:"):
|
||
# Deprecated by STATUS:RENDERING but kept for compatibility
|
||
return image_state, None
|
||
|
||
elif msg.startswith("PROMPT:"):
|
||
prompt = msg[7:]
|
||
print(f"Optimized prompt: {prompt}")
|
||
return image_state, ("prompt", prompt)
|
||
|
||
elif msg.startswith("IMAGE_START:"):
|
||
try:
|
||
parts = msg.split(":")
|
||
size = int(parts[1])
|
||
|
||
width = 64
|
||
height = 64
|
||
|
||
if len(parts) >= 4:
|
||
width = int(parts[2])
|
||
height = int(parts[3])
|
||
elif len(parts) == 3:
|
||
width = int(parts[2])
|
||
height = int(parts[2]) # assume square
|
||
|
||
print(f"Image start, size: {size}, dim: {width}x{height}")
|
||
image_data_list.clear()
|
||
image_data_list.append(width) # index 0
|
||
image_data_list.append(height) # index 1
|
||
image_data_list.append(0) # index 2: offset
|
||
|
||
# Prepare display for streaming
|
||
if display and display.tft:
|
||
# Clear screen area where image will be
|
||
# optional, but good practice if new image is smaller
|
||
pass
|
||
|
||
return IMAGE_STATE_RECEIVING, None
|
||
except Exception as e:
|
||
print(f"IMAGE_START parse error: {e}")
|
||
return image_state, None
|
||
|
||
# Deprecated text-based IMAGE_DATA handling
|
||
elif msg.startswith("IMAGE_DATA:") and image_state == IMAGE_STATE_RECEIVING:
|
||
pass
|
||
|
||
elif msg == "IMAGE_END" and image_state == IMAGE_STATE_RECEIVING:
|
||
print("Image received completely")
|
||
image_data_list.clear()
|
||
gc.collect()
|
||
return IMAGE_STATE_IDLE, ("image_done",)
|
||
|
||
elif msg.startswith("IMAGE_ERROR:"):
|
||
print(msg)
|
||
return IMAGE_STATE_IDLE, ("error", msg[12:])
|
||
|
||
return image_state, None
|
||
|
||
|
||
def print_asr(text, display=None):
|
||
"""打印ASR结果"""
|
||
print(f"ASR: {text}")
|
||
|
||
|
||
def get_boot_button_action(boot_btn):
|
||
"""获取Boot按键动作类型
|
||
|
||
返回:
|
||
0: 无动作
|
||
1: 短按 (<500ms)
|
||
2: 长按 (2-5秒)
|
||
3: 超长按 (>5秒)
|
||
"""
|
||
global _last_btn_state, _btn_release_time, _btn_press_time
|
||
|
||
current_value = boot_btn.value()
|
||
current_time = time.ticks_ms()
|
||
|
||
if current_value == 0:
|
||
if _last_btn_state != 0:
|
||
_last_btn_state = 0
|
||
_btn_press_time = current_time
|
||
return 0
|
||
|
||
if current_value == 1 and _last_btn_state == 0:
|
||
press_duration = time.ticks_diff(current_time, _btn_press_time)
|
||
_last_btn_state = 1
|
||
|
||
if press_duration < BOOT_SHORT_MS:
|
||
return 0
|
||
elif press_duration < BOOT_LONG_MS:
|
||
return 1
|
||
elif press_duration < BOOT_EXTRA_LONG_MS:
|
||
return 2
|
||
else:
|
||
return 3
|
||
|
||
if _last_btn_state is None:
|
||
_last_btn_state = current_value
|
||
_btn_release_time = current_time
|
||
|
||
return 0
|
||
|
||
|
||
def check_memory(silent=False):
|
||
"""检查内存使用情况
|
||
|
||
Args:
|
||
silent: 是否静默模式(不打印日志)
|
||
"""
|
||
free = gc.mem_free()
|
||
total = gc.mem_alloc() + free
|
||
usage = (gc.mem_alloc() / total) * 100 if total > 0 else 0
|
||
if not silent:
|
||
print(f"Memory: {free} free, {usage:.1f}% used")
|
||
return usage
|
||
|
||
|
||
def main():
|
||
print("\n=== ESP32 Audio ASR ===\n")
|
||
|
||
boot_btn = machine.Pin(0, machine.Pin.IN, machine.Pin.PULL_UP)
|
||
|
||
bl_pin = CURRENT_CONFIG.pins.get('bl')
|
||
if bl_pin:
|
||
try:
|
||
bl = machine.Pin(bl_pin, machine.Pin.OUT)
|
||
bl.on()
|
||
except:
|
||
pass
|
||
|
||
speaker = AudioPlayer()
|
||
mic = Microphone()
|
||
display = Display()
|
||
|
||
# 初始化打印机 UART
|
||
printer_uart = machine.UART(1, baudrate=115200, tx=ttl_tx, rx=ttl_rx)
|
||
|
||
if display.tft:
|
||
display.init_ui()
|
||
display.render_home_screen()
|
||
time.sleep(2)
|
||
|
||
ui_screen = UI_SCREEN_HOME
|
||
is_recording = False
|
||
ws = None
|
||
image_state = IMAGE_STATE_IDLE
|
||
image_data_list = []
|
||
current_asr_text = ""
|
||
current_prompt = ""
|
||
current_status = ""
|
||
image_generation_done = False
|
||
confirm_waiting = False
|
||
recording_stop_time = 0
|
||
|
||
def connect_ws(force=False):
|
||
nonlocal ws
|
||
try:
|
||
if ws:
|
||
ws.close()
|
||
except:
|
||
pass
|
||
ws = None
|
||
|
||
retry_count = 0
|
||
max_retries = 3
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
print(f"Connecting to {SERVER_URL} (attempt {retry_count + 1})")
|
||
if display and display.tft:
|
||
display.tft.fill_rect(0, 220, 240, 20, st7789.BLACK)
|
||
display.text(f"连接服务器...({retry_count+1})", 60, 220, st7789.CYAN, wait=False)
|
||
|
||
ws = WebSocketClient(SERVER_URL)
|
||
print("WebSocket connected!")
|
||
if display:
|
||
display.set_ws(ws)
|
||
# 预热字体,请求常用字
|
||
# 可以在这里发一个 GET_HIGH_FREQ 请求,或者简单的不做处理,因为 render_home_screen 已经触发了部分
|
||
|
||
return True
|
||
except Exception as e:
|
||
print(f"WS connection failed: {e}")
|
||
retry_count += 1
|
||
time.sleep(1)
|
||
|
||
if display and display.tft:
|
||
display.text("服务器连接失败", 60, 220, st7789.RED, wait=False)
|
||
time.sleep(2)
|
||
return False
|
||
|
||
if connect_wifi(display):
|
||
connect_ws()
|
||
# WiFi 和 WS 都连接成功后,进入录音界面
|
||
ui_screen = UI_SCREEN_RECORDING
|
||
if display.tft:
|
||
display.render_recording_screen("", 0, False)
|
||
else:
|
||
print("Running in offline mode")
|
||
# 即使离线也进入录音界面(虽然不能用)
|
||
ui_screen = UI_SCREEN_RECORDING
|
||
if display.tft:
|
||
display.render_recording_screen("离线模式", 0, False)
|
||
|
||
read_buf = bytearray(4096)
|
||
last_audio_level = 0
|
||
memory_check_counter = 0
|
||
spinner_angle = 0
|
||
last_spinner_time = 0
|
||
wait_for_release = False
|
||
|
||
while True:
|
||
try:
|
||
memory_check_counter += 1
|
||
|
||
if memory_check_counter >= 300:
|
||
memory_check_counter = 0
|
||
if check_memory(silent=True) > 80:
|
||
gc.collect()
|
||
print("Memory high, cleaned")
|
||
|
||
# Spinner Animation
|
||
if ui_screen == UI_SCREEN_RESULT and not image_generation_done and current_status in ["OPTIMIZING", "RENDERING"] and image_state != IMAGE_STATE_RECEIVING:
|
||
now = time.ticks_ms()
|
||
if time.ticks_diff(now, last_spinner_time) > 100:
|
||
if display.tft:
|
||
# Clear previous spinner (draw in BLACK)
|
||
display.draw_loading_spinner(110, 80, spinner_angle, st7789.BLACK)
|
||
|
||
spinner_angle = (spinner_angle + 45) % 360
|
||
|
||
# Draw new spinner
|
||
color = st7789.CYAN if current_status == "OPTIMIZING" else st7789.YELLOW
|
||
display.draw_loading_spinner(110, 80, spinner_angle, color)
|
||
|
||
last_spinner_time = now
|
||
|
||
btn_action = get_boot_button_action(boot_btn)
|
||
|
||
# ASR timeout check
|
||
if ui_screen == UI_SCREEN_CONFIRM and confirm_waiting:
|
||
if time.ticks_diff(time.ticks_ms(), recording_stop_time) > 2000:
|
||
confirm_waiting = False
|
||
if display.tft:
|
||
display.render_confirm_screen("", waiting=False)
|
||
|
||
# Hold to Record Logic (Press to Start, Release to Stop)
|
||
if ui_screen == UI_SCREEN_RECORDING:
|
||
if boot_btn.value() == 0 and not is_recording:
|
||
print(">>> Start recording (Hold)")
|
||
is_recording = True
|
||
confirm_waiting = False
|
||
current_asr_text = ""
|
||
current_prompt = ""
|
||
current_status = ""
|
||
image_generation_done = False
|
||
if display.tft:
|
||
display.render_recording_screen("", 0, True)
|
||
if ws is None or not ws.is_connected():
|
||
connect_ws()
|
||
if ws and ws.is_connected():
|
||
try:
|
||
ws.send("START_RECORDING")
|
||
except:
|
||
ws = None
|
||
elif boot_btn.value() == 1 and is_recording:
|
||
print(">>> Stop recording (Release)")
|
||
if ws and ws.is_connected():
|
||
try:
|
||
ws.send("STOP_RECORDING")
|
||
except:
|
||
ws = None
|
||
is_recording = False
|
||
ui_screen = UI_SCREEN_CONFIRM
|
||
image_generation_done = False
|
||
|
||
# 启动等待计时
|
||
confirm_waiting = True
|
||
recording_stop_time = time.ticks_ms()
|
||
|
||
if display.tft:
|
||
display.render_confirm_screen(current_asr_text, waiting=True)
|
||
# Consume action to prevent triggering other events
|
||
btn_action = 0
|
||
|
||
if btn_action == 1:
|
||
if ui_screen == UI_SCREEN_CONFIRM:
|
||
print(">>> Confirm and generate")
|
||
if ws and ws.is_connected():
|
||
try:
|
||
ws.send(f"GENERATE_IMAGE:{current_asr_text}")
|
||
except:
|
||
ws = None
|
||
is_recording = False
|
||
ui_screen = UI_SCREEN_RESULT
|
||
image_generation_done = False
|
||
if display.tft:
|
||
display.render_result_screen("OPTIMIZING", current_asr_text, False)
|
||
time.sleep(0.5)
|
||
elif ui_screen == UI_SCREEN_RESULT:
|
||
# Re-record
|
||
print(">>> Re-record (Short Press)")
|
||
current_asr_text = ""
|
||
confirm_waiting = False
|
||
ui_screen = UI_SCREEN_RECORDING
|
||
is_recording = False
|
||
image_generation_done = False
|
||
if display.tft:
|
||
display.render_recording_screen("", 0, False)
|
||
time.sleep(0.5)
|
||
|
||
elif btn_action == 2:
|
||
if ui_screen == UI_SCREEN_CONFIRM:
|
||
print(">>> Re-record")
|
||
current_asr_text = ""
|
||
confirm_waiting = False
|
||
ui_screen = UI_SCREEN_RECORDING
|
||
is_recording = False
|
||
image_generation_done = False
|
||
if display.tft:
|
||
display.render_recording_screen("", 0, False)
|
||
time.sleep(0.5)
|
||
elif ui_screen == UI_SCREEN_RESULT:
|
||
# Print Image
|
||
print(">>> Print Image (Long Press)")
|
||
if ws and ws.is_connected():
|
||
try:
|
||
ws.send("PRINT_IMAGE")
|
||
if display.tft:
|
||
display.render_result_screen("PRINTING", "正在请求打印...", True)
|
||
except:
|
||
ws = None
|
||
time.sleep(0.5)
|
||
|
||
elif btn_action == 3:
|
||
print(">>> Config mode")
|
||
|
||
if is_recording and btn_action == 0:
|
||
if mic.i2s:
|
||
num_read = mic.readinto(read_buf)
|
||
if num_read > 0:
|
||
if ws and ws.is_connected():
|
||
try:
|
||
ws.send(read_buf[:num_read], opcode=2)
|
||
|
||
# 移除录音时的消息接收,确保录音流畅
|
||
except:
|
||
ws = None
|
||
|
||
# 在录音结束后(CONFIRM状态)或 RESULT 状态,才接收消息
|
||
if (ui_screen == UI_SCREEN_CONFIRM or ui_screen == UI_SCREEN_RESULT or ui_screen == UI_SCREEN_RECORDING) and not is_recording:
|
||
if ws and ws.is_connected():
|
||
try:
|
||
poller = uselect.poll()
|
||
poller.register(ws.sock, uselect.POLLIN)
|
||
events = poller.poll(100)
|
||
if events:
|
||
msg = ws.recv()
|
||
if msg:
|
||
image_state, event_data = process_message(msg, display, image_state, image_data_list, printer_uart)
|
||
|
||
if event_data:
|
||
if event_data[0] == "asr":
|
||
current_asr_text = event_data[1]
|
||
print(f"Received ASR: {current_asr_text}")
|
||
confirm_waiting = False
|
||
|
||
# 收到 ASR 结果,跳转到 CONFIRM 界面
|
||
if ui_screen == UI_SCREEN_RECORDING or ui_screen == UI_SCREEN_CONFIRM:
|
||
ui_screen = UI_SCREEN_CONFIRM
|
||
if display.tft:
|
||
display.render_confirm_screen(current_asr_text, waiting=False)
|
||
|
||
elif event_data[0] == "font_update":
|
||
# 如果还在录音界面等待,刷新一下(虽然可能已经跳到 CONFIRM 了)
|
||
pass
|
||
|
||
elif event_data[0] == "status":
|
||
current_status = event_data[1]
|
||
status_text = event_data[2] if len(event_data) > 2 else ""
|
||
if display.tft and ui_screen == UI_SCREEN_RESULT:
|
||
display.render_result_screen(current_status, current_prompt, image_generation_done)
|
||
|
||
elif event_data[0] == "prompt":
|
||
current_prompt = event_data[1]
|
||
if display.tft and ui_screen == UI_SCREEN_RESULT:
|
||
display.render_result_screen(current_status, current_prompt, image_generation_done)
|
||
|
||
elif event_data[0] == "image_done":
|
||
image_generation_done = True
|
||
if display.tft and ui_screen == UI_SCREEN_RESULT:
|
||
display.render_result_screen("COMPLETE", current_prompt, True)
|
||
|
||
elif event_data[0] == "error":
|
||
if display.tft and ui_screen == UI_SCREEN_RESULT:
|
||
display.render_result_screen("ERROR", current_prompt, False)
|
||
|
||
elif event_data[0] == "printer_start":
|
||
if display.tft and ui_screen == UI_SCREEN_RESULT:
|
||
display.render_result_screen("PRINTING", "正在打印...", True)
|
||
|
||
elif event_data[0] == "printer_done":
|
||
if display.tft and ui_screen == UI_SCREEN_RESULT:
|
||
display.render_result_screen("COMPLETE", "打印完成", True)
|
||
time.sleep(1.0)
|
||
except Exception as e:
|
||
print(f"WS Recv Error: {e}")
|
||
|
||
time.sleep(0.01)
|
||
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
time.sleep(1)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|