This commit is contained in:
jeremygan2021
2026-03-03 21:27:07 +08:00
parent 2470013ef3
commit c87d5deedf
7 changed files with 460 additions and 31 deletions

View File

@@ -83,3 +83,13 @@ class Display:
self.tft.fill_rect(100, 240 - bar_height, 40, bar_height - last_bar_height, color) self.tft.fill_rect(100, 240 - bar_height, 40, bar_height - last_bar_height, color)
return bar_height return bar_height
def show_image(self, x, y, width, height, rgb565_data):
"""在指定位置显示RGB565格式的图片数据"""
if not self.tft: return
try:
# 将字节数据转换为适合blit_buffer的格式
self.tft.blit_buffer(rgb565_data, x, y, width, height)
except Exception as e:
print(f"Show image error: {e}")

141
main.py
View File

@@ -9,6 +9,7 @@ from audio import AudioPlayer, Microphone
from display import Display from display import Display
from websocket_client import WebSocketClient from websocket_client import WebSocketClient
import uselect import uselect
import ujson
WIFI_SSID = "Tangledup-AI" WIFI_SSID = "Tangledup-AI"
WIFI_PASS = "djt12345678" WIFI_PASS = "djt12345678"
@@ -16,15 +17,22 @@ SERVER_IP = "6.6.6.88"
SERVER_PORT = 8000 SERVER_PORT = 8000
SERVER_URL = f"ws://{SERVER_IP}:{SERVER_PORT}/ws/audio" SERVER_URL = f"ws://{SERVER_IP}:{SERVER_PORT}/ws/audio"
# 图片接收状态
IMAGE_STATE_IDLE = 0
IMAGE_STATE_RECEIVING = 1
def connect_wifi(max_retries=3): IMG_WIDTH = 120
IMG_HEIGHT = 120
def connect_wifi(max_retries=5):
wlan = network.WLAN(network.STA_IF) wlan = network.WLAN(network.STA_IF)
try: try:
wlan.active(False) wlan.active(False)
time.sleep(1) time.sleep(2)
wlan.active(True) wlan.active(True)
time.sleep(1) time.sleep(3)
except Exception as e: except Exception as e:
print(f"WiFi init error: {e}") print(f"WiFi init error: {e}")
return False return False
@@ -40,23 +48,25 @@ def connect_wifi(max_retries=3):
start_time = time.ticks_ms() start_time = time.ticks_ms()
while not wlan.isconnected(): while not wlan.isconnected():
if time.ticks_diff(time.ticks_ms(), start_time) > 20000: if time.ticks_diff(time.ticks_ms(), start_time) > 30000:
print("WiFi timeout!") print("WiFi timeout!")
break break
time.sleep(0.5) time.sleep(0.5)
print(".", end="")
if wlan.isconnected(): if wlan.isconnected():
print('WiFi connected!') print('\nWiFi connected!')
return True return True
if attempt < max_retries - 1: if attempt < max_retries - 1:
print(f"\nRetry {attempt + 1}/{max_retries}...")
wlan.disconnect() wlan.disconnect()
time.sleep(2) time.sleep(3)
except Exception as e: except Exception as e:
print(f"WiFi error: {e}") print(f"WiFi error: {e}")
if attempt < max_retries - 1: if attempt < max_retries - 1:
time.sleep(3) time.sleep(5)
print("WiFi connection failed!") print("WiFi connection failed!")
return False return False
@@ -69,6 +79,98 @@ def print_asr(text, display=None):
display.text(text, 0, 40, st7789.WHITE) display.text(text, 0, 40, st7789.WHITE)
def process_message(msg, display, image_state, image_data_list):
"""处理WebSocket消息返回新的image_state"""
if not isinstance(msg, str):
return image_state
# 处理ASR消息
if msg.startswith("ASR:"):
print_asr(msg[4:], display)
# 处理图片生成状态消息
elif msg.startswith("GENERATING_IMAGE:"):
print(msg)
if display and display.tft:
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
display.text("正在生成图片...", 0, 40, st7789.YELLOW)
# 处理提示词优化消息
elif msg.startswith("PROMPT:"):
prompt = msg[7:]
print(f"Optimized prompt: {prompt}")
if display and display.tft:
display.fill_rect(0, 60, 240, 40, st7789.BLACK)
display.text("提示词: " + prompt[:20], 0, 60, st7789.CYAN)
# 处理图片开始消息
elif msg.startswith("IMAGE_START:"):
try:
parts = msg.split(":")
size = int(parts[1])
img_size = int(parts[2]) if len(parts) > 2 else 64
print(f"Image start, size: {size}, img_size: {img_size}")
image_data_list.clear()
image_data_list.append(img_size) # 保存图片尺寸
return IMAGE_STATE_RECEIVING
except Exception as e:
print(f"IMAGE_START parse error: {e}")
return image_state
# 处理图片数据消息
elif msg.startswith("IMAGE_DATA:") and image_state == IMAGE_STATE_RECEIVING:
try:
data = msg.split(":", 1)[1]
image_data_list.append(data)
except:
pass
# 处理图片结束消息
elif msg == "IMAGE_END" and image_state == IMAGE_STATE_RECEIVING:
try:
print("Image received, processing...")
# 获取图片尺寸
img_size = image_data_list[0] if image_data_list else 64
hex_data = "".join(image_data_list[1:])
image_data_list.clear()
# 将hex字符串转换为字节数据
img_data = bytes.fromhex(hex_data)
print(f"Image data len: {len(img_data)}")
# 在屏幕中心显示图片
if display and display.tft:
# 计算居中位置
x = (240 - img_size) // 2
y = (240 - img_size) // 2
# 显示图片
display.show_image(x, y, img_size, img_size, img_data)
display.fill_rect(0, 0, 240, 30, st7789.WHITE)
display.text("图片已生成!", 0, 5, st7789.BLACK)
gc.collect()
print("Image displayed")
except Exception as e:
print(f"Image process error: {e}")
return IMAGE_STATE_IDLE
# 处理图片错误消息
elif msg.startswith("IMAGE_ERROR:"):
print(msg)
if display and display.tft:
display.fill_rect(0, 40, 240, 100, st7789.BLACK)
display.text("图片生成失败", 0, 40, st7789.RED)
return IMAGE_STATE_IDLE
return image_state
def main(): def main():
print("\n=== ESP32 Audio ASR ===\n") print("\n=== ESP32 Audio ASR ===\n")
@@ -91,6 +193,8 @@ def main():
is_recording = False is_recording = False
ws = None ws = None
image_state = IMAGE_STATE_IDLE
image_data_list = []
def connect_ws(): def connect_ws():
nonlocal ws nonlocal ws
@@ -151,8 +255,7 @@ def main():
events = poller.poll(0) events = poller.poll(0)
if events: if events:
msg = ws.recv() msg = ws.recv()
if isinstance(msg, str) and msg.startswith("ASR:"): image_state = process_message(msg, display, image_state, image_data_list)
print_asr(msg[4:], display)
except: except:
ws = None ws = None
@@ -170,16 +273,26 @@ def main():
try: try:
ws.send("STOP_RECORDING") ws.send("STOP_RECORDING")
t_wait = time.ticks_add(time.ticks_ms(), 500) # 等待更长时间以接收图片生成结果
t_wait = time.ticks_add(time.ticks_ms(), 30000) # 等待30秒
prev_image_state = image_state
while time.ticks_diff(t_wait, time.ticks_ms()) > 0: while time.ticks_diff(t_wait, time.ticks_ms()) > 0:
poller = uselect.poll() poller = uselect.poll()
poller.register(ws.sock, uselect.POLLIN) poller.register(ws.sock, uselect.POLLIN)
events = poller.poll(100) events = poller.poll(500)
if events: if events:
msg = ws.recv() msg = ws.recv()
if isinstance(msg, str) and msg.startswith("ASR:"): prev_image_state = image_state
print_asr(msg[4:], display) image_state = process_message(msg, display, image_state, image_data_list)
except: # 如果之前在接收图片,现在停止了,说明图片接收完成
if prev_image_state == IMAGE_STATE_RECEIVING and image_state == IMAGE_STATE_IDLE:
break
else:
# 检查是否还在接收图片
if image_state == IMAGE_STATE_IDLE:
break
except Exception as e:
print(f"Stop recording error: {e}")
ws = None ws = None
gc.collect() gc.collect()

Binary file not shown.

Binary file not shown.

View File

@@ -4,3 +4,4 @@ websockets
pydub pydub
dashscope dashscope
python-dotenv python-dotenv
Pillow

View File

@@ -5,11 +5,14 @@ import os
import subprocess import subprocess
import struct import struct
import base64 import base64
import time
import hashlib
import json
from dotenv import load_dotenv from dotenv import load_dotenv
import dashscope import dashscope
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
from dashscope import ImageSynthesis from dashscope import ImageSynthesis
import json from dashscope import Generation
# 加载环境变量 # 加载环境变量
load_dotenv() load_dotenv()
@@ -17,14 +20,272 @@ dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
app = FastAPI() app = FastAPI()
# 字体文件配置
FONT_FILE = "GB2312-16.bin"
FONT_CHUNK_SIZE = 512
HIGH_FREQ_CHARS = "的一是在不了有和人这中大为上个国我以要他时来用们生到作地于出就分对成会可主发年动同工也能下过子说产种面而方后多定行学法所民得经十三之进着等部度家电力里如水化高自二理起小物现实加量都两体制机当使点从业本去把性好应开它合还因由其些然前外天政四日那社义事平形相全表间样与关各重新线内数正心反你明看原又么利比或但质气第向道命此变条只没结解问意建月公无系军很情者最立代想已通并提直题党程展五果料象员革位入常文总次品式活设及管特件长求老头基资边流路级少图山统接知较将组见计别她手角期根论运农指几九区强放决西被干做必战先回则任取据处队南给色光门即保治北造百规热领七海口东导器压志世金增争济阶油思术极交受联什认六共权收证改清己美再采转更单风切打白教速花带安场身车例真务具万每目至达走积示议声报斗完类八离华名确才科张信马节话米整空元况今集温传土许步群广石记需段研界拉林律叫且究观越织装影算低持音众书布复容儿须际商非验连断深难近矿千周委素技备半办青省列习响约支般史感劳便团往酸历市克何除消构府称太准精值号率族维划选标写存候毛亲快效斯院查江型眼王按格养易置派层片始却专状育厂京识适属圆包火住调满县局照参红细引听该铁价严龙飞"
# 高频字对应的Unicode码点列表
HIGH_FREQ_UNICODE = [ord(c) for c in HIGH_FREQ_CHARS]
# 字体缓存
font_cache = {}
font_md5 = {}
def calculate_md5(filepath):
"""计算文件的MD5哈希值"""
if not os.path.exists(filepath):
return None
hash_md5 = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
def init_font_cache():
"""初始化字体缓存和MD5"""
global font_cache, font_md5
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if os.path.exists(font_path):
font_md5 = calculate_md5(font_path)
print(f"Font MD5: {font_md5}")
# 预加载高频字到缓存
for unicode_val in HIGH_FREQ_UNICODE:
try:
char = chr(unicode_val)
gb_bytes = char.encode('gb2312')
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
font_cache[unicode_val] = font_data
except:
pass
print(f"Preloaded {len(font_cache)} high-frequency characters")
# 启动时初始化字体缓存
init_font_cache()
# 存储接收到的音频数据 # 存储接收到的音频数据
audio_buffer = bytearray() audio_buffer = bytearray()
RECORDING_RAW_FILE = "received_audio.raw" RECORDING_RAW_FILE = "received_audio.raw"
RECORDING_MP3_FILE = "received_audio.mp3" RECORDING_MP3_FILE = "received_audio.mp3"
VOLUME_GAIN = 10.0 # 放大倍数 VOLUME_GAIN = 10.0
FONT_FILE = "GB2312-16.bin"
GENERATED_IMAGE_FILE = "generated_image.png" GENERATED_IMAGE_FILE = "generated_image.png"
GENERATED_THUMB_FILE = "generated_thumb.bin" GENERATED_THUMB_FILE = "generated_thumb.bin"
OUTPUT_DIR = "output_images"
if not os.path.exists(OUTPUT_DIR):
os.makedirs(OUTPUT_DIR)
image_counter = 0
def get_output_path():
global image_counter
image_counter += 1
timestamp = time.strftime("%Y%m%d_%H%M%S")
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
THUMB_SIZE = 245
# 字体请求队列(用于重试机制)
font_request_queue = {}
FONT_RETRY_MAX = 3
def get_font_data(unicode_val):
"""从字体文件获取单个字符数据(带缓存)"""
if unicode_val in font_cache:
return font_cache[unicode_val]
try:
char = chr(unicode_val)
gb_bytes = char.encode('gb2312')
if len(gb_bytes) == 2:
code = struct.unpack('>H', gb_bytes)[0]
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if not os.path.exists(font_path):
font_path = FONT_FILE
if os.path.exists(font_path):
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
if len(font_data) == 32:
font_cache[unicode_val] = font_data
return font_data
except:
pass
return None
def send_font_batch_with_retry(websocket, code_list, retry_count=0):
"""批量发送字体数据(带重试机制)"""
global font_request_queue
success_codes = set()
failed_codes = []
for code_str in code_list:
if not code_str:
continue
try:
unicode_val = int(code_str)
font_data = get_font_data(unicode_val)
if font_data:
import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
response = f"FONT_DATA:{code_str}:{hex_data}"
asyncio.run_coroutine_threadsafe(
websocket.send_text(response),
asyncio.get_event_loop()
)
success_codes.add(unicode_val)
else:
failed_codes.append(code_str)
except Exception as e:
print(f"Error processing font {code_str}: {e}")
failed_codes.append(code_str)
# 记录失败的请求用于重试
if failed_codes and retry_count < FONT_RETRY_MAX:
req_key = f"retry_{retry_count}_{time.time()}"
font_request_queue[req_key] = {
'codes': failed_codes,
'retry': retry_count + 1,
'timestamp': time.time()
}
return len(success_codes), failed_codes
async def send_font_with_fragment(websocket, unicode_val):
"""使用二进制分片方式发送字体数据"""
font_data = get_font_data(unicode_val)
if not font_data:
return False
# 分片发送
total_size = len(font_data)
chunk_size = FONT_CHUNK_SIZE
for i in range(0, total_size, chunk_size):
chunk = font_data[i:i+chunk_size]
seq_num = i // chunk_size
# 构造二进制消息头: 2字节序列号 + 2字节总片数 + 数据
header = struct.pack('<HH', seq_num, (total_size + chunk_size - 1) // chunk_size)
payload = header + chunk
await websocket.send_bytes(payload)
return True
async def handle_font_request(websocket, message_type, data):
"""处理字体请求"""
if message_type == "GET_FONT_MD5":
# 发送字体文件MD5
await websocket.send_text(f"FONT_MD5:{font_md5}")
return
elif message_type == "GET_HIGH_FREQ":
# 批量获取高频字
high_freq_list = HIGH_FREQ_UNICODE[:100] # 限制每次100个
req_str = ",".join([str(c) for c in high_freq_list])
await websocket.send_text(f"GET_FONTS_BATCH:{req_str}")
return
elif message_type.startswith("GET_FONTS_BATCH:"):
# 批量请求字体
try:
codes_str = data
code_list = codes_str.split(",")
print(f"Batch Font Request for {len(code_list)} chars")
success_count, failed = send_font_batch_with_retry(websocket, code_list)
print(f"Font batch: {success_count} success, {len(failed)} failed")
# 发送完成标记
await websocket.send_text(f"FONT_BATCH_END:{success_count}:{len(failed)}")
# 如果有失败的,进行重试
if failed:
await asyncio.sleep(0.5)
send_font_batch_with_retry(websocket, failed, retry_count=1)
except Exception as e:
print(f"Error handling batch font request: {e}")
await websocket.send_text("FONT_BATCH_END:0:0")
return
elif message_type.startswith("GET_FONT_FRAGMENT:"):
# 二进制分片传输请求
try:
unicode_val = int(data)
await send_font_with_fragment(websocket, unicode_val)
except Exception as e:
print(f"Error sending font fragment: {e}")
return
elif message_type.startswith("GET_FONT_UNICODE:") or message_type.startswith("GET_FONT:"):
# 单个字体请求(兼容旧版)
try:
is_unicode = message_type.startswith("GET_FONT_UNICODE:")
code_str = data
if is_unicode:
unicode_val = int(code_str)
font_data = get_font_data(unicode_val)
else:
code = int(code_str, 16)
area = (code >> 8) - 0xA0
index = (code & 0xFF) - 0xA0
if area >= 1 and index >= 1:
offset = ((area - 1) * 94 + (index - 1)) * 32
script_dir = os.path.dirname(os.path.abspath(__file__))
font_path = os.path.join(script_dir, FONT_FILE)
if not os.path.exists(font_path):
font_path = os.path.join(script_dir, "..", FONT_FILE)
if os.path.exists(font_path):
with open(font_path, "rb") as f:
f.seek(offset)
font_data = f.read(32)
else:
font_data = None
else:
font_data = None
if font_data:
import binascii
hex_data = binascii.hexlify(font_data).decode('utf-8')
response = f"FONT_DATA:{code_str}:{hex_data}"
await websocket.send_text(response)
except Exception as e:
print(f"Error handling font request: {e}")
class MyRecognitionCallback(RecognitionCallback): class MyRecognitionCallback(RecognitionCallback):
def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop): def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop):
@@ -72,13 +333,46 @@ def process_chunk_32_to_16(chunk_bytes, gain=1.0):
return processed_chunk return processed_chunk
def optimize_prompt(asr_text):
"""使用大模型优化提示词"""
print(f"Optimizing prompt for: {asr_text}")
system_prompt = """你是一个AI图像提示词优化专家。将用户简短的语音识别结果转化为详细的、适合AI图像生成的英文提示词。
要求:
1. 保留核心内容和主要元素
2. 添加适合AI绘画的描述词风格、光线、氛围等
3. 用英文输出
4. 简洁但描述详细
5. 不要添加多余解释,直接输出优化后的提示词"""
try:
response = Generation.call(
model='qwen-turbo',
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
max_tokens=200,
temperature=0.8
)
if response.status_code == 200:
optimized = response.output.choices[0].message.content.strip()
print(f"Optimized prompt: {optimized}")
return optimized
else:
print(f"Prompt optimization failed: {response.code} - {response.message}")
return asr_text
except Exception as e:
print(f"Error optimizing prompt: {e}")
return asr_text
def generate_image(prompt, websocket=None): def generate_image(prompt, websocket=None):
"""调用万相文生图API生成图片""" """调用万相文生图API生成图片"""
print(f"Generating image for prompt: {prompt}") print(f"Generating image for prompt: {prompt}")
try: try:
response = ImageSynthesis.call( response = ImageSynthesis.call(
model='wanx-v1.0-text-to-image', model='wan2.6-t2i',
prompt=prompt, prompt=prompt,
size='512x512', size='512x512',
n=1 n=1
@@ -92,21 +386,26 @@ def generate_image(prompt, websocket=None):
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE) urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
print(f"Image saved to {GENERATED_IMAGE_FILE}") print(f"Image saved to {GENERATED_IMAGE_FILE}")
# 保存一份到 output_images 目录
output_path = get_output_path()
import shutil
shutil.copy(GENERATED_IMAGE_FILE, output_path)
print(f"Image also saved to {output_path}")
# 缩放图片并转换为RGB565格式 # 缩放图片并转换为RGB565格式
try: try:
from PIL import Image from PIL import Image
img = Image.open(GENERATED_IMAGE_FILE) img = Image.open(GENERATED_IMAGE_FILE)
# 缩小到120x120 (屏幕是240x240但需要考虑内存限制) # 缩小到THUMB_SIZE x THUMB_SIZE
thumb_size = 120 img = img.resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS)
img = img.resize((thumb_size, thumb_size), Image.LANCZOS)
# 转换为RGB565格式的原始数据 # 转换为RGB565格式的原始数据
# 每个像素2字节 (R5 G6 B5) # 每个像素2字节 (R5 G6 B5)
rgb565_data = bytearray() rgb565_data = bytearray()
for y in range(thumb_size): for y in range(THUMB_SIZE):
for x in range(thumb_size): for x in range(THUMB_SIZE):
r, g, b = img.getpixel((x, y))[:3] r, g, b = img.getpixel((x, y))[:3]
# 转换为RGB565 # 转换为RGB565
@@ -255,13 +554,19 @@ async def websocket_endpoint(websocket: WebSocket):
# 先发送 ASR 文字到 ESP32 显示 # 先发送 ASR 文字到 ESP32 显示
await websocket.send_text(f"ASR:{asr_text}") await websocket.send_text(f"ASR:{asr_text}")
await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...") await websocket.send_text("GENERATING_IMAGE:正在优化提示词...")
# 等待一会让 ESP32 显示文字 # 等待一会让 ESP32 显示文字
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
# 优化提示词
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text)
await websocket.send_text(f"PROMPT:{optimized_prompt}")
await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...")
# 调用文生图API # 调用文生图API
image_path = await asyncio.to_thread(generate_image, asr_text) image_path = await asyncio.to_thread(generate_image, optimized_prompt)
if image_path and os.path.exists(image_path): if image_path and os.path.exists(image_path):
# 读取图片并发送回ESP32 # 读取图片并发送回ESP32
@@ -270,14 +575,14 @@ async def websocket_endpoint(websocket: WebSocket):
print(f"Sending image to ESP32, size: {len(image_data)} bytes") print(f"Sending image to ESP32, size: {len(image_data)} bytes")
# 将图片转换为base64发送 # 使用hex编码发送每个字节2个字符
image_b64 = base64.b64encode(image_data).decode('utf-8') image_hex = image_data.hex()
await websocket.send_text(f"IMAGE_START:{len(image_data)}") await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
# 分片发送图片数据 # 分片发送图片数据
chunk_size = 4096 chunk_size = 1024
for i in range(0, len(image_b64), chunk_size): for i in range(0, len(image_hex), chunk_size):
chunk = image_b64[i:i+chunk_size] chunk = image_hex[i:i+chunk_size]
await websocket.send_text(f"IMAGE_DATA:{chunk}") await websocket.send_text(f"IMAGE_DATA:{chunk}")
await websocket.send_text("IMAGE_END") await websocket.send_text("IMAGE_END")