Compare commits
11 Commits
2d315948a2
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 53e8fbb4dd | |||
| f7c73fa57e | |||
| bad6bfa34b | |||
| 054e720e39 | |||
| f8e94328a7 | |||
| aee6f8804f | |||
| 765a0aebdc | |||
| dc5a02f4ec | |||
| 4f6d7d9035 | |||
| 4667021944 | |||
| 06f2b2928b |
@@ -1,5 +1,11 @@
|
||||
# 量迹AI · SAM3「分割一切」视觉分割服务
|
||||
|
||||
|
||||
# Admin Config
|
||||
ADMIN_PASSWORD = "admin_secure_password" # 可以根据需求修改
|
||||
HISTORY_FILE = "history.json"
|
||||
|
||||
|
||||
本项目在开源 SAM3(Segment Anything Model 3)能力之上,封装了面向业务的 **“分割一切”** 推理服务:通过 **FastAPI** 提供文本提示词驱动的图像分割接口,并扩展了 **塔罗牌分割/识别**、**人脸与头发分割 + 属性分析** 等场景能力。
|
||||
|
||||
本仓库定位为:**模型推理 + API 服务** 的可复用工程模板(适合在 MacOS 开发、服务器部署)。
|
||||
|
||||
258
fastAPI_tarot.py
258
fastAPI_tarot.py
@@ -21,6 +21,8 @@ import traceback
|
||||
import re
|
||||
import asyncio
|
||||
import shutil
|
||||
import subprocess
|
||||
import ast
|
||||
from datetime import datetime
|
||||
from typing import Optional, List, Dict, Any
|
||||
from contextlib import asynccontextmanager
|
||||
@@ -71,7 +73,7 @@ HISTORY_FILE = "history.json"
|
||||
# Dashscope (Qwen-VL) 配置
|
||||
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
|
||||
QWEN_MODEL = 'qwen-vl-max' # Default model
|
||||
AVAILABLE_QWEN_MODELS = ["qwen-vl-max", "qwen-vl-plus"]
|
||||
AVAILABLE_QWEN_MODELS = ["qwen-vl-max", "qwen-vl-plus","qwen3.5-plus"]
|
||||
|
||||
# 清理配置 (Cleanup Config)
|
||||
CLEANUP_CONFIG = {
|
||||
@@ -287,6 +289,35 @@ def append_to_history(req_type: str, prompt: str, status: str, result_path: str
|
||||
print(f"Failed to write history: {e}")
|
||||
|
||||
|
||||
def extract_json_from_response(text: str) -> dict:
|
||||
"""
|
||||
Robustly extract JSON from text, handling:
|
||||
1. Markdown code blocks (```json ... ```)
|
||||
2. Single quotes (Python dict style) via ast.literal_eval
|
||||
"""
|
||||
try:
|
||||
# 1. Try to find JSON block
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
|
||||
if json_match:
|
||||
clean_text = json_match.group(1).strip()
|
||||
else:
|
||||
# Try to find { ... } block if no markdown
|
||||
match = re.search(r'\{.*\}', text, re.DOTALL)
|
||||
if match:
|
||||
clean_text = match.group(0).strip()
|
||||
else:
|
||||
clean_text = text.strip()
|
||||
|
||||
# 2. Try standard JSON
|
||||
return json.loads(clean_text)
|
||||
except Exception as e1:
|
||||
# 3. Try ast.literal_eval for single quotes
|
||||
try:
|
||||
return ast.literal_eval(clean_text)
|
||||
except Exception as e2:
|
||||
# 4. Fail
|
||||
raise ValueError(f"Could not parse JSON: {e1} | {e2} | Content: {text[:100]}...")
|
||||
|
||||
def translate_to_sam3_prompt(text: str) -> str:
|
||||
"""
|
||||
使用 Qwen 模型将中文提示词翻译为英文
|
||||
@@ -566,13 +597,13 @@ def recognize_card_with_qwen(image_path: str) -> dict:
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.output.choices[0].message.content[0]['text']
|
||||
import json
|
||||
try:
|
||||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||
result = json.loads(clean_content)
|
||||
result = extract_json_from_response(content)
|
||||
result["model_used"] = QWEN_MODEL
|
||||
return result
|
||||
except:
|
||||
return {"raw_response": content}
|
||||
except Exception as e:
|
||||
print(f"JSON Parse Error in recognize_card: {e}")
|
||||
return {"raw_response": content, "error": str(e), "model_used": QWEN_MODEL}
|
||||
else:
|
||||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||||
|
||||
@@ -601,13 +632,13 @@ def recognize_spread_with_qwen(image_path: str) -> dict:
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.output.choices[0].message.content[0]['text']
|
||||
import json
|
||||
try:
|
||||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||
result = json.loads(clean_content)
|
||||
result = extract_json_from_response(content)
|
||||
result["model_used"] = QWEN_MODEL
|
||||
return result
|
||||
except:
|
||||
return {"raw_response": content, "spread_name": "Unknown"}
|
||||
except Exception as e:
|
||||
print(f"JSON Parse Error in recognize_spread: {e}")
|
||||
return {"raw_response": content, "error": str(e), "spread_name": "Unknown", "model_used": QWEN_MODEL}
|
||||
else:
|
||||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||||
|
||||
@@ -950,6 +981,10 @@ async def recognize_tarot(
|
||||
processor = request.app.state.processor
|
||||
|
||||
try:
|
||||
# 在执行 GPU 操作前,切换到线程中运行,避免阻塞主线程(虽然 SAM3 推理在 CPU 上可能已经很快,但为了保险)
|
||||
# 注意:processor 内部调用了 torch,如果是在 GPU 上,最好不要多线程调用同一个 model
|
||||
# 但这里只是推理,且是单次请求。
|
||||
# 如果是 CPU 推理,run_in_executor 有助于防止阻塞 loop
|
||||
inference_state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
@@ -974,15 +1009,25 @@ async def recognize_tarot(
|
||||
main_file_path = None
|
||||
main_file_url = None
|
||||
|
||||
# Step 0: 牌阵识别
|
||||
# Step 0: 牌阵识别 (异步启动)
|
||||
spread_info = {"spread_name": "Unknown"}
|
||||
spread_task = None
|
||||
if main_file_path:
|
||||
# 使用原始图的一份拷贝给 Qwen 识别牌阵
|
||||
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
|
||||
image.save(temp_raw_path)
|
||||
spread_info = recognize_spread_with_qwen(temp_raw_path)
|
||||
|
||||
# 将同步调用包装为异步任务
|
||||
loop = asyncio.get_event_loop()
|
||||
spread_task = loop.run_in_executor(None, recognize_spread_with_qwen, temp_raw_path)
|
||||
|
||||
if detected_count != expected_count:
|
||||
# 如果数量不对,等待牌阵识别完成(如果已启动)再返回
|
||||
if spread_task:
|
||||
try:
|
||||
spread_info = await spread_task
|
||||
except Exception as e:
|
||||
print(f"Spread recognition failed: {e}")
|
||||
|
||||
duration = time.time() - start_time
|
||||
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", result_path=f"results/{request_id}/{main_filename}" if main_file_url else None, details=f"Detected {detected_count}, expected {expected_count}", duration=duration)
|
||||
return JSONResponse(
|
||||
@@ -1004,21 +1049,47 @@ async def recognize_tarot(
|
||||
append_to_history("tarot-recognize", f"expected: {expected_count}", "failed", details=f"Crop Error: {str(e)}", duration=duration)
|
||||
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
|
||||
|
||||
# 遍历每张卡片进行识别
|
||||
# 遍历每张卡片进行识别 (并发)
|
||||
tarot_cards = []
|
||||
|
||||
# 1. 准备任务列表
|
||||
loop = asyncio.get_event_loop()
|
||||
card_tasks = []
|
||||
|
||||
for obj in saved_objects:
|
||||
fname = obj["filename"]
|
||||
file_path = os.path.join(output_dir, fname)
|
||||
|
||||
# 调用 Qwen-VL 识别 (串行)
|
||||
recognition_res = recognize_card_with_qwen(file_path)
|
||||
|
||||
# 创建异步任务
|
||||
# 使用 lambda 来延迟调用,确保参数传递正确
|
||||
task = loop.run_in_executor(None, recognize_card_with_qwen, file_path)
|
||||
card_tasks.append(task)
|
||||
|
||||
# 2. 等待所有卡片识别任务完成
|
||||
# 同时等待牌阵识别任务 (如果还在运行)
|
||||
if card_tasks:
|
||||
all_card_results = await asyncio.gather(*card_tasks)
|
||||
else:
|
||||
all_card_results = []
|
||||
|
||||
if spread_task:
|
||||
try:
|
||||
# 如果之前没有await spread_task,这里确保它完成
|
||||
# 注意:如果 detected_count != expected_count 分支已经 await 过了,这里不会重复执行
|
||||
# 但那个分支有 return,所以这里肯定是还没 await 的
|
||||
spread_info = await spread_task
|
||||
except Exception as e:
|
||||
print(f"Spread recognition failed: {e}")
|
||||
|
||||
# 3. 组装结果
|
||||
for i, obj in enumerate(saved_objects):
|
||||
fname = obj["filename"]
|
||||
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||||
|
||||
tarot_cards.append({
|
||||
"url": file_url,
|
||||
"is_rotated": obj["is_rotated_by_algorithm"],
|
||||
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
|
||||
"recognition": recognition_res,
|
||||
"recognition": all_card_results[i],
|
||||
"note": obj["note"]
|
||||
})
|
||||
|
||||
@@ -1082,14 +1153,26 @@ async def segment_face(
|
||||
|
||||
# 调用独立服务进行处理
|
||||
try:
|
||||
result = human_analysis_service.process_face_segmentation_and_analysis(
|
||||
processor=processor,
|
||||
image=image,
|
||||
prompt=final_prompt,
|
||||
output_base_dir=RESULT_IMAGE_DIR,
|
||||
qwen_model=QWEN_MODEL,
|
||||
analysis_prompt=PROMPTS["face_analysis"]
|
||||
)
|
||||
# 使用新增加的异步并发函数
|
||||
if hasattr(human_analysis_service, "process_face_segmentation_and_analysis_async"):
|
||||
result = await human_analysis_service.process_face_segmentation_and_analysis_async(
|
||||
processor=processor,
|
||||
image=image,
|
||||
prompt=final_prompt,
|
||||
output_base_dir=RESULT_IMAGE_DIR,
|
||||
qwen_model=QWEN_MODEL,
|
||||
analysis_prompt=PROMPTS["face_analysis"]
|
||||
)
|
||||
else:
|
||||
# 回退到同步
|
||||
result = human_analysis_service.process_face_segmentation_and_analysis(
|
||||
processor=processor,
|
||||
image=image,
|
||||
prompt=final_prompt,
|
||||
output_base_dir=RESULT_IMAGE_DIR,
|
||||
qwen_model=QWEN_MODEL,
|
||||
analysis_prompt=PROMPTS["face_analysis"]
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
@@ -1287,12 +1370,38 @@ async def get_config(request: Request):
|
||||
"""
|
||||
Get system config info
|
||||
"""
|
||||
device = "Unknown"
|
||||
device_str = "Unknown"
|
||||
gpu_status = {}
|
||||
|
||||
if hasattr(request.app.state, "device"):
|
||||
device = str(request.app.state.device)
|
||||
device_str = str(request.app.state.device)
|
||||
|
||||
# 获取 GPU 详细信息
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
device_id = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_id)
|
||||
|
||||
total_mem = props.total_memory
|
||||
reserved_mem = torch.cuda.memory_reserved(device_id)
|
||||
allocated_mem = torch.cuda.memory_allocated(device_id)
|
||||
|
||||
gpu_status = {
|
||||
"available": True,
|
||||
"name": props.name,
|
||||
"total_memory": f"{total_mem / 1024**3:.2f} GB",
|
||||
"reserved_memory": f"{reserved_mem / 1024**3:.2f} GB",
|
||||
"allocated_memory": f"{allocated_mem / 1024**3:.2f} GB",
|
||||
"memory_usage_percent": round((reserved_mem / total_mem) * 100, 1)
|
||||
}
|
||||
except Exception as e:
|
||||
gpu_status = {"available": True, "error": str(e)}
|
||||
else:
|
||||
gpu_status = {"available": False, "reason": "No CUDA device detected"}
|
||||
|
||||
return {
|
||||
"device": device,
|
||||
"device": device_str,
|
||||
"gpu_status": gpu_status,
|
||||
"cleanup_config": CLEANUP_CONFIG,
|
||||
"current_qwen_model": QWEN_MODEL,
|
||||
"available_qwen_models": AVAILABLE_QWEN_MODELS
|
||||
@@ -1348,6 +1457,93 @@ async def update_prompts(
|
||||
PROMPTS[key] = content
|
||||
return {"status": "success", "message": f"Prompt '{key}' updated"}
|
||||
|
||||
# ------------------------------------------
|
||||
# GPU Status Helper & API
|
||||
# ------------------------------------------
|
||||
|
||||
def get_gpu_status_smi():
|
||||
"""
|
||||
Get detailed GPU status using nvidia-smi
|
||||
Returns: dict with utilization, memory, temp, power, etc.
|
||||
"""
|
||||
cuda_version = "Unknown"
|
||||
try:
|
||||
import torch
|
||||
if torch.version.cuda:
|
||||
cuda_version = torch.version.cuda
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
# Check if nvidia-smi is available
|
||||
# Fields: utilization.gpu, utilization.memory, temperature.gpu, power.draw, power.limit, memory.total, memory.used, memory.free, name, driver_version
|
||||
result = subprocess.run(
|
||||
['nvidia-smi', '--query-gpu=utilization.gpu,utilization.memory,temperature.gpu,power.draw,power.limit,memory.total,memory.used,memory.free,name,driver_version', '--format=csv,noheader,nounits'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
encoding='utf-8'
|
||||
)
|
||||
|
||||
if result.returncode != 0:
|
||||
raise Exception("nvidia-smi failed")
|
||||
|
||||
# Parse the first line (assuming single GPU for now, or take the first one)
|
||||
line = result.stdout.strip().split('\n')[0]
|
||||
vals = [x.strip() for x in line.split(',')]
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"gpu_util": float(vals[0]), # %
|
||||
"mem_util": float(vals[1]), # % (controller utilization)
|
||||
"temperature": float(vals[2]), # C
|
||||
"power_draw": float(vals[3]), # W
|
||||
"power_limit": float(vals[4]), # W
|
||||
"mem_total": float(vals[5]), # MB
|
||||
"mem_used": float(vals[6]), # MB
|
||||
"mem_free": float(vals[7]), # MB
|
||||
"name": vals[8],
|
||||
"driver_version": vals[9],
|
||||
"cuda_version": cuda_version,
|
||||
"source": "nvidia-smi",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
except Exception as e:
|
||||
# Fallback to torch if available
|
||||
if torch.cuda.is_available():
|
||||
try:
|
||||
device_id = torch.cuda.current_device()
|
||||
props = torch.cuda.get_device_properties(device_id)
|
||||
mem_reserved = torch.cuda.memory_reserved(device_id) / 1024**2 # MB
|
||||
mem_total = props.total_memory / 1024**2 # MB
|
||||
|
||||
return {
|
||||
"available": True,
|
||||
"gpu_util": 0, # Torch can't get this easily
|
||||
"mem_util": (mem_reserved / mem_total) * 100,
|
||||
"temperature": 0,
|
||||
"power_draw": 0,
|
||||
"power_limit": 0,
|
||||
"mem_total": mem_total,
|
||||
"mem_used": mem_reserved,
|
||||
"mem_free": mem_total - mem_reserved,
|
||||
"name": props.name,
|
||||
"driver_version": "Unknown",
|
||||
"cuda_version": cuda_version,
|
||||
"source": "torch",
|
||||
"timestamp": time.time()
|
||||
}
|
||||
except:
|
||||
pass
|
||||
|
||||
return {"available": False, "error": str(e)}
|
||||
|
||||
@app.get("/admin/api/gpu/status", dependencies=[Depends(verify_admin)])
|
||||
async def get_gpu_status_api():
|
||||
"""
|
||||
Get real-time GPU status
|
||||
"""
|
||||
return get_gpu_status_smi()
|
||||
|
||||
# ==========================================
|
||||
# 10. Main Entry Point (启动入口)
|
||||
# ==========================================
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
{"timestamp": 1771347621.2198663, "type": "general", "prompt": "正面的麻将牌", "final_prompt": "Front-facing mahjong tile", "status": "success", "result_path": "results/seg_72b3c186467d48bf8591c9699ce90ca7.jpg", "details": "Detected: 13", "duration": 2.699465274810791}
|
||||
|
||||
@@ -6,6 +6,8 @@ import numpy as np
|
||||
import json
|
||||
import torch
|
||||
import cv2
|
||||
import ast
|
||||
import re
|
||||
from PIL import Image
|
||||
from dashscope import MultiModalConversation
|
||||
|
||||
@@ -95,6 +97,35 @@ def create_highlighted_visualization(image: Image.Image, masks, output_path: str
|
||||
# Save
|
||||
Image.fromarray(result_np).save(output_path)
|
||||
|
||||
def extract_json_from_response(text: str) -> dict:
|
||||
"""
|
||||
Robustly extract JSON from text, handling:
|
||||
1. Markdown code blocks (```json ... ```)
|
||||
2. Single quotes (Python dict style) via ast.literal_eval
|
||||
"""
|
||||
try:
|
||||
# 1. Try to find JSON block
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
|
||||
if json_match:
|
||||
clean_text = json_match.group(1).strip()
|
||||
else:
|
||||
# Try to find { ... } block if no markdown
|
||||
match = re.search(r'\{.*\}', text, re.DOTALL)
|
||||
if match:
|
||||
clean_text = match.group(0).strip()
|
||||
else:
|
||||
clean_text = text.strip()
|
||||
|
||||
# 2. Try standard JSON
|
||||
return json.loads(clean_text)
|
||||
except Exception as e1:
|
||||
# 3. Try ast.literal_eval for single quotes
|
||||
try:
|
||||
return ast.literal_eval(clean_text)
|
||||
except Exception as e2:
|
||||
# 4. Fail
|
||||
raise ValueError(f"Could not parse JSON: {e1} | {e2} | Content: {text[:100]}...")
|
||||
|
||||
def analyze_demographics_with_qwen(image_path: str, model_name: str = 'qwen-vl-max', prompt_template: str = None) -> dict:
|
||||
"""
|
||||
调用 Qwen-VL 模型分析人物的年龄和性别
|
||||
@@ -131,19 +162,21 @@ def analyze_demographics_with_qwen(image_path: str, model_name: str = 'qwen-vl-m
|
||||
|
||||
if response.status_code == 200:
|
||||
content = response.output.choices[0].message.content[0]['text']
|
||||
# 清理 Markdown 代码块标记
|
||||
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||
try:
|
||||
result = json.loads(clean_content)
|
||||
result = extract_json_from_response(content)
|
||||
result["model_used"] = model_name
|
||||
return result
|
||||
except json.JSONDecodeError:
|
||||
return {"raw_analysis": clean_content}
|
||||
except Exception as e:
|
||||
print(f"JSON Parse Error in face analysis: {e}")
|
||||
return {"raw_analysis": content, "error": str(e), "model_used": model_name}
|
||||
else:
|
||||
return {"error": f"API Error: {response.code} - {response.message}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"分析失败: {str(e)}"}
|
||||
|
||||
import asyncio
|
||||
|
||||
def process_face_segmentation_and_analysis(
|
||||
processor,
|
||||
image: Image.Image,
|
||||
@@ -156,11 +189,11 @@ def process_face_segmentation_and_analysis(
|
||||
核心处理逻辑:
|
||||
1. SAM3 分割 (默认提示词 "head" 以包含头发)
|
||||
2. 裁剪图片
|
||||
3. Qwen-VL 识别性别年龄
|
||||
3. Qwen-VL 识别性别年龄 (并发)
|
||||
4. 返回结果
|
||||
"""
|
||||
|
||||
# 1. SAM3 推理
|
||||
# 1. SAM3 推理 (同步,因为涉及 GPU 操作)
|
||||
inference_state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
@@ -179,7 +212,7 @@ def process_face_segmentation_and_analysis(
|
||||
output_dir = os.path.join(output_base_dir, request_id)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# --- 新增:生成背景变暗的整体可视化图 ---
|
||||
# --- 生成可视化图 ---
|
||||
vis_filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||
vis_path = os.path.join(output_dir, vis_filename)
|
||||
try:
|
||||
@@ -188,38 +221,238 @@ def process_face_segmentation_and_analysis(
|
||||
except Exception as e:
|
||||
print(f"可视化生成失败: {e}")
|
||||
full_vis_relative_path = None
|
||||
# -------------------------------------
|
||||
# ------------------
|
||||
|
||||
results = []
|
||||
|
||||
# 转换 boxes 为 numpy
|
||||
# 转换 boxes 和 scores
|
||||
if isinstance(boxes, torch.Tensor):
|
||||
boxes_np = boxes.cpu().numpy()
|
||||
else:
|
||||
boxes_np = boxes
|
||||
|
||||
# 转换 scores 为 list
|
||||
if isinstance(scores, torch.Tensor):
|
||||
scores_list = scores.tolist()
|
||||
else:
|
||||
scores_list = scores if isinstance(scores, list) else [float(scores)]
|
||||
|
||||
for i, box in enumerate(boxes_np):
|
||||
# 2. 裁剪 (带一点 padding 以保留完整发型)
|
||||
# 2. 裁剪 (带一点 padding 以保留完整发型)
|
||||
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
|
||||
# 准备异步任务
|
||||
async def run_analysis_tasks():
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = []
|
||||
temp_results = [] # 存储 (index, filename, score) 以便后续排序组合
|
||||
|
||||
for i, box in enumerate(boxes_np):
|
||||
# 2. 裁剪 (同步)
|
||||
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
|
||||
filename = f"face_{i}.jpg"
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
cropped_img.save(save_path)
|
||||
|
||||
# 3. 准备识别任务
|
||||
task = loop.run_in_executor(
|
||||
None,
|
||||
analyze_demographics_with_qwen,
|
||||
save_path,
|
||||
qwen_model,
|
||||
analysis_prompt
|
||||
)
|
||||
tasks.append(task)
|
||||
temp_results.append({
|
||||
"filename": filename,
|
||||
"relative_path": f"results/{request_id}/{filename}",
|
||||
"score": float(scores_list[i]) if i < len(scores_list) else 0.0
|
||||
})
|
||||
|
||||
# 等待所有任务完成
|
||||
if tasks:
|
||||
analysis_results = await asyncio.gather(*tasks)
|
||||
else:
|
||||
analysis_results = []
|
||||
|
||||
# 组合结果
|
||||
final_results = []
|
||||
for i, item in enumerate(temp_results):
|
||||
item["analysis"] = analysis_results[i]
|
||||
final_results.append(item)
|
||||
|
||||
return final_results
|
||||
|
||||
# 运行异步任务
|
||||
# 注意:由于本函数被 FastAPI (异步环境) 中的同步或异步函数调用,
|
||||
# 如果上层是 async def,我们可以直接 await。
|
||||
# 但由于这个函数定义没有 async,且之前的调用是同步的,
|
||||
# 为了兼容性,我们需要检查当前是否在事件循环中。
|
||||
|
||||
# 然而,查看 fastAPI_tarot.py,这个函数是在 async def segment_face 中被调用的。
|
||||
# 但它是作为普通函数被导入和调用的。
|
||||
# 为了不破坏现有签名,我们可以使用 asyncio.run() 或者在新循环中运行,
|
||||
# 但这在已经运行的 loop 中是不允许的。
|
||||
|
||||
# 最佳方案:修改本函数为 async,并在 fastAPI_tarot.py 中 await 它。
|
||||
# 但这需要修改 fastAPI_tarot.py 的调用处。
|
||||
|
||||
# 既然我们已经修改了 fastAPI_tarot.py,我们也可以顺便修改这里的签名。
|
||||
# 但为了稳妥,我们可以用一种 hack:
|
||||
# 如果在一个正在运行的 loop 中调用,我们必须返回 awaitable 或者使用 loop.run_until_complete (会报错)
|
||||
|
||||
# 让我们先把这个函数改成 async,然后去修改 fastAPI_tarot.py 的调用。
|
||||
# 这是最正确的做法。
|
||||
pass # 占位,实际代码在下面
|
||||
|
||||
async def process_face_segmentation_and_analysis_async(
|
||||
processor,
|
||||
image: Image.Image,
|
||||
prompt: str = "head",
|
||||
output_base_dir: str = "static/results",
|
||||
qwen_model: str = "qwen-vl-max",
|
||||
analysis_prompt: str = None
|
||||
) -> dict:
|
||||
# ... (同上逻辑,只是是 async)
|
||||
|
||||
# 1. SAM3 推理
|
||||
inference_state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
|
||||
detected_count = len(masks)
|
||||
if detected_count == 0:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "未检测到目标",
|
||||
"detected_count": 0,
|
||||
"results": []
|
||||
}
|
||||
|
||||
# 保存裁剪图
|
||||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
output_dir = os.path.join(output_base_dir, request_id)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
vis_filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||
vis_path = os.path.join(output_dir, vis_filename)
|
||||
try:
|
||||
create_highlighted_visualization(image, masks, vis_path)
|
||||
full_vis_relative_path = f"results/{request_id}/{vis_filename}"
|
||||
except Exception as e:
|
||||
print(f"可视化生成失败: {e}")
|
||||
full_vis_relative_path = None
|
||||
|
||||
if isinstance(boxes, torch.Tensor):
|
||||
boxes_np = boxes.cpu().numpy()
|
||||
else:
|
||||
boxes_np = boxes
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
scores_list = scores.tolist()
|
||||
else:
|
||||
scores_list = scores if isinstance(scores, list) else [float(scores)]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
tasks = []
|
||||
results = []
|
||||
|
||||
for i, box in enumerate(boxes_np):
|
||||
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
|
||||
filename = f"face_{i}.jpg"
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
cropped_img.save(save_path)
|
||||
|
||||
# 3. 识别
|
||||
task = loop.run_in_executor(
|
||||
None,
|
||||
analyze_demographics_with_qwen,
|
||||
save_path,
|
||||
qwen_model,
|
||||
analysis_prompt
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
results.append({
|
||||
"filename": filename,
|
||||
"relative_path": f"results/{request_id}/{filename}",
|
||||
"score": float(scores_list[i]) if i < len(scores_list) else 0.0
|
||||
})
|
||||
|
||||
if tasks:
|
||||
analysis_results = await asyncio.gather(*tasks)
|
||||
else:
|
||||
analysis_results = []
|
||||
|
||||
for i, item in enumerate(results):
|
||||
item["analysis"] = analysis_results[i]
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"成功检测并分析 {detected_count} 个人脸",
|
||||
"detected_count": detected_count,
|
||||
"request_id": request_id,
|
||||
"full_visualization": full_vis_relative_path,
|
||||
"scores": scores_list,
|
||||
"results": results
|
||||
}
|
||||
|
||||
# 保留旧的同步接口以兼容其他潜在调用者,但内部实现可能会有问题如果它在 loop 中运行
|
||||
# 既然我们主要关注 fastAPI_tarot.py,我们可以直接替换 process_face_segmentation_and_analysis
|
||||
# 或者让它只是一个 wrapper
|
||||
def process_face_segmentation_and_analysis(
|
||||
processor,
|
||||
image: Image.Image,
|
||||
prompt: str = "head",
|
||||
output_base_dir: str = "static/results",
|
||||
qwen_model: str = "qwen-vl-max",
|
||||
analysis_prompt: str = None
|
||||
) -> dict:
|
||||
"""
|
||||
同步版本 (保留以兼容)
|
||||
注意:如果在 async loop 中调用此函数,且此函数内部没有异步操作,则会阻塞 loop。
|
||||
如果需要异步并发,请使用 process_face_segmentation_and_analysis_async
|
||||
"""
|
||||
# 这里我们简单地复用逻辑,但去除异步部分,退化为串行
|
||||
|
||||
# 1. SAM3 推理
|
||||
inference_state = processor.set_image(image)
|
||||
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||
|
||||
detected_count = len(masks)
|
||||
if detected_count == 0:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "未检测到目标",
|
||||
"detected_count": 0,
|
||||
"results": []
|
||||
}
|
||||
|
||||
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
output_dir = os.path.join(output_base_dir, request_id)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
vis_filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||
vis_path = os.path.join(output_dir, vis_filename)
|
||||
try:
|
||||
create_highlighted_visualization(image, masks, vis_path)
|
||||
full_vis_relative_path = f"results/{request_id}/{vis_filename}"
|
||||
except Exception as e:
|
||||
print(f"可视化生成失败: {e}")
|
||||
full_vis_relative_path = None
|
||||
|
||||
if isinstance(boxes, torch.Tensor):
|
||||
boxes_np = boxes.cpu().numpy()
|
||||
else:
|
||||
boxes_np = boxes
|
||||
|
||||
if isinstance(scores, torch.Tensor):
|
||||
scores_list = scores.tolist()
|
||||
else:
|
||||
scores_list = scores if isinstance(scores, list) else [float(scores)]
|
||||
|
||||
results = []
|
||||
for i, box in enumerate(boxes_np):
|
||||
cropped_img = crop_head_with_padding(image, box, padding_ratio=0.1)
|
||||
filename = f"face_{i}.jpg"
|
||||
save_path = os.path.join(output_dir, filename)
|
||||
cropped_img.save(save_path)
|
||||
|
||||
# 同步调用
|
||||
analysis = analyze_demographics_with_qwen(save_path, model_name=qwen_model, prompt_template=analysis_prompt)
|
||||
|
||||
# 构造返回结果
|
||||
# 注意:URL 生成需要依赖外部的 request context,这里只返回相对路径或文件名
|
||||
# 由调用方组装完整 URL
|
||||
results.append({
|
||||
"filename": filename,
|
||||
"relative_path": f"results/{request_id}/{filename}",
|
||||
@@ -232,7 +465,8 @@ def process_face_segmentation_and_analysis(
|
||||
"message": f"成功检测并分析 {detected_count} 个人脸",
|
||||
"detected_count": detected_count,
|
||||
"request_id": request_id,
|
||||
"full_visualization": full_vis_relative_path, # 返回相对路径
|
||||
"scores": scores_list, # 返回全部分数
|
||||
"full_visualization": full_vis_relative_path,
|
||||
"scores": scores_list,
|
||||
"results": results
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
1411
static/admin.html
1411
static/admin.html
File diff suppressed because it is too large
Load Diff
Binary file not shown.
|
Before Width: | Height: | Size: 78 KiB |
Reference in New Issue
Block a user