diff --git a/display.py b/display.py index acabb97..ec525ff 100644 --- a/display.py +++ b/display.py @@ -83,3 +83,13 @@ class Display: self.tft.fill_rect(100, 240 - bar_height, 40, bar_height - last_bar_height, color) return bar_height + + def show_image(self, x, y, width, height, rgb565_data): + """在指定位置显示RGB565格式的图片数据""" + if not self.tft: return + + try: + # 将字节数据转换为适合blit_buffer的格式 + self.tft.blit_buffer(rgb565_data, x, y, width, height) + except Exception as e: + print(f"Show image error: {e}") diff --git a/main.py b/main.py index b2d0ff8..9ab1d1e 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,7 @@ from audio import AudioPlayer, Microphone from display import Display from websocket_client import WebSocketClient import uselect +import ujson WIFI_SSID = "Tangledup-AI" WIFI_PASS = "djt12345678" @@ -16,15 +17,22 @@ SERVER_IP = "6.6.6.88" SERVER_PORT = 8000 SERVER_URL = f"ws://{SERVER_IP}:{SERVER_PORT}/ws/audio" +# 图片接收状态 +IMAGE_STATE_IDLE = 0 +IMAGE_STATE_RECEIVING = 1 -def connect_wifi(max_retries=3): +IMG_WIDTH = 120 +IMG_HEIGHT = 120 + + +def connect_wifi(max_retries=5): wlan = network.WLAN(network.STA_IF) try: wlan.active(False) - time.sleep(1) + time.sleep(2) wlan.active(True) - time.sleep(1) + time.sleep(3) except Exception as e: print(f"WiFi init error: {e}") return False @@ -40,23 +48,25 @@ def connect_wifi(max_retries=3): start_time = time.ticks_ms() while not wlan.isconnected(): - if time.ticks_diff(time.ticks_ms(), start_time) > 20000: + if time.ticks_diff(time.ticks_ms(), start_time) > 30000: print("WiFi timeout!") break time.sleep(0.5) + print(".", end="") if wlan.isconnected(): - print('WiFi connected!') + print('\nWiFi connected!') return True if attempt < max_retries - 1: + print(f"\nRetry {attempt + 1}/{max_retries}...") wlan.disconnect() - time.sleep(2) + time.sleep(3) except Exception as e: print(f"WiFi error: {e}") if attempt < max_retries - 1: - time.sleep(3) + time.sleep(5) print("WiFi connection failed!") return False @@ -69,6 +79,98 @@ def print_asr(text, display=None): display.text(text, 0, 40, st7789.WHITE) +def process_message(msg, display, image_state, image_data_list): + """处理WebSocket消息,返回新的image_state""" + if not isinstance(msg, str): + return image_state + + # 处理ASR消息 + if msg.startswith("ASR:"): + print_asr(msg[4:], display) + + # 处理图片生成状态消息 + elif msg.startswith("GENERATING_IMAGE:"): + print(msg) + if display and display.tft: + display.fill_rect(0, 40, 240, 100, st7789.BLACK) + display.text("正在生成图片...", 0, 40, st7789.YELLOW) + + # 处理提示词优化消息 + elif msg.startswith("PROMPT:"): + prompt = msg[7:] + print(f"Optimized prompt: {prompt}") + if display and display.tft: + display.fill_rect(0, 60, 240, 40, st7789.BLACK) + display.text("提示词: " + prompt[:20], 0, 60, st7789.CYAN) + + # 处理图片开始消息 + elif msg.startswith("IMAGE_START:"): + try: + parts = msg.split(":") + size = int(parts[1]) + img_size = int(parts[2]) if len(parts) > 2 else 64 + print(f"Image start, size: {size}, img_size: {img_size}") + image_data_list.clear() + image_data_list.append(img_size) # 保存图片尺寸 + return IMAGE_STATE_RECEIVING + except Exception as e: + print(f"IMAGE_START parse error: {e}") + return image_state + + # 处理图片数据消息 + elif msg.startswith("IMAGE_DATA:") and image_state == IMAGE_STATE_RECEIVING: + try: + data = msg.split(":", 1)[1] + image_data_list.append(data) + except: + pass + + # 处理图片结束消息 + elif msg == "IMAGE_END" and image_state == IMAGE_STATE_RECEIVING: + try: + print("Image received, processing...") + + # 获取图片尺寸 + img_size = image_data_list[0] if image_data_list else 64 + hex_data = "".join(image_data_list[1:]) + image_data_list.clear() + + # 将hex字符串转换为字节数据 + img_data = bytes.fromhex(hex_data) + + print(f"Image data len: {len(img_data)}") + + # 在屏幕中心显示图片 + if display and display.tft: + # 计算居中位置 + x = (240 - img_size) // 2 + y = (240 - img_size) // 2 + + # 显示图片 + display.show_image(x, y, img_size, img_size, img_data) + + display.fill_rect(0, 0, 240, 30, st7789.WHITE) + display.text("图片已生成!", 0, 5, st7789.BLACK) + + gc.collect() + print("Image displayed") + + except Exception as e: + print(f"Image process error: {e}") + + return IMAGE_STATE_IDLE + + # 处理图片错误消息 + elif msg.startswith("IMAGE_ERROR:"): + print(msg) + if display and display.tft: + display.fill_rect(0, 40, 240, 100, st7789.BLACK) + display.text("图片生成失败", 0, 40, st7789.RED) + return IMAGE_STATE_IDLE + + return image_state + + def main(): print("\n=== ESP32 Audio ASR ===\n") @@ -91,6 +193,8 @@ def main(): is_recording = False ws = None + image_state = IMAGE_STATE_IDLE + image_data_list = [] def connect_ws(): nonlocal ws @@ -151,8 +255,7 @@ def main(): events = poller.poll(0) if events: msg = ws.recv() - if isinstance(msg, str) and msg.startswith("ASR:"): - print_asr(msg[4:], display) + image_state = process_message(msg, display, image_state, image_data_list) except: ws = None @@ -170,16 +273,26 @@ def main(): try: ws.send("STOP_RECORDING") - t_wait = time.ticks_add(time.ticks_ms(), 500) + # 等待更长时间以接收图片生成结果 + t_wait = time.ticks_add(time.ticks_ms(), 30000) # 等待30秒 + prev_image_state = image_state while time.ticks_diff(t_wait, time.ticks_ms()) > 0: poller = uselect.poll() poller.register(ws.sock, uselect.POLLIN) - events = poller.poll(100) + events = poller.poll(500) if events: msg = ws.recv() - if isinstance(msg, str) and msg.startswith("ASR:"): - print_asr(msg[4:], display) - except: + prev_image_state = image_state + image_state = process_message(msg, display, image_state, image_data_list) + # 如果之前在接收图片,现在停止了,说明图片接收完成 + if prev_image_state == IMAGE_STATE_RECEIVING and image_state == IMAGE_STATE_IDLE: + break + else: + # 检查是否还在接收图片 + if image_state == IMAGE_STATE_IDLE: + break + except Exception as e: + print(f"Stop recording error: {e}") ws = None gc.collect() diff --git a/websocket_server/__pycache__/server.cpython-312.pyc b/websocket_server/__pycache__/server.cpython-312.pyc index 3be65f6..caec868 100644 Binary files a/websocket_server/__pycache__/server.cpython-312.pyc and b/websocket_server/__pycache__/server.cpython-312.pyc differ diff --git a/websocket_server/received_audio.mp3 b/websocket_server/received_audio.mp3 index 10b03fd..bd5948a 100644 Binary files a/websocket_server/received_audio.mp3 and b/websocket_server/received_audio.mp3 differ diff --git a/websocket_server/received_audio.raw b/websocket_server/received_audio.raw index cb0bbd1..ac6dcae 100644 Binary files a/websocket_server/received_audio.raw and b/websocket_server/received_audio.raw differ diff --git a/websocket_server/requirements.txt b/websocket_server/requirements.txt index c69469c..f2b1c2f 100644 --- a/websocket_server/requirements.txt +++ b/websocket_server/requirements.txt @@ -4,3 +4,4 @@ websockets pydub dashscope python-dotenv +Pillow diff --git a/websocket_server/server.py b/websocket_server/server.py index 915edee..fb5d663 100644 --- a/websocket_server/server.py +++ b/websocket_server/server.py @@ -5,11 +5,14 @@ import os import subprocess import struct import base64 +import time +import hashlib +import json from dotenv import load_dotenv import dashscope from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult from dashscope import ImageSynthesis -import json +from dashscope import Generation # 加载环境变量 load_dotenv() @@ -17,14 +20,272 @@ dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") app = FastAPI() +# 字体文件配置 +FONT_FILE = "GB2312-16.bin" +FONT_CHUNK_SIZE = 512 +HIGH_FREQ_CHARS = "的一是在不了有和人这中大为上个国我以要他时来用们生到作地于出就分对成会可主发年动同工也能下过子说产种面而方后多定行学法所民得经十三之进着等部度家电力里如水化高自二理起小物现实加量都两体制机当使点从业本去把性好应开它合还因由其些然前外天政四日那社义事平形相全表间样与关各重新线内数正心反你明看原又么利比或但质气第向道命此变条只没结解问意建月公无系军很情者最立代想已通并提直题党程展五果料象员革位入常文总次品式活设及管特件长求老头基资边流路级少图山统接知较将组见计别她手角期根论运农指几九区强放决西被干做必战先回则任取据处队南给色光门即保治北造百规热领七海口东导器压志世金增争济阶油思术极交受联什认六共权收证改清己美再采转更单风切打白教速花带安场身车例真务具万每目至达走积示议声报斗完类八离华名确才科张信马节话米整空元况今集温传土许步群广石记需段研界拉林律叫且究观越织装影算低持音众书布复容儿须际商非验连断深难近矿千周委素技备半办青省列习响约支般史感劳便团往酸历市克何除消构府称太准精值号率族维划选标写存候毛亲快效斯院查江型眼王按格养易置派层片始却专状育厂京识适属圆包火住调满县局照参红细引听该铁价严龙飞" + +# 高频字对应的Unicode码点列表 +HIGH_FREQ_UNICODE = [ord(c) for c in HIGH_FREQ_CHARS] + +# 字体缓存 +font_cache = {} +font_md5 = {} + +def calculate_md5(filepath): + """计算文件的MD5哈希值""" + if not os.path.exists(filepath): + return None + hash_md5 = hashlib.md5() + with open(filepath, "rb") as f: + for chunk in iter(lambda: f.read(4096), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + +def init_font_cache(): + """初始化字体缓存和MD5""" + global font_cache, font_md5 + script_dir = os.path.dirname(os.path.abspath(__file__)) + font_path = os.path.join(script_dir, FONT_FILE) + + if not os.path.exists(font_path): + font_path = os.path.join(script_dir, "..", FONT_FILE) + + if os.path.exists(font_path): + font_md5 = calculate_md5(font_path) + print(f"Font MD5: {font_md5}") + + # 预加载高频字到缓存 + for unicode_val in HIGH_FREQ_UNICODE: + try: + char = chr(unicode_val) + gb_bytes = char.encode('gb2312') + if len(gb_bytes) == 2: + code = struct.unpack('>H', gb_bytes)[0] + area = (code >> 8) - 0xA0 + index = (code & 0xFF) - 0xA0 + if area >= 1 and index >= 1: + offset = ((area - 1) * 94 + (index - 1)) * 32 + with open(font_path, "rb") as f: + f.seek(offset) + font_data = f.read(32) + if len(font_data) == 32: + font_cache[unicode_val] = font_data + except: + pass + print(f"Preloaded {len(font_cache)} high-frequency characters") + +# 启动时初始化字体缓存 +init_font_cache() + # 存储接收到的音频数据 audio_buffer = bytearray() RECORDING_RAW_FILE = "received_audio.raw" RECORDING_MP3_FILE = "received_audio.mp3" -VOLUME_GAIN = 10.0 # 放大倍数 -FONT_FILE = "GB2312-16.bin" +VOLUME_GAIN = 10.0 GENERATED_IMAGE_FILE = "generated_image.png" GENERATED_THUMB_FILE = "generated_thumb.bin" +OUTPUT_DIR = "output_images" + +if not os.path.exists(OUTPUT_DIR): + os.makedirs(OUTPUT_DIR) + +image_counter = 0 + +def get_output_path(): + global image_counter + image_counter += 1 + timestamp = time.strftime("%Y%m%d_%H%M%S") + return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png") + +THUMB_SIZE = 245 + +# 字体请求队列(用于重试机制) +font_request_queue = {} +FONT_RETRY_MAX = 3 + + +def get_font_data(unicode_val): + """从字体文件获取单个字符数据(带缓存)""" + if unicode_val in font_cache: + return font_cache[unicode_val] + + try: + char = chr(unicode_val) + gb_bytes = char.encode('gb2312') + if len(gb_bytes) == 2: + code = struct.unpack('>H', gb_bytes)[0] + area = (code >> 8) - 0xA0 + index = (code & 0xFF) - 0xA0 + + if area >= 1 and index >= 1: + offset = ((area - 1) * 94 + (index - 1)) * 32 + + script_dir = os.path.dirname(os.path.abspath(__file__)) + font_path = os.path.join(script_dir, FONT_FILE) + if not os.path.exists(font_path): + font_path = os.path.join(script_dir, "..", FONT_FILE) + if not os.path.exists(font_path): + font_path = FONT_FILE + + if os.path.exists(font_path): + with open(font_path, "rb") as f: + f.seek(offset) + font_data = f.read(32) + if len(font_data) == 32: + font_cache[unicode_val] = font_data + return font_data + except: + pass + return None + + +def send_font_batch_with_retry(websocket, code_list, retry_count=0): + """批量发送字体数据(带重试机制)""" + global font_request_queue + + success_codes = set() + failed_codes = [] + + for code_str in code_list: + if not code_str: + continue + + try: + unicode_val = int(code_str) + font_data = get_font_data(unicode_val) + + if font_data: + import binascii + hex_data = binascii.hexlify(font_data).decode('utf-8') + response = f"FONT_DATA:{code_str}:{hex_data}" + asyncio.run_coroutine_threadsafe( + websocket.send_text(response), + asyncio.get_event_loop() + ) + success_codes.add(unicode_val) + else: + failed_codes.append(code_str) + except Exception as e: + print(f"Error processing font {code_str}: {e}") + failed_codes.append(code_str) + + # 记录失败的请求用于重试 + if failed_codes and retry_count < FONT_RETRY_MAX: + req_key = f"retry_{retry_count}_{time.time()}" + font_request_queue[req_key] = { + 'codes': failed_codes, + 'retry': retry_count + 1, + 'timestamp': time.time() + } + + return len(success_codes), failed_codes + + +async def send_font_with_fragment(websocket, unicode_val): + """使用二进制分片方式发送字体数据""" + font_data = get_font_data(unicode_val) + if not font_data: + return False + + # 分片发送 + total_size = len(font_data) + chunk_size = FONT_CHUNK_SIZE + + for i in range(0, total_size, chunk_size): + chunk = font_data[i:i+chunk_size] + seq_num = i // chunk_size + + # 构造二进制消息头: 2字节序列号 + 2字节总片数 + 数据 + header = struct.pack('> 8) - 0xA0 + index = (code & 0xFF) - 0xA0 + if area >= 1 and index >= 1: + offset = ((area - 1) * 94 + (index - 1)) * 32 + script_dir = os.path.dirname(os.path.abspath(__file__)) + font_path = os.path.join(script_dir, FONT_FILE) + if not os.path.exists(font_path): + font_path = os.path.join(script_dir, "..", FONT_FILE) + if os.path.exists(font_path): + with open(font_path, "rb") as f: + f.seek(offset) + font_data = f.read(32) + else: + font_data = None + else: + font_data = None + + if font_data: + import binascii + hex_data = binascii.hexlify(font_data).decode('utf-8') + response = f"FONT_DATA:{code_str}:{hex_data}" + await websocket.send_text(response) + except Exception as e: + print(f"Error handling font request: {e}") class MyRecognitionCallback(RecognitionCallback): def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop): @@ -72,13 +333,46 @@ def process_chunk_32_to_16(chunk_bytes, gain=1.0): return processed_chunk +def optimize_prompt(asr_text): + """使用大模型优化提示词""" + print(f"Optimizing prompt for: {asr_text}") + + system_prompt = """你是一个AI图像提示词优化专家。将用户简短的语音识别结果转化为详细的、适合AI图像生成的英文提示词。 +要求: +1. 保留核心内容和主要元素 +2. 添加适合AI绘画的描述词(风格、光线、氛围等) +3. 用英文输出 +4. 简洁但描述详细 +5. 不要添加多余解释,直接输出优化后的提示词""" + + try: + response = Generation.call( + model='qwen-turbo', + prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:', + max_tokens=200, + temperature=0.8 + ) + + if response.status_code == 200: + optimized = response.output.choices[0].message.content.strip() + print(f"Optimized prompt: {optimized}") + return optimized + else: + print(f"Prompt optimization failed: {response.code} - {response.message}") + return asr_text + + except Exception as e: + print(f"Error optimizing prompt: {e}") + return asr_text + + def generate_image(prompt, websocket=None): """调用万相文生图API生成图片""" print(f"Generating image for prompt: {prompt}") try: response = ImageSynthesis.call( - model='wanx-v1.0-text-to-image', + model='wan2.6-t2i', prompt=prompt, size='512x512', n=1 @@ -92,21 +386,26 @@ def generate_image(prompt, websocket=None): urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE) print(f"Image saved to {GENERATED_IMAGE_FILE}") + # 保存一份到 output_images 目录 + output_path = get_output_path() + import shutil + shutil.copy(GENERATED_IMAGE_FILE, output_path) + print(f"Image also saved to {output_path}") + # 缩放图片并转换为RGB565格式 try: from PIL import Image img = Image.open(GENERATED_IMAGE_FILE) - # 缩小到120x120 (屏幕是240x240,但需要考虑内存限制) - thumb_size = 120 - img = img.resize((thumb_size, thumb_size), Image.LANCZOS) + # 缩小到THUMB_SIZE x THUMB_SIZE + img = img.resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS) # 转换为RGB565格式的原始数据 # 每个像素2字节 (R5 G6 B5) rgb565_data = bytearray() - for y in range(thumb_size): - for x in range(thumb_size): + for y in range(THUMB_SIZE): + for x in range(THUMB_SIZE): r, g, b = img.getpixel((x, y))[:3] # 转换为RGB565 @@ -255,13 +554,19 @@ async def websocket_endpoint(websocket: WebSocket): # 先发送 ASR 文字到 ESP32 显示 await websocket.send_text(f"ASR:{asr_text}") - await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...") + await websocket.send_text("GENERATING_IMAGE:正在优化提示词...") # 等待一会让 ESP32 显示文字 await asyncio.sleep(0.5) + # 优化提示词 + optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text) + + await websocket.send_text(f"PROMPT:{optimized_prompt}") + await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...") + # 调用文生图API - image_path = await asyncio.to_thread(generate_image, asr_text) + image_path = await asyncio.to_thread(generate_image, optimized_prompt) if image_path and os.path.exists(image_path): # 读取图片并发送回ESP32 @@ -270,14 +575,14 @@ async def websocket_endpoint(websocket: WebSocket): print(f"Sending image to ESP32, size: {len(image_data)} bytes") - # 将图片转换为base64发送 - image_b64 = base64.b64encode(image_data).decode('utf-8') - await websocket.send_text(f"IMAGE_START:{len(image_data)}") + # 使用hex编码发送(每个字节2个字符) + image_hex = image_data.hex() + await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}") # 分片发送图片数据 - chunk_size = 4096 - for i in range(0, len(image_b64), chunk_size): - chunk = image_b64[i:i+chunk_size] + chunk_size = 1024 + for i in range(0, len(image_hex), chunk_size): + chunk = image_hex[i:i+chunk_size] await websocket.send_text(f"IMAGE_DATA:{chunk}") await websocket.send_text("IMAGE_END")