diff --git a/fastAPI_tarot.py b/fastAPI_tarot.py index 9bfb740..bb08eea 100644 --- a/fastAPI_tarot.py +++ b/fastAPI_tarot.py @@ -23,6 +23,7 @@ from PIL import Image from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_results +import human_analysis_service # 引入新服务 # ------------------- 配置与路径 ------------------- STATIC_DIR = "static" @@ -92,7 +93,7 @@ async def lifespan(app: FastAPI): app = FastAPI( lifespan=lifespan, title="SAM3 Segmentation API", - description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-speed`", + description="## 🔒 受 API Key 保护\n请点击右上角 **Authorize** 并输入: `123quant-*****`", ) # 手动添加 OpenAPI 安全配置,让 Docs 里的锁头生效 @@ -177,7 +178,7 @@ def load_image_from_url(url: str) -> Image.Image: except Exception as e: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") -def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR) -> list[dict]: +def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RESULT_IMAGE_DIR, is_tarot: bool = True) -> list[dict]: """ 根据 mask 和 box 进行透视矫正并裁剪出独立的对象图片 (保留透明背景) 返回包含文件名和元数据的列表 @@ -237,7 +238,7 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE is_rotated = False # Enforce Portrait for Tarot cards (Standard 7x12 cm ratio approx) - if w > h: + if is_tarot and w > h: # Rotate 90 degrees clockwise warped = cv2.rotate(warped, cv2.ROTATE_90_CLOCKWISE) is_rotated = True @@ -246,7 +247,8 @@ def crop_and_save_objects(image: Image.Image, masks, boxes, output_dir: str = RE pil_warped = Image.fromarray(warped) # Save - filename = f"tarot_{uuid.uuid4().hex}_{i}.png" + prefix = "tarot" if is_tarot else "segment" + filename = f"{prefix}_{uuid.uuid4().hex}_{i}.png" save_path = os.path.join(output_dir, filename) pil_warped.save(save_path) @@ -272,23 +274,73 @@ def generate_and_save_result(image: Image.Image, inference_state, output_dir: st def recognize_card_with_qwen(image_path: str) -> dict: """ - 调用 Qwen-VL 识别塔罗牌 + 调用 Qwen-VL 识别塔罗牌 (采用正逆位对比策略) """ try: - # 确保路径是绝对路径并加上 file:// 前缀 + # 确保路径是绝对路径 abs_path = os.path.abspath(image_path) file_url = f"file://{abs_path}" - messages = [ - { - "role": "user", - "content": [ - {"image": file_url}, - {"text": "这是一张塔罗牌。请识别它的名字(中文),并判断它是正位还是逆位。请以JSON格式返回,包含 'name' 和 'position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。"} - ] - } - ] + # ------------------------------------------------- + # 优化策略:生成一张旋转180度的对比图 + # 让 AI 做选择题而不是判断题,大幅提高准确率 + # ------------------------------------------------- + try: + # 1. 打开原图 + img = Image.open(abs_path) + + # 2. 生成旋转图 (180度) + rotated_img = img.rotate(180) + + # 3. 保存旋转图 + dir_name = os.path.dirname(abs_path) + file_name = os.path.basename(abs_path) + rotated_name = f"rotated_{file_name}" + rotated_path = os.path.join(dir_name, rotated_name) + rotated_img.save(rotated_path) + + rotated_file_url = f"file://{rotated_path}" + + # 4. 构建对比 Prompt + # 发送两张图:图1=原图, 图2=旋转图 + # 询问 AI 哪一张是“正位” + messages = [ + { + "role": "user", + "content": [ + {"image": file_url}, # 图1 (原图) + {"image": rotated_file_url}, # 图2 (旋转180度) + {"text": """这是一张塔罗牌的两个方向: +图1:原始方向 +图2:旋转180度后的方向 + +请仔细对比两张图片的牌面内容(文字方向、人物站立方向、图案逻辑): +1. 识别这张牌的名字(中文)。 +2. 判断哪一张图片展示了正确的“正位”(Upright)状态。 + - 如果图1是正位,说明原图就是正位。 + - 如果图2是正位,说明原图是逆位。 + +请以JSON格式返回,包含 'name' 和 'position' 两个字段。 +例如:{'name': '愚者', 'position': '正位'} 或 {'name': '倒吊人', 'position': '逆位'}。 +不要包含Markdown代码块标记。"""} + ] + } + ] + + except Exception as e: + print(f"对比图生成失败,回退到单图模式: {e}") + # 回退到旧的单图模式 + messages = [ + { + "role": "user", + "content": [ + {"image": file_url}, + {"text": "这是一张塔罗牌。请识别它的名字(中文),并判断它是正位还是逆位。请以JSON格式返回,包含 'name' 和 'position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。"} + ] + } + ] + # 调用模型 response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) if response.status_code == 200: @@ -352,7 +404,8 @@ async def segment( request: Request, prompt: str = Form(...), file: Optional[UploadFile] = File(None), - image_url: Optional[str] = Form(None) + image_url: Optional[str] = Form(None), + save_segment_images: bool = Form(False) ): if not file and not image_url: raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") @@ -380,13 +433,43 @@ async def segment( raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") file_url = request.url_for("static", path=f"results/{filename}") + + # New logic for saving segments + saved_segments_info = [] + if save_segment_images: + try: + request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" + output_dir = os.path.join(RESULT_IMAGE_DIR, request_id) + os.makedirs(output_dir, exist_ok=True) + + saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir, is_tarot=False) + + for obj in saved_objects: + fname = obj["filename"] + seg_url = str(request.url_for("static", path=f"results/{request_id}/{fname}")) + saved_segments_info.append({ + "url": seg_url, + "filename": fname + }) + except Exception as e: + # Log error but don't fail the whole request if segmentation saving fails? + # Or fail it? Let's fail it to be safe or include error in response. + # Given simple requirement, I'll let it fail or just print. + print(f"Error saving segments: {e}") + # We can optionally raise HTTPException here too. + raise HTTPException(status_code=500, detail=f"保存分割图片失败: {str(e)}") - return JSONResponse(content={ + response_content = { "status": "success", "result_image_url": str(file_url), "detected_count": len(masks), "scores": scores.tolist() if torch.is_tensor(scores) else scores - }) + } + + if save_segment_images: + response_content["segmented_images"] = saved_segments_info + + return JSONResponse(content=response_content) @app.post("/segment_tarot", dependencies=[Depends(verify_api_key)]) async def segment_tarot( @@ -592,6 +675,64 @@ async def recognize_tarot( "scores": scores.tolist() if torch.is_tensor(scores) else scores }) +@app.post("/segment_face", dependencies=[Depends(verify_api_key)]) +async def segment_face( + request: Request, + file: Optional[UploadFile] = File(None), + image_url: Optional[str] = Form(None), + prompt: str = Form("face and hair") # 默认提示词包含头发 +): + """ + 人脸/头部检测与属性分析接口 (新功能) + 1. 调用 SAM3 分割出头部区域 (含头发) + 2. 裁剪并保存 + 3. 调用 Qwen-VL 识别性别和年龄 + """ + if not file and not image_url: + raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + + # 1. 加载图片 + try: + if file: + image = Image.open(file.file).convert("RGB") + elif image_url: + image = load_image_from_url(image_url) + except Exception as e: + raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") + + processor = request.app.state.processor + + # 2. 调用独立服务进行处理 + try: + # 传入 processor 和 image + # 注意:Result Image Dir 我们直接复用 RESULT_IMAGE_DIR + result = human_analysis_service.process_face_segmentation_and_analysis( + processor=processor, + image=image, + prompt=prompt, + output_base_dir=RESULT_IMAGE_DIR + ) + except Exception as e: + # 打印详细错误堆栈以便调试 + import traceback + traceback.print_exc() + raise HTTPException(status_code=500, detail=f"处理失败: {str(e)}") + + # 3. 补全 URL (因为 service 层不知道 request context) + if result["status"] == "success": + # 处理全图可视化的 URL + if result.get("full_visualization"): + full_vis_rel_path = result["full_visualization"] + result["full_visualization"] = str(request.url_for("static", path=full_vis_rel_path)) + + for item in result["results"]: + # item["relative_path"] 是相对路径,如 results/xxx/xxx.jpg + # 我们需要将其转换为完整 URL + relative_path = item.pop("relative_path") # 移除相对路径字段,只返回 URL + item["url"] = str(request.url_for("static", path=relative_path)) + + return JSONResponse(content=result) + if __name__ == "__main__": import uvicorn # 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数 diff --git a/human_analysis_service.py b/human_analysis_service.py new file mode 100644 index 0000000..d61e537 --- /dev/null +++ b/human_analysis_service.py @@ -0,0 +1,231 @@ +import os +import uuid +import time +import requests +import numpy as np +import json +import torch +import cv2 +from PIL import Image +from dashscope import MultiModalConversation + +# 配置 (与 fastAPI_tarot.py 保持一致或通过参数传入) +# 这里的常量可以根据需要调整,或者从主文件传入 +QWEN_MODEL = 'qwen-vl-max' + +def load_image_from_url(url: str) -> Image.Image: + """ + 从 URL 下载图片并转换为 RGB 格式 + """ + try: + headers = {'User-Agent': 'Mozilla/5.0'} + response = requests.get(url, headers=headers, stream=True, timeout=10) + response.raise_for_status() + image = Image.open(response.raw).convert("RGB") + return image + except Exception as e: + raise Exception(f"无法下载图片: {str(e)}") + +def crop_head_with_padding(image: Image.Image, box, padding_ratio=0.1) -> Image.Image: + """ + 根据 bounding box 裁剪图片,并添加一定的 padding + box格式: [x1, y1, x2, y2] + """ + img_w, img_h = image.size + x1, y1, x2, y2 = box + + w = x2 - x1 + h = y2 - y1 + + # 计算 padding + pad_w = w * padding_ratio + pad_h = h * padding_ratio + + # 应用 padding 并确保不越界 + new_x1 = max(0, int(x1 - pad_w)) + new_y1 = max(0, int(y1 - pad_h)) + new_x2 = min(img_w, int(x2 + pad_w)) + new_y2 = min(img_h, int(y2 + pad_h)) + + return image.crop((new_x1, new_y1, new_x2, new_y2)) + +def create_highlighted_visualization(image: Image.Image, masks, output_path: str): + """ + 创建一个突出显示头部(Mask区域)的可视化图,背景变暗 + """ + # Convert PIL to numpy RGB + img_np = np.array(image) + + # Create darkened background (e.g., 30% brightness) + darkened_np = (img_np * 0.3).astype(np.uint8) + + # Combine all masks + if len(masks) > 0: + combined_mask = np.zeros(img_np.shape[:2], dtype=bool) + for mask in masks: + # Handle tensor/numpy conversions + if isinstance(mask, torch.Tensor): + m = mask.cpu().numpy().squeeze() + else: + m = mask.squeeze() + + # Ensure 2D + if m.ndim > 2: + m = m[0] + + # Threshold if probability or float + if m.dtype != bool: + m = m > 0.5 + + # Resize mask if it doesn't match image size (rare but possible with some internal resizing) + if m.shape != img_np.shape[:2]: + # resize to match image + m = cv2.resize(m.astype(np.uint8), (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST).astype(bool) + + combined_mask = np.logical_or(combined_mask, m) + + # Expand mask to 3 channels for broadcasting + mask_3ch = np.stack([combined_mask]*3, axis=-1) + + # Composite: Original where mask is True, Darkened where False + result_np = np.where(mask_3ch, img_np, darkened_np) + else: + result_np = darkened_np # No masks, just dark + + # Save + Image.fromarray(result_np).save(output_path) + +def analyze_demographics_with_qwen(image_path: str) -> dict: + """ + 调用 Qwen-VL 模型分析人物的年龄和性别 + """ + try: + # 确保路径是绝对路径 + abs_path = os.path.abspath(image_path) + file_url = f"file://{abs_path}" + + # 构造 Prompt + messages = [ + { + "role": "user", + "content": [ + {"image": file_url}, + {"text": """请仔细观察这张图片中的人物头部/面部特写: +1. 识别性别 (Gender):男性/女性 +2. 预估年龄 (Age):请给出一个合理的年龄范围,例如 "25-30岁" +3. 简要描述:发型、发色、是否有眼镜等显著特征。 + +请以 JSON 格式返回,包含 'gender', 'age', 'description' 字段。 +不要包含 Markdown 标记。"""} + ] + } + ] + + # 调用模型 + response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages) + + if response.status_code == 200: + content = response.output.choices[0].message.content[0]['text'] + # 清理 Markdown 代码块标记 + clean_content = content.replace("```json", "").replace("```", "").strip() + try: + result = json.loads(clean_content) + return result + except json.JSONDecodeError: + return {"raw_analysis": clean_content} + else: + return {"error": f"API Error: {response.code} - {response.message}"} + + except Exception as e: + return {"error": f"分析失败: {str(e)}"} + +def process_face_segmentation_and_analysis( + processor, + image: Image.Image, + prompt: str = "head", + output_base_dir: str = "static/results" +) -> dict: + """ + 核心处理逻辑: + 1. SAM3 分割 (默认提示词 "head" 以包含头发) + 2. 裁剪图片 + 3. Qwen-VL 识别性别年龄 + 4. 返回结果 + """ + + # 1. SAM3 推理 + inference_state = processor.set_image(image) + output = processor.set_text_prompt(state=inference_state, prompt=prompt) + masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + + detected_count = len(masks) + if detected_count == 0: + return { + "status": "success", + "message": "未检测到目标", + "detected_count": 0, + "results": [] + } + + # 准备结果目录 + request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" + output_dir = os.path.join(output_base_dir, request_id) + os.makedirs(output_dir, exist_ok=True) + + # --- 新增:生成背景变暗的整体可视化图 --- + vis_filename = f"seg_{uuid.uuid4().hex}.jpg" + vis_path = os.path.join(output_dir, vis_filename) + try: + create_highlighted_visualization(image, masks, vis_path) + full_vis_relative_path = f"results/{request_id}/{vis_filename}" + except Exception as e: + print(f"可视化生成失败: {e}") + full_vis_relative_path = None + # ------------------------------------- + + results = [] + + # 转换 boxes 为 numpy + if isinstance(boxes, torch.Tensor): + boxes_np = boxes.cpu().numpy() + else: + boxes_np = boxes + + # 转换 scores 为 list + if isinstance(scores, torch.Tensor): + scores_list = scores.tolist() + else: + scores_list = scores if isinstance(scores, list) else [float(scores)] + + for i, box in enumerate(boxes_np): + # 2. 裁剪 (带一点 padding 以保留完整发型) + # 2. 裁剪 (带一点 padding 以保留完整发型) + cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1) + + # 保存裁剪图 + filename = f"face_{i}.jpg" + save_path = os.path.join(output_dir, filename) + cropped_img.save(save_path) + + # 3. 识别 + analysis = analyze_demographics_with_qwen(save_path) + + # 构造返回结果 + # 注意:URL 生成需要依赖外部的 request context,这里只返回相对路径或文件名 + # 由调用方组装完整 URL + results.append({ + "filename": filename, + "relative_path": f"results/{request_id}/{filename}", + "analysis": analysis, + "score": float(scores_list[i]) if i < len(scores_list) else 0.0 + }) + + return { + "status": "success", + "message": f"成功检测并分析 {detected_count} 个人脸", + "detected_count": detected_count, + "request_id": request_id, + "full_visualization": full_vis_relative_path, # 返回相对路径 + "scores": scores_list, # 返回全部分数 + "results": results + } diff --git a/static/results/1771165722_cd68234d/segment_9a700ea672454cc2804f9557be357aa6_0.png b/static/results/1771165722_cd68234d/segment_9a700ea672454cc2804f9557be357aa6_0.png new file mode 100644 index 0000000..b56d507 Binary files /dev/null and b/static/results/1771165722_cd68234d/segment_9a700ea672454cc2804f9557be357aa6_0.png differ diff --git a/static/results/1771165793_d82b3afa/face_0.jpg b/static/results/1771165793_d82b3afa/face_0.jpg new file mode 100644 index 0000000..44d239c Binary files /dev/null and b/static/results/1771165793_d82b3afa/face_0.jpg differ diff --git a/static/results/1771165793_d82b3afa/seg_cf489e48f6664b89a83fab2a577f8205.jpg b/static/results/1771165793_d82b3afa/seg_cf489e48f6664b89a83fab2a577f8205.jpg new file mode 100644 index 0000000..42f424a Binary files /dev/null and b/static/results/1771165793_d82b3afa/seg_cf489e48f6664b89a83fab2a577f8205.jpg differ diff --git a/static/results/seg_487ee0634f7d4aceb8d29565ac2b8149.jpg b/static/results/seg_487ee0634f7d4aceb8d29565ac2b8149.jpg new file mode 100644 index 0000000..d58cda7 Binary files /dev/null and b/static/results/seg_487ee0634f7d4aceb8d29565ac2b8149.jpg differ diff --git a/static/results/seg_5ef48ba87d364a819ed281458ea2c7a0.jpg b/static/results/seg_5ef48ba87d364a819ed281458ea2c7a0.jpg new file mode 100644 index 0000000..2a0e579 Binary files /dev/null and b/static/results/seg_5ef48ba87d364a819ed281458ea2c7a0.jpg differ diff --git a/static/results/seg_f73c363e8f4946e0a0cfb47a3709f7a5.jpg b/static/results/seg_f73c363e8f4946e0a0cfb47a3709f7a5.jpg new file mode 100644 index 0000000..3081831 Binary files /dev/null and b/static/results/seg_f73c363e8f4946e0a0cfb47a3709f7a5.jpg differ