Files
V2_micropython/websocket_server/image_generator.py
jeremygan2021 64ff8ffbd4
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 20s
t
2026-03-05 21:43:39 +08:00

220 lines
9.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import time
import json
import requests
from dotenv import load_dotenv
import dashscope
from dashscope import ImageSynthesis, Generation
# Load environment variables
load_dotenv()
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
class ImageGenerator:
def __init__(self, provider="dashscope", model=None):
self.provider = provider
self.model = model
self.api_key = None
if provider == "doubao":
self.api_key = os.getenv("volcengine_API_KEY")
if not self.model:
self.model = "doubao-seedream-5-0-260128" # Default model from user input
elif provider == "dashscope":
self.api_key = os.getenv("DASHSCOPE_API_KEY")
if not self.model:
self.model = "wanx2.0-t2i-turbo"
def optimize_prompt(self, asr_text, progress_callback=None):
"""Use LLM to optimize the prompt"""
print(f"Optimizing prompt for: {asr_text}")
if progress_callback:
progress_callback(0, "正在准备优化提示词...")
system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。
关键要求:
1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。
2. 画面必须清晰、线条粗壮,适合低分辨率热敏打印机打印。
3. 绝对不要有复杂的阴影、渐变、黑白线条描述。
4. 背景必须是纯白 (White background)。
5. 提示词内容请使用英文描述,因为绘图模型对英文理解更好,但在描述中强调 "black and white line art", "simple lines", "vector style"
6. 尺寸比例遵循宽48mm:高30mm (约 1.6:1)。
7. 直接输出优化后的提示词,不要包含任何解释。
如果用户要求输入文字,则用```把文字包裹起来,文字是中文
"black and white line art Chinese:```中国人```"
"""
try:
if progress_callback:
progress_callback(10, "正在调用AI优化提示词...")
# Currently using Qwen-Turbo for all providers for prompt optimization
# You can also decouple this if needed
response = Generation.call(
model='qwen-turbo',
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
max_tokens=200,
temperature=0.8
)
if response.status_code == 200:
if hasattr(response, 'output') and response.output and \
hasattr(response.output, 'choices') and response.output.choices and \
len(response.output.choices) > 0:
optimized = response.output.choices[0].message.content.strip()
print(f"Optimized prompt: {optimized}")
if progress_callback:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
optimized = response.output.text.strip()
print(f"Optimized prompt (direct text): {optimized}")
if progress_callback:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
else:
print(f"Prompt optimization response format error: {response}")
if progress_callback:
progress_callback(0, "提示词优化响应格式错误")
return asr_text
else:
print(f"Prompt optimization failed: {response.code} - {response.message}")
if progress_callback:
progress_callback(0, f"提示词优化失败: {response.message}")
return asr_text
except Exception as e:
print(f"Error optimizing prompt: {e}")
if progress_callback:
progress_callback(0, f"提示词优化出错: {str(e)}")
return asr_text
def generate_image(self, prompt, progress_callback=None):
"""Generate image based on provider"""
if self.provider == "dashscope":
return self._generate_dashscope(prompt, progress_callback)
elif self.provider == "doubao":
return self._generate_doubao(prompt, progress_callback)
else:
raise ValueError(f"Unknown provider: {self.provider}")
def _generate_dashscope(self, prompt, progress_callback=None):
print(f"Generating image with DashScope for prompt: {prompt}")
if progress_callback:
progress_callback(35, "正在请求DashScope生成图片...")
try:
response = ImageSynthesis.call(
model=self.model,
prompt=prompt,
size='1280*720'
)
if response.status_code == 200:
if not response.output:
print("Error: response.output is None")
return None
task_status = response.output.get('task_status')
if task_status == 'PENDING' or task_status == 'RUNNING':
print("Waiting for image generation to complete...")
if progress_callback:
progress_callback(45, "AI正在生成图片中...")
task_id = response.output.get('task_id')
max_wait = 120
waited = 0
while waited < max_wait:
time.sleep(2)
waited += 2
task_result = ImageSynthesis.fetch(task_id)
if task_result.output.task_status == 'SUCCEEDED':
response.output = task_result.output
break
elif task_result.output.task_status == 'FAILED':
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
print(f"Image generation failed: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
if response.output.get('task_status') == 'SUCCEEDED':
image_url = response.output['results'][0]['url']
print(f"Image generated, url: {image_url}")
return image_url
else:
error_msg = f"{response.code} - {response.message}"
print(f"Image generation failed: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
except Exception as e:
print(f"Error generating image: {e}")
if progress_callback:
progress_callback(35, f"图片生成出错: {str(e)}")
return None
def _generate_doubao(self, prompt, progress_callback=None):
print(f"Generating image with Doubao for prompt: {prompt}")
if progress_callback:
progress_callback(35, "正在请求豆包生成图片...")
url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
data = {
"model": self.model,
"prompt": prompt,
"sequential_image_generation": "disabled",
"response_format": "url",
"size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time?
# User's curl says "2K". I will stick to it or maybe check docs.
# Actually for thermal printer, we don't need 2K. But let's follow user example.
"stream": False,
"watermark": True
}
try:
response = requests.post(url, headers=headers, json=data, timeout=60)
if response.status_code == 200:
result = response.json()
# Check format of result
# Typically OpenAI compatible or similar
# User example didn't show response format, but usually it's "data": [{"url": "..."}]
if "data" in result and len(result["data"]) > 0:
image_url = result["data"][0]["url"]
print(f"Image generated, url: {image_url}")
return image_url
elif "error" in result:
error_msg = result["error"].get("message", "Unknown error")
print(f"Doubao API error: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
else:
print(f"Unexpected response format: {result}")
return None
else:
print(f"Doubao API failed with status {response.status_code}: {response.text}")
if progress_callback:
progress_callback(35, f"图片生成失败: {response.status_code}")
return None
except Exception as e:
print(f"Error calling Doubao API: {e}")
if progress_callback:
progress_callback(35, f"图片生成出错: {str(e)}")
return None