1
This commit is contained in:
@@ -18,7 +18,8 @@ NON_CAMERA.pins = {
|
||||
'sck': 9, # SPI CLK / SCK
|
||||
'dc': 46, # Data/Command
|
||||
'rst': 11, # Reset
|
||||
'cs': 12 # Chip Select
|
||||
'cs': 12, # Chip Select
|
||||
'btn': 0 # Boot按键
|
||||
}
|
||||
NON_CAMERA.audio = {
|
||||
'enabled': True,
|
||||
|
||||
183
font.py
183
font.py
@@ -2,29 +2,39 @@ import framebuf
|
||||
import struct
|
||||
import time
|
||||
import binascii
|
||||
import gc
|
||||
|
||||
class Font:
|
||||
def __init__(self, ws=None):
|
||||
self.ws = ws
|
||||
self.cache = {} # Simple cache for font bitmaps: {code: bytes}
|
||||
self.cache = {}
|
||||
self.pending_requests = set()
|
||||
self.retry_count = {}
|
||||
self.max_retries = 3
|
||||
|
||||
def set_ws(self, ws):
|
||||
self.ws = ws
|
||||
|
||||
def clear_cache(self):
|
||||
"""清除字体缓存以释放内存"""
|
||||
self.cache.clear()
|
||||
gc.collect()
|
||||
|
||||
def get_cache_size(self):
|
||||
"""获取当前缓存的字体数量"""
|
||||
return len(self.cache)
|
||||
|
||||
def text(self, tft, text, x, y, color, bg=0x0000):
|
||||
"""
|
||||
Draw text on ST7789 display using WebSocket to fetch fonts
|
||||
"""
|
||||
# Pre-calculate color bytes
|
||||
"""在ST7789显示器上绘制文本"""
|
||||
if not text:
|
||||
return
|
||||
|
||||
color_bytes = struct.pack(">H", color)
|
||||
bg_bytes = struct.pack(">H", bg)
|
||||
|
||||
# Create LUT for current color/bg
|
||||
lut = [bytearray(16) for _ in range(256)]
|
||||
for i in range(256):
|
||||
for bit in range(8):
|
||||
# bit 7 is first pixel (leftmost)
|
||||
# target index: (7-bit)*2
|
||||
val = (i >> bit) & 1
|
||||
idx = (7 - bit) * 2
|
||||
if val:
|
||||
@@ -36,7 +46,6 @@ class Font:
|
||||
|
||||
initial_x = x
|
||||
|
||||
# 1. Identify missing fonts
|
||||
missing_codes = set()
|
||||
for char in text:
|
||||
if ord(char) > 127:
|
||||
@@ -44,12 +53,8 @@ class Font:
|
||||
if code not in self.cache:
|
||||
missing_codes.add(code)
|
||||
|
||||
# 2. Batch request missing fonts
|
||||
if missing_codes and self.ws:
|
||||
# Convert to list for consistent order/string
|
||||
missing_list = list(missing_codes)
|
||||
# Limit batch size? Maybe 20 chars at a time?
|
||||
# For short ASR result, usually < 20 chars.
|
||||
|
||||
req_str = ",".join([str(c) for c in missing_list])
|
||||
print(f"Batch requesting fonts: {req_str}")
|
||||
@@ -59,15 +64,12 @@ class Font:
|
||||
except Exception as e:
|
||||
print(f"Batch font request failed: {e}")
|
||||
|
||||
# 3. Draw text
|
||||
for char in text:
|
||||
# Handle newlines
|
||||
if char == '\n':
|
||||
x = initial_x
|
||||
y += 16
|
||||
continue
|
||||
|
||||
# Boundary check
|
||||
if x + 16 > tft.width:
|
||||
x = initial_x
|
||||
y += 16
|
||||
@@ -77,121 +79,79 @@ class Font:
|
||||
is_chinese = False
|
||||
buf_data = None
|
||||
|
||||
# Check if it's Chinese (or non-ASCII)
|
||||
if ord(char) > 127:
|
||||
code = ord(char)
|
||||
if code in self.cache:
|
||||
buf_data = self.cache[code]
|
||||
is_chinese = True
|
||||
else:
|
||||
# Still missing after batch request?
|
||||
# Could be timeout or invalid char.
|
||||
pass
|
||||
if code in self.pending_requests:
|
||||
retry = self.retry_count.get(code, 0)
|
||||
if retry < self.max_retries:
|
||||
self.retry_count[code] = retry + 1
|
||||
self._request_single_font(code)
|
||||
|
||||
if is_chinese and buf_data:
|
||||
# Draw Chinese character (16x16)
|
||||
self._draw_bitmap(tft, buf_data, x, y, 16, 16, lut)
|
||||
x += 16
|
||||
else:
|
||||
# Draw ASCII (8x16) using built-in framebuf font (8x8 actually)
|
||||
# If char is not ASCII, replace with '?' to avoid framebuf errors
|
||||
if ord(char) > 127:
|
||||
char = '?'
|
||||
self._draw_ascii(tft, char, x, y, color, bg)
|
||||
x += 8
|
||||
|
||||
def _request_single_font(self, code):
|
||||
"""请求单个字体"""
|
||||
if self.ws:
|
||||
try:
|
||||
self.ws.send(f"GET_FONT_UNICODE:{code}")
|
||||
except:
|
||||
pass
|
||||
|
||||
def _wait_for_fonts(self, target_codes):
|
||||
"""
|
||||
Blocking wait for a set of font codes.
|
||||
Buffers other messages to self.ws.unread_messages.
|
||||
"""
|
||||
"""等待字体数据返回"""
|
||||
if not self.ws or not target_codes:
|
||||
return
|
||||
|
||||
start = time.ticks_ms()
|
||||
self.local_deferred = []
|
||||
|
||||
# 2 seconds timeout for batch
|
||||
while time.ticks_diff(time.ticks_ms(), start) < 2000 and target_codes:
|
||||
|
||||
# Check unread_messages first?
|
||||
# Actually ws.recv() in our modified client already checks unread_messages.
|
||||
# But wait, if we put something BACK into unread_messages, we need to be careful not to read it again immediately if we are looping?
|
||||
# No, we only put NON-FONT messages back. We are looking for FONT messages.
|
||||
# So if we pop a non-font message, we put it back?
|
||||
# If we put it back at head, we will read it again next loop! Infinite loop!
|
||||
#
|
||||
# Solution: We should NOT use ws.recv() which pops from unread.
|
||||
# We should assume unread_messages might contain what we need?
|
||||
#
|
||||
# Actually, `ws.recv()` pops from `unread_messages`.
|
||||
# If we get a message that is NOT what we want, we should store it in a temporary list, and push them all back at the end?
|
||||
# Or append to `unread_messages` (if it's a queue).
|
||||
# But `unread_messages` is used as a LIFO or FIFO?
|
||||
# pop(0) -> FIFO.
|
||||
# If we append, it goes to end.
|
||||
# So:
|
||||
# 1. recv() -> gets msg.
|
||||
# 2. Is it font?
|
||||
# Yes -> process.
|
||||
# No -> append to `temp_buffer`.
|
||||
# 3. After function finishes (or timeout), extend `unread_messages` with `temp_buffer`?
|
||||
# Wait, `unread_messages` should be preserved order.
|
||||
# If we had [A, B] in unread.
|
||||
# recv() gets A. Not font. Temp=[A].
|
||||
# recv() gets B. Not font. Temp=[A, B].
|
||||
# recv() gets network C (Font). Process.
|
||||
# End.
|
||||
# Restore: unread = Temp + unread? (unread is empty now).
|
||||
# So unread becomes [A, B]. Correct.
|
||||
|
||||
import uselect
|
||||
|
||||
# Fast check if we can read
|
||||
# But we want to block until SOMETHING arrives.
|
||||
|
||||
# If unread_messages is not empty, we should process them first.
|
||||
# But we can't peak easily without modifying recv again.
|
||||
# Let's just use recv() and handle the buffering logic here.
|
||||
|
||||
while time.ticks_diff(time.ticks_ms(), start) < 3000 and target_codes:
|
||||
try:
|
||||
# Use a poller for the socket part to implement timeout
|
||||
# But recv() handles logic.
|
||||
# If unread_messages is empty, we poll socket.
|
||||
|
||||
can_read = False
|
||||
if self.ws.unread_messages:
|
||||
if hasattr(self.ws, 'unread_messages') and self.ws.unread_messages:
|
||||
can_read = True
|
||||
else:
|
||||
import uselect
|
||||
poller = uselect.poll()
|
||||
poller.register(self.ws.sock, uselect.POLLIN)
|
||||
events = poller.poll(100) # 100ms
|
||||
events = poller.poll(100)
|
||||
if events:
|
||||
can_read = True
|
||||
|
||||
if can_read:
|
||||
msg = self.ws.recv() # This will pop from unread or read from sock
|
||||
msg = self.ws.recv()
|
||||
if msg is None:
|
||||
# Socket closed or error?
|
||||
# Or just timeout in recv (but we polled).
|
||||
continue
|
||||
|
||||
if isinstance(msg, str):
|
||||
if msg == "FONT_BATCH_END":
|
||||
# Batch complete. Mark remaining as failed.
|
||||
# We need to iterate over a copy because we are modifying target_codes?
|
||||
# Actually we just clear it.
|
||||
# But wait, target_codes is passed by reference (set).
|
||||
# If we clear it, loop breaks.
|
||||
# But we also want to mark cache as None for missing ones.
|
||||
temp_missing = list(target_codes)
|
||||
for c in temp_missing:
|
||||
print(f"Batch missing/failed: {c}")
|
||||
self.cache[c] = None # Cache failure
|
||||
if msg.startswith("FONT_BATCH_END:"):
|
||||
parts = msg[15:].split(":")
|
||||
success = int(parts[0]) if len(parts) > 0 else 0
|
||||
failed = int(parts[1]) if len(parts) > 1 else 0
|
||||
|
||||
if failed > 0:
|
||||
temp_missing = list(target_codes)
|
||||
for c in temp_missing:
|
||||
if c not in self.cache:
|
||||
print(f"Font failed after retries: {c}")
|
||||
self.cache[c] = None
|
||||
if c in target_codes:
|
||||
target_codes.remove(c)
|
||||
|
||||
target_codes.clear()
|
||||
|
||||
elif msg.startswith("FONT_DATA:"):
|
||||
# General font data handler
|
||||
parts = msg.split(":")
|
||||
if len(parts) >= 3:
|
||||
try:
|
||||
@@ -205,60 +165,39 @@ class Font:
|
||||
self.cache[c] = d
|
||||
if c in target_codes:
|
||||
target_codes.remove(c)
|
||||
# print(f"Batch loaded: {c}")
|
||||
if c in self.retry_count:
|
||||
del self.retry_count[c]
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
# Other message, e.g. START_PLAYBACK
|
||||
self.local_deferred.append(msg)
|
||||
|
||||
elif msg is not None:
|
||||
# Binary message? Buffer it too.
|
||||
self.local_deferred.append(msg)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Wait font error: {e}")
|
||||
|
||||
# End of wait. Restore deferred messages.
|
||||
if self.local_deferred:
|
||||
# We want new_list = local_deferred + old_list
|
||||
self.ws.unread_messages = self.local_deferred + self.ws.unread_messages
|
||||
if hasattr(self.ws, 'unread_messages'):
|
||||
self.ws.unread_messages = self.local_deferred + self.ws.unread_messages
|
||||
self.local_deferred = []
|
||||
|
||||
def _wait_for_font(self, target_code_str):
|
||||
# Compatibility wrapper or deprecated?
|
||||
# The new logic uses batch wait.
|
||||
pass
|
||||
|
||||
def _draw_bitmap(self, tft, bitmap, x, y, w, h, lut):
|
||||
# Convert 1bpp bitmap to RGB565 buffer using LUT
|
||||
|
||||
# Optimize buffer allocation
|
||||
# bitmap length is w * h / 8 = 32 bytes for 16x16
|
||||
|
||||
# Create list of chunks
|
||||
"""绘制位图"""
|
||||
chunks = [lut[b] for b in bitmap]
|
||||
|
||||
# Join chunks into one buffer
|
||||
rgb_buf = b''.join(chunks)
|
||||
|
||||
tft.blit_buffer(rgb_buf, x, y, w, h)
|
||||
|
||||
def _draw_ascii(self, tft, char, x, y, color, bg):
|
||||
# Use framebuf for ASCII
|
||||
"""绘制ASCII字符"""
|
||||
w, h = 8, 8
|
||||
buf = bytearray(w * h // 8)
|
||||
fb = framebuf.FrameBuffer(buf, w, h, framebuf.MONO_VLSB)
|
||||
fb.fill(0)
|
||||
fb.text(char, 0, 0, 1)
|
||||
|
||||
# Since framebuf.text is 8x8, we center it vertically in 16px height
|
||||
# Drawing pixel by pixel is slow but compatible
|
||||
# To optimize, we can build a small buffer
|
||||
|
||||
# Create a 8x16 RGB565 buffer
|
||||
rgb_buf = bytearray(8 * 16 * 2)
|
||||
# Fill with background
|
||||
bg_high, bg_low = bg >> 8, bg & 0xFF
|
||||
color_high, color_low = color >> 8, color & 0xFF
|
||||
|
||||
@@ -266,14 +205,10 @@ class Font:
|
||||
rgb_buf[i] = bg_high
|
||||
rgb_buf[i+1] = bg_low
|
||||
|
||||
# Draw the 8x8 character into the buffer (centered)
|
||||
# MONO_VLSB: each byte is a column of 8 pixels
|
||||
for col in range(8): # 0..7
|
||||
for col in range(8):
|
||||
byte = buf[col]
|
||||
for row in range(8): # 0..7
|
||||
for row in range(8):
|
||||
if (byte >> row) & 1:
|
||||
# Calculate position in rgb_buf
|
||||
# Target: x=col, y=row+4
|
||||
pos = ((row + 4) * 8 + col) * 2
|
||||
rgb_buf[pos] = color_high
|
||||
rgb_buf[pos+1] = color_low
|
||||
|
||||
587
main.py
587
main.py
@@ -17,15 +17,27 @@ SERVER_IP = "6.6.6.88"
|
||||
SERVER_PORT = 8000
|
||||
SERVER_URL = f"ws://{SERVER_IP}:{SERVER_PORT}/ws/audio"
|
||||
|
||||
# 图片接收状态
|
||||
IMAGE_STATE_IDLE = 0
|
||||
IMAGE_STATE_RECEIVING = 1
|
||||
|
||||
UI_SCREEN_RECORDING = 1
|
||||
UI_SCREEN_CONFIRM = 2
|
||||
UI_SCREEN_RESULT = 3
|
||||
|
||||
BOOT_SHORT_MS = 500
|
||||
BOOT_LONG_MS = 2000
|
||||
BOOT_EXTRA_LONG_MS = 5000
|
||||
|
||||
IMG_WIDTH = 120
|
||||
IMG_HEIGHT = 120
|
||||
|
||||
_last_btn_state = None
|
||||
_btn_release_time = 0
|
||||
_btn_press_time = 0
|
||||
|
||||
|
||||
def connect_wifi(max_retries=5):
|
||||
"""连接WiFi网络"""
|
||||
wlan = network.WLAN(network.STA_IF)
|
||||
|
||||
try:
|
||||
@@ -72,38 +84,183 @@ def connect_wifi(max_retries=5):
|
||||
return False
|
||||
|
||||
|
||||
def print_asr(text, display=None):
|
||||
print(f"ASR: {text}")
|
||||
if display and display.tft:
|
||||
display.fill_rect(0, 40, 240, 160, st7789.BLACK)
|
||||
display.text(text, 0, 40, st7789.WHITE)
|
||||
def draw_mic_icon(display, x, y, active=True):
|
||||
"""绘制麦克风图标"""
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
color = st7789.GREEN if active else st7789.DARKGREY
|
||||
|
||||
display.tft.fill_rect(x + 5, y, 10, 5, color)
|
||||
display.tft.fill_rect(x + 3, y + 5, 14, 10, color)
|
||||
display.tft.fill_rect(x + 8, y + 15, 4, 8, color)
|
||||
display.tft.fill_rect(x + 6, y + 23, 8, 2, color)
|
||||
display.tft.fill_rect(x + 8, y + 25, 4, 3, color)
|
||||
|
||||
|
||||
def draw_loading_spinner(display, x, y, angle, color=st7789.WHITE):
|
||||
"""绘制旋转加载图标"""
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
import math
|
||||
rad = math.radians(angle)
|
||||
|
||||
# Clear previous (simple erase)
|
||||
# This is tricky without a buffer, so we just draw over.
|
||||
# For better performance we should remember previous pos.
|
||||
|
||||
center_x = x + 10
|
||||
center_y = y + 10
|
||||
radius = 8
|
||||
|
||||
for i in range(8):
|
||||
theta = math.radians(i * 45) + rad
|
||||
px = int(center_x + radius * math.cos(theta))
|
||||
py = int(center_y + radius * math.sin(theta))
|
||||
|
||||
# Brightness based on angle (simulated by color or size)
|
||||
# Here we just draw dots
|
||||
display.tft.pixel(px, py, color)
|
||||
|
||||
def draw_check_icon(display, x, y):
|
||||
"""绘制勾选图标"""
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
display.tft.line(x, y + 5, x + 3, y + 8, st7789.GREEN)
|
||||
display.tft.line(x + 3, y + 8, x + 10, y, st7789.GREEN)
|
||||
|
||||
|
||||
def draw_progress_bar(display, x, y, width, height, progress, color=st7789.CYAN):
|
||||
"""绘制进度条"""
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
display.tft.fill_rect(x, y, width, height, st7789.DARKGREY)
|
||||
if progress > 0:
|
||||
bar_width = int(width * min(progress, 1.0))
|
||||
display.tft.fill_rect(x, y, bar_width, height, color)
|
||||
|
||||
|
||||
def render_recording_screen(display, asr_text="", audio_level=0):
|
||||
"""渲染录音界面"""
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
display.tft.fill(st7789.BLACK)
|
||||
|
||||
display.tft.fill_rect(0, 0, 240, 30, st7789.WHITE)
|
||||
display.text("语音识别", 80, 8, st7789.BLACK)
|
||||
|
||||
draw_mic_icon(display, 105, 50, True)
|
||||
|
||||
if audio_level > 0:
|
||||
bar_width = min(int(audio_level * 2), 200)
|
||||
display.tft.fill_rect(20, 100, bar_width, 10, st7789.GREEN)
|
||||
|
||||
if asr_text:
|
||||
display.text(asr_text[:20], 20, 130, st7789.WHITE)
|
||||
|
||||
display.tft.fill_rect(60, 200, 120, 25, st7789.RED)
|
||||
display.text("松开停止", 85, 205, st7789.WHITE)
|
||||
|
||||
|
||||
def render_confirm_screen(display, asr_text=""):
|
||||
"""渲染确认界面"""
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
display.tft.fill(st7789.BLACK)
|
||||
|
||||
display.tft.fill_rect(0, 0, 240, 30, st7789.CYAN)
|
||||
display.text("说完了吗?", 75, 8, st7789.BLACK)
|
||||
|
||||
display.tft.fill_rect(10, 50, 220, 80, st7789.DARKGREY)
|
||||
display.text(asr_text if asr_text else "未识别到文字", 20, 75, st7789.WHITE)
|
||||
|
||||
display.tft.fill_rect(20, 150, 80, 30, st7789.GREEN)
|
||||
display.text("短按确认", 30, 158, st7789.BLACK)
|
||||
|
||||
display.tft.fill_rect(140, 150, 80, 30, st7789.RED)
|
||||
display.text("长按重录", 155, 158, st7789.WHITE)
|
||||
|
||||
|
||||
def render_result_screen(display, status="", prompt="", image_received=False):
|
||||
"""渲染结果界面"""
|
||||
if not display or not display.tft:
|
||||
return
|
||||
|
||||
# Only clear if we are starting a new state or it's the first render
|
||||
# But for simplicity we clear all for now. Optimizing this requires state tracking.
|
||||
display.tft.fill(st7789.BLACK)
|
||||
|
||||
# Header
|
||||
display.tft.fill_rect(0, 0, 240, 30, st7789.WHITE)
|
||||
display.text("AI 生成中", 80, 8, st7789.BLACK)
|
||||
|
||||
if status == "OPTIMIZING":
|
||||
display.text("正在思考...", 80, 60, st7789.CYAN)
|
||||
display.text("优化提示词中", 70, 80, st7789.CYAN)
|
||||
draw_progress_bar(display, 40, 110, 160, 6, 0.3, st7789.CYAN)
|
||||
# Spinner will be drawn by main loop
|
||||
|
||||
elif status == "RENDERING":
|
||||
display.text("正在绘画...", 80, 60, st7789.YELLOW)
|
||||
display.text("AI作画中", 85, 80, st7789.YELLOW)
|
||||
draw_progress_bar(display, 40, 110, 160, 6, 0.7, st7789.YELLOW)
|
||||
# Spinner will be drawn by main loop
|
||||
|
||||
elif status == "COMPLETE" or image_received:
|
||||
display.text("生成完成!", 80, 50, st7789.GREEN)
|
||||
draw_check_icon(display, 110, 80)
|
||||
|
||||
elif status == "ERROR":
|
||||
display.text("生成失败", 80, 50, st7789.RED)
|
||||
|
||||
if prompt:
|
||||
display.tft.fill_rect(10, 140, 220, 50, 0x2124) # Dark Grey
|
||||
display.text("提示词:", 15, 145, st7789.CYAN)
|
||||
display.text(prompt[:25] + "..." if len(prompt) > 25 else prompt, 15, 165, st7789.WHITE)
|
||||
|
||||
display.tft.fill_rect(60, 210, 120, 25, st7789.BLUE)
|
||||
display.text("返回录音", 90, 215, st7789.WHITE)
|
||||
|
||||
def process_message(msg, display, image_state, image_data_list):
|
||||
"""处理WebSocket消息,返回新的image_state"""
|
||||
if not isinstance(msg, str):
|
||||
return image_state
|
||||
"""处理WebSocket消息"""
|
||||
# Handle binary image data
|
||||
if isinstance(msg, (bytes, bytearray)):
|
||||
if image_state == IMAGE_STATE_RECEIVING:
|
||||
image_data_list.append(msg)
|
||||
# Optional: Update progress bar or indicator
|
||||
return image_state, None
|
||||
return image_state, None
|
||||
|
||||
if not isinstance(msg, str):
|
||||
return image_state, None
|
||||
|
||||
status_info = None
|
||||
|
||||
# 处理ASR消息
|
||||
if msg.startswith("ASR:"):
|
||||
print_asr(msg[4:], display)
|
||||
return image_state, ("asr", msg[4:])
|
||||
|
||||
elif msg.startswith("STATUS:"):
|
||||
parts = msg[7:].split(":", 1)
|
||||
status_type = parts[0]
|
||||
status_text = parts[1] if len(parts) > 1 else ""
|
||||
print(f"Status: {status_type} - {status_text}")
|
||||
return image_state, ("status", status_type, status_text)
|
||||
|
||||
# 处理图片生成状态消息
|
||||
elif msg.startswith("GENERATING_IMAGE:"):
|
||||
print(msg)
|
||||
if display and display.tft:
|
||||
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
|
||||
display.text("正在生成图片...", 0, 40, st7789.YELLOW)
|
||||
# Deprecated by STATUS:RENDERING but kept for compatibility
|
||||
return image_state, None
|
||||
|
||||
# 处理提示词优化消息
|
||||
elif msg.startswith("PROMPT:"):
|
||||
prompt = msg[7:]
|
||||
print(f"Optimized prompt: {prompt}")
|
||||
if display and display.tft:
|
||||
display.fill_rect(0, 60, 240, 40, st7789.BLACK)
|
||||
display.text("提示词: " + prompt[:20], 0, 60, st7789.CYAN)
|
||||
return image_state, ("prompt", prompt)
|
||||
|
||||
# 处理图片开始消息
|
||||
elif msg.startswith("IMAGE_START:"):
|
||||
try:
|
||||
parts = msg.split(":")
|
||||
@@ -111,64 +268,120 @@ def process_message(msg, display, image_state, image_data_list):
|
||||
img_size = int(parts[2]) if len(parts) > 2 else 64
|
||||
print(f"Image start, size: {size}, img_size: {img_size}")
|
||||
image_data_list.clear()
|
||||
image_data_list.append(img_size) # 保存图片尺寸
|
||||
return IMAGE_STATE_RECEIVING
|
||||
image_data_list.append(img_size) # Store metadata at index 0
|
||||
return IMAGE_STATE_RECEIVING, None
|
||||
except Exception as e:
|
||||
print(f"IMAGE_START parse error: {e}")
|
||||
return image_state
|
||||
return image_state, None
|
||||
|
||||
# 处理图片数据消息
|
||||
# Deprecated text-based IMAGE_DATA handling
|
||||
elif msg.startswith("IMAGE_DATA:") and image_state == IMAGE_STATE_RECEIVING:
|
||||
try:
|
||||
data = msg.split(":", 1)[1]
|
||||
image_data_list.append(data)
|
||||
# Convert hex to bytes immediately if using old protocol, but we switched to binary
|
||||
# Keep this just in case server rolls back? No, let's assume binary.
|
||||
pass
|
||||
except:
|
||||
pass
|
||||
|
||||
# 处理图片结束消息
|
||||
elif msg == "IMAGE_END" and image_state == IMAGE_STATE_RECEIVING:
|
||||
try:
|
||||
print("Image received, processing...")
|
||||
|
||||
# 获取图片尺寸
|
||||
img_size = image_data_list[0] if image_data_list else 64
|
||||
hex_data = "".join(image_data_list[1:])
|
||||
# Combine all binary chunks (skipping metadata at index 0)
|
||||
img_data = b"".join(image_data_list[1:])
|
||||
image_data_list.clear()
|
||||
|
||||
# 将hex字符串转换为字节数据
|
||||
img_data = bytes.fromhex(hex_data)
|
||||
|
||||
print(f"Image data len: {len(img_data)}")
|
||||
|
||||
# 在屏幕中心显示图片
|
||||
if display and display.tft:
|
||||
# 计算居中位置
|
||||
x = (240 - img_size) // 2
|
||||
y = (240 - img_size) // 2
|
||||
|
||||
# 显示图片
|
||||
display.show_image(x, y, img_size, img_size, img_data)
|
||||
|
||||
display.fill_rect(0, 0, 240, 30, st7789.WHITE)
|
||||
display.text("图片已生成!", 0, 5, st7789.BLACK)
|
||||
# Overlay success message slightly
|
||||
display.tft.fill_rect(0, 0, 240, 30, st7789.WHITE)
|
||||
display.text("图片已生成!", 70, 5, st7789.BLACK)
|
||||
|
||||
gc.collect()
|
||||
print("Image displayed")
|
||||
return IMAGE_STATE_IDLE, ("image_done",)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Image process error: {e}")
|
||||
import sys
|
||||
sys.print_exception(e)
|
||||
|
||||
return IMAGE_STATE_IDLE
|
||||
return IMAGE_STATE_IDLE, None
|
||||
|
||||
# 处理图片错误消息
|
||||
elif msg.startswith("IMAGE_ERROR:"):
|
||||
print(msg)
|
||||
if display and display.tft:
|
||||
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
|
||||
display.text("图片生成失败", 0, 40, st7789.RED)
|
||||
return IMAGE_STATE_IDLE
|
||||
return IMAGE_STATE_IDLE, ("error", msg[12:])
|
||||
|
||||
return image_state
|
||||
return image_state, None
|
||||
|
||||
|
||||
def print_asr(text, display=None):
|
||||
"""打印ASR结果"""
|
||||
print(f"ASR: {text}")
|
||||
if display and display.tft:
|
||||
display.fill_rect(0, 40, 240, 160, st7789.BLACK)
|
||||
display.text(text, 0, 40, st7789.WHITE)
|
||||
|
||||
|
||||
def get_boot_button_action(boot_btn):
|
||||
"""获取Boot按键动作类型
|
||||
|
||||
返回:
|
||||
0: 无动作
|
||||
1: 短按 (<500ms)
|
||||
2: 长按 (2-5秒)
|
||||
3: 超长按 (>5秒)
|
||||
"""
|
||||
global _last_btn_state, _btn_release_time, _btn_press_time
|
||||
|
||||
current_value = boot_btn.value()
|
||||
current_time = time.ticks_ms()
|
||||
|
||||
if current_value == 0:
|
||||
if _last_btn_state != 0:
|
||||
_last_btn_state = 0
|
||||
_btn_press_time = current_time
|
||||
return 0
|
||||
|
||||
if current_value == 1 and _last_btn_state == 0:
|
||||
press_duration = time.ticks_diff(current_time, _btn_press_time)
|
||||
_last_btn_state = 1
|
||||
|
||||
if press_duration < BOOT_SHORT_MS:
|
||||
return 0
|
||||
elif press_duration < BOOT_LONG_MS:
|
||||
return 1
|
||||
elif press_duration < BOOT_EXTRA_LONG_MS:
|
||||
return 2
|
||||
else:
|
||||
return 3
|
||||
|
||||
if _last_btn_state is None:
|
||||
_last_btn_state = current_value
|
||||
_btn_release_time = current_time
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def check_memory(silent=False):
|
||||
"""检查内存使用情况
|
||||
|
||||
Args:
|
||||
silent: 是否静默模式(不打印日志)
|
||||
"""
|
||||
free = gc.mem_free()
|
||||
total = gc.mem_alloc() + free
|
||||
usage = (gc.mem_alloc() / total) * 100 if total > 0 else 0
|
||||
if not silent:
|
||||
print(f"Memory: {free} free, {usage:.1f}% used")
|
||||
return usage
|
||||
|
||||
|
||||
def main():
|
||||
@@ -191,12 +404,18 @@ def main():
|
||||
if display.tft:
|
||||
display.init_ui()
|
||||
|
||||
ui_screen = UI_SCREEN_RECORDING
|
||||
is_recording = False
|
||||
ws = None
|
||||
image_state = IMAGE_STATE_IDLE
|
||||
image_data_list = []
|
||||
current_asr_text = ""
|
||||
current_prompt = ""
|
||||
current_status = ""
|
||||
image_generation_done = False
|
||||
confirm_waiting = False
|
||||
|
||||
def connect_ws():
|
||||
def connect_ws(force=False):
|
||||
nonlocal ws
|
||||
try:
|
||||
if ws:
|
||||
@@ -205,16 +424,24 @@ def main():
|
||||
pass
|
||||
ws = None
|
||||
|
||||
try:
|
||||
print(f"Connecting to {SERVER_URL}")
|
||||
ws = WebSocketClient(SERVER_URL)
|
||||
print("WebSocket connected!")
|
||||
if display:
|
||||
display.set_ws(ws)
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"WS connection failed: {e}")
|
||||
return False
|
||||
retry_count = 0
|
||||
max_retries = 3
|
||||
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
print(f"Connecting to {SERVER_URL} (attempt {retry_count + 1})")
|
||||
ws = WebSocketClient(SERVER_URL)
|
||||
print("WebSocket connected!")
|
||||
if display:
|
||||
display.set_ws(ws)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"WS connection failed: {e}")
|
||||
retry_count += 1
|
||||
time.sleep(1)
|
||||
|
||||
return False
|
||||
|
||||
if connect_wifi():
|
||||
connect_ws()
|
||||
@@ -222,27 +449,162 @@ def main():
|
||||
print("Running in offline mode")
|
||||
|
||||
read_buf = bytearray(4096)
|
||||
last_audio_level = 0
|
||||
memory_check_counter = 0
|
||||
spinner_angle = 0
|
||||
last_spinner_time = 0
|
||||
|
||||
while True:
|
||||
try:
|
||||
btn_val = boot_btn.value()
|
||||
memory_check_counter += 1
|
||||
|
||||
if btn_val == 0:
|
||||
if not is_recording:
|
||||
print(">>> Recording...")
|
||||
is_recording = True
|
||||
if memory_check_counter >= 300:
|
||||
memory_check_counter = 0
|
||||
if check_memory(silent=True) > 80:
|
||||
gc.collect()
|
||||
print("Memory high, cleaned")
|
||||
|
||||
# Spinner Animation
|
||||
if ui_screen == UI_SCREEN_RESULT and not image_generation_done and current_status in ["OPTIMIZING", "RENDERING"]:
|
||||
now = time.ticks_ms()
|
||||
if time.ticks_diff(now, last_spinner_time) > 100:
|
||||
if display.tft:
|
||||
display.fill(st7789.WHITE)
|
||||
# Clear previous spinner (draw in BLACK)
|
||||
draw_loading_spinner(display, 110, 80, spinner_angle, st7789.BLACK)
|
||||
|
||||
if ws is None or not ws.is_connected():
|
||||
connect_ws()
|
||||
spinner_angle = (spinner_angle + 45) % 360
|
||||
|
||||
# Draw new spinner
|
||||
color = st7789.CYAN if current_status == "OPTIMIZING" else st7789.YELLOW
|
||||
draw_loading_spinner(display, 110, 80, spinner_angle, color)
|
||||
|
||||
last_spinner_time = now
|
||||
|
||||
btn_action = get_boot_button_action(boot_btn)
|
||||
|
||||
if btn_action == 1:
|
||||
if is_recording:
|
||||
print(">>> Stop recording")
|
||||
if ws and ws.is_connected():
|
||||
try:
|
||||
ws.send("START_RECORDING")
|
||||
ws.send("STOP_RECORDING")
|
||||
except:
|
||||
ws = None
|
||||
|
||||
is_recording = False
|
||||
ui_screen = UI_SCREEN_RESULT
|
||||
image_generation_done = False
|
||||
|
||||
if display.tft:
|
||||
render_result_screen(display, "OPTIMIZING", current_asr_text, False)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
elif ui_screen == UI_SCREEN_RECORDING:
|
||||
if not is_recording:
|
||||
print(">>> Recording...")
|
||||
is_recording = True
|
||||
confirm_waiting = False
|
||||
current_asr_text = ""
|
||||
current_prompt = ""
|
||||
current_status = ""
|
||||
image_generation_done = False
|
||||
|
||||
if display.tft:
|
||||
render_recording_screen(display, "", 0)
|
||||
|
||||
if ws is None or not ws.is_connected():
|
||||
connect_ws()
|
||||
|
||||
if ws and ws.is_connected():
|
||||
try:
|
||||
ws.send("START_RECORDING")
|
||||
except:
|
||||
ws = None
|
||||
|
||||
elif ui_screen == UI_SCREEN_CONFIRM:
|
||||
print(">>> Confirm and generate")
|
||||
if ws and ws.is_connected():
|
||||
try:
|
||||
ws.send("STOP_RECORDING")
|
||||
except:
|
||||
ws = None
|
||||
|
||||
is_recording = False
|
||||
ui_screen = UI_SCREEN_RESULT
|
||||
image_generation_done = False
|
||||
|
||||
if display.tft:
|
||||
render_result_screen(display, "OPTIMIZING", current_asr_text, False)
|
||||
|
||||
time.sleep(0.5)
|
||||
|
||||
elif ui_screen == UI_SCREEN_RESULT:
|
||||
print(">>> Back to recording")
|
||||
ui_screen = UI_SCREEN_RECORDING
|
||||
is_recording = False
|
||||
current_asr_text = ""
|
||||
current_prompt = ""
|
||||
current_status = ""
|
||||
image_generation_done = False
|
||||
confirm_waiting = False
|
||||
|
||||
if display.tft:
|
||||
render_recording_screen(display, "", 0)
|
||||
|
||||
elif btn_action == 2:
|
||||
if is_recording:
|
||||
print(">>> Stop recording (long press)")
|
||||
if ws and ws.is_connected():
|
||||
try:
|
||||
ws.send("STOP_RECORDING")
|
||||
except:
|
||||
ws = None
|
||||
|
||||
is_recording = False
|
||||
|
||||
if ui_screen == UI_SCREEN_RECORDING or is_recording == False:
|
||||
if current_asr_text:
|
||||
print(">>> Generate image with ASR text")
|
||||
ui_screen = UI_SCREEN_RESULT
|
||||
image_generation_done = False
|
||||
|
||||
if display.tft:
|
||||
render_result_screen(display, "OPTIMIZING", current_asr_text, False)
|
||||
|
||||
time.sleep(0.5)
|
||||
else:
|
||||
print(">>> Re-record")
|
||||
current_asr_text = ""
|
||||
confirm_waiting = False
|
||||
ui_screen = UI_SCREEN_RECORDING
|
||||
|
||||
if display.tft:
|
||||
render_recording_screen(display, "", 0)
|
||||
|
||||
elif ui_screen == UI_SCREEN_CONFIRM:
|
||||
print(">>> Re-record")
|
||||
current_asr_text = ""
|
||||
confirm_waiting = False
|
||||
ui_screen = UI_SCREEN_RECORDING
|
||||
|
||||
if display.tft:
|
||||
render_recording_screen(display, "", 0)
|
||||
|
||||
elif ui_screen == UI_SCREEN_RESULT:
|
||||
print(">>> Generate image (manual)")
|
||||
if ws and ws.is_connected():
|
||||
try:
|
||||
ws.send("START_RECORDING")
|
||||
is_recording = True
|
||||
ui_screen = UI_SCREEN_RECORDING
|
||||
except:
|
||||
ws = None
|
||||
|
||||
elif btn_action == 3:
|
||||
print(">>> Config mode")
|
||||
|
||||
if is_recording and btn_action == 0:
|
||||
if mic.i2s:
|
||||
num_read = mic.readinto(read_buf)
|
||||
if num_read > 0:
|
||||
@@ -255,48 +617,73 @@ def main():
|
||||
events = poller.poll(0)
|
||||
if events:
|
||||
msg = ws.recv()
|
||||
image_state = process_message(msg, display, image_state, image_data_list)
|
||||
image_state, event_data = process_message(msg, display, image_state, image_data_list)
|
||||
|
||||
if event_data:
|
||||
if event_data[0] == "asr":
|
||||
current_asr_text = event_data[1]
|
||||
if display.tft:
|
||||
render_recording_screen(display, current_asr_text, last_audio_level)
|
||||
|
||||
elif event_data[0] == "status":
|
||||
current_status = event_data[1]
|
||||
status_text = event_data[2] if len(event_data) > 2 else ""
|
||||
if display.tft:
|
||||
render_result_screen(display, current_status, current_prompt, image_generation_done)
|
||||
|
||||
elif event_data[0] == "prompt":
|
||||
current_prompt = event_data[1]
|
||||
|
||||
elif event_data[0] == "image_done":
|
||||
image_generation_done = True
|
||||
if display.tft:
|
||||
render_result_screen(display, "COMPLETE", current_prompt, True)
|
||||
|
||||
elif event_data[0] == "error":
|
||||
if display.tft:
|
||||
render_result_screen(display, "ERROR", current_prompt, False)
|
||||
|
||||
except:
|
||||
ws = None
|
||||
|
||||
if ui_screen == UI_SCREEN_RESULT and ws and ws.is_connected():
|
||||
try:
|
||||
poller = uselect.poll()
|
||||
poller.register(ws.sock, uselect.POLLIN)
|
||||
events = poller.poll(100)
|
||||
if events:
|
||||
msg = ws.recv()
|
||||
if msg:
|
||||
image_state, event_data = process_message(msg, display, image_state, image_data_list)
|
||||
|
||||
if event_data:
|
||||
if event_data[0] == "asr":
|
||||
current_asr_text = event_data[1]
|
||||
|
||||
elif event_data[0] == "status":
|
||||
current_status = event_data[1]
|
||||
status_text = event_data[2] if len(event_data) > 2 else ""
|
||||
if display.tft:
|
||||
render_result_screen(display, current_status, current_prompt, image_generation_done)
|
||||
|
||||
elif event_data[0] == "prompt":
|
||||
current_prompt = event_data[1]
|
||||
if display.tft:
|
||||
render_result_screen(display, current_status, current_prompt, image_generation_done)
|
||||
|
||||
elif event_data[0] == "image_done":
|
||||
image_generation_done = True
|
||||
if display.tft:
|
||||
render_result_screen(display, "COMPLETE", current_prompt, True)
|
||||
|
||||
elif event_data[0] == "error":
|
||||
if display.tft:
|
||||
render_result_screen(display, "ERROR", current_prompt, False)
|
||||
except:
|
||||
pass
|
||||
|
||||
continue
|
||||
|
||||
elif is_recording:
|
||||
print(">>> Stop")
|
||||
is_recording = False
|
||||
|
||||
if display.tft:
|
||||
display.init_ui()
|
||||
|
||||
if ws:
|
||||
try:
|
||||
ws.send("STOP_RECORDING")
|
||||
|
||||
# 等待更长时间以接收图片生成结果
|
||||
t_wait = time.ticks_add(time.ticks_ms(), 30000) # 等待30秒
|
||||
prev_image_state = image_state
|
||||
while time.ticks_diff(t_wait, time.ticks_ms()) > 0:
|
||||
poller = uselect.poll()
|
||||
poller.register(ws.sock, uselect.POLLIN)
|
||||
events = poller.poll(500)
|
||||
if events:
|
||||
msg = ws.recv()
|
||||
prev_image_state = image_state
|
||||
image_state = process_message(msg, display, image_state, image_data_list)
|
||||
# 如果之前在接收图片,现在停止了,说明图片接收完成
|
||||
if prev_image_state == IMAGE_STATE_RECEIVING and image_state == IMAGE_STATE_IDLE:
|
||||
break
|
||||
else:
|
||||
# 检查是否还在接收图片
|
||||
if image_state == IMAGE_STATE_IDLE:
|
||||
break
|
||||
except Exception as e:
|
||||
print(f"Stop recording error: {e}")
|
||||
ws = None
|
||||
|
||||
gc.collect()
|
||||
|
||||
time.sleep(0.01)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -31,6 +31,7 @@ HIGH_FREQ_UNICODE = [ord(c) for c in HIGH_FREQ_CHARS]
|
||||
# 字体缓存
|
||||
font_cache = {}
|
||||
font_md5 = {}
|
||||
font_data_buffer = None
|
||||
|
||||
def calculate_md5(filepath):
|
||||
"""计算文件的MD5哈希值"""
|
||||
@@ -44,7 +45,7 @@ def calculate_md5(filepath):
|
||||
|
||||
def init_font_cache():
|
||||
"""初始化字体缓存和MD5"""
|
||||
global font_cache, font_md5
|
||||
global font_cache, font_md5, font_data_buffer
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
|
||||
@@ -55,24 +56,18 @@ def init_font_cache():
|
||||
font_md5 = calculate_md5(font_path)
|
||||
print(f"Font MD5: {font_md5}")
|
||||
|
||||
# 预加载高频字到缓存
|
||||
# 加载整个字体文件到内存
|
||||
try:
|
||||
with open(font_path, "rb") as f:
|
||||
font_data_buffer = f.read()
|
||||
print(f"Loaded font file into memory: {len(font_data_buffer)} bytes")
|
||||
except Exception as e:
|
||||
print(f"Error loading font file: {e}")
|
||||
font_data_buffer = None
|
||||
|
||||
# 预加载高频字到缓存 (仍然保留以便快速访问)
|
||||
for unicode_val in HIGH_FREQ_UNICODE:
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
except:
|
||||
pass
|
||||
get_font_data(unicode_val)
|
||||
print(f"Preloaded {len(font_cache)} high-frequency characters")
|
||||
|
||||
# 启动时初始化字体缓存
|
||||
@@ -104,6 +99,114 @@ THUMB_SIZE = 245
|
||||
font_request_queue = {}
|
||||
FONT_RETRY_MAX = 3
|
||||
|
||||
# 图片生成任务管理
|
||||
class ImageGenerationTask:
|
||||
"""图片生成任务管理类"""
|
||||
def __init__(self, task_id: str, asr_text: str, websocket: WebSocket):
|
||||
self.task_id = task_id
|
||||
self.asr_text = asr_text
|
||||
self.websocket = websocket
|
||||
self.status = "pending" # pending, optimizing, generating, completed, failed
|
||||
self.progress = 0
|
||||
self.message = ""
|
||||
self.result = None
|
||||
self.error = None
|
||||
|
||||
# 存储活跃的图片生成任务
|
||||
active_tasks = {}
|
||||
task_counter = 0
|
||||
|
||||
|
||||
async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||||
"""异步启动图片生成任务,不阻塞WebSocket连接"""
|
||||
global task_counter, active_tasks
|
||||
|
||||
task_id = f"task_{task_counter}_{int(time.time() * 1000)}"
|
||||
task_counter += 1
|
||||
|
||||
task = ImageGenerationTask(task_id, asr_text, websocket)
|
||||
active_tasks[task_id] = task
|
||||
|
||||
print(f"Starting async image generation task: {task_id}")
|
||||
|
||||
await websocket.send_text(f"TASK_ID:{task_id}")
|
||||
|
||||
def progress_callback(progress: int, message: str):
|
||||
"""进度回调函数"""
|
||||
task.progress = progress
|
||||
task.message = message
|
||||
try:
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}"),
|
||||
asyncio.get_event_loop()
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error sending progress: {e}")
|
||||
|
||||
try:
|
||||
task.status = "optimizing"
|
||||
|
||||
await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
|
||||
|
||||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||||
task.optimized_prompt = optimized_prompt
|
||||
|
||||
task.status = "generating"
|
||||
await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
|
||||
|
||||
task.result = image_path
|
||||
|
||||
if image_path and os.path.exists(image_path):
|
||||
task.status = "completed"
|
||||
await websocket.send_text("STATUS:COMPLETE:图片生成完成")
|
||||
await asyncio.sleep(0.2)
|
||||
|
||||
await send_image_to_client(websocket, image_path)
|
||||
else:
|
||||
task.status = "failed"
|
||||
task.error = "图片生成失败"
|
||||
await websocket.send_text("IMAGE_ERROR:图片生成失败")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成失败")
|
||||
|
||||
except Exception as e:
|
||||
task.status = "failed"
|
||||
task.error = str(e)
|
||||
print(f"Image generation task error: {e}")
|
||||
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
|
||||
await websocket.send_text("STATUS:ERROR:图片生成出错")
|
||||
finally:
|
||||
if task_id in active_tasks:
|
||||
del active_tasks[task_id]
|
||||
|
||||
return task
|
||||
|
||||
|
||||
async def send_image_to_client(websocket: WebSocket, image_path: str):
|
||||
"""发送图片数据到客户端"""
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
|
||||
|
||||
# Send start marker
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||||
|
||||
# Send binary data directly
|
||||
chunk_size = 4096 # Increased chunk size for binary
|
||||
for i in range(0, len(image_data), chunk_size):
|
||||
chunk = image_data[i:i+chunk_size]
|
||||
await websocket.send_bytes(chunk)
|
||||
|
||||
# Send end marker
|
||||
await websocket.send_text("IMAGE_END")
|
||||
print("Image sent to ESP32 (Binary)")
|
||||
|
||||
|
||||
def get_font_data(unicode_val):
|
||||
"""从字体文件获取单个字符数据(带缓存)"""
|
||||
@@ -121,20 +224,27 @@ def get_font_data(unicode_val):
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
if font_data_buffer:
|
||||
if offset + 32 <= len(font_data_buffer):
|
||||
font_data = font_data_buffer[offset:offset+32]
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
else:
|
||||
# Fallback to file reading if buffer failed
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
@@ -333,10 +443,13 @@ def process_chunk_32_to_16(chunk_bytes, gain=1.0):
|
||||
return processed_chunk
|
||||
|
||||
|
||||
def optimize_prompt(asr_text):
|
||||
def optimize_prompt(asr_text, progress_callback=None):
|
||||
"""使用大模型优化提示词"""
|
||||
print(f"Optimizing prompt for: {asr_text}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(0, "正在准备优化提示词...")
|
||||
|
||||
system_prompt = """你是一个AI图像提示词优化专家。将用户简短的语音识别结果转化为详细的、适合AI图像生成的英文提示词。
|
||||
要求:
|
||||
1. 保留核心内容和主要元素
|
||||
@@ -346,6 +459,9 @@ def optimize_prompt(asr_text):
|
||||
5. 不要添加多余解释,直接输出优化后的提示词"""
|
||||
|
||||
try:
|
||||
if progress_callback:
|
||||
progress_callback(10, "正在调用AI优化提示词...")
|
||||
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
|
||||
@@ -356,31 +472,76 @@ def optimize_prompt(asr_text):
|
||||
if response.status_code == 200:
|
||||
optimized = response.output.choices[0].message.content.strip()
|
||||
print(f"Optimized prompt: {optimized}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||
|
||||
return optimized
|
||||
else:
|
||||
print(f"Prompt optimization failed: {response.code} - {response.message}")
|
||||
if progress_callback:
|
||||
progress_callback(0, f"提示词优化失败: {response.message}")
|
||||
return asr_text
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error optimizing prompt: {e}")
|
||||
if progress_callback:
|
||||
progress_callback(0, f"提示词优化出错: {str(e)}")
|
||||
return asr_text
|
||||
|
||||
|
||||
def generate_image(prompt, websocket=None):
|
||||
"""调用万相文生图API生成图片"""
|
||||
def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2):
|
||||
"""调用万相文生图API生成图片
|
||||
|
||||
Args:
|
||||
prompt: 图像生成提示词
|
||||
progress_callback: 进度回调函数 (progress, message)
|
||||
retry_count: 当前重试次数
|
||||
max_retries: 最大重试次数
|
||||
"""
|
||||
print(f"Generating image for prompt: {prompt}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(35, "正在请求AI生成图片...")
|
||||
|
||||
try:
|
||||
response = ImageSynthesis.call(
|
||||
model='wan2.6-t2i',
|
||||
prompt=prompt,
|
||||
size='512x512',
|
||||
n=1
|
||||
model='wanx2.0-t2i-turbo',
|
||||
prompt=prompt
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
image_url = response.output['results'][0]['url']
|
||||
print(f"Image generated, downloading from: {image_url}")
|
||||
task_status = response.output.get('task_status')
|
||||
|
||||
if task_status == 'PENDING' or task_status == 'RUNNING':
|
||||
print("Waiting for image generation to complete...")
|
||||
if progress_callback:
|
||||
progress_callback(45, "AI正在生成图片中...")
|
||||
|
||||
import time
|
||||
task_id = response.output.get('task_id')
|
||||
max_wait = 120
|
||||
waited = 0
|
||||
while waited < max_wait:
|
||||
time.sleep(2)
|
||||
waited += 2
|
||||
task_result = ImageSynthesis.fetch(task_id)
|
||||
if task_result.output.task_status == 'SUCCEEDED':
|
||||
response.output = task_result.output
|
||||
break
|
||||
elif task_result.output.task_status == 'FAILED':
|
||||
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
|
||||
print(f"Image generation failed: {error_msg}")
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成失败: {error_msg}")
|
||||
return None
|
||||
|
||||
if response.output.get('task_status') == 'SUCCEEDED':
|
||||
image_url = response.output['results'][0]['url']
|
||||
print(f"Image generated, downloading from: {image_url}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(70, "正在下载生成的图片...")
|
||||
|
||||
import urllib.request
|
||||
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
|
||||
@@ -392,6 +553,9 @@ def generate_image(prompt, websocket=None):
|
||||
shutil.copy(GENERATED_IMAGE_FILE, output_path)
|
||||
print(f"Image also saved to {output_path}")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(80, "正在处理图片...")
|
||||
|
||||
# 缩放图片并转换为RGB565格式
|
||||
try:
|
||||
from PIL import Image
|
||||
@@ -422,21 +586,50 @@ def generate_image(prompt, websocket=None):
|
||||
f.write(rgb565_data)
|
||||
|
||||
print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes")
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(100, "图片生成完成!")
|
||||
|
||||
return GENERATED_THUMB_FILE
|
||||
|
||||
except ImportError:
|
||||
print("PIL not available, sending original image")
|
||||
if progress_callback:
|
||||
progress_callback(100, "图片生成完成!(原始格式)")
|
||||
return GENERATED_IMAGE_FILE
|
||||
except Exception as e:
|
||||
print(f"Error processing image: {e}")
|
||||
if progress_callback:
|
||||
progress_callback(80, f"图片处理出错: {str(e)}")
|
||||
return GENERATED_IMAGE_FILE
|
||||
else:
|
||||
print(f"Image generation failed: {response.code} - {response.message}")
|
||||
return None
|
||||
error_msg = f"{response.code} - {response.message}"
|
||||
print(f"Image generation failed: {error_msg}")
|
||||
|
||||
# 重试机制
|
||||
if retry_count < max_retries:
|
||||
print(f"Retrying... ({retry_count + 1}/{max_retries})")
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成失败,正在重试 ({retry_count + 1}/{max_retries})...")
|
||||
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
|
||||
else:
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成失败: {error_msg}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error generating image: {e}")
|
||||
return None
|
||||
|
||||
# 重试机制
|
||||
if retry_count < max_retries:
|
||||
print(f"Retrying after error... ({retry_count + 1}/{max_retries})")
|
||||
if progress_callback:
|
||||
progress_callback(35, f"生成出错,正在重试 ({retry_count + 1}/{max_retries})...")
|
||||
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
|
||||
else:
|
||||
if progress_callback:
|
||||
progress_callback(35, f"图片生成出错: {str(e)}")
|
||||
return None
|
||||
|
||||
@app.websocket("/ws/audio")
|
||||
async def websocket_endpoint(websocket: WebSocket):
|
||||
@@ -554,132 +747,36 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
# 先发送 ASR 文字到 ESP32 显示
|
||||
await websocket.send_text(f"ASR:{asr_text}")
|
||||
await websocket.send_text("GENERATING_IMAGE:正在优化提示词...")
|
||||
|
||||
# 等待一会让 ESP32 显示文字
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 优化提示词
|
||||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text)
|
||||
|
||||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||||
await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...")
|
||||
|
||||
# 调用文生图API
|
||||
image_path = await asyncio.to_thread(generate_image, optimized_prompt)
|
||||
|
||||
if image_path and os.path.exists(image_path):
|
||||
# 读取图片并发送回ESP32
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
|
||||
|
||||
# 使用hex编码发送(每个字节2个字符)
|
||||
image_hex = image_data.hex()
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||||
|
||||
# 分片发送图片数据
|
||||
chunk_size = 1024
|
||||
for i in range(0, len(image_hex), chunk_size):
|
||||
chunk = image_hex[i:i+chunk_size]
|
||||
await websocket.send_text(f"IMAGE_DATA:{chunk}")
|
||||
|
||||
await websocket.send_text("IMAGE_END")
|
||||
print("Image sent to ESP32")
|
||||
else:
|
||||
await websocket.send_text("IMAGE_ERROR:图片生成失败")
|
||||
await start_async_image_generation(websocket, asr_text)
|
||||
else:
|
||||
print("No ASR text, skipping image generation")
|
||||
|
||||
print("Server processing finished.")
|
||||
|
||||
elif text.startswith("GET_FONTS_BATCH:"):
|
||||
# Format: GET_FONTS_BATCH:code1,code2,code3 (decimal unicode)
|
||||
elif text.startswith("GET_TASK_STATUS:"):
|
||||
task_id = text.split(":", 1)[1].strip()
|
||||
if task_id in active_tasks:
|
||||
task = active_tasks[task_id]
|
||||
await websocket.send_text(f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}")
|
||||
else:
|
||||
await websocket.send_text(f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成")
|
||||
|
||||
elif text.startswith("GET_FONTS_BATCH:") or text.startswith("GET_FONT") or text == "GET_FONT_MD5" or text == "GET_HIGH_FREQ":
|
||||
# 使用新的统一字体处理函数
|
||||
try:
|
||||
codes_str = text.split(":")[1]
|
||||
code_list = codes_str.split(",")
|
||||
print(f"Batch Font Request for {len(code_list)} chars: {code_list}")
|
||||
|
||||
for code_str in code_list:
|
||||
if not code_str: continue
|
||||
|
||||
try:
|
||||
unicode_val = int(code_str)
|
||||
char = chr(unicode_val)
|
||||
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
else:
|
||||
print(f"Character {char} is not a valid 2-byte GB2312 char")
|
||||
# Send empty/dummy? Or just skip.
|
||||
# Better to send something so client doesn't wait forever if it counts responses.
|
||||
# But client probably uses a set of missing chars.
|
||||
continue
|
||||
|
||||
# Calc offset
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
# Read font file
|
||||
# Optimization: Open file once outside loop?
|
||||
# For simplicity, keep it here, OS caching helps.
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||||
await websocket.send_text(response)
|
||||
# Small yield to let network flush?
|
||||
# await asyncio.sleep(0.001)
|
||||
except Exception as e:
|
||||
print(f"Error processing batch item {code_str}: {e}")
|
||||
|
||||
# Send a completion marker
|
||||
await websocket.send_text("FONT_BATCH_END")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling BATCH FONT request: {e}")
|
||||
await websocket.send_text("FONT_BATCH_END") # Ensure we unblock client
|
||||
|
||||
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
|
||||
# 格式: GET_FONT_UNICODE:12345 (decimal) or GET_FONT:0xA1A1 (hex)
|
||||
try:
|
||||
is_unicode = text.startswith("GET_FONT_UNICODE:")
|
||||
code_str = text.split(":")[1]
|
||||
|
||||
target_code_str = code_str # Used for response
|
||||
|
||||
if is_unicode:
|
||||
unicode_val = int(code_str)
|
||||
char = chr(unicode_val)
|
||||
try:
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
else:
|
||||
print(f"Character {char} is not a valid 2-byte GB2312 char")
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"Failed to encode {char} to gb2312: {e}")
|
||||
continue
|
||||
if text.startswith("GET_FONTS_BATCH:"):
|
||||
await handle_font_request(websocket, text, text.split(":", 1)[1])
|
||||
elif text.startswith("GET_FONT_FRAGMENT:"):
|
||||
await handle_font_request(websocket, text, text.split(":", 1)[1])
|
||||
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
|
||||
parts = text.split(":", 1)
|
||||
await handle_font_request(websocket, parts[0], parts[1] if len(parts) > 1 else "")
|
||||
else:
|
||||
code = int(code_str, 16)
|
||||
await handle_font_request(websocket, text, "")
|
||||
except Exception as e:
|
||||
print(f"Font request error: {e}")
|
||||
await websocket.send_text("FONT_BATCH_END:0:0")
|
||||
|
||||
# 计算偏移量
|
||||
# GB2312 编码范围:0xA1A1 - 0xFEFE
|
||||
|
||||
154
websocket_server/test_generated_thumb.bin
Normal file
154
websocket_server/test_generated_thumb.bin
Normal file
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user