1
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -4,3 +4,4 @@ websockets
|
||||
pydub
|
||||
dashscope
|
||||
python-dotenv
|
||||
Pillow
|
||||
|
||||
@@ -5,11 +5,14 @@ import os
|
||||
import subprocess
|
||||
import struct
|
||||
import base64
|
||||
import time
|
||||
import hashlib
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
import dashscope
|
||||
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
|
||||
from dashscope import ImageSynthesis
|
||||
import json
|
||||
from dashscope import Generation
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
@@ -17,14 +20,272 @@ dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# 字体文件配置
|
||||
FONT_FILE = "GB2312-16.bin"
|
||||
FONT_CHUNK_SIZE = 512
|
||||
HIGH_FREQ_CHARS = "的一是在不了有和人这中大为上个国我以要他时来用们生到作地于出就分对成会可主发年动同工也能下过子说产种面而方后多定行学法所民得经十三之进着等部度家电力里如水化高自二理起小物现实加量都两体制机当使点从业本去把性好应开它合还因由其些然前外天政四日那社义事平形相全表间样与关各重新线内数正心反你明看原又么利比或但质气第向道命此变条只没结解问意建月公无系军很情者最立代想已通并提直题党程展五果料象员革位入常文总次品式活设及管特件长求老头基资边流路级少图山统接知较将组见计别她手角期根论运农指几九区强放决西被干做必战先回则任取据处队南给色光门即保治北造百规热领七海口东导器压志世金增争济阶油思术极交受联什认六共权收证改清己美再采转更单风切打白教速花带安场身车例真务具万每目至达走积示议声报斗完类八离华名确才科张信马节话米整空元况今集温传土许步群广石记需段研界拉林律叫且究观越织装影算低持音众书布复容儿须际商非验连断深难近矿千周委素技备半办青省列习响约支般史感劳便团往酸历市克何除消构府称太准精值号率族维划选标写存候毛亲快效斯院查江型眼王按格养易置派层片始却专状育厂京识适属圆包火住调满县局照参红细引听该铁价严龙飞"
|
||||
|
||||
# 高频字对应的Unicode码点列表
|
||||
HIGH_FREQ_UNICODE = [ord(c) for c in HIGH_FREQ_CHARS]
|
||||
|
||||
# 字体缓存
|
||||
font_cache = {}
|
||||
font_md5 = {}
|
||||
|
||||
def calculate_md5(filepath):
|
||||
"""计算文件的MD5哈希值"""
|
||||
if not os.path.exists(filepath):
|
||||
return None
|
||||
hash_md5 = hashlib.md5()
|
||||
with open(filepath, "rb") as f:
|
||||
for chunk in iter(lambda: f.read(4096), b""):
|
||||
hash_md5.update(chunk)
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
def init_font_cache():
|
||||
"""初始化字体缓存和MD5"""
|
||||
global font_cache, font_md5
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
|
||||
if os.path.exists(font_path):
|
||||
font_md5 = calculate_md5(font_path)
|
||||
print(f"Font MD5: {font_md5}")
|
||||
|
||||
# 预加载高频字到缓存
|
||||
for unicode_val in HIGH_FREQ_UNICODE:
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
except:
|
||||
pass
|
||||
print(f"Preloaded {len(font_cache)} high-frequency characters")
|
||||
|
||||
# 启动时初始化字体缓存
|
||||
init_font_cache()
|
||||
|
||||
# 存储接收到的音频数据
|
||||
audio_buffer = bytearray()
|
||||
RECORDING_RAW_FILE = "received_audio.raw"
|
||||
RECORDING_MP3_FILE = "received_audio.mp3"
|
||||
VOLUME_GAIN = 10.0 # 放大倍数
|
||||
FONT_FILE = "GB2312-16.bin"
|
||||
VOLUME_GAIN = 10.0
|
||||
GENERATED_IMAGE_FILE = "generated_image.png"
|
||||
GENERATED_THUMB_FILE = "generated_thumb.bin"
|
||||
OUTPUT_DIR = "output_images"
|
||||
|
||||
if not os.path.exists(OUTPUT_DIR):
|
||||
os.makedirs(OUTPUT_DIR)
|
||||
|
||||
image_counter = 0
|
||||
|
||||
def get_output_path():
|
||||
global image_counter
|
||||
image_counter += 1
|
||||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
|
||||
|
||||
THUMB_SIZE = 245
|
||||
|
||||
# 字体请求队列(用于重试机制)
|
||||
font_request_queue = {}
|
||||
FONT_RETRY_MAX = 3
|
||||
|
||||
|
||||
def get_font_data(unicode_val):
|
||||
"""从字体文件获取单个字符数据(带缓存)"""
|
||||
if unicode_val in font_cache:
|
||||
return font_cache[unicode_val]
|
||||
|
||||
try:
|
||||
char = chr(unicode_val)
|
||||
gb_bytes = char.encode('gb2312')
|
||||
if len(gb_bytes) == 2:
|
||||
code = struct.unpack('>H', gb_bytes)[0]
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = FONT_FILE
|
||||
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
if len(font_data) == 32:
|
||||
font_cache[unicode_val] = font_data
|
||||
return font_data
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||||
"""批量发送字体数据(带重试机制)"""
|
||||
global font_request_queue
|
||||
|
||||
success_codes = set()
|
||||
failed_codes = []
|
||||
|
||||
for code_str in code_list:
|
||||
if not code_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
unicode_val = int(code_str)
|
||||
font_data = get_font_data(unicode_val)
|
||||
|
||||
if font_data:
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
websocket.send_text(response),
|
||||
asyncio.get_event_loop()
|
||||
)
|
||||
success_codes.add(unicode_val)
|
||||
else:
|
||||
failed_codes.append(code_str)
|
||||
except Exception as e:
|
||||
print(f"Error processing font {code_str}: {e}")
|
||||
failed_codes.append(code_str)
|
||||
|
||||
# 记录失败的请求用于重试
|
||||
if failed_codes and retry_count < FONT_RETRY_MAX:
|
||||
req_key = f"retry_{retry_count}_{time.time()}"
|
||||
font_request_queue[req_key] = {
|
||||
'codes': failed_codes,
|
||||
'retry': retry_count + 1,
|
||||
'timestamp': time.time()
|
||||
}
|
||||
|
||||
return len(success_codes), failed_codes
|
||||
|
||||
|
||||
async def send_font_with_fragment(websocket, unicode_val):
|
||||
"""使用二进制分片方式发送字体数据"""
|
||||
font_data = get_font_data(unicode_val)
|
||||
if not font_data:
|
||||
return False
|
||||
|
||||
# 分片发送
|
||||
total_size = len(font_data)
|
||||
chunk_size = FONT_CHUNK_SIZE
|
||||
|
||||
for i in range(0, total_size, chunk_size):
|
||||
chunk = font_data[i:i+chunk_size]
|
||||
seq_num = i // chunk_size
|
||||
|
||||
# 构造二进制消息头: 2字节序列号 + 2字节总片数 + 数据
|
||||
header = struct.pack('<HH', seq_num, (total_size + chunk_size - 1) // chunk_size)
|
||||
payload = header + chunk
|
||||
|
||||
await websocket.send_bytes(payload)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def handle_font_request(websocket, message_type, data):
|
||||
"""处理字体请求"""
|
||||
if message_type == "GET_FONT_MD5":
|
||||
# 发送字体文件MD5
|
||||
await websocket.send_text(f"FONT_MD5:{font_md5}")
|
||||
return
|
||||
|
||||
elif message_type == "GET_HIGH_FREQ":
|
||||
# 批量获取高频字
|
||||
high_freq_list = HIGH_FREQ_UNICODE[:100] # 限制每次100个
|
||||
req_str = ",".join([str(c) for c in high_freq_list])
|
||||
await websocket.send_text(f"GET_FONTS_BATCH:{req_str}")
|
||||
return
|
||||
|
||||
elif message_type.startswith("GET_FONTS_BATCH:"):
|
||||
# 批量请求字体
|
||||
try:
|
||||
codes_str = data
|
||||
code_list = codes_str.split(",")
|
||||
print(f"Batch Font Request for {len(code_list)} chars")
|
||||
|
||||
success_count, failed = send_font_batch_with_retry(websocket, code_list)
|
||||
print(f"Font batch: {success_count} success, {len(failed)} failed")
|
||||
|
||||
# 发送完成标记
|
||||
await websocket.send_text(f"FONT_BATCH_END:{success_count}:{len(failed)}")
|
||||
|
||||
# 如果有失败的,进行重试
|
||||
if failed:
|
||||
await asyncio.sleep(0.5)
|
||||
send_font_batch_with_retry(websocket, failed, retry_count=1)
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error handling batch font request: {e}")
|
||||
await websocket.send_text("FONT_BATCH_END:0:0")
|
||||
return
|
||||
|
||||
elif message_type.startswith("GET_FONT_FRAGMENT:"):
|
||||
# 二进制分片传输请求
|
||||
try:
|
||||
unicode_val = int(data)
|
||||
await send_font_with_fragment(websocket, unicode_val)
|
||||
except Exception as e:
|
||||
print(f"Error sending font fragment: {e}")
|
||||
return
|
||||
|
||||
elif message_type.startswith("GET_FONT_UNICODE:") or message_type.startswith("GET_FONT:"):
|
||||
# 单个字体请求(兼容旧版)
|
||||
try:
|
||||
is_unicode = message_type.startswith("GET_FONT_UNICODE:")
|
||||
code_str = data
|
||||
|
||||
if is_unicode:
|
||||
unicode_val = int(code_str)
|
||||
font_data = get_font_data(unicode_val)
|
||||
else:
|
||||
code = int(code_str, 16)
|
||||
area = (code >> 8) - 0xA0
|
||||
index = (code & 0xFF) - 0xA0
|
||||
if area >= 1 and index >= 1:
|
||||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
font_path = os.path.join(script_dir, FONT_FILE)
|
||||
if not os.path.exists(font_path):
|
||||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||||
if os.path.exists(font_path):
|
||||
with open(font_path, "rb") as f:
|
||||
f.seek(offset)
|
||||
font_data = f.read(32)
|
||||
else:
|
||||
font_data = None
|
||||
else:
|
||||
font_data = None
|
||||
|
||||
if font_data:
|
||||
import binascii
|
||||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||||
await websocket.send_text(response)
|
||||
except Exception as e:
|
||||
print(f"Error handling font request: {e}")
|
||||
|
||||
class MyRecognitionCallback(RecognitionCallback):
|
||||
def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop):
|
||||
@@ -72,13 +333,46 @@ def process_chunk_32_to_16(chunk_bytes, gain=1.0):
|
||||
return processed_chunk
|
||||
|
||||
|
||||
def optimize_prompt(asr_text):
|
||||
"""使用大模型优化提示词"""
|
||||
print(f"Optimizing prompt for: {asr_text}")
|
||||
|
||||
system_prompt = """你是一个AI图像提示词优化专家。将用户简短的语音识别结果转化为详细的、适合AI图像生成的英文提示词。
|
||||
要求:
|
||||
1. 保留核心内容和主要元素
|
||||
2. 添加适合AI绘画的描述词(风格、光线、氛围等)
|
||||
3. 用英文输出
|
||||
4. 简洁但描述详细
|
||||
5. 不要添加多余解释,直接输出优化后的提示词"""
|
||||
|
||||
try:
|
||||
response = Generation.call(
|
||||
model='qwen-turbo',
|
||||
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
|
||||
max_tokens=200,
|
||||
temperature=0.8
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
optimized = response.output.choices[0].message.content.strip()
|
||||
print(f"Optimized prompt: {optimized}")
|
||||
return optimized
|
||||
else:
|
||||
print(f"Prompt optimization failed: {response.code} - {response.message}")
|
||||
return asr_text
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error optimizing prompt: {e}")
|
||||
return asr_text
|
||||
|
||||
|
||||
def generate_image(prompt, websocket=None):
|
||||
"""调用万相文生图API生成图片"""
|
||||
print(f"Generating image for prompt: {prompt}")
|
||||
|
||||
try:
|
||||
response = ImageSynthesis.call(
|
||||
model='wanx-v1.0-text-to-image',
|
||||
model='wan2.6-t2i',
|
||||
prompt=prompt,
|
||||
size='512x512',
|
||||
n=1
|
||||
@@ -92,21 +386,26 @@ def generate_image(prompt, websocket=None):
|
||||
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
|
||||
print(f"Image saved to {GENERATED_IMAGE_FILE}")
|
||||
|
||||
# 保存一份到 output_images 目录
|
||||
output_path = get_output_path()
|
||||
import shutil
|
||||
shutil.copy(GENERATED_IMAGE_FILE, output_path)
|
||||
print(f"Image also saved to {output_path}")
|
||||
|
||||
# 缩放图片并转换为RGB565格式
|
||||
try:
|
||||
from PIL import Image
|
||||
img = Image.open(GENERATED_IMAGE_FILE)
|
||||
|
||||
# 缩小到120x120 (屏幕是240x240,但需要考虑内存限制)
|
||||
thumb_size = 120
|
||||
img = img.resize((thumb_size, thumb_size), Image.LANCZOS)
|
||||
# 缩小到THUMB_SIZE x THUMB_SIZE
|
||||
img = img.resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS)
|
||||
|
||||
# 转换为RGB565格式的原始数据
|
||||
# 每个像素2字节 (R5 G6 B5)
|
||||
rgb565_data = bytearray()
|
||||
|
||||
for y in range(thumb_size):
|
||||
for x in range(thumb_size):
|
||||
for y in range(THUMB_SIZE):
|
||||
for x in range(THUMB_SIZE):
|
||||
r, g, b = img.getpixel((x, y))[:3]
|
||||
|
||||
# 转换为RGB565
|
||||
@@ -255,13 +554,19 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
# 先发送 ASR 文字到 ESP32 显示
|
||||
await websocket.send_text(f"ASR:{asr_text}")
|
||||
await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...")
|
||||
await websocket.send_text("GENERATING_IMAGE:正在优化提示词...")
|
||||
|
||||
# 等待一会让 ESP32 显示文字
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# 优化提示词
|
||||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text)
|
||||
|
||||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||||
await websocket.send_text("GENERATING_IMAGE:正在生成图片,请稍候...")
|
||||
|
||||
# 调用文生图API
|
||||
image_path = await asyncio.to_thread(generate_image, asr_text)
|
||||
image_path = await asyncio.to_thread(generate_image, optimized_prompt)
|
||||
|
||||
if image_path and os.path.exists(image_path):
|
||||
# 读取图片并发送回ESP32
|
||||
@@ -270,14 +575,14 @@ async def websocket_endpoint(websocket: WebSocket):
|
||||
|
||||
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
|
||||
|
||||
# 将图片转换为base64发送
|
||||
image_b64 = base64.b64encode(image_data).decode('utf-8')
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}")
|
||||
# 使用hex编码发送(每个字节2个字符)
|
||||
image_hex = image_data.hex()
|
||||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||||
|
||||
# 分片发送图片数据
|
||||
chunk_size = 4096
|
||||
for i in range(0, len(image_b64), chunk_size):
|
||||
chunk = image_b64[i:i+chunk_size]
|
||||
chunk_size = 1024
|
||||
for i in range(0, len(image_hex), chunk_size):
|
||||
chunk = image_hex[i:i+chunk_size]
|
||||
await websocket.send_text(f"IMAGE_DATA:{chunk}")
|
||||
|
||||
await websocket.send_text("IMAGE_END")
|
||||
|
||||
Reference in New Issue
Block a user