This commit is contained in:
jeremygan2021
2026-03-05 22:02:01 +08:00
parent 9558ea4b35
commit 2392d0d705
2 changed files with 30 additions and 51 deletions

41
main.py
View File

@@ -119,22 +119,21 @@ def process_message(msg, display, image_state, image_data_list, printer_uart=Non
if isinstance(msg, (bytes, bytearray)): if isinstance(msg, (bytes, bytearray)):
if image_state == IMAGE_STATE_RECEIVING: if image_state == IMAGE_STATE_RECEIVING:
try: try:
if len(image_data_list) < 3: if len(image_data_list) < 2:
# 异常情况,重置 # 异常情况,重置
return IMAGE_STATE_IDLE, None return IMAGE_STATE_IDLE, None
width = image_data_list[0] img_size = image_data_list[0]
height = image_data_list[1] current_offset = image_data_list[1]
current_offset = image_data_list[2]
# Stream directly to display # Stream directly to display
if display and display.tft: if display and display.tft:
x = (240 - width) // 2 x = (240 - img_size) // 2
y = (240 - height) // 2 y = (240 - img_size) // 2
display.show_image_chunk(x, y, width, height, msg, current_offset) display.show_image_chunk(x, y, img_size, img_size, msg, current_offset)
# Update offset # Update offset
image_data_list[2] += len(msg) image_data_list[1] += len(msg)
except Exception as e: except Exception as e:
print(f"Stream image error: {e}") print(f"Stream image error: {e}")
@@ -196,28 +195,22 @@ def process_message(msg, display, image_state, image_data_list, printer_uart=Non
try: try:
parts = msg.split(":") parts = msg.split(":")
size = int(parts[1]) size = int(parts[1])
img_size = int(parts[2]) if len(parts) > 2 else 64
model_name = parts[3] if len(parts) > 3 else "Unknown Model"
width = 64 print(f"Image start, size: {size}, img_size: {img_size}")
height = 64 convert_img.print_model_info(model_name)
if len(parts) >= 4:
width = int(parts[2])
height = int(parts[3])
elif len(parts) == 3:
width = int(parts[2])
height = int(parts[2]) # assume square
print(f"Image start, size: {size}, dim: {width}x{height}")
image_data_list.clear() image_data_list.clear()
image_data_list.append(width) # index 0 image_data_list.append(img_size) # Store metadata at index 0
image_data_list.append(height) # index 1 image_data_list.append(0) # Store current received bytes offset at index 1
image_data_list.append(0) # index 2: offset
# Prepare display for streaming # Prepare display for streaming
if display and display.tft: if display and display.tft:
# Clear screen area where image will be # Calculate position
# optional, but good practice if new image is smaller x = (240 - img_size) // 2
pass y = (240 - img_size) // 2
# Pre-set window (this will be done in first chunk call)
return IMAGE_STATE_RECEIVING, None return IMAGE_STATE_RECEIVING, None
except Exception as e: except Exception as e:

View File

@@ -235,17 +235,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
# 同步调用图片生成函数 # 同步调用图片生成函数
gen_result = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback) image_path = await asyncio.to_thread(generate_image, optimized_prompt, progress_callback)
image_path = None
width = 0
height = 0
if gen_result:
if isinstance(gen_result, tuple):
image_path, width, height = gen_result
else:
image_path = gen_result
task.result = image_path task.result = image_path
@@ -254,7 +244,7 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
await websocket.send_text("STATUS:COMPLETE:图片生成完成") await websocket.send_text("STATUS:COMPLETE:图片生成完成")
await asyncio.sleep(0.2) await asyncio.sleep(0.2)
await send_image_to_client(websocket, image_path, width, height) await send_image_to_client(websocket, image_path)
else: else:
task.status = "failed" task.status = "failed"
task.error = "图片生成失败" task.error = "图片生成失败"
@@ -277,17 +267,14 @@ async def start_async_image_generation(websocket: WebSocket, asr_text: str):
return task return task
async def send_image_to_client(websocket: WebSocket, image_path: str, width: int = 0, height: int = 0): async def send_image_to_client(websocket: WebSocket, image_path: str):
"""发送图片数据到客户端""" """发送图片数据到客户端"""
with open(image_path, 'rb') as f: with open(image_path, 'rb') as f:
image_data = f.read() image_data = f.read()
print(f"Sending image to ESP32, size: {len(image_data)} bytes, dim: {width}x{height}") print(f"Sending image to ESP32, size: {len(image_data)} bytes")
# Send start marker # Send start marker
if width > 0 and height > 0:
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{width}:{height}")
else:
await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}") await websocket.send_text(f"IMAGE_START:{len(image_data)}:{THUMB_SIZE}")
# Send binary data directly # Send binary data directly
@@ -607,16 +594,15 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
from PIL import Image from PIL import Image
img = Image.open(GENERATED_IMAGE_FILE) img = Image.open(GENERATED_IMAGE_FILE)
# 缩小到 fit THUMB_SIZE x THUMB_SIZE (保持比例) # 缩小到THUMB_SIZE x THUMB_SIZE
img.thumbnail((THUMB_SIZE, THUMB_SIZE), Image.Resampling.LANCZOS) img = img.resize((THUMB_SIZE, THUMB_SIZE), Image.LANCZOS)
width, height = img.size
# 转换为RGB565格式的原始数据 # 转换为RGB565格式的原始数据
# 每个像素2字节 (R5 G6 B5) # 每个像素2字节 (R5 G6 B5)
rgb565_data = bytearray() rgb565_data = bytearray()
for y in range(height): for y in range(THUMB_SIZE):
for x in range(width): for x in range(THUMB_SIZE):
r, g, b = img.getpixel((x, y))[:3] r, g, b = img.getpixel((x, y))[:3]
# 转换为RGB565 # 转换为RGB565
@@ -633,23 +619,23 @@ def generate_image(prompt, progress_callback=None, retry_count=0, max_retries=2)
with open(GENERATED_THUMB_FILE, 'wb') as f: with open(GENERATED_THUMB_FILE, 'wb') as f:
f.write(rgb565_data) f.write(rgb565_data)
print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes, dim: {width}x{height}") print(f"Thumbnail saved to {GENERATED_THUMB_FILE}, size: {len(rgb565_data)} bytes")
if progress_callback: if progress_callback:
progress_callback(100, "图片生成完成!") progress_callback(100, "图片生成完成!")
return GENERATED_THUMB_FILE, width, height return GENERATED_THUMB_FILE
except ImportError: except ImportError:
print("PIL not available, sending original image") print("PIL not available, sending original image")
if progress_callback: if progress_callback:
progress_callback(100, "图片生成完成!(原始格式)") progress_callback(100, "图片生成完成!(原始格式)")
return GENERATED_IMAGE_FILE, 0, 0 return GENERATED_IMAGE_FILE
except Exception as e: except Exception as e:
print(f"Error processing image: {e}") print(f"Error processing image: {e}")
if progress_callback: if progress_callback:
progress_callback(80, f"图片处理出错: {str(e)}") progress_callback(80, f"图片处理出错: {str(e)}")
return GENERATED_IMAGE_FILE, 0, 0 return GENERATED_IMAGE_FILE
except Exception as e: except Exception as e:
print(f"Error in generate_image: {e}") print(f"Error in generate_image: {e}")