diff --git a/websocket_server/image_generator.py b/websocket_server/image_generator.py index eca0258..923ec31 100644 --- a/websocket_server/image_generator.py +++ b/websocket_server/image_generator.py @@ -10,16 +10,17 @@ from dashscope import ImageSynthesis, Generation load_dotenv() dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") + class ImageGenerator: def __init__(self, provider="dashscope", model=None): self.provider = provider self.model = model self.api_key = None - + if provider == "doubao": self.api_key = os.getenv("volcengine_API_KEY") if not self.model: - self.model = "doubao-seedream-5-0-260128" # Default model from user input + self.model = "doubao-seedream-4.0" elif provider == "dashscope": self.api_key = os.getenv("DASHSCOPE_API_KEY") if not self.model: @@ -28,10 +29,10 @@ class ImageGenerator: def optimize_prompt(self, asr_text, progress_callback=None): """Use LLM to optimize the prompt""" print(f"Optimizing prompt for: {asr_text}") - + if progress_callback: progress_callback(0, "正在准备优化提示词...") - + system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。 关键要求: 1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。 @@ -49,29 +50,36 @@ example: try: if progress_callback: progress_callback(10, "正在调用AI优化提示词...") - + # Currently using Qwen-Turbo for all providers for prompt optimization # You can also decouple this if needed response = Generation.call( - model='qwen-plus', - prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:', + model="qwen-plus", + prompt=f"{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:", max_tokens=200, - temperature=0.8 + temperature=0.8, ) - + if response.status_code == 200: - if hasattr(response, 'output') and response.output and \ - hasattr(response.output, 'choices') and response.output.choices and \ - len(response.output.choices) > 0: - + if ( + hasattr(response, "output") + and response.output + and hasattr(response.output, "choices") + and response.output.choices + and len(response.output.choices) > 0 + ): optimized = response.output.choices[0].message.content.strip() print(f"Optimized prompt: {optimized}") - + if progress_callback: progress_callback(30, f"提示词优化完成: {optimized[:50]}...") - + return optimized - elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'): + elif ( + hasattr(response, "output") + and response.output + and hasattr(response.output, "text") + ): optimized = response.output.text.strip() print(f"Optimized prompt (direct text): {optimized}") if progress_callback: @@ -83,11 +91,13 @@ example: progress_callback(0, "提示词优化响应格式错误") return asr_text else: - print(f"Prompt optimization failed: {response.code} - {response.message}") + print( + f"Prompt optimization failed: {response.code} - {response.message}" + ) if progress_callback: progress_callback(0, f"提示词优化失败: {response.message}") return asr_text - + except Exception as e: print(f"Error optimizing prompt: {e}") if progress_callback: @@ -105,48 +115,50 @@ example: def _generate_dashscope(self, prompt, progress_callback=None): print(f"Generating image with DashScope for prompt: {prompt}") - + if progress_callback: progress_callback(35, "正在请求DashScope生成图片...") - + try: response = ImageSynthesis.call( - model=self.model, - prompt=prompt, - size='1280*720' + model=self.model, prompt=prompt, size="1280*720" ) - + if response.status_code == 200: if not response.output: print("Error: response.output is None") return None - - task_status = response.output.get('task_status') - - if task_status == 'PENDING' or task_status == 'RUNNING': + + task_status = response.output.get("task_status") + + if task_status == "PENDING" or task_status == "RUNNING": print("Waiting for image generation to complete...") if progress_callback: progress_callback(45, "AI正在生成图片中...") - - task_id = response.output.get('task_id') + + task_id = response.output.get("task_id") max_wait = 120 waited = 0 while waited < max_wait: time.sleep(2) waited += 2 task_result = ImageSynthesis.fetch(task_id) - if task_result.output.task_status == 'SUCCEEDED': + if task_result.output.task_status == "SUCCEEDED": response.output = task_result.output break - elif task_result.output.task_status == 'FAILED': - error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error' + elif task_result.output.task_status == "FAILED": + error_msg = ( + task_result.output.message + if hasattr(task_result.output, "message") + else "Unknown error" + ) print(f"Image generation failed: {error_msg}") if progress_callback: progress_callback(35, f"图片生成失败: {error_msg}") return None - - if response.output.get('task_status') == 'SUCCEEDED': - image_url = response.output['results'][0]['url'] + + if response.output.get("task_status") == "SUCCEEDED": + image_url = response.output["results"][0]["url"] print(f"Image generated, url: {image_url}") return image_url else: @@ -155,7 +167,7 @@ example: if progress_callback: progress_callback(35, f"图片生成失败: {error_msg}") return None - + except Exception as e: print(f"Error generating image: {e}") if progress_callback: @@ -164,55 +176,57 @@ example: def _generate_doubao(self, prompt, progress_callback=None): print(f"Generating image with Doubao for prompt: {prompt}") - + if progress_callback: progress_callback(35, "正在请求豆包生成图片...") - + url = "https://ark.cn-beijing.volces.com/api/v3/images/generations" headers = { "Content-Type": "application/json", - "Authorization": f"Bearer {self.api_key}" + "Authorization": f"Bearer {self.api_key}", } data = { "model": self.model, "prompt": prompt, "sequential_image_generation": "disabled", "response_format": "url", - "size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time? - # User's curl says "2K". I will stick to it or maybe check docs. - # Actually for thermal printer, we don't need 2K. But let's follow user example. + "size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time? + # User's curl says "2K". I will stick to it or maybe check docs. + # Actually for thermal printer, we don't need 2K. But let's follow user example. "stream": False, - "watermark": True + "watermark": True, } - + try: response = requests.post(url, headers=headers, json=data, timeout=60) - + if response.status_code == 200: result = response.json() # Check format of result # Typically OpenAI compatible or similar # User example didn't show response format, but usually it's "data": [{"url": "..."}] - + if "data" in result and len(result["data"]) > 0: image_url = result["data"][0]["url"] print(f"Image generated, url: {image_url}") return image_url elif "error" in result: - error_msg = result["error"].get("message", "Unknown error") - print(f"Doubao API error: {error_msg}") - if progress_callback: - progress_callback(35, f"图片生成失败: {error_msg}") - return None + error_msg = result["error"].get("message", "Unknown error") + print(f"Doubao API error: {error_msg}") + if progress_callback: + progress_callback(35, f"图片生成失败: {error_msg}") + return None else: print(f"Unexpected response format: {result}") return None else: - print(f"Doubao API failed with status {response.status_code}: {response.text}") + print( + f"Doubao API failed with status {response.status_code}: {response.text}" + ) if progress_callback: progress_callback(35, f"图片生成失败: {response.status_code}") return None - + except Exception as e: print(f"Error calling Doubao API: {e}") if progress_callback: