278 lines
12 KiB
Python
278 lines
12 KiB
Python
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||
import uvicorn
|
||
import asyncio
|
||
import os
|
||
import subprocess
|
||
import struct
|
||
from dotenv import load_dotenv
|
||
import dashscope
|
||
from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult
|
||
import json
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||
|
||
app = FastAPI()
|
||
|
||
# 存储接收到的音频数据
|
||
audio_buffer = bytearray()
|
||
RECORDING_RAW_FILE = "received_audio.raw"
|
||
RECORDING_MP3_FILE = "received_audio.mp3"
|
||
VOLUME_GAIN = 10.0 # 放大倍数
|
||
FONT_FILE = "GB2312-16.bin"
|
||
|
||
class MyRecognitionCallback(RecognitionCallback):
|
||
def __init__(self, websocket: WebSocket, loop: asyncio.AbstractEventLoop):
|
||
self.websocket = websocket
|
||
self.loop = loop
|
||
|
||
def on_open(self) -> None:
|
||
print("ASR Session started")
|
||
|
||
def on_close(self) -> None:
|
||
print("ASR Session closed")
|
||
|
||
def on_event(self, result: RecognitionResult) -> None:
|
||
if result.get_sentence():
|
||
text = result.get_sentence()['text']
|
||
print(f"ASR Result: {text}")
|
||
# 将识别结果发送回客户端
|
||
try:
|
||
asyncio.run_coroutine_threadsafe(
|
||
self.websocket.send_text(f"ASR:{text}"),
|
||
self.loop
|
||
)
|
||
except Exception as e:
|
||
print(f"Failed to send ASR result to client: {e}")
|
||
|
||
def process_chunk_32_to_16(chunk_bytes, gain=1.0):
|
||
processed_chunk = bytearray()
|
||
# Iterate 4 bytes at a time
|
||
for i in range(0, len(chunk_bytes), 4):
|
||
if i+3 < len(chunk_bytes):
|
||
# 取 chunk[i+2] 和 chunk[i+3] 组成 16-bit signed int
|
||
sample = struct.unpack_from('<h', chunk_bytes, i+2)[0]
|
||
|
||
# 放大音量
|
||
sample = int(sample * gain)
|
||
|
||
# 限幅 (Clamping) 防止溢出爆音
|
||
if sample > 32767: sample = 32767
|
||
elif sample < -32768: sample = -32768
|
||
|
||
# 重新打包为 16-bit little-endian
|
||
processed_chunk.extend(struct.pack('<h', sample))
|
||
return processed_chunk
|
||
|
||
@app.websocket("/ws/audio")
|
||
async def websocket_endpoint(websocket: WebSocket):
|
||
global audio_buffer
|
||
await websocket.accept()
|
||
print("Client connected")
|
||
|
||
recognition = None
|
||
processed_buffer = bytearray()
|
||
loop = asyncio.get_running_loop()
|
||
|
||
try:
|
||
while True:
|
||
# 接收消息 (可能是文本指令或二进制音频数据)
|
||
try:
|
||
message = await websocket.receive()
|
||
except RuntimeError as e:
|
||
if "Cannot call \"receive\" once a disconnect message has been received" in str(e):
|
||
print("Client disconnected (RuntimeError caught)")
|
||
break
|
||
raise e
|
||
|
||
if "text" in message:
|
||
text = message["text"]
|
||
print(f"Received text: {text}")
|
||
|
||
if text == "START_RECORDING":
|
||
print("Start recording...")
|
||
audio_buffer = bytearray() # 清空缓冲区
|
||
processed_buffer = bytearray()
|
||
|
||
# 启动实时语音识别
|
||
try:
|
||
callback = MyRecognitionCallback(websocket, loop)
|
||
recognition = Recognition(
|
||
model='paraformer-realtime-v2',
|
||
format='pcm',
|
||
sample_rate=16000,
|
||
callback=callback
|
||
)
|
||
recognition.start()
|
||
print("DashScope ASR started")
|
||
except Exception as e:
|
||
print(f"Failed to start ASR: {e}")
|
||
recognition = None
|
||
|
||
elif text == "STOP_RECORDING":
|
||
print(f"Stop recording. Total raw bytes: {len(audio_buffer)}")
|
||
|
||
# 停止语音识别
|
||
if recognition:
|
||
try:
|
||
recognition.stop()
|
||
print("DashScope ASR stopped")
|
||
except Exception as e:
|
||
print(f"Error stopping ASR: {e}")
|
||
recognition = None
|
||
|
||
# 使用实时处理过的音频数据
|
||
processed_audio = processed_buffer
|
||
|
||
print(f"Processed audio size: {len(processed_audio)} bytes (Gain: {VOLUME_GAIN}x)")
|
||
|
||
# 2. 保存原始 RAW 文件 (16-bit PCM)
|
||
with open(RECORDING_RAW_FILE, "wb") as f:
|
||
f.write(processed_audio)
|
||
|
||
# 3. 转换为 MP3 并保存 (使用 ffmpeg 命令行,避免 Python 3.13 audioop 问题)
|
||
try:
|
||
# ffmpeg -y -f s16le -ar 16000 -ac 1 -i received_audio.raw received_audio.mp3
|
||
cmd = [
|
||
"ffmpeg",
|
||
"-y", # 覆盖输出文件
|
||
"-f", "s16le", # 输入格式: signed 16-bit little endian
|
||
"-ar", "16000", # 输入采样率
|
||
"-ac", "1", # 输入声道数
|
||
"-i", RECORDING_RAW_FILE,
|
||
RECORDING_MP3_FILE
|
||
]
|
||
print(f"Running command: {' '.join(cmd)}")
|
||
|
||
# Use asyncio.create_subprocess_exec instead of subprocess.run to avoid blocking the event loop
|
||
process = await asyncio.create_subprocess_exec(
|
||
*cmd,
|
||
stdout=asyncio.subprocess.PIPE,
|
||
stderr=asyncio.subprocess.PIPE
|
||
)
|
||
stdout, stderr = await process.communicate()
|
||
|
||
if process.returncode != 0:
|
||
raise subprocess.CalledProcessError(process.returncode, cmd, output=stdout, stderr=stderr)
|
||
|
||
print(f"Saved MP3 to {RECORDING_MP3_FILE}")
|
||
except subprocess.CalledProcessError as e:
|
||
print(f"Error converting to MP3: {e}")
|
||
# stderr might be bytes
|
||
error_msg = e.stderr.decode() if isinstance(e.stderr, bytes) else str(e.stderr)
|
||
print(f"FFmpeg stderr: {error_msg}")
|
||
except FileNotFoundError:
|
||
print("Error: ffmpeg not found. Please install ffmpeg.")
|
||
except Exception as e:
|
||
print(f"Error converting to MP3: {e}")
|
||
|
||
# 4. 发送回客户端播放
|
||
print("Sending audio back...")
|
||
await websocket.send_text("START_PLAYBACK")
|
||
|
||
# 分块发送
|
||
chunk_size = 4096
|
||
for i in range(0, len(processed_audio), chunk_size):
|
||
chunk = processed_audio[i:i+chunk_size]
|
||
await websocket.send_bytes(chunk)
|
||
# 小延时,避免发送过快导致 ESP32 缓冲区溢出
|
||
# 4096 bytes / 32000 bytes/s (16k*2) = ~0.128s
|
||
# 0.04s 约为 3 倍速发送,既保证缓冲又不至于拥塞
|
||
await asyncio.sleep(0.04)
|
||
|
||
await websocket.send_text("STOP_PLAYBACK")
|
||
print("Audio sent back finished.")
|
||
|
||
elif text.startswith("GET_FONT:"):
|
||
# 格式: GET_FONT:0xA1A1
|
||
try:
|
||
print(f"Font Request Received: {text}")
|
||
hex_code = text.split(":")[1]
|
||
code = int(hex_code, 16)
|
||
|
||
# 计算偏移量
|
||
# GB2312 编码范围:0xA1A1 - 0xFEFE
|
||
# 区码:高字节 - 0xA0
|
||
# 位码:低字节 - 0xA0
|
||
area = (code >> 8) - 0xA0
|
||
index = (code & 0xFF) - 0xA0
|
||
|
||
if area >= 1 and index >= 1:
|
||
offset = ((area - 1) * 94 + (index - 1)) * 32
|
||
|
||
# 读取字体文件
|
||
# 注意:这里为了简单,每次都打开文件。如果并发高,应该缓存文件句柄或内容。
|
||
# 假设字体文件在当前目录或上级目录
|
||
# Prioritize finding the file in the script's directory
|
||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||
font_path = os.path.join(script_dir, FONT_FILE)
|
||
|
||
# Fallback: check one level up
|
||
if not os.path.exists(font_path):
|
||
font_path = os.path.join(script_dir, "..", FONT_FILE)
|
||
|
||
# Fallback: check current working directory
|
||
if not os.path.exists(font_path):
|
||
font_path = FONT_FILE
|
||
|
||
if os.path.exists(font_path):
|
||
print(f"Reading font from: {font_path} (Offset: {offset})")
|
||
with open(font_path, "rb") as f:
|
||
f.seek(offset)
|
||
font_data = f.read(32)
|
||
|
||
if len(font_data) == 32:
|
||
import binascii
|
||
hex_data = binascii.hexlify(font_data).decode('utf-8')
|
||
response = f"FONT_DATA:{hex_code}:{hex_data}"
|
||
print(f"Sending Font Response: {response[:30]}...")
|
||
await websocket.send_text(response)
|
||
else:
|
||
print(f"Error: Read {len(font_data)} bytes for font data (expected 32)")
|
||
else:
|
||
print(f"Font file not found: {font_path}")
|
||
else:
|
||
print(f"Invalid GB2312 code: {hex_code} (Area: {area}, Index: {index})")
|
||
except Exception as e:
|
||
print(f"Error handling GET_FONT: {e}")
|
||
|
||
elif "bytes" in message:
|
||
# 接收音频数据并追加到缓冲区
|
||
data = message["bytes"]
|
||
audio_buffer.extend(data)
|
||
|
||
# 实时处理并发送给 ASR
|
||
pcm_chunk = process_chunk_32_to_16(data, VOLUME_GAIN)
|
||
processed_buffer.extend(pcm_chunk)
|
||
|
||
if recognition:
|
||
try:
|
||
recognition.send_audio_frame(pcm_chunk)
|
||
except Exception as e:
|
||
print(f"Error sending audio frame to ASR: {e}")
|
||
|
||
except WebSocketDisconnect:
|
||
print("Client disconnected")
|
||
if recognition:
|
||
try:
|
||
recognition.stop()
|
||
except:
|
||
pass
|
||
except Exception as e:
|
||
print(f"Error: {e}")
|
||
if recognition:
|
||
try:
|
||
recognition.stop()
|
||
except:
|
||
pass
|
||
|
||
if __name__ == "__main__":
|
||
# 获取本机IP,方便ESP32连接
|
||
import socket
|
||
hostname = socket.gethostname()
|
||
local_ip = socket.gethostbyname(hostname)
|
||
print(f"Server running on ws://{local_ip}:8000/ws/audio")
|
||
|
||
uvicorn.run(app, host="0.0.0.0", port=8000)
|