Update doubao model to seedream-4.0
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 4s
All checks were successful
Deploy WebSocket Server / deploy (push) Successful in 4s
This commit is contained in:
@@ -10,16 +10,17 @@ from dashscope import ImageSynthesis, Generation
|
|||||||
load_dotenv()
|
load_dotenv()
|
||||||
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
dashscope.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||||
|
|
||||||
|
|
||||||
class ImageGenerator:
|
class ImageGenerator:
|
||||||
def __init__(self, provider="dashscope", model=None):
|
def __init__(self, provider="dashscope", model=None):
|
||||||
self.provider = provider
|
self.provider = provider
|
||||||
self.model = model
|
self.model = model
|
||||||
self.api_key = None
|
self.api_key = None
|
||||||
|
|
||||||
if provider == "doubao":
|
if provider == "doubao":
|
||||||
self.api_key = os.getenv("volcengine_API_KEY")
|
self.api_key = os.getenv("volcengine_API_KEY")
|
||||||
if not self.model:
|
if not self.model:
|
||||||
self.model = "doubao-seedream-5-0-260128" # Default model from user input
|
self.model = "doubao-seedream-4.0"
|
||||||
elif provider == "dashscope":
|
elif provider == "dashscope":
|
||||||
self.api_key = os.getenv("DASHSCOPE_API_KEY")
|
self.api_key = os.getenv("DASHSCOPE_API_KEY")
|
||||||
if not self.model:
|
if not self.model:
|
||||||
@@ -28,10 +29,10 @@ class ImageGenerator:
|
|||||||
def optimize_prompt(self, asr_text, progress_callback=None):
|
def optimize_prompt(self, asr_text, progress_callback=None):
|
||||||
"""Use LLM to optimize the prompt"""
|
"""Use LLM to optimize the prompt"""
|
||||||
print(f"Optimizing prompt for: {asr_text}")
|
print(f"Optimizing prompt for: {asr_text}")
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(0, "正在准备优化提示词...")
|
progress_callback(0, "正在准备优化提示词...")
|
||||||
|
|
||||||
system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。
|
system_prompt = """你是一个AI图像提示词优化专家。你的任务是将用户的语音识别结果转化为适合生成"黑白线稿"的提示词。
|
||||||
关键要求:
|
关键要求:
|
||||||
1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。
|
1. 风格必须是:简单的黑白线稿、简笔画、图标风格 (Line art, Sketch, Icon style)。
|
||||||
@@ -49,29 +50,36 @@ example:
|
|||||||
try:
|
try:
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(10, "正在调用AI优化提示词...")
|
progress_callback(10, "正在调用AI优化提示词...")
|
||||||
|
|
||||||
# Currently using Qwen-Turbo for all providers for prompt optimization
|
# Currently using Qwen-Turbo for all providers for prompt optimization
|
||||||
# You can also decouple this if needed
|
# You can also decouple this if needed
|
||||||
response = Generation.call(
|
response = Generation.call(
|
||||||
model='qwen-plus',
|
model="qwen-plus",
|
||||||
prompt=f'{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:',
|
prompt=f"{system_prompt}\n\n用户语音识别结果:{asr_text}\n\n优化后的提示词:",
|
||||||
max_tokens=200,
|
max_tokens=200,
|
||||||
temperature=0.8
|
temperature=0.8,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
if hasattr(response, 'output') and response.output and \
|
if (
|
||||||
hasattr(response.output, 'choices') and response.output.choices and \
|
hasattr(response, "output")
|
||||||
len(response.output.choices) > 0:
|
and response.output
|
||||||
|
and hasattr(response.output, "choices")
|
||||||
|
and response.output.choices
|
||||||
|
and len(response.output.choices) > 0
|
||||||
|
):
|
||||||
optimized = response.output.choices[0].message.content.strip()
|
optimized = response.output.choices[0].message.content.strip()
|
||||||
print(f"Optimized prompt: {optimized}")
|
print(f"Optimized prompt: {optimized}")
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
progress_callback(30, f"提示词优化完成: {optimized[:50]}...")
|
||||||
|
|
||||||
return optimized
|
return optimized
|
||||||
elif hasattr(response, 'output') and response.output and hasattr(response.output, 'text'):
|
elif (
|
||||||
|
hasattr(response, "output")
|
||||||
|
and response.output
|
||||||
|
and hasattr(response.output, "text")
|
||||||
|
):
|
||||||
optimized = response.output.text.strip()
|
optimized = response.output.text.strip()
|
||||||
print(f"Optimized prompt (direct text): {optimized}")
|
print(f"Optimized prompt (direct text): {optimized}")
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
@@ -83,11 +91,13 @@ example:
|
|||||||
progress_callback(0, "提示词优化响应格式错误")
|
progress_callback(0, "提示词优化响应格式错误")
|
||||||
return asr_text
|
return asr_text
|
||||||
else:
|
else:
|
||||||
print(f"Prompt optimization failed: {response.code} - {response.message}")
|
print(
|
||||||
|
f"Prompt optimization failed: {response.code} - {response.message}"
|
||||||
|
)
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(0, f"提示词优化失败: {response.message}")
|
progress_callback(0, f"提示词优化失败: {response.message}")
|
||||||
return asr_text
|
return asr_text
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error optimizing prompt: {e}")
|
print(f"Error optimizing prompt: {e}")
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
@@ -105,48 +115,50 @@ example:
|
|||||||
|
|
||||||
def _generate_dashscope(self, prompt, progress_callback=None):
|
def _generate_dashscope(self, prompt, progress_callback=None):
|
||||||
print(f"Generating image with DashScope for prompt: {prompt}")
|
print(f"Generating image with DashScope for prompt: {prompt}")
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(35, "正在请求DashScope生成图片...")
|
progress_callback(35, "正在请求DashScope生成图片...")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = ImageSynthesis.call(
|
response = ImageSynthesis.call(
|
||||||
model=self.model,
|
model=self.model, prompt=prompt, size="1280*720"
|
||||||
prompt=prompt,
|
|
||||||
size='1280*720'
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
if not response.output:
|
if not response.output:
|
||||||
print("Error: response.output is None")
|
print("Error: response.output is None")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
task_status = response.output.get('task_status')
|
task_status = response.output.get("task_status")
|
||||||
|
|
||||||
if task_status == 'PENDING' or task_status == 'RUNNING':
|
if task_status == "PENDING" or task_status == "RUNNING":
|
||||||
print("Waiting for image generation to complete...")
|
print("Waiting for image generation to complete...")
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(45, "AI正在生成图片中...")
|
progress_callback(45, "AI正在生成图片中...")
|
||||||
|
|
||||||
task_id = response.output.get('task_id')
|
task_id = response.output.get("task_id")
|
||||||
max_wait = 120
|
max_wait = 120
|
||||||
waited = 0
|
waited = 0
|
||||||
while waited < max_wait:
|
while waited < max_wait:
|
||||||
time.sleep(2)
|
time.sleep(2)
|
||||||
waited += 2
|
waited += 2
|
||||||
task_result = ImageSynthesis.fetch(task_id)
|
task_result = ImageSynthesis.fetch(task_id)
|
||||||
if task_result.output.task_status == 'SUCCEEDED':
|
if task_result.output.task_status == "SUCCEEDED":
|
||||||
response.output = task_result.output
|
response.output = task_result.output
|
||||||
break
|
break
|
||||||
elif task_result.output.task_status == 'FAILED':
|
elif task_result.output.task_status == "FAILED":
|
||||||
error_msg = task_result.output.message if hasattr(task_result.output, 'message') else 'Unknown error'
|
error_msg = (
|
||||||
|
task_result.output.message
|
||||||
|
if hasattr(task_result.output, "message")
|
||||||
|
else "Unknown error"
|
||||||
|
)
|
||||||
print(f"Image generation failed: {error_msg}")
|
print(f"Image generation failed: {error_msg}")
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(35, f"图片生成失败: {error_msg}")
|
progress_callback(35, f"图片生成失败: {error_msg}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if response.output.get('task_status') == 'SUCCEEDED':
|
if response.output.get("task_status") == "SUCCEEDED":
|
||||||
image_url = response.output['results'][0]['url']
|
image_url = response.output["results"][0]["url"]
|
||||||
print(f"Image generated, url: {image_url}")
|
print(f"Image generated, url: {image_url}")
|
||||||
return image_url
|
return image_url
|
||||||
else:
|
else:
|
||||||
@@ -155,7 +167,7 @@ example:
|
|||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(35, f"图片生成失败: {error_msg}")
|
progress_callback(35, f"图片生成失败: {error_msg}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error generating image: {e}")
|
print(f"Error generating image: {e}")
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
@@ -164,55 +176,57 @@ example:
|
|||||||
|
|
||||||
def _generate_doubao(self, prompt, progress_callback=None):
|
def _generate_doubao(self, prompt, progress_callback=None):
|
||||||
print(f"Generating image with Doubao for prompt: {prompt}")
|
print(f"Generating image with Doubao for prompt: {prompt}")
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(35, "正在请求豆包生成图片...")
|
progress_callback(35, "正在请求豆包生成图片...")
|
||||||
|
|
||||||
url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
|
url = "https://ark.cn-beijing.volces.com/api/v3/images/generations"
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
}
|
}
|
||||||
data = {
|
data = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"sequential_image_generation": "disabled",
|
"sequential_image_generation": "disabled",
|
||||||
"response_format": "url",
|
"response_format": "url",
|
||||||
"size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time?
|
"size": "2K", # Doubao supports different sizes, user example used 2K. But we might want something smaller if possible to save bandwidth/time?
|
||||||
# User's curl says "2K". I will stick to it or maybe check docs.
|
# User's curl says "2K". I will stick to it or maybe check docs.
|
||||||
# Actually for thermal printer, we don't need 2K. But let's follow user example.
|
# Actually for thermal printer, we don't need 2K. But let's follow user example.
|
||||||
"stream": False,
|
"stream": False,
|
||||||
"watermark": True
|
"watermark": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(url, headers=headers, json=data, timeout=60)
|
response = requests.post(url, headers=headers, json=data, timeout=60)
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
result = response.json()
|
result = response.json()
|
||||||
# Check format of result
|
# Check format of result
|
||||||
# Typically OpenAI compatible or similar
|
# Typically OpenAI compatible or similar
|
||||||
# User example didn't show response format, but usually it's "data": [{"url": "..."}]
|
# User example didn't show response format, but usually it's "data": [{"url": "..."}]
|
||||||
|
|
||||||
if "data" in result and len(result["data"]) > 0:
|
if "data" in result and len(result["data"]) > 0:
|
||||||
image_url = result["data"][0]["url"]
|
image_url = result["data"][0]["url"]
|
||||||
print(f"Image generated, url: {image_url}")
|
print(f"Image generated, url: {image_url}")
|
||||||
return image_url
|
return image_url
|
||||||
elif "error" in result:
|
elif "error" in result:
|
||||||
error_msg = result["error"].get("message", "Unknown error")
|
error_msg = result["error"].get("message", "Unknown error")
|
||||||
print(f"Doubao API error: {error_msg}")
|
print(f"Doubao API error: {error_msg}")
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(35, f"图片生成失败: {error_msg}")
|
progress_callback(35, f"图片生成失败: {error_msg}")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
print(f"Unexpected response format: {result}")
|
print(f"Unexpected response format: {result}")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
print(f"Doubao API failed with status {response.status_code}: {response.text}")
|
print(
|
||||||
|
f"Doubao API failed with status {response.status_code}: {response.text}"
|
||||||
|
)
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(35, f"图片生成失败: {response.status_code}")
|
progress_callback(35, f"图片生成失败: {response.status_code}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error calling Doubao API: {e}")
|
print(f"Error calling Doubao API: {e}")
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
|
|||||||
Reference in New Issue
Block a user