t
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 20s

This commit is contained in:
jeremygan2021
2026-03-05 21:43:39 +08:00
parent b79d45cf34
commit 64ff8ffbd4
4 changed files with 311 additions and 200 deletions

View File

@@ -0,0 +1,219 @@
import os
import time
import json
import requests
from dotenv import load_dotenv
import dashscope
from dashscope import ImageSynthesis, Generation
# Load environment variables
load_dotenv()
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
class ImageGenerator:
def __init__(self, provider="dashscope", model=None):
self.provider = provider
self.model = model
self.api_key = None
if provider == "doubao":
self.api_key = os.getenv("volcengine_API_KEY")
if not self.model:
self.model = "doubao-seedream-5-0-260128" # Default model from user input
elif provider == "dashscope":
self.api_key = os.getenv("DASHSCOPE_API_KEY")
if not self.model:
self.model = "wanx2.0-t2i-turbo"
def optimize_prompt(self, asr_text, progress_callback=None):
"""Use LLM to optimize the prompt"""
print(f"Optimizing prompt for: {asr_text}")
if progress_callback:
progress_callback(0, "正在准备优化提示词...")
system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。
关键要求:
1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。
2. 画面必须清晰、线条粗壮,适合低分辨率热敏打印机打印。
3. 绝对不要有复杂的阴影、渐变、黑白线条描述。
4. 背景必须是纯白 (White background)。
5. 提示词内容请使用英文描述,因为绘图模型对英文理解更好,但在描述中强调 "black and white line art", "simple lines", "vector style"
6. 尺寸比例遵循宽48mm:高30mm (约 1.6:1)。
7. 直接输出优化后的提示词,不要包含任何解释。
如果用户要求输入文字,则用```把文字包裹起来,文字是中文
"black and white line art Chinese:```中国人```"
"""
try:
if progress_callback:
progress_callback(10, "正在调用AI优化提示词...")
# Currently using Qwen-Turbo for all providers for prompt optimization
# You can also decouple this if needed
response = Generation.call(
model='qwen-turbo',
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
max_tokens=200,
temperature=0.8
)
if response.status_code == 200:
if hasattr(response, 'output') and response.output and \
hasattr(response.output, 'choices') and response.output.choices and \
len(response.output.choices) > 0:
optimized = response.output.choices[0].message.content.strip()
print(f"Optimized prompt: {optimized}")
if progress_callback:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
optimized = response.output.text.strip()
print(f"Optimized prompt (direct text): {optimized}")
if progress_callback:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized
else:
print(f"Prompt optimization response format error: {response}")
if progress_callback:
progress_callback(0, "提示词优化响应格式错误")
return asr_text
else:
print(f"Prompt optimization failed: {response.code} - {response.message}")
if progress_callback:
progress_callback(0, f"提示词优化失败: {response.message}")
return asr_text
except Exception as e:
print(f"Error optimizing prompt: {e}")
if progress_callback:
progress_callback(0, f"提示词优化出错: {str(e)}")
return asr_text
def generate_image(self, prompt, progress_callback=None):
"""Generate image based on provider"""
if self.provider == "dashscope":
return self._generate_dashscope(prompt, progress_callback)
elif self.provider == "doubao":
return self._generate_doubao(prompt, progress_callback)
else:
raise ValueError(f"Unknown provider: {self.provider}")
def _generate_dashscope(self, prompt, progress_callback=None):
print(f"Generating image with DashScope for prompt: {prompt}")
if progress_callback:
progress_callback(35, "正在请求DashScope生成图片...")
try:
response = ImageSynthesis.call(
model=self.model,
prompt=prompt,
size='1280*720'
)
if response.status_code == 200:
if not response.output:
print("Error: response.output is None")
return None
task_status = response.output.get('task_status')
if task_status == 'PENDING' or task_status == 'RUNNING':
print("Waiting for image generation to complete...")
if progress_callback:
progress_callback(45, "AI正在生成图片中...")
task_id = response.output.get('task_id')
max_wait = 120
waited = 0
while waited < max_wait:
time.sleep(2)
waited += 2
task_result = ImageSynthesis.fetch(task_id)
if task_result.output.task_status == 'SUCCEEDED':
response.output = task_result.output
break
elif task_result.output.task_status == 'FAILED':
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
print(f"Image generation failed: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
if response.output.get('task_status') == 'SUCCEEDED':
image_url = response.output['results'][0]['url']
print(f"Image generated, url: {image_url}")
return image_url
else:
error_msg = f"{response.code} - {response.message}"
print(f"Image generation failed: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
except Exception as e:
print(f"Error generating image: {e}")
if progress_callback:
progress_callback(35, f"图片生成出错: {str(e)}")
return None
def _generate_doubao(self, prompt, progress_callback=None):
print(f"Generating image with Doubao for prompt: {prompt}")
if progress_callback:
progress_callback(35, "正在请求豆包生成图片...")
url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}"
}
data = {
"model": self.model,
"prompt": prompt,
"sequential_image_generation": "disabled",
"response_format": "url",
"size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time?
# User's curl says "2K". I will stick to it or maybe check docs.
# Actually for thermal printer, we don't need 2K. But let's follow user example.
"stream": False,
"watermark": True
}
try:
response = requests.post(url, headers=headers, json=data, timeout=60)
if response.status_code == 200:
result = response.json()
# Check format of result
# Typically OpenAI compatible or similar
# User example didn't show response format, but usually it's "data": [{"url": "..."}]
if "data" in result and len(result["data"]) > 0:
image_url = result["data"][0]["url"]
print(f"Image generated, url: {image_url}")
return image_url
elif "error" in result:
error_msg = result["error"].get("message", "Unknown error")
print(f"Doubao API error: {error_msg}")
if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}")
return None
else:
print(f"Unexpected response format: {result}")
return None
else:
print(f"Doubao API failed with status {response.status_code}: {response.text}")
if progress_callback:
progress_callback(35, f"图片生成失败: {response.status_code}")
return None
except Exception as e:
print(f"Error calling Doubao API: {e}")
if progress_callback:
progress_callback(35, f"图片生成出错: {str(e)}")
return None