Files
sam3_local/human_analysis_service.py
2026-02-15 22:32:05 +08:00

232 lines
7.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import uuid
import time
import requests
import numpy as np
import json
import torch
import cv2
from PIL import Image
from dashscope import MultiModalConversation
# 配置 (与 fastAPI_tarot.py 保持一致或通过参数传入)
# 这里的常量可以根据需要调整,或者从主文件传入
QWEN_MODEL = 'qwen-vl-max'
def load_image_from_url(url: str) -> Image.Image:
"""
从 URL 下载图片并转换为 RGB 格式
"""
try:
headers = {'User-Agent': 'Mozilla/5.0'}
response = requests.get(url, headers=headers, stream=True, timeout=10)
response.raise_for_status()
image = Image.open(response.raw).convert("RGB")
return image
except Exception as e:
raise Exception(f"无法下载图片: {str(e)}")
def crop_head_with_padding(image: Image.Image, box, padding_ratio=0.1) -> Image.Image:
"""
根据 bounding box 裁剪图片,并添加一定的 padding
box格式: [x1, y1, x2, y2]
"""
img_w, img_h = image.size
x1, y1, x2, y2 = box
w = x2 - x1
h = y2 - y1
# 计算 padding
pad_w = w * padding_ratio
pad_h = h * padding_ratio
# 应用 padding 并确保不越界
new_x1 = max(0, int(x1 - pad_w))
new_y1 = max(0, int(y1 - pad_h))
new_x2 = min(img_w, int(x2 + pad_w))
new_y2 = min(img_h, int(y2 + pad_h))
return image.crop((new_x1, new_y1, new_x2, new_y2))
def create_highlighted_visualization(image: Image.Image, masks, output_path: str):
"""
创建一个突出显示头部Mask区域的可视化图背景变暗
"""
# Convert PIL to numpy RGB
img_np = np.array(image)
# Create darkened background (e.g., 30% brightness)
darkened_np = (img_np * 0.3).astype(np.uint8)
# Combine all masks
if len(masks) > 0:
combined_mask = np.zeros(img_np.shape[:2], dtype=bool)
for mask in masks:
# Handle tensor/numpy conversions
if isinstance(mask, torch.Tensor):
m = mask.cpu().numpy().squeeze()
else:
m = mask.squeeze()
# Ensure 2D
if m.ndim > 2:
m = m[0]
# Threshold if probability or float
if m.dtype != bool:
m = m > 0.5
# Resize mask if it doesn't match image size (rare but possible with some internal resizing)
if m.shape != img_np.shape[:2]:
# resize to match image
m = cv2.resize(m.astype(np.uint8), (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST).astype(bool)
combined_mask = np.logical_or(combined_mask, m)
# Expand mask to 3 channels for broadcasting
mask_3ch = np.stack([combined_mask]*3, axis=-1)
# Composite: Original where mask is True, Darkened where False
result_np = np.where(mask_3ch, img_np, darkened_np)
else:
result_np = darkened_np # No masks, just dark
# Save
Image.fromarray(result_np).save(output_path)
def analyze_demographics_with_qwen(image_path: str) -> dict:
"""
调用 Qwen-VL 模型分析人物的年龄和性别
"""
try:
# 确保路径是绝对路径
abs_path = os.path.abspath(image_path)
file_url = f"file://{abs_path}"
# 构造 Prompt
messages = [
{
"role": "user",
"content": [
{"image": file_url},
{"text": """请仔细观察这张图片中的人物头部/面部特写:
1. 识别性别 (Gender):男性/女性
2. 预估年龄 (Age):请给出一个合理的年龄范围,例如 "25-30岁"
3. 简要描述:发型、发色、是否有眼镜等显著特征。
请以 JSON 格式返回,包含 'gender', 'age', 'description' 字段。
不要包含 Markdown 标记。"""}
]
}
]
# 调用模型
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
if response.status_code == 200:
content = response.output.choices[0].message.content[0]['text']
# 清理 Markdown 代码块标记
clean_content = content.replace("```json", "").replace("```", "").strip()
try:
result = json.loads(clean_content)
return result
except json.JSONDecodeError:
return {"raw_analysis": clean_content}
else:
return {"error": f"API Error: {response.code} - {response.message}"}
except Exception as e:
return {"error": f"分析失败: {str(e)}"}
def process_face_segmentation_and_analysis(
processor,
image: Image.Image,
prompt: str = "head",
output_base_dir: str = "static/results"
) -> dict:
"""
核心处理逻辑:
1. SAM3 分割 (默认提示词 "head" 以包含头发)
2. 裁剪图片
3. Qwen-VL 识别性别年龄
4. 返回结果
"""
# 1. SAM3 推理
inference_state = processor.set_image(image)
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
detected_count = len(masks)
if detected_count == 0:
return {
"status": "success",
"message": "未检测到目标",
"detected_count": 0,
"results": []
}
# 准备结果目录
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
output_dir = os.path.join(output_base_dir, request_id)
os.makedirs(output_dir, exist_ok=True)
# --- 新增:生成背景变暗的整体可视化图 ---
vis_filename = f"seg_{uuid.uuid4().hex}.jpg"
vis_path = os.path.join(output_dir, vis_filename)
try:
create_highlighted_visualization(image, masks, vis_path)
full_vis_relative_path = f"results/{request_id}/{vis_filename}"
except Exception as e:
print(f"可视化生成失败: {e}")
full_vis_relative_path = None
# -------------------------------------
results = []
# 转换 boxes 为 numpy
if isinstance(boxes, torch.Tensor):
boxes_np = boxes.cpu().numpy()
else:
boxes_np = boxes
# 转换 scores 为 list
if isinstance(scores, torch.Tensor):
scores_list = scores.tolist()
else:
scores_list = scores if isinstance(scores, list) else [float(scores)]
for i, box in enumerate(boxes_np):
# 2. 裁剪 (带一点 padding 以保留完整发型)
# 2. 裁剪 (带一点 padding 以保留完整发型)
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
# 保存裁剪图
filename = f"face_{i}.jpg"
save_path = os.path.join(output_dir, filename)
cropped_img.save(save_path)
# 3. 识别
analysis = analyze_demographics_with_qwen(save_path)
# 构造返回结果
# 注意URL 生成需要依赖外部的 request context这里只返回相对路径或文件名
# 由调用方组装完整 URL
results.append({
"filename": filename,
"relative_path": f"results/{request_id}/{filename}",
"analysis": analysis,
"score": float(scores_list[i]) if i < len(scores_list) else 0.0
})
return {
"status": "success",
"message": f"成功检测并分析 {detected_count} 个人脸",
"detected_count": detected_count,
"request_id": request_id,
"full_visualization": full_vis_relative_path, # 返回相对路径
"scores": scores_list, # 返回全部分数
"results": results
}