473 lines
16 KiB
Python
473 lines
16 KiB
Python
import os
|
||
import uuid
|
||
import time
|
||
import requests
|
||
import numpy as np
|
||
import json
|
||
import torch
|
||
import cv2
|
||
import ast
|
||
import re
|
||
from PIL import Image
|
||
from dashscope import MultiModalConversation
|
||
|
||
# 配置 (与 fastAPI_tarot.py 保持一致或通过参数传入)
|
||
# 这里的常量可以根据需要调整,或者从主文件传入
|
||
QWEN_MODEL = 'qwen-vl-max'
|
||
|
||
def load_image_from_url(url: str) -> Image.Image:
|
||
"""
|
||
从 URL 下载图片并转换为 RGB 格式
|
||
"""
|
||
try:
|
||
headers = {'User-Agent': 'Mozilla/5.0'}
|
||
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||
response.raise_for_status()
|
||
image = Image.open(response.raw).convert("RGB")
|
||
return image
|
||
except Exception as e:
|
||
raise Exception(f"无法下载图片: {str(e)}")
|
||
|
||
def crop_head_with_padding(image: Image.Image, box, padding_ratio=0.1) -> Image.Image:
|
||
"""
|
||
根据 bounding box 裁剪图片,并添加一定的 padding
|
||
box格式: [x1, y1, x2, y2]
|
||
"""
|
||
img_w, img_h = image.size
|
||
x1, y1, x2, y2 = box
|
||
|
||
w = x2 - x1
|
||
h = y2 - y1
|
||
|
||
# 计算 padding
|
||
pad_w = w * padding_ratio
|
||
pad_h = h * padding_ratio
|
||
|
||
# 应用 padding 并确保不越界
|
||
new_x1 = max(0, int(x1 - pad_w))
|
||
new_y1 = max(0, int(y1 - pad_h))
|
||
new_x2 = min(img_w, int(x2 + pad_w))
|
||
new_y2 = min(img_h, int(y2 + pad_h))
|
||
|
||
return image.crop((new_x1, new_y1, new_x2, new_y2))
|
||
|
||
def create_highlighted_visualization(image: Image.Image, masks, output_path: str):
|
||
"""
|
||
创建一个突出显示头部(Mask区域)的可视化图,背景变暗
|
||
"""
|
||
# Convert PIL to numpy RGB
|
||
img_np = np.array(image)
|
||
|
||
# Create darkened background (e.g., 30% brightness)
|
||
darkened_np = (img_np * 0.3).astype(np.uint8)
|
||
|
||
# Combine all masks
|
||
if len(masks) > 0:
|
||
combined_mask = np.zeros(img_np.shape[:2], dtype=bool)
|
||
for mask in masks:
|
||
# Handle tensor/numpy conversions
|
||
if isinstance(mask, torch.Tensor):
|
||
m = mask.cpu().numpy().squeeze()
|
||
else:
|
||
m = mask.squeeze()
|
||
|
||
# Ensure 2D
|
||
if m.ndim > 2:
|
||
m = m[0]
|
||
|
||
# Threshold if probability or float
|
||
if m.dtype != bool:
|
||
m = m > 0.5
|
||
|
||
# Resize mask if it doesn't match image size (rare but possible with some internal resizing)
|
||
if m.shape != img_np.shape[:2]:
|
||
# resize to match image
|
||
m = cv2.resize(m.astype(np.uint8), (img_np.shape[1], img_np.shape[0]), interpolation=cv2.INTER_NEAREST).astype(bool)
|
||
|
||
combined_mask = np.logical_or(combined_mask, m)
|
||
|
||
# Expand mask to 3 channels for broadcasting
|
||
mask_3ch = np.stack([combined_mask]*3, axis=-1)
|
||
|
||
# Composite: Original where mask is True, Darkened where False
|
||
result_np = np.where(mask_3ch, img_np, darkened_np)
|
||
else:
|
||
result_np = darkened_np # No masks, just dark
|
||
|
||
# Save
|
||
Image.fromarray(result_np).save(output_path)
|
||
|
||
def extract_json_from_response(text: str) -> dict:
|
||
"""
|
||
Robustly extract JSON from text, handling:
|
||
1. Markdown code blocks (```json ... ```)
|
||
2. Single quotes (Python dict style) via ast.literal_eval
|
||
"""
|
||
try:
|
||
# 1. Try to find JSON block
|
||
json_match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
|
||
if json_match:
|
||
clean_text = json_match.group(1).strip()
|
||
else:
|
||
# Try to find { ... } block if no markdown
|
||
match = re.search(r'\{.*\}', text, re.DOTALL)
|
||
if match:
|
||
clean_text = match.group(0).strip()
|
||
else:
|
||
clean_text = text.strip()
|
||
|
||
# 2. Try standard JSON
|
||
return json.loads(clean_text)
|
||
except Exception as e1:
|
||
# 3. Try ast.literal_eval for single quotes
|
||
try:
|
||
return ast.literal_eval(clean_text)
|
||
except Exception as e2:
|
||
# 4. Fail
|
||
raise ValueError(f"Could not parse JSON: {e1} | {e2} | Content: {text[:100]}...")
|
||
|
||
def analyze_demographics_with_qwen(image_path: str, model_name: str = 'qwen-vl-max', prompt_template: str = None) -> dict:
|
||
"""
|
||
调用 Qwen-VL 模型分析人物的年龄和性别
|
||
"""
|
||
try:
|
||
# 确保路径是绝对路径
|
||
abs_path = os.path.abspath(image_path)
|
||
file_url = f"file://{abs_path}"
|
||
|
||
# 默认 Prompt
|
||
default_prompt = """请仔细观察这张图片中的人物头部/面部特写:
|
||
1. 识别性别 (Gender):男性/女性
|
||
2. 预估年龄 (Age):请给出一个合理的年龄范围,例如 "25-30岁"
|
||
3. 简要描述:发型、发色、是否有眼镜等显著特征。
|
||
|
||
请以 JSON 格式返回,包含 'gender', 'age', 'description' 字段。
|
||
不要包含 Markdown 标记。"""
|
||
|
||
final_prompt = prompt_template if prompt_template else default_prompt
|
||
|
||
# 构造 Prompt
|
||
messages = [
|
||
{
|
||
"role": "user",
|
||
"content": [
|
||
{"image": file_url},
|
||
{"text": final_prompt}
|
||
]
|
||
}
|
||
]
|
||
|
||
# 调用模型
|
||
response = MultiModalConversation.call(model=model_name, messages=messages)
|
||
|
||
if response.status_code == 200:
|
||
content = response.output.choices[0].message.content[0]['text']
|
||
try:
|
||
result = extract_json_from_response(content)
|
||
result["model_used"] = model_name
|
||
return result
|
||
except Exception as e:
|
||
print(f"JSON Parse Error in face analysis: {e}")
|
||
return {"raw_analysis": content, "error": str(e), "model_used": model_name}
|
||
else:
|
||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||
|
||
except Exception as e:
|
||
return {"error": f"分析失败: {str(e)}"}
|
||
|
||
import asyncio
|
||
|
||
def process_face_segmentation_and_analysis(
|
||
processor,
|
||
image: Image.Image,
|
||
prompt: str = "head",
|
||
output_base_dir: str = "static/results",
|
||
qwen_model: str = "qwen-vl-max",
|
||
analysis_prompt: str = None
|
||
) -> dict:
|
||
"""
|
||
核心处理逻辑:
|
||
1. SAM3 分割 (默认提示词 "head" 以包含头发)
|
||
2. 裁剪图片
|
||
3. Qwen-VL 识别性别年龄 (并发)
|
||
4. 返回结果
|
||
"""
|
||
|
||
# 1. SAM3 推理 (同步,因为涉及 GPU 操作)
|
||
inference_state = processor.set_image(image)
|
||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
|
||
detected_count = len(masks)
|
||
if detected_count == 0:
|
||
return {
|
||
"status": "success",
|
||
"message": "未检测到目标",
|
||
"detected_count": 0,
|
||
"results": []
|
||
}
|
||
|
||
# 准备结果目录
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(output_base_dir, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
# --- 生成可视化图 ---
|
||
vis_filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||
vis_path = os.path.join(output_dir, vis_filename)
|
||
try:
|
||
create_highlighted_visualization(image, masks, vis_path)
|
||
full_vis_relative_path = f"results/{request_id}/{vis_filename}"
|
||
except Exception as e:
|
||
print(f"可视化生成失败: {e}")
|
||
full_vis_relative_path = None
|
||
# ------------------
|
||
|
||
# 转换 boxes 和 scores
|
||
if isinstance(boxes, torch.Tensor):
|
||
boxes_np = boxes.cpu().numpy()
|
||
else:
|
||
boxes_np = boxes
|
||
|
||
if isinstance(scores, torch.Tensor):
|
||
scores_list = scores.tolist()
|
||
else:
|
||
scores_list = scores if isinstance(scores, list) else [float(scores)]
|
||
|
||
# 准备异步任务
|
||
async def run_analysis_tasks():
|
||
loop = asyncio.get_event_loop()
|
||
tasks = []
|
||
temp_results = [] # 存储 (index, filename, score) 以便后续排序组合
|
||
|
||
for i, box in enumerate(boxes_np):
|
||
# 2. 裁剪 (同步)
|
||
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
|
||
filename = f"face_{i}.jpg"
|
||
save_path = os.path.join(output_dir, filename)
|
||
cropped_img.save(save_path)
|
||
|
||
# 3. 准备识别任务
|
||
task = loop.run_in_executor(
|
||
None,
|
||
analyze_demographics_with_qwen,
|
||
save_path,
|
||
qwen_model,
|
||
analysis_prompt
|
||
)
|
||
tasks.append(task)
|
||
temp_results.append({
|
||
"filename": filename,
|
||
"relative_path": f"results/{request_id}/{filename}",
|
||
"score": float(scores_list[i]) if i < len(scores_list) else 0.0
|
||
})
|
||
|
||
# 等待所有任务完成
|
||
if tasks:
|
||
analysis_results = await asyncio.gather(*tasks)
|
||
else:
|
||
analysis_results = []
|
||
|
||
# 组合结果
|
||
final_results = []
|
||
for i, item in enumerate(temp_results):
|
||
item["analysis"] = analysis_results[i]
|
||
final_results.append(item)
|
||
|
||
return final_results
|
||
|
||
# 运行异步任务
|
||
# 注意:由于本函数被 FastAPI (异步环境) 中的同步或异步函数调用,
|
||
# 如果上层是 async def,我们可以直接 await。
|
||
# 但由于这个函数定义没有 async,且之前的调用是同步的,
|
||
# 为了兼容性,我们需要检查当前是否在事件循环中。
|
||
|
||
# 然而,查看 fastAPI_tarot.py,这个函数是在 async def segment_face 中被调用的。
|
||
# 但它是作为普通函数被导入和调用的。
|
||
# 为了不破坏现有签名,我们可以使用 asyncio.run() 或者在新循环中运行,
|
||
# 但这在已经运行的 loop 中是不允许的。
|
||
|
||
# 最佳方案:修改本函数为 async,并在 fastAPI_tarot.py 中 await 它。
|
||
# 但这需要修改 fastAPI_tarot.py 的调用处。
|
||
|
||
# 既然我们已经修改了 fastAPI_tarot.py,我们也可以顺便修改这里的签名。
|
||
# 但为了稳妥,我们可以用一种 hack:
|
||
# 如果在一个正在运行的 loop 中调用,我们必须返回 awaitable 或者使用 loop.run_until_complete (会报错)
|
||
|
||
# 让我们先把这个函数改成 async,然后去修改 fastAPI_tarot.py 的调用。
|
||
# 这是最正确的做法。
|
||
pass # 占位,实际代码在下面
|
||
|
||
async def process_face_segmentation_and_analysis_async(
|
||
processor,
|
||
image: Image.Image,
|
||
prompt: str = "head",
|
||
output_base_dir: str = "static/results",
|
||
qwen_model: str = "qwen-vl-max",
|
||
analysis_prompt: str = None
|
||
) -> dict:
|
||
# ... (同上逻辑,只是是 async)
|
||
|
||
# 1. SAM3 推理
|
||
inference_state = processor.set_image(image)
|
||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
|
||
detected_count = len(masks)
|
||
if detected_count == 0:
|
||
return {
|
||
"status": "success",
|
||
"message": "未检测到目标",
|
||
"detected_count": 0,
|
||
"results": []
|
||
}
|
||
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(output_base_dir, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
vis_filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||
vis_path = os.path.join(output_dir, vis_filename)
|
||
try:
|
||
create_highlighted_visualization(image, masks, vis_path)
|
||
full_vis_relative_path = f"results/{request_id}/{vis_filename}"
|
||
except Exception as e:
|
||
print(f"可视化生成失败: {e}")
|
||
full_vis_relative_path = None
|
||
|
||
if isinstance(boxes, torch.Tensor):
|
||
boxes_np = boxes.cpu().numpy()
|
||
else:
|
||
boxes_np = boxes
|
||
|
||
if isinstance(scores, torch.Tensor):
|
||
scores_list = scores.tolist()
|
||
else:
|
||
scores_list = scores if isinstance(scores, list) else [float(scores)]
|
||
|
||
loop = asyncio.get_event_loop()
|
||
tasks = []
|
||
results = []
|
||
|
||
for i, box in enumerate(boxes_np):
|
||
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
|
||
filename = f"face_{i}.jpg"
|
||
save_path = os.path.join(output_dir, filename)
|
||
cropped_img.save(save_path)
|
||
|
||
task = loop.run_in_executor(
|
||
None,
|
||
analyze_demographics_with_qwen,
|
||
save_path,
|
||
qwen_model,
|
||
analysis_prompt
|
||
)
|
||
tasks.append(task)
|
||
|
||
results.append({
|
||
"filename": filename,
|
||
"relative_path": f"results/{request_id}/{filename}",
|
||
"score": float(scores_list[i]) if i < len(scores_list) else 0.0
|
||
})
|
||
|
||
if tasks:
|
||
analysis_results = await asyncio.gather(*tasks)
|
||
else:
|
||
analysis_results = []
|
||
|
||
for i, item in enumerate(results):
|
||
item["analysis"] = analysis_results[i]
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": f"成功检测并分析 {detected_count} 个人脸",
|
||
"detected_count": detected_count,
|
||
"request_id": request_id,
|
||
"full_visualization": full_vis_relative_path,
|
||
"scores": scores_list,
|
||
"results": results
|
||
}
|
||
|
||
# 保留旧的同步接口以兼容其他潜在调用者,但内部实现可能会有问题如果它在 loop 中运行
|
||
# 既然我们主要关注 fastAPI_tarot.py,我们可以直接替换 process_face_segmentation_and_analysis
|
||
# 或者让它只是一个 wrapper
|
||
def process_face_segmentation_and_analysis(
|
||
processor,
|
||
image: Image.Image,
|
||
prompt: str = "head",
|
||
output_base_dir: str = "static/results",
|
||
qwen_model: str = "qwen-vl-max",
|
||
analysis_prompt: str = None
|
||
) -> dict:
|
||
"""
|
||
同步版本 (保留以兼容)
|
||
注意:如果在 async loop 中调用此函数,且此函数内部没有异步操作,则会阻塞 loop。
|
||
如果需要异步并发,请使用 process_face_segmentation_and_analysis_async
|
||
"""
|
||
# 这里我们简单地复用逻辑,但去除异步部分,退化为串行
|
||
|
||
# 1. SAM3 推理
|
||
inference_state = processor.set_image(image)
|
||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||
|
||
detected_count = len(masks)
|
||
if detected_count == 0:
|
||
return {
|
||
"status": "success",
|
||
"message": "未检测到目标",
|
||
"detected_count": 0,
|
||
"results": []
|
||
}
|
||
|
||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||
output_dir = os.path.join(output_base_dir, request_id)
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
vis_filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||
vis_path = os.path.join(output_dir, vis_filename)
|
||
try:
|
||
create_highlighted_visualization(image, masks, vis_path)
|
||
full_vis_relative_path = f"results/{request_id}/{vis_filename}"
|
||
except Exception as e:
|
||
print(f"可视化生成失败: {e}")
|
||
full_vis_relative_path = None
|
||
|
||
if isinstance(boxes, torch.Tensor):
|
||
boxes_np = boxes.cpu().numpy()
|
||
else:
|
||
boxes_np = boxes
|
||
|
||
if isinstance(scores, torch.Tensor):
|
||
scores_list = scores.tolist()
|
||
else:
|
||
scores_list = scores if isinstance(scores, list) else [float(scores)]
|
||
|
||
results = []
|
||
for i, box in enumerate(boxes_np):
|
||
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
|
||
filename = f"face_{i}.jpg"
|
||
save_path = os.path.join(output_dir, filename)
|
||
cropped_img.save(save_path)
|
||
|
||
# 同步调用
|
||
analysis = analyze_demographics_with_qwen(save_path, model_name=qwen_model, prompt_template=analysis_prompt)
|
||
|
||
results.append({
|
||
"filename": filename,
|
||
"relative_path": f"results/{request_id}/{filename}",
|
||
"analysis": analysis,
|
||
"score": float(scores_list[i]) if i < len(scores_list) else 0.0
|
||
})
|
||
|
||
return {
|
||
"status": "success",
|
||
"message": f"成功检测并分析 {detected_count} 个人脸",
|
||
"detected_count": detected_count,
|
||
"request_id": request_id,
|
||
"full_visualization": full_vis_relative_path,
|
||
"scores": scores_list,
|
||
"results": results
|
||
}
|
||
|