Update doubao model to seedream-4.0
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 4s

This commit is contained in:
jeremygan2021
2026-03-20 17:53:23 +08:00
parent c9550f8a0d
commit 5b91e90d45

View File

@@ -10,6 +10,7 @@ from dashscope import ImageSynthesis, Generation
load_dotenv() load_dotenv()
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY") dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
class ImageGenerator: class ImageGenerator:
def __init__(self, provider="dashscope", model=None): def __init__(self, provider="dashscope", model=None):
self.provider = provider self.provider = provider
@@ -19,7 +20,7 @@ class ImageGenerator:
if provider == "doubao": if provider == "doubao":
self.api_key = os.getenv("volcengine_API_KEY") self.api_key = os.getenv("volcengine_API_KEY")
if not self.model: if not self.model:
self.model = "doubao-seedream-5-0-260128" # Default model from user input self.model = "doubao-seedream-4.0"
elif provider == "dashscope": elif provider == "dashscope":
self.api_key = os.getenv("DASHSCOPE_API_KEY") self.api_key = os.getenv("DASHSCOPE_API_KEY")
if not self.model: if not self.model:
@@ -53,17 +54,20 @@ example:
# Currently using Qwen-Turbo for all providers for prompt optimization # Currently using Qwen-Turbo for all providers for prompt optimization
# You can also decouple this if needed # You can also decouple this if needed
response = Generation.call( response = Generation.call(
model='qwen-plus', model="qwen-plus",
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:', prompt=f"{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:",
max_tokens=200, max_tokens=200,
temperature=0.8 temperature=0.8,
) )
if response.status_code == 200: if response.status_code == 200:
if hasattr(response, 'output') and response.output and \ if (
hasattr(response.output, 'choices') and response.output.choices and \ hasattr(response, "output")
len(response.output.choices) > 0: and response.output
and hasattr(response.output, "choices")
and response.output.choices
and len(response.output.choices) > 0
):
optimized = response.output.choices[0].message.content.strip() optimized = response.output.choices[0].message.content.strip()
print(f"Optimized prompt: {optimized}") print(f"Optimized prompt: {optimized}")
@@ -71,7 +75,11 @@ example:
progress_callback(30, f"提示词优化完成: {optimized[:50]}...") progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
return optimized return optimized
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'): elif (
hasattr(response, "output")
and response.output
and hasattr(response.output, "text")
):
optimized = response.output.text.strip() optimized = response.output.text.strip()
print(f"Optimized prompt (direct text): {optimized}") print(f"Optimized prompt (direct text): {optimized}")
if progress_callback: if progress_callback:
@@ -83,7 +91,9 @@ example:
progress_callback(0, "提示词优化响应格式错误") progress_callback(0, "提示词优化响应格式错误")
return asr_text return asr_text
else: else:
print(f"Prompt optimization failed: {response.code} - {response.message}") print(
f"Prompt optimization failed: {response.code} - {response.message}"
)
if progress_callback: if progress_callback:
progress_callback(0, f"提示词优化失败: {response.message}") progress_callback(0, f"提示词优化失败: {response.message}")
return asr_text return asr_text
@@ -111,9 +121,7 @@ example:
try: try:
response = ImageSynthesis.call( response = ImageSynthesis.call(
model=self.model, model=self.model, prompt=prompt, size="1280*720"
prompt=prompt,
size='1280*720'
) )
if response.status_code == 200: if response.status_code == 200:
@@ -121,32 +129,36 @@ example:
print("Error: response.output is None") print("Error: response.output is None")
return None return None
task_status = response.output.get('task_status') task_status = response.output.get("task_status")
if task_status == 'PENDING' or task_status == 'RUNNING': if task_status == "PENDING" or task_status == "RUNNING":
print("Waiting for image generation to complete...") print("Waiting for image generation to complete...")
if progress_callback: if progress_callback:
progress_callback(45, "AI正在生成图片中...") progress_callback(45, "AI正在生成图片中...")
task_id = response.output.get('task_id') task_id = response.output.get("task_id")
max_wait = 120 max_wait = 120
waited = 0 waited = 0
while waited < max_wait: while waited < max_wait:
time.sleep(2) time.sleep(2)
waited += 2 waited += 2
task_result = ImageSynthesis.fetch(task_id) task_result = ImageSynthesis.fetch(task_id)
if task_result.output.task_status == 'SUCCEEDED': if task_result.output.task_status == "SUCCEEDED":
response.output = task_result.output response.output = task_result.output
break break
elif task_result.output.task_status == 'FAILED': elif task_result.output.task_status == "FAILED":
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error' error_msg = (
task_result.output.message
if hasattr(task_result.output, "message")
else "Unknown error"
)
print(f"Image generation failed: {error_msg}") print(f"Image generation failed: {error_msg}")
if progress_callback: if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}") progress_callback(35, f"图片生成失败: {error_msg}")
return None return None
if response.output.get('task_status') == 'SUCCEEDED': if response.output.get("task_status") == "SUCCEEDED":
image_url = response.output['results'][0]['url'] image_url = response.output["results"][0]["url"]
print(f"Image generated, url: {image_url}") print(f"Image generated, url: {image_url}")
return image_url return image_url
else: else:
@@ -171,18 +183,18 @@ example:
url = "https://ark.cn-beijing.volces.com/api/v3/images/generations" url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
headers = { headers = {
"Content-Type": "application/json", "Content-Type": "application/json",
"Authorization": f"Bearer {self.api_key}" "Authorization": f"Bearer {self.api_key}",
} }
data = { data = {
"model": self.model, "model": self.model,
"prompt": prompt, "prompt": prompt,
"sequential_image_generation": "disabled", "sequential_image_generation": "disabled",
"response_format": "url", "response_format": "url",
"size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time? "size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time?
# User's curl says "2K". I will stick to it or maybe check docs. # User's curl says "2K". I will stick to it or maybe check docs.
# Actually for thermal printer, we don't need 2K. But let's follow user example. # Actually for thermal printer, we don't need 2K. But let's follow user example.
"stream": False, "stream": False,
"watermark": True "watermark": True,
} }
try: try:
@@ -199,16 +211,18 @@ example:
print(f"Image generated, url: {image_url}") print(f"Image generated, url: {image_url}")
return image_url return image_url
elif "error" in result: elif "error" in result:
error_msg = result["error"].get("message", "Unknown error") error_msg = result["error"].get("message", "Unknown error")
print(f"Doubao API error: {error_msg}") print(f"Doubao API error: {error_msg}")
if progress_callback: if progress_callback:
progress_callback(35, f"图片生成失败: {error_msg}") progress_callback(35, f"图片生成失败: {error_msg}")
return None return None
else: else:
print(f"Unexpected response format: {result}") print(f"Unexpected response format: {result}")
return None return None
else: else:
print(f"Doubao API failed with status {response.status_code}: {response.text}") print(
f"Doubao API failed with status {response.status_code}: {response.text}"
)
if progress_callback: if progress_callback:
progress_callback(35, f"图片生成失败: {response.status_code}") progress_callback(35, f"图片生成失败: {response.status_code}")
return None return None