309 lines
9.7 KiB
Python
309 lines
9.7 KiB
Python
import machine
|
||
import time
|
||
import struct
|
||
import gc
|
||
import network
|
||
import st7789py as st7789
|
||
from config import CURRENT_CONFIG
|
||
from audio import AudioPlayer, Microphone
|
||
from display import Display
|
||
from websocket_client import WebSocketClient
|
||
import uselect
|
||
import ujson
|
||
|
||
WIFI_SSID = "Tangledup-AI"
|
||
WIFI_PASS = "djt12345678"
|
||
SERVER_IP = "6.6.6.88"
|
||
SERVER_PORT = 8000
|
||
SERVER_URL = f"ws://{SERVER_IP}:{SERVER_PORT}/ws/audio"
|
||
|
||
# 图片接收状态
|
||
IMAGE_STATE_IDLE = 0
|
||
IMAGE_STATE_RECEIVING = 1
|
||
|
||
IMG_WIDTH = 120
|
||
IMG_HEIGHT = 120
|
||
|
||
|
||
def connect_wifi(max_retries=5):
|
||
wlan = network.WLAN(network.STA_IF)
|
||
|
||
try:
|
||
wlan.active(False)
|
||
time.sleep(2)
|
||
wlan.active(True)
|
||
time.sleep(3)
|
||
except Exception as e:
|
||
print(f"WiFi init error: {e}")
|
||
return False
|
||
|
||
for attempt in range(max_retries):
|
||
try:
|
||
if wlan.isconnected():
|
||
print('WiFi connected')
|
||
return True
|
||
|
||
print(f'Connecting to WiFi {WIFI_SSID}...')
|
||
wlan.connect(WIFI_SSID, WIFI_PASS)
|
||
|
||
start_time = time.ticks_ms()
|
||
while not wlan.isconnected():
|
||
if time.ticks_diff(time.ticks_ms(), start_time) > 30000:
|
||
print("WiFi timeout!")
|
||
break
|
||
time.sleep(0.5)
|
||
print(".", end="")
|
||
|
||
if wlan.isconnected():
|
||
print('\nWiFi connected!')
|
||
return True
|
||
|
||
if attempt < max_retries - 1:
|
||
print(f"\nRetry {attempt + 1}/{max_retries}...")
|
||
wlan.disconnect()
|
||
time.sleep(3)
|
||
|
||
except Exception as e:
|
||
print(f"WiFi error: {e}")
|
||
if attempt < max_retries - 1:
|
||
time.sleep(5)
|
||
|
||
print("WiFi connection failed!")
|
||
return False
|
||
|
||
|
||
def print_asr(text, display=None):
|
||
print(f"ASR: {text}")
|
||
if display and display.tft:
|
||
display.fill_rect(0, 40, 240, 160, st7789.BLACK)
|
||
display.text(text, 0, 40, st7789.WHITE)
|
||
|
||
|
||
def process_message(msg, display, image_state, image_data_list):
|
||
"""处理WebSocket消息,返回新的image_state"""
|
||
if not isinstance(msg, str):
|
||
return image_state
|
||
|
||
# 处理ASR消息
|
||
if msg.startswith("ASR:"):
|
||
print_asr(msg[4:], display)
|
||
|
||
# 处理图片生成状态消息
|
||
elif msg.startswith("GENERATING_IMAGE:"):
|
||
print(msg)
|
||
if display and display.tft:
|
||
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
|
||
display.text("正在生成图片...", 0, 40, st7789.YELLOW)
|
||
|
||
# 处理提示词优化消息
|
||
elif msg.startswith("PROMPT:"):
|
||
prompt = msg[7:]
|
||
print(f"Optimized prompt: {prompt}")
|
||
if display and display.tft:
|
||
display.fill_rect(0, 60, 240, 40, st7789.BLACK)
|
||
display.text("提示词: " + prompt[:20], 0, 60, st7789.CYAN)
|
||
|
||
# 处理图片开始消息
|
||
elif msg.startswith("IMAGE_START:"):
|
||
try:
|
||
parts = msg.split(":")
|
||
size = int(parts[1])
|
||
img_size = int(parts[2]) if len(parts) > 2 else 64
|
||
print(f"Image start, size: {size}, img_size: {img_size}")
|
||
image_data_list.clear()
|
||
image_data_list.append(img_size) # 保存图片尺寸
|
||
return IMAGE_STATE_RECEIVING
|
||
except Exception as e:
|
||
print(f"IMAGE_START parse error: {e}")
|
||
return image_state
|
||
|
||
# 处理图片数据消息
|
||
elif msg.startswith("IMAGE_DATA:") and image_state == IMAGE_STATE_RECEIVING:
|
||
try:
|
||
data = msg.split(":", 1)[1]
|
||
image_data_list.append(data)
|
||
except:
|
||
pass
|
||
|
||
# 处理图片结束消息
|
||
elif msg == "IMAGE_END" and image_state == IMAGE_STATE_RECEIVING:
|
||
try:
|
||
print("Image received, processing...")
|
||
|
||
# 获取图片尺寸
|
||
img_size = image_data_list[0] if image_data_list else 64
|
||
hex_data = "".join(image_data_list[1:])
|
||
image_data_list.clear()
|
||
|
||
# 将hex字符串转换为字节数据
|
||
img_data = bytes.fromhex(hex_data)
|
||
|
||
print(f"Image data len: {len(img_data)}")
|
||
|
||
# 在屏幕中心显示图片
|
||
if display and display.tft:
|
||
# 计算居中位置
|
||
x = (240 - img_size) // 2
|
||
y = (240 - img_size) // 2
|
||
|
||
# 显示图片
|
||
display.show_image(x, y, img_size, img_size, img_data)
|
||
|
||
display.fill_rect(0, 0, 240, 30, st7789.WHITE)
|
||
display.text("图片已生成!", 0, 5, st7789.BLACK)
|
||
|
||
gc.collect()
|
||
print("Image displayed")
|
||
|
||
except Exception as e:
|
||
print(f"Image process error: {e}")
|
||
|
||
return IMAGE_STATE_IDLE
|
||
|
||
# 处理图片错误消息
|
||
elif msg.startswith("IMAGE_ERROR:"):
|
||
print(msg)
|
||
if display and display.tft:
|
||
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
|
||
display.text("图片生成失败", 0, 40, st7789.RED)
|
||
return IMAGE_STATE_IDLE
|
||
|
||
return image_state
|
||
|
||
|
||
def main():
|
||
print("\n=== ESP32 Audio ASR ===\n")
|
||
|
||
boot_btn = machine.Pin(0, machine.Pin.IN, machine.Pin.PULL_UP)
|
||
|
||
bl_pin = CURRENT_CONFIG.pins.get('bl')
|
||
if bl_pin:
|
||
try:
|
||
bl = machine.Pin(bl_pin, machine.Pin.OUT)
|
||
bl.on()
|
||
except:
|
||
pass
|
||
|
||
speaker = AudioPlayer()
|
||
mic = Microphone()
|
||
display = Display()
|
||
|
||
if display.tft:
|
||
display.init_ui()
|
||
|
||
is_recording = False
|
||
ws = None
|
||
image_state = IMAGE_STATE_IDLE
|
||
image_data_list = []
|
||
|
||
def connect_ws():
|
||
nonlocal ws
|
||
try:
|
||
if ws:
|
||
ws.close()
|
||
except:
|
||
pass
|
||
ws = None
|
||
|
||
try:
|
||
print(f"Connecting to {SERVER_URL}")
|
||
ws = WebSocketClient(SERVER_URL)
|
||
print("WebSocket connected!")
|
||
if display:
|
||
display.set_ws(ws)
|
||
return True
|
||
except Exception as e:
|
||
print(f"WS connection failed: {e}")
|
||
return False
|
||
|
||
if connect_wifi():
|
||
connect_ws()
|
||
else:
|
||
print("Running in offline mode")
|
||
|
||
read_buf = bytearray(4096)
|
||
|
||
while True:
|
||
try:
|
||
btn_val = boot_btn.value()
|
||
|
||
if btn_val == 0:
|
||
if not is_recording:
|
||
print(">>> Recording...")
|
||
is_recording = True
|
||
if display.tft:
|
||
display.fill(st7789.WHITE)
|
||
|
||
if ws is None or not ws.is_connected():
|
||
connect_ws()
|
||
|
||
if ws and ws.is_connected():
|
||
try:
|
||
ws.send("START_RECORDING")
|
||
except:
|
||
ws = None
|
||
|
||
if mic.i2s:
|
||
num_read = mic.readinto(read_buf)
|
||
if num_read > 0:
|
||
if ws and ws.is_connected():
|
||
try:
|
||
ws.send(read_buf[:num_read], opcode=2)
|
||
|
||
poller = uselect.poll()
|
||
poller.register(ws.sock, uselect.POLLIN)
|
||
events = poller.poll(0)
|
||
if events:
|
||
msg = ws.recv()
|
||
image_state = process_message(msg, display, image_state, image_data_list)
|
||
|
||
except:
|
||
ws = None
|
||
|
||
continue
|
||
|
||
elif is_recording:
|
||
print(">>> Stop")
|
||
is_recording = False
|
||
|
||
if display.tft:
|
||
display.init_ui()
|
||
|
||
if ws:
|
||
try:
|
||
ws.send("STOP_RECORDING")
|
||
|
||
# 等待更长时间以接收图片生成结果
|
||
t_wait = time.ticks_add(time.ticks_ms(), 30000) # 等待30秒
|
||
prev_image_state = image_state
|
||
while time.ticks_diff(t_wait, time.ticks_ms()) > 0:
|
||
poller = uselect.poll()
|
||
poller.register(ws.sock, uselect.POLLIN)
|
||
events = poller.poll(500)
|
||
if events:
|
||
msg = ws.recv()
|
||
prev_image_state = image_state
|
||
image_state = process_message(msg, display, image_state, image_data_list)
|
||
# 如果之前在接收图片,现在停止了,说明图片接收完成
|
||
if prev_image_state == IMAGE_STATE_RECEIVING and image_state == IMAGE_STATE_IDLE:
|
||
break
|
||
else:
|
||
# 检查是否还在接收图片
|
||
if image_state == IMAGE_STATE_IDLE:
|
||
break
|
||
except Exception as e:
|
||
print(f"Stop recording error: {e}")
|
||
ws = None
|
||
|
||
gc.collect()
|
||
|
||
time.sleep(0.01)
|
||
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
time.sleep(1)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
main()
|