diff --git a/fastAPI_tarot.py b/fastAPI_tarot.py index bb08eea..8f6e7c7 100644 --- a/fastAPI_tarot.py +++ b/fastAPI_tarot.py @@ -127,7 +127,48 @@ app.openapi = custom_openapi app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") +import re + # ------------------- 辅助函数 ------------------- +def is_english(text: str) -> bool: + """ + 简单判断是否为英文 prompt。 + 如果有中文字符则认为不是英文。 + """ + for char in text: + if '\u4e00' <= char <= '\u9fff': + return False + return True + +def translate_to_sam3_prompt(text: str) -> str: + """ + 使用大模型将非英文提示词翻译为适合 SAM3 的英文提示词 + """ + print(f"正在翻译提示词: {text}") + try: + messages = [ + { + "role": "user", + "content": [ + {"text": f"请将以下描述翻译成简洁、精准的英文,用于图像分割模型(SAM)的提示词。直接返回英文,不要包含任何解释或其他文字。\n\n输入: {text}"} + ] + } + ] + response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) + + if response.status_code == 200: + translated_text = response.output.choices[0].message.content[0]['text'].strip() + # 去除可能的 markdown 标记或引号 + translated_text = translated_text.replace('"', '').replace("'", "").strip() + print(f"翻译结果: {translated_text}") + return translated_text + else: + print(f"翻译失败: {response.code} - {response.message}") + return text # Fallback to original + except Exception as e: + print(f"翻译异常: {e}") + return text + def order_points(pts): """ 对四个坐标点进行排序:左上,右上,右下,左下 @@ -178,10 +219,10 @@ def load_image_from_url(url: str) -> Image.Image: except Exception as e: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") -def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True) -> list[dict]: +def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True, cutout: bool = False) -> list[dict]: """ - 根据 mask 和 box 进行透视矫正并裁剪出独立的对象图片 (保留透明背景) - 返回包含文件名和元数据的列表 + 根据 mask 和 box 进行处理并保存独立的对象图片 + :param cutout: 如果为 True,则进行轮廓抠图(透明背景);否则进行透视矫正(主要用于卡片) """ saved_objects = [] # Convert image to numpy array (RGB) @@ -193,74 +234,126 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE mask_np = mask.cpu().numpy().squeeze() else: mask_np = mask.squeeze() + + # Handle box conversion + if isinstance(box, torch.Tensor): + box_np = box.cpu().numpy() + else: + box_np = np.array(box) - # Ensure mask is uint8 binary for OpenCV + # Ensure mask is uint8 binary for OpenCV/Pillow if mask_np.dtype == bool: mask_uint8 = (mask_np * 255).astype(np.uint8) else: mask_uint8 = (mask_np > 0.5).astype(np.uint8) * 255 + + if cutout: + # --- 轮廓抠图模式 (透明背景) --- + # 1. 准备 RGBA 原图 + if image.mode != "RGBA": + img_rgba = image.convert("RGBA") + else: + img_rgba = image.copy() - # Find contours - contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) - if not contours: - continue + # 2. 准备 Alpha Mask + # mask_uint8 已经是 0-255,可以直接作为 Alpha 通道的基础 + # 将 Mask 缩放到和图片一样大 (一般是一样的,但以防万一) + if mask_uint8.shape != img_rgba.size[::-1]: # size is (w, h), shape is (h, w) + # 如果尺寸不一致可能需要 resize,这里假设尺寸一致 + pass + + mask_img = Image.fromarray(mask_uint8, mode='L') - # Get largest contour - c = max(contours, key=cv2.contourArea) - - # Approximate contour to polygon - peri = cv2.arcLength(c, True) - approx = cv2.approxPolyDP(c, 0.04 * peri, True) - - # If we have 4 points, use them. If not, fallback to minAreaRect - if len(approx) == 4: - pts = approx.reshape(4, 2) + # 3. 将 Mask 应用到 Alpha 通道 + # 创建一个新的空白透明图 + cutout_img = Image.new("RGBA", img_rgba.size, (0, 0, 0, 0)) + # 将原图粘贴上去,使用 mask 作为 mask + cutout_img.paste(image.convert("RGB"), (0, 0), mask=mask_img) + + # 4. Crop to Box + # box is [x1, y1, x2, y2] + x1, y1, x2, y2 = map(int, box_np) + # 边界检查 + w, h = cutout_img.size + x1 = max(0, x1); y1 = max(0, y1) + x2 = min(w, x2); y2 = min(h, y2) + + # 避免无效 crop + if x2 > x1 and y2 > y1: + final_img = cutout_img.crop((x1, y1, x2, y2)) + else: + final_img = cutout_img # Fallback + + # Save + prefix = "cutout" + is_rotated = False # 抠图模式下不进行自动旋转 + + filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" + save_path = os.path.join(output_dir, filename) + final_img.save(save_path) + + saved_objects.append({ + "filename": filename, + "is_rotated_by_algorithm": is_rotated, + "note": "Mask cutout applied. Background removed." + }) + else: - rect = cv2.minAreaRect(c) - pts = cv2.boxPoints(rect) + # --- 透视矫正模式 (原有逻辑) --- + # Find contours + contours, _ = cv2.findContours(mask_uint8, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + continue + + # Get largest contour + c = max(contours, key=cv2.contourArea) - # Apply perspective transform - # 注意:这里我们只变换RGB部分,Alpha通道需要额外处理或者直接应用同样的变换 - # 为了简单,我们直接对原图(假设不带Alpha)进行变换 - # 如果需要保留背景透明,需要先将原图转为RGBA,再做变换 - - # Check if original image has Alpha - if img_arr.shape[2] == 4: - warped = four_point_transform(img_arr, pts) - else: - # Add alpha channel from mask? - # 透视变换后的矩形本身就是去掉了背景的,所以不需要额外的Mask Alpha - # 但是为了保持一致性,我们可以给变换后的图加一个全不透明的Alpha,或者保留RGB - warped = four_point_transform(img_arr, pts) - - # Check orientation (Portrait vs Landscape) - h, w = warped.shape[:2] - is_rotated = False - - # Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx) - if is_tarot and w > h: - # Rotate 90 degrees clockwise - warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE) - is_rotated = True + # Approximate contour to polygon + peri = cv2.arcLength(c, True) + approx = cv2.approxPolyDP(c, 0.04 * peri, True) - # Convert back to PIL - pil_warped = Image.fromarray(warped) - - # Save - prefix = "tarot" if is_tarot else "segment" - filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" - save_path = os.path.join(output_dir, filename) - pil_warped.save(save_path) - - # 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒) - # 这里我们假设长边垂直为正位,如果做了旋转则标记 - # 真正的正逆位需要OCR或图像识别 - - saved_objects.append({ - "filename": filename, - "is_rotated_by_algorithm": is_rotated, - "note": "Geometric correction applied. True upright/reversed requires content analysis." - }) + # If we have 4 points, use them. If not, fallback to minAreaRect + if len(approx) == 4: + pts = approx.reshape(4, 2) + else: + rect = cv2.minAreaRect(c) + pts = cv2.boxPoints(rect) + + # Apply perspective transform + # Check if original image has Alpha + if img_arr.shape[2] == 4: + warped = four_point_transform(img_arr, pts) + else: + warped = four_point_transform(img_arr, pts) + + # Check orientation (Portrait vs Landscape) + h, w = warped.shape[:2] + is_rotated = False + + # Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx) + if is_tarot and w > h: + # Rotate 90 degrees clockwise + warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE) + is_rotated = True + + # Convert back to PIL + pil_warped = Image.fromarray(warped) + + # Save + prefix = "tarot" if is_tarot else "segment" + filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" + save_path = os.path.join(output_dir, filename) + pil_warped.save(save_path) + + # 正逆位判断逻辑 (基于几何只能做到这一步,无法区分上下颠倒) + # 这里我们假设长边垂直为正位,如果做了旋转则标记 + # 真正的正逆位需要OCR或图像识别 + + saved_objects.append({ + "filename": filename, + "is_rotated_by_algorithm": is_rotated, + "note": "Geometric correction applied. True upright/reversed requires content analysis." + }) return saved_objects @@ -405,11 +498,28 @@ async def segment( prompt: str = Form(...), file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None), - save_segment_images: bool = Form(False) + save_segment_images: bool = Form(False), + cutout: bool = Form(False) ): if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + # ------------------- Prompt 翻译/优化 ------------------- + final_prompt = prompt + if not is_english(prompt): + # 使用 Qwen 进行翻译 + try: + # 构建消息请求翻译 + # 注意:translate_to_sam3_prompt 是一个阻塞调用,可能会增加耗时 + # 但对于分割任务来说是可以接受的 + translated = translate_to_sam3_prompt(prompt) + if translated: + final_prompt = translated + except Exception as e: + print(f"Prompt翻译失败,使用原始Prompt: {e}") + + print(f"最终使用的 Prompt: {final_prompt}") + try: if file: image = Image.open(file.file).convert("RGB") @@ -422,7 +532,7 @@ async def segment( try: inference_state = processor.set_image(image) - output = processor.set_text_prompt(state=inference_state, prompt=prompt) + output = processor.set_text_prompt(state=inference_state, prompt=final_prompt) masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") @@ -442,21 +552,26 @@ async def segment( output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) os.makedirs(output_dir, exist_ok=True) - saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir, is_tarot=False) + # 传递 cutout 参数 + saved_objects = crop_and_save_objects( + image, + masks, + boxes, + output_dir=output_dir, + is_tarot=False, + cutout=cutout + ) for obj in saved_objects: fname = obj["filename"] seg_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) saved_segments_info.append({ "url": seg_url, - "filename": fname + "filename": fname, + "note": obj.get("note", "") }) except Exception as e: - # Log error but don't fail the whole request if segmentation saving fails? - # Or fail it? Let's fail it to be safe or include error in response. - # Given simple requirement, I'll let it fail or just print. print(f"Error saving segments: {e}") - # We can optionally raise HTTPException here too. raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}") response_content = { @@ -691,6 +806,18 @@ async def segment_face( if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + # ------------------- Prompt 翻译/优化 ------------------- + final_prompt = prompt + if not is_english(prompt): + try: + translated = translate_to_sam3_prompt(prompt) + if translated: + final_prompt = translated + except Exception as e: + print(f"Prompt翻译失败,使用原始Prompt: {e}") + + print(f"Face Segment 最终使用的 Prompt: {final_prompt}") + # 1. 加载图片 try: if file: @@ -709,7 +836,7 @@ async def segment_face( result = human_analysis_service.process_face_segmentation_and_analysis( processor=processor, image=image, - prompt=prompt, + prompt=final_prompt, output_base_dir=RESULT_IMAGE_DIR ) except Exception as e: diff --git a/static/results/seg_37fe4609bc984323a0284b404456f966.jpg b/static/results/seg_37fe4609bc984323a0284b404456f966.jpg new file mode 100644 index 0000000..9e75b44 Binary files /dev/null and b/static/results/seg_37fe4609bc984323a0284b404456f966.jpg differ