984 lines
41 KiB
Python
984 lines
41 KiB
Python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||
import uvicorn
|
||
import asyncio
|
||
import os
|
||
import subprocess
|
||
import struct
|
||
import base64
|
||
import time
|
||
import hashlib
|
||
import json
|
||
from dotenv import load_dotenv
|
||
import dashscope
|
||
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
|
||
from dashscope import ImageSynthesis
|
||
from dashscope import Generation
|
||
|
||
import sys
|
||
# import os
|
||
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||
import convert_img
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||
|
||
app = FastAPI()
|
||
|
||
# 字体文件配置
|
||
FONT_FILE = "GB2312-16.bin"
|
||
FONT_CHUNK_SIZE = 512
|
||
HIGH_FREQ_CHARS = "的一是在不了有和人这中大为上个国我以要他时来用们生到作地于出就分对成会可主发年动同工也能下过子说产种面而方后多定行学法所民得经十三之进着等部度家电力里如水化高自二理起小物现实加量都两体制机当使点从业本去把性好应开它合还因由其些然前外天政四日那社义事平形相全表间样与关各重新线内数正心反你明看原又么利比或但质气第向道命此变条只没结解问意建月公无系军很情者最立代想已通并提直题党程展五果料象员革位入常文总次品式活设及管特件长求老头基资边流路级少图山统接知较将组见计别她手角期根论运农指几九区强放决西被干做必战先回则任取据处队南给色光门即保治北造百规热领七海口东导器压志世金增争济阶油思术极交受联什认六共权收证改清己美再采转更单风切打白教速花带安场身车例真务具万每目至达走积示议声报斗完类八离华名确才科张信马节话米整空元况今集温传土许步群广石记需段研界拉林律叫且究观越织装影算低持音众书布复容儿须际商非验连断深难近矿千周委素技备半办青省列习响约支般史感劳便团往酸历市克何除消构府称太准精值号率族维划选标写存候毛亲快效斯院查江型眼王按格养易置派层片始却专状育厂京识适属圆包火住调满县局照参红细引听该铁价严龙飞量迹AI贴纸生成连功败请试"
|
||
|
||
# 高频字对应的Unicode码点列表
|
||
HIGH_FREQ_UNICODE = [ord(c) for c in HIGH_FREQ_CHARS]
|
||
|
||
# 字体缓存
|
||
font_cache = {}
|
||
font_md5 = {}
|
||
font_data_buffer = None
|
||
|
||
def calculate_md5(filepath):
|
||
"""计算文件的MD5哈希值"""
|
||
if not os.path.exists(filepath):
|
||
return None
|
||
hash_md5 = hashlib.md5()
|
||
with open(filepath, "rb") as f:
|
||
for chunk in iter(lambda: f.read(4096), b""):
|
||
hash_md5.update(chunk)
|
||
return hash_md5.hexdigest()
|
||
|
||
|
||
def get_font_data(unicode_val):
|
||
"""从字体文件获取单个字符数据(带缓存)"""
|
||
if unicode_val in font_cache:
|
||
return font_cache[unicode_val]
|
||
|
||
try:
|
||
char = chr(unicode_val)
|
||
gb_bytes = char.encode('gb2312')
|
||
if len(gb_bytes) == 2:
|
||
code = struct.unpack('>H', gb_bytes)[0]
|
||
area = (code >> 8) - 0xA0
|
||
index = (code & 0xFF) - 0xA0
|
||
|
||
if area >= 1 and index >= 1:
|
||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||
|
||
if font_data_buffer:
|
||
if offset + 32 <= len(font_data_buffer):
|
||
font_data = font_data_buffer[offset:offset+32]
|
||
font_cache[unicode_val] = font_data
|
||
return font_data
|
||
else:
|
||
# Fallback to file reading if buffer failed
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
font_path = os.path.join(script_dir, FONT_FILE)
|
||
if not os.path.exists(font_path):
|
||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||
if not os.path.exists(font_path):
|
||
font_path = FONT_FILE
|
||
|
||
if os.path.exists(font_path):
|
||
with open(font_path, "rb") as f:
|
||
f.seek(offset)
|
||
font_data = f.read(32)
|
||
if len(font_data) == 32:
|
||
font_cache[unicode_val] = font_data
|
||
return font_data
|
||
except:
|
||
pass
|
||
return None
|
||
|
||
|
||
def init_font_cache():
|
||
"""初始化字体缓存和MD5"""
|
||
global font_cache, font_md5, font_data_buffer
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
font_path = os.path.join(script_dir, FONT_FILE)
|
||
|
||
if not os.path.exists(font_path):
|
||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||
|
||
if os.path.exists(font_path):
|
||
font_md5 = calculate_md5(font_path)
|
||
print(f"Font MD5: {font_md5}")
|
||
|
||
# 加载整个字体文件到内存
|
||
try:
|
||
with open(font_path, "rb") as f:
|
||
font_data_buffer = f.read()
|
||
print(f"Loaded font file into memory: {len(font_data_buffer)} bytes")
|
||
except Exception as e:
|
||
print(f"Error loading font file: {e}")
|
||
font_data_buffer = None
|
||
|
||
# 预加载高频字到缓存 (仍然保留以便快速访问)
|
||
for unicode_val in HIGH_FREQ_UNICODE:
|
||
get_font_data(unicode_val)
|
||
print(f"Preloaded {len(font_cache)} high-frequency characters")
|
||
|
||
# 启动时初始化字体缓存
|
||
init_font_cache()
|
||
|
||
# 存储接收到的音频数据
|
||
audio_buffer = bytearray()
|
||
RECORDING_RAW_FILE = "received_audio.raw"
|
||
RECORDING_MP3_FILE = "received_audio.mp3"
|
||
VOLUME_GAIN = 10.0
|
||
GENERATED_IMAGE_FILE = "generated_image.png"
|
||
GENERATED_THUMB_FILE = "generated_thumb.bin"
|
||
OUTPUT_DIR = "output_images"
|
||
|
||
if not os.path.exists(OUTPUT_DIR):
|
||
os.makedirs(OUTPUT_DIR)
|
||
|
||
image_counter = 0
|
||
|
||
def get_output_path():
|
||
global image_counter
|
||
image_counter += 1
|
||
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||
return os.path.join(OUTPUT_DIR, f"image_{timestamp}_{image_counter}.png")
|
||
|
||
THUMB_SIZE = 240
|
||
|
||
# 字体请求队列(用于重试机制)
|
||
font_request_queue = {}
|
||
FONT_RETRY_MAX = 3
|
||
|
||
# 图片生成任务管理
|
||
class ImageGenerationTask:
|
||
"""图片生成任务管理类"""
|
||
def __init__(self, task_id: str, asr_text: str, websocket: WebSocket):
|
||
self.task_id = task_id
|
||
self.asr_text = asr_text
|
||
self.websocket = websocket
|
||
self.status = "pending" # pending, optimizing, generating, completed, failed
|
||
self.progress = 0
|
||
self.message = ""
|
||
self.result = None
|
||
self.error = None
|
||
|
||
# 存储活跃的图片生成任务
|
||
active_tasks = {}
|
||
task_counter = 0
|
||
|
||
|
||
async def start_async_image_generation(websocket: WebSocket, asr_text: str):
|
||
"""异步启动图片生成任务,不阻塞WebSocket连接"""
|
||
global task_counter, active_tasks
|
||
|
||
task_id = f"task_{task_counter}_{int(time.time() * 1000)}"
|
||
task_counter += 1
|
||
|
||
task = ImageGenerationTask(task_id, asr_text, websocket)
|
||
active_tasks[task_id] = task
|
||
|
||
print(f"Starting async image generation task: {task_id}")
|
||
|
||
await websocket.send_text(f"TASK_ID:{task_id}")
|
||
|
||
# 获取当前事件循环
|
||
try:
|
||
loop = asyncio.get_running_loop()
|
||
except RuntimeError:
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
|
||
# 使用队列在线程和主事件循环之间传递消息
|
||
message_queue = asyncio.Queue()
|
||
|
||
async def progress_callback_async(progress: int, message: str):
|
||
"""异步进度回调"""
|
||
task.progress = progress
|
||
task.message = message
|
||
try:
|
||
await websocket.send_text(f"TASK_PROGRESS:{task_id}:{progress}:{message}")
|
||
except Exception as e:
|
||
print(f"Error sending progress: {e}")
|
||
|
||
def progress_callback(progress: int, message: str):
|
||
"""进度回调函数(在线程中调用)"""
|
||
task.progress = progress
|
||
task.message = message
|
||
# 通过队列在主循环中发送消息
|
||
asyncio.run_coroutine_threadsafe(
|
||
progress_callback_async(progress, message),
|
||
loop
|
||
)
|
||
|
||
try:
|
||
task.status = "optimizing"
|
||
|
||
await websocket.send_text("STATUS:OPTIMIZING:正在优化提示词...")
|
||
await asyncio.sleep(0.2)
|
||
|
||
# 同步调用优化函数
|
||
optimized_prompt = await asyncio.to_thread(optimize_prompt, asr_text, progress_callback)
|
||
|
||
# 确保返回有效的提示词
|
||
if not optimized_prompt:
|
||
optimized_prompt = asr_text
|
||
print(f"Warning: optimize_prompt returned None, using original text: {asr_text}")
|
||
|
||
await websocket.send_text(f"PROMPT:{optimized_prompt}")
|
||
task.optimized_prompt = optimized_prompt
|
||
|
||
task.status = "generating"
|
||
await websocket.send_text("STATUS:RENDERING:正在生成图片,请稍候...")
|
||
await asyncio.sleep(0.2)
|
||
|
||
# 同步调用图片生成函数
|
||
image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
|
||
|
||
task.result = image_path
|
||
|
||
if image_path and os.path.exists(image_path):
|
||
task.status = "completed"
|
||
await websocket.send_text("STATUS:COMPLETE:图片生成完成")
|
||
await asyncio.sleep(0.2)
|
||
|
||
await send_image_to_client(websocket, image_path)
|
||
else:
|
||
task.status = "failed"
|
||
task.error = "图片生成失败"
|
||
await websocket.send_text("IMAGE_ERROR:图片生成失败")
|
||
await websocket.send_text("STATUS:ERROR:图片生成失败")
|
||
|
||
except Exception as e:
|
||
task.status = "failed"
|
||
task.error = str(e)
|
||
print(f"Image generation task error: {e}")
|
||
try:
|
||
await websocket.send_text(f"IMAGE_ERROR:图片生成出错: {str(e)}")
|
||
await websocket.send_text("STATUS:ERROR:图片生成出错")
|
||
except Exception as ws_e:
|
||
print(f"Error sending error message: {ws_e}")
|
||
finally:
|
||
if task_id in active_tasks:
|
||
del active_tasks[task_id]
|
||
|
||
return task
|
||
|
||
|
||
async def send_image_to_client(websocket: WebSocket, image_path: str):
|
||
"""发送图片数据到客户端"""
|
||
with open(image_path, 'rb') as f:
|
||
image_data = f.read()
|
||
|
||
print(f"Sending image to ESP32, size: {len(image_data)} bytes")
|
||
|
||
# Send start marker
|
||
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
|
||
|
||
# Send binary data directly
|
||
chunk_size = 512 # Decreased chunk size for ESP32 memory stability
|
||
for i in range(0, len(image_data), chunk_size):
|
||
chunk = image_data[i:i+chunk_size]
|
||
await websocket.send_bytes(chunk)
|
||
|
||
# Send end marker
|
||
await websocket.send_text("IMAGE_END")
|
||
print("Image sent to ESP32 (Binary)")
|
||
|
||
|
||
async def send_font_batch_with_retry(websocket, code_list, retry_count=0):
|
||
"""批量发送字体数据(带重试机制)"""
|
||
global font_request_queue
|
||
|
||
success_codes = set()
|
||
failed_codes = []
|
||
|
||
for code_str in code_list:
|
||
if not code_str:
|
||
continue
|
||
|
||
try:
|
||
unicode_val = int(code_str)
|
||
font_data = get_font_data(unicode_val)
|
||
|
||
if font_data:
|
||
import binascii
|
||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||
await websocket.send_text(response)
|
||
success_codes.add(unicode_val)
|
||
else:
|
||
failed_codes.append(code_str)
|
||
except Exception as e:
|
||
print(f"Error processing font {code_str}: {e}")
|
||
failed_codes.append(code_str)
|
||
|
||
# 记录失败的请求用于重试
|
||
if failed_codes and retry_count < FONT_RETRY_MAX:
|
||
req_key = f"retry_{retry_count}_{time.time()}"
|
||
font_request_queue[req_key] = {
|
||
'codes': failed_codes,
|
||
'retry': retry_count + 1,
|
||
'timestamp': time.time()
|
||
}
|
||
|
||
return len(success_codes), failed_codes
|
||
|
||
|
||
async def send_font_with_fragment(websocket, unicode_val):
|
||
"""使用二进制分片方式发送字体数据"""
|
||
font_data = get_font_data(unicode_val)
|
||
if not font_data:
|
||
return False
|
||
|
||
# 分片发送
|
||
total_size = len(font_data)
|
||
chunk_size = FONT_CHUNK_SIZE
|
||
|
||
for i in range(0, total_size, chunk_size):
|
||
chunk = font_data[i:i+chunk_size]
|
||
seq_num = i // chunk_size
|
||
|
||
# 构造二进制消息头: 2字节序列号 + 2字节总片数 + 数据
|
||
header = struct.pack('<HH', seq_num, (total_size + chunk_size - 1) // chunk_size)
|
||
payload = header + chunk
|
||
|
||
await websocket.send_bytes(payload)
|
||
|
||
return True
|
||
|
||
|
||
async def handle_font_request(websocket, message_type, data):
|
||
"""处理字体请求"""
|
||
if message_type == "GET_FONT_MD5":
|
||
# 发送字体文件MD5
|
||
await websocket.send_text(f"FONT_MD5:{font_md5}")
|
||
return
|
||
|
||
elif message_type == "GET_HIGH_FREQ":
|
||
# 批量获取高频字
|
||
high_freq_list = HIGH_FREQ_UNICODE[:100] # 限制每次100个
|
||
req_str = ",".join([str(c) for c in high_freq_list])
|
||
await websocket.send_text(f"GET_FONTS_BATCH:{req_str}")
|
||
return
|
||
|
||
elif message_type.startswith("GET_FONTS_BATCH:"):
|
||
# 批量请求字体
|
||
try:
|
||
codes_str = data
|
||
code_list = codes_str.split(",")
|
||
print(f"Batch Font Request for {len(code_list)} chars")
|
||
|
||
success_count, failed = await send_font_batch_with_retry(websocket, code_list)
|
||
print(f"Font batch: {success_count} success, {len(failed)} failed")
|
||
|
||
# 发送完成标记
|
||
await websocket.send_text(f"FONT_BATCH_END:{success_count}:{len(failed)}")
|
||
|
||
# 如果有失败的,进行重试
|
||
if failed:
|
||
await asyncio.sleep(0.5)
|
||
await send_font_batch_with_retry(websocket, failed, retry_count=1)
|
||
|
||
except Exception as e:
|
||
print(f"Error handling batch font request: {e}")
|
||
await websocket.send_text("FONT_BATCH_END:0:0")
|
||
return
|
||
|
||
elif message_type.startswith("GET_FONT_FRAGMENT:"):
|
||
# 二进制分片传输请求
|
||
try:
|
||
unicode_val = int(data)
|
||
await send_font_with_fragment(websocket, unicode_val)
|
||
except Exception as e:
|
||
print(f"Error sending font fragment: {e}")
|
||
return
|
||
|
||
elif message_type.startswith("GET_FONT_UNICODE:") or message_type.startswith("GET_FONT:"):
|
||
# 单个字体请求(兼容旧版)
|
||
try:
|
||
is_unicode = message_type.startswith("GET_FONT_UNICODE:")
|
||
code_str = data
|
||
|
||
if is_unicode:
|
||
unicode_val = int(code_str)
|
||
font_data = get_font_data(unicode_val)
|
||
else:
|
||
code = int(code_str, 16)
|
||
area = (code >> 8) - 0xA0
|
||
index = (code & 0xFF) - 0xA0
|
||
if area >= 1 and index >= 1:
|
||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
font_path = os.path.join(script_dir, FONT_FILE)
|
||
if not os.path.exists(font_path):
|
||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||
if os.path.exists(font_path):
|
||
with open(font_path, "rb") as f:
|
||
f.seek(offset)
|
||
font_data = f.read(32)
|
||
else:
|
||
font_data = None
|
||
else:
|
||
font_data = None
|
||
|
||
if font_data:
|
||
import binascii
|
||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||
response = f"FONT_DATA:{code_str}:{hex_data}"
|
||
await websocket.send_text(response)
|
||
except Exception as e:
|
||
print(f"Error handling font request: {e}")
|
||
|
||
class MyRecognitionCallback(RecognitionCallback):
|
||
def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop):
|
||
self.websocket = websocket
|
||
self.loop = loop
|
||
self.final_text = "" # 保存最终识别结果
|
||
self.sentence_list = [] # 累积所有句子
|
||
self.last_send_time = 0 # 上次发送时间
|
||
|
||
def on_open(self) -> None:
|
||
print("ASR Session started")
|
||
self.sentence_list = []
|
||
self.final_text = ""
|
||
self.last_send_time = 0
|
||
|
||
def on_close(self) -> None:
|
||
print("ASR Session closed")
|
||
# 关闭时将所有句子合并为完整文本
|
||
if self.sentence_list:
|
||
self.final_text = "".join(self.sentence_list)
|
||
print(f"Final combined ASR text: {self.final_text}")
|
||
# 最后发送一次完整的
|
||
try:
|
||
if self.loop.is_running():
|
||
asyncio.run_coroutine_threadsafe(
|
||
self.websocket.send_text(f"ASR:{self.final_text}"),
|
||
self.loop
|
||
)
|
||
except Exception as e:
|
||
print(f"Failed to send final ASR result: {e}")
|
||
|
||
def on_event(self, result: RecognitionResult) -> None:
|
||
if result.get_sentence():
|
||
text = result.get_sentence()['text']
|
||
|
||
# 获取当前句子的结束状态
|
||
# 注意:DashScope Python SDK 的 Result 结构可能需要根据版本调整
|
||
# 这里假设我们只关心文本内容的变化
|
||
|
||
# 简单的去重逻辑:如果新来的文本比上一句长且包含上一句,则认为是同一句的更新
|
||
if self.sentence_list:
|
||
last_sentence = self.sentence_list[-1]
|
||
# 去掉句尾标点进行比较,因为流式结果可能标点不稳定
|
||
last_clean = last_sentence.rstrip('。,?!')
|
||
text_clean = text.rstrip('。,?!')
|
||
|
||
if text_clean.startswith(last_clean):
|
||
# 更新当前句子
|
||
self.sentence_list[-1] = text
|
||
elif last_clean.startswith(text_clean):
|
||
# 如果新来的比旧的短但也是前缀(不太可能发生,除非回溯),忽略或更新
|
||
pass
|
||
else:
|
||
# 新的句子
|
||
self.sentence_list.append(text)
|
||
else:
|
||
self.sentence_list.append(text)
|
||
|
||
# 同时更新 final_text 以便 Stop 时获取
|
||
self.final_text = "".join(self.sentence_list)
|
||
print(f"ASR Update: {self.final_text}")
|
||
|
||
# 用户要求录音时不返回文字,只在结束后返回完整结果
|
||
# 所以这里注释掉实时发送逻辑
|
||
# 将识别结果发送回客户端
|
||
# 增加节流机制:每 500ms 发送一次,或者文本长度变化较大时发送
|
||
# current_time = time.time()
|
||
# if current_time - self.last_send_time > 0.5:
|
||
# self.last_send_time = current_time
|
||
# try:
|
||
# if self.loop.is_running():
|
||
# asyncio.run_coroutine_threadsafe(
|
||
# self.websocket.send_text(f"ASR:{self.final_text}"),
|
||
# self.loop
|
||
# )
|
||
# except Exception as e:
|
||
# print(f"Failed to send ASR result to client: {e}")
|
||
|
||
def process_chunk_32_to_16(chunk_bytes, gain=1.0):
|
||
processed_chunk = bytearray()
|
||
# Iterate 4 bytes at a time
|
||
for i in range(0, len(chunk_bytes), 4):
|
||
if i+3 < len(chunk_bytes):
|
||
# 取 chunk[i+2] 和 chunk[i+3] 组成 16-bit signed int
|
||
sample = struct.unpack_from('<h', chunk_bytes, i+2)[0]
|
||
|
||
# 放大音量
|
||
sample = int(sample * gain)
|
||
|
||
# 限幅 (Clamping) 防止溢出爆音
|
||
if sample > 32767: sample = 32767
|
||
elif sample < -32768: sample = -32768
|
||
|
||
# 重新打包为 16-bit little-endian
|
||
processed_chunk.extend(struct.pack('<h', sample))
|
||
return processed_chunk
|
||
|
||
|
||
def optimize_prompt(asr_text, progress_callback=None):
|
||
"""使用大模型优化提示词"""
|
||
print(f"Optimizing prompt for: {asr_text}")
|
||
|
||
if progress_callback:
|
||
progress_callback(0, "正在准备优化提示词...")
|
||
|
||
system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。
|
||
关键要求:
|
||
1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。
|
||
2. 画面必须清晰、线条粗壮,适合低分辨率热敏打印机打印。
|
||
3. 绝对不要有复杂的阴影、渐变、彩色描述。
|
||
4. 背景必须是纯白 (White background)。
|
||
5. 提示词内容请使用英文描述,因为绘图模型对英文理解更好,但在描述中强调 "black and white line art", "simple lines", "vector style"。
|
||
6. 尺寸比例遵循宽48mm:高30mm (约 1.6:1)。
|
||
7. 直接输出优化后的提示词,不要包含任何解释。"""
|
||
|
||
try:
|
||
if progress_callback:
|
||
progress_callback(10, "正在调用AI优化提示词...")
|
||
print(f"Calling AI with prompt: {system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:")
|
||
|
||
response = Generation.call(
|
||
model='qwen-turbo',
|
||
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
|
||
max_tokens=200,
|
||
temperature=0.8
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
if hasattr(response, 'output') and response.output and \
|
||
hasattr(response.output, 'choices') and response.output.choices and \
|
||
len(response.output.choices) > 0:
|
||
|
||
optimized = response.output.choices[0].message.content.strip()
|
||
print(f"Optimized prompt: {optimized}")
|
||
|
||
if progress_callback:
|
||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||
|
||
return optimized
|
||
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
|
||
# Handle case where API returns text directly instead of choices
|
||
optimized = response.output.text.strip()
|
||
print(f"Optimized prompt (direct text): {optimized}")
|
||
|
||
if progress_callback:
|
||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||
|
||
return optimized
|
||
else:
|
||
print(f"Prompt optimization response format error: {response}")
|
||
if progress_callback:
|
||
progress_callback(0, "提示词优化响应格式错误")
|
||
return asr_text
|
||
else:
|
||
print(f"Prompt optimization failed: {response.code} - {response.message}")
|
||
if progress_callback:
|
||
progress_callback(0, f"提示词优化失败: {response.message}")
|
||
return asr_text
|
||
|
||
except Exception as e:
|
||
print(f"Error optimizing prompt: {e}")
|
||
if progress_callback:
|
||
progress_callback(0, f"提示词优化出错: {str(e)}")
|
||
return asr_text
|
||
|
||
|
||
def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2):
|
||
"""调用万相文生图API生成图片
|
||
|
||
Args:
|
||
prompt: 图像生成提示词
|
||
progress_callback: 进度回调函数 (progress, message)
|
||
retry_count: 当前重试次数
|
||
max_retries: 最大重试次数
|
||
"""
|
||
print(f"Generating image for prompt: {prompt}")
|
||
|
||
if progress_callback:
|
||
progress_callback(35, "正在请求AI生成图片...")
|
||
|
||
try:
|
||
if not prompt:
|
||
print("Error: prompt is empty")
|
||
if progress_callback:
|
||
progress_callback(0, "提示词为空")
|
||
return None
|
||
|
||
response = ImageSynthesis.call(
|
||
model='wanx2.0-t2i-turbo',
|
||
prompt=prompt
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
if not response.output:
|
||
print("Error: response.output is None")
|
||
if progress_callback:
|
||
progress_callback(0, "API响应无效")
|
||
return None
|
||
|
||
task_status = response.output.get('task_status')
|
||
|
||
if task_status == 'PENDING' or task_status == 'RUNNING':
|
||
print("Waiting for image generation to complete...")
|
||
if progress_callback:
|
||
progress_callback(45, "AI正在生成图片中...")
|
||
|
||
import time
|
||
task_id = response.output.get('task_id')
|
||
max_wait = 120
|
||
waited = 0
|
||
while waited < max_wait:
|
||
time.sleep(2)
|
||
waited += 2
|
||
task_result = ImageSynthesis.fetch(task_id)
|
||
if task_result.output.task_status == 'SUCCEEDED':
|
||
response.output = task_result.output
|
||
break
|
||
elif task_result.output.task_status == 'FAILED':
|
||
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
|
||
print(f"Image generation failed: {error_msg}")
|
||
if progress_callback:
|
||
progress_callback(35, f"图片生成失败: {error_msg}")
|
||
return None
|
||
|
||
if response.output.get('task_status') == 'SUCCEEDED':
|
||
image_url = response.output['results'][0]['url']
|
||
print(f"Image generated, downloading from: {image_url}")
|
||
|
||
if progress_callback:
|
||
progress_callback(70, "正在下载生成的图片...")
|
||
|
||
import urllib.request
|
||
urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE)
|
||
print(f"Image saved to {GENERATED_IMAGE_FILE}")
|
||
|
||
# 保存一份到 output_images 目录
|
||
output_path = get_output_path()
|
||
import shutil
|
||
shutil.copy(GENERATED_IMAGE_FILE, output_path)
|
||
print(f"Image also saved to {output_path}")
|
||
|
||
if progress_callback:
|
||
progress_callback(80, "正在处理图片...")
|
||
|
||
# 缩放图片并转换为RGB565格式
|
||
try:
|
||
from PIL import Image
|
||
img = Image.open(GENERATED_IMAGE_FILE)
|
||
|
||
# 缩小到THUMB_SIZE x THUMB_SIZE
|
||
img = img.resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS)
|
||
|
||
# 转换为RGB565格式的原始数据
|
||
# 每个像素2字节 (R5 G6 B5)
|
||
rgb565_data = bytearray()
|
||
|
||
for y in range(THUMB_SIZE):
|
||
for x in range(THUMB_SIZE):
|
||
r, g, b = img.getpixel((x, y))[:3]
|
||
|
||
# 转换为RGB565
|
||
r5 = (r >> 3) & 0x1F
|
||
g6 = (g >> 2) & 0x3F
|
||
b5 = (b >> 3) & 0x1F
|
||
|
||
# Pack as Big Endian (>H) which is standard for SPI displays
|
||
# RGB565: Red(5) Green(6) Blue(5)
|
||
rgb565 = (r5 << 11) | (g6 << 5) | b5
|
||
rgb565_data.extend(struct.pack('>H', rgb565))
|
||
|
||
# 保存为.bin文件
|
||
with open(GENERATED_THUMB_FILE, 'wb') as f:
|
||
f.write(rgb565_data)
|
||
|
||
print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes")
|
||
|
||
if progress_callback:
|
||
progress_callback(100, "图片生成完成!")
|
||
|
||
return GENERATED_THUMB_FILE
|
||
|
||
except ImportError:
|
||
print("PIL not available, sending original image")
|
||
if progress_callback:
|
||
progress_callback(100, "图片生成完成!(原始格式)")
|
||
return GENERATED_IMAGE_FILE
|
||
except Exception as e:
|
||
print(f"Error processing image: {e}")
|
||
if progress_callback:
|
||
progress_callback(80, f"图片处理出错: {str(e)}")
|
||
return GENERATED_IMAGE_FILE
|
||
else:
|
||
error_msg = f"{response.code} - {response.message}"
|
||
print(f"Image generation failed: {error_msg}")
|
||
|
||
# 重试机制
|
||
if retry_count < max_retries:
|
||
print(f"Retrying... ({retry_count + 1}/{max_retries})")
|
||
if progress_callback:
|
||
progress_callback(35, f"图片生成失败,正在重试 ({retry_count + 1}/{max_retries})...")
|
||
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
|
||
else:
|
||
if progress_callback:
|
||
progress_callback(35, f"图片生成失败: {error_msg}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"Error generating image: {e}")
|
||
|
||
# 重试机制
|
||
if retry_count < max_retries:
|
||
print(f"Retrying after error... ({retry_count + 1}/{max_retries})")
|
||
if progress_callback:
|
||
progress_callback(35, f"生成出错,正在重试 ({retry_count + 1}/{max_retries})...")
|
||
return generate_image(prompt, progress_callback, retry_count + 1, max_retries)
|
||
else:
|
||
if progress_callback:
|
||
progress_callback(35, f"图片生成出错: {str(e)}")
|
||
return None
|
||
|
||
@app.websocket("/ws/audio")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
global audio_buffer
|
||
await websocket.accept()
|
||
print("Client connected")
|
||
|
||
recognition = None
|
||
callback = None # 保存callback对象
|
||
processed_buffer = bytearray()
|
||
loop = asyncio.get_running_loop()
|
||
|
||
try:
|
||
while True:
|
||
# 接收消息 (可能是文本指令或二进制音频数据)
|
||
try:
|
||
message = await websocket.receive()
|
||
except RuntimeError as e:
|
||
if "Cannot call \"receive\" once a disconnect message has been received" in str(e):
|
||
print("Client disconnected (RuntimeError caught)")
|
||
break
|
||
raise e
|
||
|
||
if "text" in message:
|
||
text = message["text"]
|
||
print(f"Received text: {text}")
|
||
|
||
if text == "START_RECORDING":
|
||
print("Start recording...")
|
||
audio_buffer = bytearray() # 清空缓冲区
|
||
processed_buffer = bytearray()
|
||
|
||
# 启动实时语音识别
|
||
try:
|
||
callback = MyRecognitionCallback(websocket, loop)
|
||
recognition = Recognition(
|
||
model='paraformer-realtime-v2',
|
||
format='pcm',
|
||
sample_rate=16000,
|
||
callback=callback
|
||
)
|
||
recognition.start()
|
||
print("DashScope ASR started")
|
||
except Exception as e:
|
||
print(f"Failed to start ASR: {e}")
|
||
recognition = None
|
||
callback = None
|
||
|
||
elif text == "STOP_RECORDING":
|
||
print(f"Stop recording. Total raw bytes: {len(audio_buffer)}")
|
||
|
||
# 停止语音识别
|
||
if recognition:
|
||
try:
|
||
recognition.stop()
|
||
print("DashScope ASR stopped")
|
||
except Exception as e:
|
||
print(f"Error stopping ASR: {e}")
|
||
recognition = None
|
||
|
||
# 使用实时处理过的音频数据
|
||
processed_audio = processed_buffer
|
||
|
||
print(f"Processed audio size: {len(processed_audio)} bytes (Gain: {VOLUME_GAIN}x)")
|
||
|
||
# 获取ASR识别结果
|
||
asr_text = ""
|
||
if callback:
|
||
asr_text = callback.final_text
|
||
print(f"Final ASR text: {asr_text}")
|
||
|
||
# 2. 保存原始 RAW 文件 (16-bit PCM)
|
||
with open(RECORDING_RAW_FILE, "wb") as f:
|
||
f.write(processed_audio)
|
||
|
||
# 3. 转换为 MP3 并保存 (使用 ffmpeg 命令行,避免 Python 3.13 audioop 问题)
|
||
try:
|
||
# ffmpeg -y -f s16le -ar 16000 -ac 1 -i received_audio.raw received_audio.mp3
|
||
cmd = [
|
||
"ffmpeg",
|
||
"-y", # 覆盖输出文件
|
||
"-f", "s16le", # 输入格式: signed 16-bit little endian
|
||
"-ar", "16000", # 输入采样率
|
||
"-ac", "1", # 输入声道数
|
||
"-i", RECORDING_RAW_FILE,
|
||
RECORDING_MP3_FILE
|
||
]
|
||
print(f"Running command: {' '.join(cmd)}")
|
||
|
||
# Use asyncio.create_subprocess_exec instead of subprocess.run to avoid blocking the event loop
|
||
process = await asyncio.create_subprocess_exec(
|
||
*cmd,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE
|
||
)
|
||
stdout, stderr = await process.communicate()
|
||
|
||
if process.returncode != 0:
|
||
raise subprocess.CalledProcessError(process.returncode, cmd, output=stdout, stderr=stderr)
|
||
|
||
print(f"Saved MP3 to {RECORDING_MP3_FILE}")
|
||
except subprocess.CalledProcessError as e:
|
||
print(f"Error converting to MP3: {e}")
|
||
# stderr might be bytes
|
||
error_msg = e.stderr.decode() if isinstance(e.stderr, bytes) else str(e.stderr)
|
||
print(f"FFmpeg stderr: {error_msg}")
|
||
except FileNotFoundError:
|
||
print("Error: ffmpeg not found. Please install ffmpeg.")
|
||
except Exception as e:
|
||
print(f"Error converting to MP3: {e}")
|
||
|
||
# 4. 如果有识别结果,发送ASR文字到ESP32
|
||
if asr_text:
|
||
print(f"ASR result: {asr_text}")
|
||
|
||
# 主动下发字体数据
|
||
try:
|
||
unique_chars = set(asr_text)
|
||
code_list = [str(ord(c)) for c in unique_chars]
|
||
print(f"Sending font data for {len(code_list)} characters...")
|
||
success_count, failed = await send_font_batch_with_retry(websocket, code_list)
|
||
print(f"Font data sent: {success_count} success, {len(failed)} failed")
|
||
except Exception as e:
|
||
print(f"Error sending font data: {e}")
|
||
|
||
# 发送 ASR 文字到 ESP32 显示
|
||
await websocket.send_text(f"ASR:{asr_text}")
|
||
|
||
# 以前自动生成图片的逻辑已移除
|
||
# 等待客户端发送 GENERATE_IMAGE 指令
|
||
else:
|
||
print("No ASR text")
|
||
# 如果没有文字,也通知一下,避免UI卡在某个状态
|
||
# await websocket.send_text("ASR:")
|
||
|
||
print("Server processing finished.")
|
||
|
||
elif text.startswith("GENERATE_IMAGE:"):
|
||
# 收到生成图片指令
|
||
prompt_text = text.split(":", 1)[1]
|
||
print(f"Received GENERATE_IMAGE request: {prompt_text}")
|
||
if prompt_text:
|
||
asyncio.create_task(start_async_image_generation(websocket, prompt_text))
|
||
else:
|
||
await websocket.send_text("STATUS:ERROR:提示词为空")
|
||
|
||
elif text == "PRINT_IMAGE":
|
||
print("Received PRINT_IMAGE request")
|
||
if os.path.exists(GENERATED_IMAGE_FILE):
|
||
try:
|
||
# Use convert_img logic to get TSPL commands
|
||
tspl_data = convert_img.image_to_tspl_commands(GENERATED_IMAGE_FILE)
|
||
if tspl_data:
|
||
print(f"Sending printer data: {len(tspl_data)} bytes")
|
||
await websocket.send_text(f"PRINTER_DATA_START:{len(tspl_data)}")
|
||
|
||
# Send in chunks
|
||
chunk_size = 512
|
||
for i in range(0, len(tspl_data), chunk_size):
|
||
chunk = tspl_data[i:i+chunk_size]
|
||
await websocket.send_bytes(chunk)
|
||
# Small delay to prevent overwhelming ESP32 buffer
|
||
await asyncio.sleep(0.01)
|
||
|
||
await websocket.send_text("PRINTER_DATA_END")
|
||
print("Printer data sent")
|
||
else:
|
||
await websocket.send_text("STATUS:ERROR:图片转换失败")
|
||
except Exception as e:
|
||
print(f"Error converting image for printer: {e}")
|
||
await websocket.send_text(f"STATUS:ERROR:打印出错: {str(e)}")
|
||
else:
|
||
await websocket.send_text("STATUS:ERROR:没有可打印的图片")
|
||
|
||
elif text.startswith("GET_TASK_STATUS:"):
|
||
task_id = text.split(":", 1)[1].strip()
|
||
if task_id in active_tasks:
|
||
task = active_tasks[task_id]
|
||
await websocket.send_text(f"TASK_STATUS:{task_id}:{task.status}:{task.progress}:{task.message}")
|
||
else:
|
||
await websocket.send_text(f"TASK_STATUS:{task_id}:unknown:0:任务不存在或已完成")
|
||
|
||
elif text.startswith("GET_FONTS_BATCH:") or text.startswith("GET_FONT") or text == "GET_FONT_MD5" or text == "GET_HIGH_FREQ":
|
||
# 使用新的统一字体处理函数
|
||
try:
|
||
if text.startswith("GET_FONTS_BATCH:"):
|
||
await handle_font_request(websocket, text, text.split(":", 1)[1])
|
||
elif text.startswith("GET_FONT_FRAGMENT:"):
|
||
await handle_font_request(websocket, text, text.split(":", 1)[1])
|
||
elif text.startswith("GET_FONT_UNICODE:") or text.startswith("GET_FONT:"):
|
||
parts = text.split(":", 1)
|
||
await handle_font_request(websocket, parts[0], parts[1] if len(parts) > 1 else "")
|
||
else:
|
||
await handle_font_request(websocket, text, "")
|
||
except Exception as e:
|
||
print(f"Font request error: {e}")
|
||
await websocket.send_text("FONT_BATCH_END:0:0")
|
||
|
||
elif "bytes" in message:
|
||
# 接收音频数据并追加到缓冲区
|
||
data = message["bytes"]
|
||
audio_buffer.extend(data)
|
||
|
||
# 实时处理并发送给 ASR
|
||
pcm_chunk = process_chunk_32_to_16(data, VOLUME_GAIN)
|
||
processed_buffer.extend(pcm_chunk)
|
||
|
||
if recognition:
|
||
try:
|
||
recognition.send_audio_frame(pcm_chunk)
|
||
except Exception as e:
|
||
print(f"Error sending audio frame to ASR: {e}")
|
||
|
||
except WebSocketDisconnect:
|
||
print("Client disconnected")
|
||
if recognition:
|
||
try:
|
||
recognition.stop()
|
||
except:
|
||
pass
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
if recognition:
|
||
try:
|
||
recognition.stop()
|
||
except:
|
||
pass
|
||
|
||
if __name__ == "__main__":
|
||
# 获取本机IP,方便ESP32连接
|
||
import socket
|
||
hostname = socket.gethostname()
|
||
local_ip = socket.gethostbyname(hostname)
|
||
print(f"Server running on ws://{local_ip}:8000/ws/audio")
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|