From 64ff8ffbd4b6602367c8143eef1184a468b854e5 Mon Sep 17 00:00:00 2001 From: jeremygan2021 Date: Thu, 5 Mar 2026 21:43:39 +0800 Subject: [PATCH] t --- websocket_server/.env | 3 +- websocket_server/convert_img.py | 2 +- websocket_server/image_generator.py | 219 +++++++++++++++++++++ websocket_server/server.py | 287 +++++++++------------------- 4 files changed, 311 insertions(+), 200 deletions(-) create mode 100644 websocket_server/image_generator.py diff --git a/websocket_server/.env b/websocket_server/.env index 4e472f7..35be800 100644 --- a/websocket_server/.env +++ b/websocket_server/.env @@ -1 +1,2 @@ -DASHSCOPE_API_KEY=sk-a294f382488d46a1aa0d7cd8e750729b \ No newline at end of file +DASHSCOPE_API_KEY=sk-a294f382488d46a1aa0d7cd8e750729b、 +volcengine_API_KEY=db1f8b60-0ffc-473c-98da-40daa3a95df8 \ No newline at end of file diff --git a/websocket_server/convert_img.py b/websocket_server/convert_img.py index b62f20c..0b1cfb7 100644 --- a/websocket_server/convert_img.py +++ b/websocket_server/convert_img.py @@ -87,7 +87,7 @@ def image_to_tspl_commands(image_path): # GAP 2 mm, 0 mm cmds.extend(b"GAP 2 mm, 0 mm\r\n") # HOME - cmds.extend(b"HOME\r\n") + # cmds.extend(b"HOME\r\n") # 注释掉 HOME,防止每次打印都自动进纸一张 # 2. BITMAP # BITMAP x, y, width_bytes, height, mode, data diff --git a/websocket_server/image_generator.py b/websocket_server/image_generator.py new file mode 100644 index 0000000..5ce1bd3 --- /dev/null +++ b/websocket_server/image_generator.py @@ -0,0 +1,219 @@ +import os +import time +import json +import requests +from dotenv import load_dotenv +import dashscope +from dashscope import ImageSynthesis, Generation + +# Load environment variables +load_dotenv() +dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") + +class ImageGenerator: + def __init__(self, provider="dashscope", model=None): + self.provider = provider + self.model = model + self.api_key = None + + if provider == "doubao": + self.api_key = os.getenv("volcengine_API_KEY") + if not self.model: + self.model = "doubao-seedream-5-0-260128" # Default model from user input + elif provider == "dashscope": + self.api_key = os.getenv("DASHSCOPE_API_KEY") + if not self.model: + self.model = "wanx2.0-t2i-turbo" + + def optimize_prompt(self, asr_text, progress_callback=None): + """Use LLM to optimize the prompt""" + print(f"Optimizing prompt for: {asr_text}") + + if progress_callback: + progress_callback(0, "正在准备优化提示词...") + + system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。 +关键要求: +1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。 +2. 画面必须清晰、线条粗壮,适合低分辨率热敏打印机打印。 +3. 绝对不要有复杂的阴影、渐变、黑白线条描述。 +4. 背景必须是纯白 (White background)。 +5. 提示词内容请使用英文描述,因为绘图模型对英文理解更好,但在描述中强调 "black and white line art", "simple lines", "vector style"。 +6. 尺寸比例遵循宽48mm:高30mm (约 1.6:1)。 +7. 直接输出优化后的提示词,不要包含任何解释。 +如果用户要求输入文字,则用```把文字包裹起来,文字是中文 +"black and white line art, Chinese:```中国人```" +""" + + try: + if progress_callback: + progress_callback(10, "正在调用AI优化提示词...") + + # Currently using Qwen-Turbo for all providers for prompt optimization + # You can also decouple this if needed + response = Generation.call( + model='qwen-turbo', + prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:', + max_tokens=200, + temperature=0.8 + ) + + if response.status_code == 200: + if hasattr(response, 'output') and response.output and \ + hasattr(response.output, 'choices') and response.output.choices and \ + len(response.output.choices) > 0: + + optimized = response.output.choices[0].message.content.strip() + print(f"Optimized prompt: {optimized}") + + if progress_callback: + progress_callback(30, f"提示词优化完成: {optimized[:50]}...") + + return optimized + elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'): + optimized = response.output.text.strip() + print(f"Optimized prompt (direct text): {optimized}") + if progress_callback: + progress_callback(30, f"提示词优化完成: {optimized[:50]}...") + return optimized + else: + print(f"Prompt optimization response format error: {response}") + if progress_callback: + progress_callback(0, "提示词优化响应格式错误") + return asr_text + else: + print(f"Prompt optimization failed: {response.code} - {response.message}") + if progress_callback: + progress_callback(0, f"提示词优化失败: {response.message}") + return asr_text + + except Exception as e: + print(f"Error optimizing prompt: {e}") + if progress_callback: + progress_callback(0, f"提示词优化出错: {str(e)}") + return asr_text + + def generate_image(self, prompt, progress_callback=None): + """Generate image based on provider""" + if self.provider == "dashscope": + return self._generate_dashscope(prompt, progress_callback) + elif self.provider == "doubao": + return self._generate_doubao(prompt, progress_callback) + else: + raise ValueError(f"Unknown provider: {self.provider}") + + def _generate_dashscope(self, prompt, progress_callback=None): + print(f"Generating image with DashScope for prompt: {prompt}") + + if progress_callback: + progress_callback(35, "正在请求DashScope生成图片...") + + try: + response = ImageSynthesis.call( + model=self.model, + prompt=prompt, + size='1280*720' + ) + + if response.status_code == 200: + if not response.output: + print("Error: response.output is None") + return None + + task_status = response.output.get('task_status') + + if task_status == 'PENDING' or task_status == 'RUNNING': + print("Waiting for image generation to complete...") + if progress_callback: + progress_callback(45, "AI正在生成图片中...") + + task_id = response.output.get('task_id') + max_wait = 120 + waited = 0 + while waited < max_wait: + time.sleep(2) + waited += 2 + task_result = ImageSynthesis.fetch(task_id) + if task_result.output.task_status == 'SUCCEEDED': + response.output = task_result.output + break + elif task_result.output.task_status == 'FAILED': + error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error' + print(f"Image generation failed: {error_msg}") + if progress_callback: + progress_callback(35, f"图片生成失败: {error_msg}") + return None + + if response.output.get('task_status') == 'SUCCEEDED': + image_url = response.output['results'][0]['url'] + print(f"Image generated, url: {image_url}") + return image_url + else: + error_msg = f"{response.code} - {response.message}" + print(f"Image generation failed: {error_msg}") + if progress_callback: + progress_callback(35, f"图片生成失败: {error_msg}") + return None + + except Exception as e: + print(f"Error generating image: {e}") + if progress_callback: + progress_callback(35, f"图片生成出错: {str(e)}") + return None + + def _generate_doubao(self, prompt, progress_callback=None): + print(f"Generating image with Doubao for prompt: {prompt}") + + if progress_callback: + progress_callback(35, "正在请求豆包生成图片...") + + url = "https://ark.cn-beijing.volces.com/api/v3/images/generations" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + data = { + "model": self.model, + "prompt": prompt, + "sequential_image_generation": "disabled", + "response_format": "url", + "size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time? + # User's curl says "2K". I will stick to it or maybe check docs. + # Actually for thermal printer, we don't need 2K. But let's follow user example. + "stream": False, + "watermark": True + } + + try: + response = requests.post(url, headers=headers, json=data, timeout=60) + + if response.status_code == 200: + result = response.json() + # Check format of result + # Typically OpenAI compatible or similar + # User example didn't show response format, but usually it's "data": [{"url": "..."}] + + if "data" in result and len(result["data"]) > 0: + image_url = result["data"][0]["url"] + print(f"Image generated, url: {image_url}") + return image_url + elif "error" in result: + error_msg = result["error"].get("message", "Unknown error") + print(f"Doubao API error: {error_msg}") + if progress_callback: + progress_callback(35, f"图片生成失败: {error_msg}") + return None + else: + print(f"Unexpected response format: {result}") + return None + else: + print(f"Doubao API failed with status {response.status_code}: {response.text}") + if progress_callback: + progress_callback(35, f"图片生成失败: {response.status_code}") + return None + + except Exception as e: + print(f"Error calling Doubao API: {e}") + if progress_callback: + progress_callback(35, f"图片生成出错: {str(e)}") + return None diff --git a/websocket_server/server.py b/websocket_server/server.py index c274994..1fa47ed 100644 --- a/websocket_server/server.py +++ b/websocket_server/server.py @@ -11,18 +11,23 @@ import json from dotenv import load_dotenv import dashscope from dashscope.audio.asr import Recognition, RecognitionCallback, RecognitionResult -from dashscope import ImageSynthesis -from dashscope import Generation +# from dashscope import ImageSynthesis +# from dashscope import Generation import sys # import os # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import convert_img +from image_generator import ImageGenerator # 加载环境变量 load_dotenv() dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") +# Initialize image generator +# provider="doubao" or "dashscope" +image_generator = ImageGenerator(provider="doubao") + app = FastAPI() # 字体文件配置 @@ -526,82 +531,11 @@ def process_chunk_32_to_16(chunk_bytes, gain=1.0): def optimize_prompt(asr_text, progress_callback=None): """使用大模型优化提示词""" - print(f"Optimizing prompt for: {asr_text}") - - if progress_callback: - progress_callback(0, "正在准备优化提示词...") - - system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。 -关键要求: -1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。 -2. 画面必须清晰、线条粗壮,适合低分辨率热敏打印机打印。 -3. 绝对不要有复杂的阴影、渐变、黑白线条描述。 -4. 背景必须是纯白 (White background)。 -5. 提示词内容请使用英文描述,因为绘图模型对英文理解更好,但在描述中强调 "black and white line art", "simple lines", "vector style"。 -6. 尺寸比例遵循宽48mm:高30mm (约 1.6:1)。 -7. 直接输出优化后的提示词,不要包含任何解释。 -如果用户要求输入文字,则用双引号把文字包裹起来,文字是中文""" - - try: - if progress_callback: - progress_callback(10, "正在调用AI优化提示词...") - print(f"Calling AI with prompt: {system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:") - - response = Generation.call( - model='qwen-turbo', - prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:', - max_tokens=200, - temperature=0.8 - ) - - if response.status_code == 200: - if hasattr(response, 'output') and response.output and \ - hasattr(response.output, 'choices') and response.output.choices and \ - len(response.output.choices) > 0: - - optimized = response.output.choices[0].message.content.strip() - print(f"Optimized prompt: {optimized}") - - if progress_callback: - progress_callback(30, f"提示词优化完成: {optimized[:50]}...") - - return optimized - elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'): - # Handle case where API returns text directly instead of choices - optimized = response.output.text.strip() - print(f"Optimized prompt (direct text): {optimized}") - - if progress_callback: - progress_callback(30, f"提示词优化完成: {optimized[:50]}...") - - return optimized - else: - print(f"Prompt optimization response format error: {response}") - if progress_callback: - progress_callback(0, "提示词优化响应格式错误") - return asr_text - else: - print(f"Prompt optimization failed: {response.code} - {response.message}") - if progress_callback: - progress_callback(0, f"提示词优化失败: {response.message}") - return asr_text - - except Exception as e: - print(f"Error optimizing prompt: {e}") - if progress_callback: - progress_callback(0, f"提示词优化出错: {str(e)}") - return asr_text + return image_generator.optimize_prompt(asr_text, progress_callback) def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2): - """调用万相文生图API生成图片 - - Args: - prompt: 图像生成提示词 - progress_callback: 进度回调函数 (progress, message) - retry_count: 当前重试次数 - max_retries: 最大重试次数 - """ + """调用AI生图API生成图片""" print(f"Generating image for prompt: {prompt}") if progress_callback: @@ -614,139 +548,96 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2) progress_callback(0, "提示词为空") return None - response = ImageSynthesis.call( - model='wanx2.0-t2i-turbo', - prompt=prompt, - size='1280*720' - ) + # Call the generator + image_url = image_generator.generate_image(prompt, progress_callback) - if response.status_code == 200: - if not response.output: - print("Error: response.output is None") - if progress_callback: - progress_callback(0, "API响应无效") - return None - - task_status = response.output.get('task_status') - - if task_status == 'PENDING' or task_status == 'RUNNING': - print("Waiting for image generation to complete...") - if progress_callback: - progress_callback(45, "AI正在生成图片中...") - - import time - task_id = response.output.get('task_id') - max_wait = 120 - waited = 0 - while waited < max_wait: - time.sleep(2) - waited += 2 - task_result = ImageSynthesis.fetch(task_id) - if task_result.output.task_status == 'SUCCEEDED': - response.output = task_result.output - break - elif task_result.output.task_status == 'FAILED': - error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error' - print(f"Image generation failed: {error_msg}") - if progress_callback: - progress_callback(35, f"图片生成失败: {error_msg}") - return None - - if response.output.get('task_status') == 'SUCCEEDED': - image_url = response.output['results'][0]['url'] - print(f"Image generated, downloading from: {image_url}") - - if progress_callback: - progress_callback(70, "正在下载生成的图片...") - - import urllib.request - urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE) - print(f"Image saved to {GENERATED_IMAGE_FILE}") - - # 保存一份到 output_images 目录 - output_path = get_output_path() - import shutil - shutil.copy(GENERATED_IMAGE_FILE, output_path) - print(f"Image also saved to {output_path}") - - if progress_callback: - progress_callback(80, "正在处理图片...") - - # 缩放图片并转换为RGB565格式 - try: - from PIL import Image - img = Image.open(GENERATED_IMAGE_FILE) - - # 缩小到THUMB_SIZE x THUMB_SIZE - img = img.resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS) - - # 转换为RGB565格式的原始数据 - # 每个像素2字节 (R5 G6 B5) - rgb565_data = bytearray() - - for y in range(THUMB_SIZE): - for x in range(THUMB_SIZE): - r, g, b = img.getpixel((x, y))[:3] - - # 转换为RGB565 - r5 = (r >> 3) & 0x1F - g6 = (g >> 2) & 0x3F - b5 = (b >> 3) & 0x1F - - # Pack as Big Endian (>H) which is standard for SPI displays - # RGB565: Red(5) Green(6) Blue(5) - rgb565 = (r5 << 11) | (g6 << 5) | b5 - rgb565_data.extend(struct.pack('>H', rgb565)) - - # 保存为.bin文件 - with open(GENERATED_THUMB_FILE, 'wb') as f: - f.write(rgb565_data) - - print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes") - - if progress_callback: - progress_callback(100, "图片生成完成!") - - return GENERATED_THUMB_FILE - - except ImportError: - print("PIL not available, sending original image") - if progress_callback: - progress_callback(100, "图片生成完成!(原始格式)") - return GENERATED_IMAGE_FILE - except Exception as e: - print(f"Error processing image: {e}") - if progress_callback: - progress_callback(80, f"图片处理出错: {str(e)}") - return GENERATED_IMAGE_FILE - else: - error_msg = f"{response.code} - {response.message}" - print(f"Image generation failed: {error_msg}") - - # 重试机制 + if not image_url: + # Retry logic if retry_count < max_retries: print(f"Retrying... ({retry_count + 1}/{max_retries})") if progress_callback: - progress_callback(35, f"图片生成失败,正在重试 ({retry_count + 1}/{max_retries})...") + progress_callback(35, f"生成失败,正在重试 ({retry_count + 1}/{max_retries})...") return generate_image(prompt, progress_callback, retry_count + 1, max_retries) else: - if progress_callback: - progress_callback(35, f"图片生成失败: {error_msg}") return None + + # Download and process + print(f"Image generated, downloading from: {image_url}") + if progress_callback: + progress_callback(70, "正在下载生成的图片...") + + import urllib.request + try: + urllib.request.urlretrieve(image_url, GENERATED_IMAGE_FILE) + print(f"Image saved to {GENERATED_IMAGE_FILE}") + except Exception as e: + print(f"Download error: {e}") + if progress_callback: + progress_callback(35, f"下载失败: {e}") + return None + + # Save to output dir + output_path = get_output_path() + import shutil + shutil.copy(GENERATED_IMAGE_FILE, output_path) + print(f"Image also saved to {output_path}") + + if progress_callback: + progress_callback(80, "正在处理图片...") + + # Resize and convert to RGB565 (Reuse existing logic) + try: + from PIL import Image + img = Image.open(GENERATED_IMAGE_FILE) + + # 缩小到THUMB_SIZE x THUMB_SIZE + img = img.resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS) + + # 转换为RGB565格式的原始数据 + # 每个像素2字节 (R5 G6 B5) + rgb565_data = bytearray() + + for y in range(THUMB_SIZE): + for x in range(THUMB_SIZE): + r, g, b = img.getpixel((x, y))[:3] + + # 转换为RGB565 + r5 = (r >> 3) & 0x1F + g6 = (g >> 2) & 0x3F + b5 = (b >> 3) & 0x1F + + # Pack as Big Endian (>H) which is standard for SPI displays + # RGB565: Red(5) Green(6) Blue(5) + rgb565 = (r5 << 11) | (g6 << 5) | b5 + rgb565_data.extend(struct.pack('>H', rgb565)) + + # 保存为.bin文件 + with open(GENERATED_THUMB_FILE, 'wb') as f: + f.write(rgb565_data) + + print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes") + + if progress_callback: + progress_callback(100, "图片生成完成!") + + return GENERATED_THUMB_FILE + + except ImportError: + print("PIL not available, sending original image") + if progress_callback: + progress_callback(100, "图片生成完成!(原始格式)") + return GENERATED_IMAGE_FILE + except Exception as e: + print(f"Error processing image: {e}") + if progress_callback: + progress_callback(80, f"图片处理出错: {str(e)}") + return GENERATED_IMAGE_FILE except Exception as e: - print(f"Error generating image: {e}") - - # 重试机制 + print(f"Error in generate_image: {e}") if retry_count < max_retries: - print(f"Retrying after error... ({retry_count + 1}/{max_retries})") - if progress_callback: - progress_callback(35, f"生成出错,正在重试 ({retry_count + 1}/{max_retries})...") - return generate_image(prompt, progress_callback, retry_count + 1, max_retries) - else: - if progress_callback: - progress_callback(35, f"图片生成出错: {str(e)}") - return None + return generate_image(prompt, progress_callback, retry_count + 1, max_retries) + return None @app.websocket("/ws/audio") async def websocket_endpoint(websocket: WebSocket):