FastAPI
This commit is contained in:
159
fastAPI_main.py
Normal file
159
fastAPI_main.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
import requests
|
||||||
|
from typing import Optional
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import matplotlib
|
||||||
|
# 关键:设置非交互式后端,避免服务器环境下报错
|
||||||
|
matplotlib.use('Agg')
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request
|
||||||
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
# SAM3 相关导入 (请确保你的环境中已正确安装 sam3)
|
||||||
|
from sam3.model_builder import build_sam3_image_model
|
||||||
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
from sam3.visualization_utils import plot_results
|
||||||
|
|
||||||
|
# ------------------- 配置与路径 -------------------
|
||||||
|
STATIC_DIR = "static"
|
||||||
|
RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results")
|
||||||
|
os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# ------------------- 生命周期管理 -------------------
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""
|
||||||
|
FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源
|
||||||
|
"""
|
||||||
|
print("正在加载 SAM3 模型到 GPU...")
|
||||||
|
|
||||||
|
# 1. 检测设备
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。")
|
||||||
|
|
||||||
|
# 2. 加载模型 (全局单例)
|
||||||
|
model = build_sam3_image_model()
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval() # 切换到评估模式
|
||||||
|
|
||||||
|
# 3. 初始化 Processor
|
||||||
|
processor = Sam3Processor(model)
|
||||||
|
|
||||||
|
# 4. 存入 app.state 供全局访问
|
||||||
|
app.state.model = model
|
||||||
|
app.state.processor = processor
|
||||||
|
app.state.device = device
|
||||||
|
|
||||||
|
print(f"模型加载完成,设备: {device}")
|
||||||
|
|
||||||
|
yield # 服务运行中...
|
||||||
|
|
||||||
|
# 清理资源 (如果需要)
|
||||||
|
print("正在清理资源...")
|
||||||
|
|
||||||
|
# ------------------- FastAPI 初始化 -------------------
|
||||||
|
app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API")
|
||||||
|
|
||||||
|
# 挂载静态文件目录,用于通过 URL 访问生成的图片
|
||||||
|
app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static")
|
||||||
|
|
||||||
|
# ------------------- 辅助函数 -------------------
|
||||||
|
def load_image_from_url(url: str) -> Image.Image:
|
||||||
|
"""从网络 URL 下载图片"""
|
||||||
|
try:
|
||||||
|
headers = {'User-Agent': 'Mozilla/5.0'}
|
||||||
|
response = requests.get(url, headers=headers, stream=True, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
image = Image.open(response.raw).convert("RGB")
|
||||||
|
return image
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}")
|
||||||
|
|
||||||
|
def generate_and_save_result(image: Image.Image, inference_state) -> str:
|
||||||
|
"""生成可视化结果图并保存,返回文件名"""
|
||||||
|
# 生成唯一文件名防止冲突
|
||||||
|
filename = f"seg_{uuid.uuid4().hex}.jpg"
|
||||||
|
save_path = os.path.join(RESULT_IMAGE_DIR, filename)
|
||||||
|
|
||||||
|
# 绘图 (复用你提供的逻辑)
|
||||||
|
plot_results(image, inference_state)
|
||||||
|
|
||||||
|
# 保存
|
||||||
|
plt.savefig(save_path, dpi=150, bbox_inches='tight')
|
||||||
|
plt.close() # 务必关闭,防止内存泄漏
|
||||||
|
|
||||||
|
return filename
|
||||||
|
|
||||||
|
# ------------------- API 接口 -------------------
|
||||||
|
@app.post("/segment")
|
||||||
|
async def segment(
|
||||||
|
request: Request,
|
||||||
|
prompt: str = Form(...),
|
||||||
|
file: Optional[UploadFile] = File(None),
|
||||||
|
image_url: Optional[str] = Form(None)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。
|
||||||
|
"""
|
||||||
|
|
||||||
|
# 1. 校验输入
|
||||||
|
if not file and not image_url:
|
||||||
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
|
|
||||||
|
# 2. 获取图片对象
|
||||||
|
try:
|
||||||
|
if file:
|
||||||
|
image = Image.open(file.file).convert("RGB")
|
||||||
|
elif image_url:
|
||||||
|
image = load_image_from_url(image_url)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||||||
|
|
||||||
|
# 3. 获取模型
|
||||||
|
processor = request.app.state.processor
|
||||||
|
|
||||||
|
# 4. 执行推理
|
||||||
|
try:
|
||||||
|
# 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt=prompt)
|
||||||
|
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||||
|
|
||||||
|
# 5. 生成可视化并保存
|
||||||
|
try:
|
||||||
|
filename = generate_and_save_result(image, inference_state)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}")
|
||||||
|
|
||||||
|
# 6. 构建返回 URL
|
||||||
|
# request.url_for 会自动根据当前域名生成正确的访问链接
|
||||||
|
file_url = request.url_for("static", path=f"results/{filename}")
|
||||||
|
|
||||||
|
return JSONResponse(content={
|
||||||
|
"status": "success",
|
||||||
|
"result_image_url": str(file_url),
|
||||||
|
"detected_count": len(masks),
|
||||||
|
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||||
|
})
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import uvicorn
|
||||||
|
|
||||||
|
# 使用 Python 函数参数的方式传递配置
|
||||||
|
uvicorn.run(
|
||||||
|
"fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=55600,
|
||||||
|
proxy_headers=True, # 对应 --proxy-headers
|
||||||
|
forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*"
|
||||||
|
)
|
||||||
3
requirement.txt
Normal file
3
requirement.txt
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
uvicorn
|
||||||
|
python-multipart
|
||||||
|
fastapi
|
||||||
@@ -875,6 +875,8 @@ def plot_results(img, results):
|
|||||||
relative_coords=False,
|
relative_coords=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
def single_visualization(img, anns, title):
|
def single_visualization(img, anns, title):
|
||||||
"""
|
"""
|
||||||
|
|||||||
BIN
sam3_image_food_result.jpg
Normal file
BIN
sam3_image_food_result.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 84 KiB |
190
test.py
190
test.py
@@ -1,39 +1,165 @@
|
|||||||
import torch
|
import torch
|
||||||
#################################### For Image ####################################
|
import matplotlib.pyplot as plt
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from sam3.model_builder import build_sam3_image_model
|
# 只保留SAM3实际存在的核心模块
|
||||||
|
from sam3.model_builder import build_sam3_image_model, build_sam3_video_predictor
|
||||||
from sam3.model.sam3_image_processor import Sam3Processor
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
# Load the model
|
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
|
||||||
model = build_sam3_image_model()
|
|
||||||
processor = Sam3Processor(model)
|
|
||||||
# Load an image
|
|
||||||
image = Image.open("/home/quant/data/dev/sam3-main/assets/player.gif")
|
|
||||||
inference_state = processor.set_image(image)
|
|
||||||
# Prompt the model with text
|
|
||||||
output = processor.set_text_prompt(state=inference_state, prompt="pepole")
|
|
||||||
|
|
||||||
# Get the masks, bounding boxes, and scores
|
# ==================== 显存优化配置 ====================
|
||||||
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
||||||
|
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
#################################### For Video ####################################
|
# 通用视频帧读取函数
|
||||||
|
def get_video_frames(video_path):
|
||||||
|
frame_paths = []
|
||||||
|
if os.path.isfile(video_path) and video_path.endswith(('.mp4', '.avi', '.mov')):
|
||||||
|
total_frames = int(cv2.VideoCapture(video_path).get(cv2.CAP_PROP_FRAME_COUNT))
|
||||||
|
return frame_paths, total_frames
|
||||||
|
elif os.path.isdir(video_path):
|
||||||
|
frame_files = sorted([f for f in os.listdir(video_path) if f.endswith(('.jpg', '.png'))])
|
||||||
|
frame_paths = [os.path.join(video_path, f) for f in frame_files]
|
||||||
|
total_frames = len(frame_paths)
|
||||||
|
return frame_paths, total_frames
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的视频路径:{video_path}")
|
||||||
|
|
||||||
# from sam3.model_builder import build_sam3_video_predictor
|
# 缩小视频帧分辨率(减少显存占用)
|
||||||
|
def resize_frame(frame, max_side=640):
|
||||||
|
h, w = frame.shape[:2]
|
||||||
|
scale = max_side / max(h, w)
|
||||||
|
if scale < 1:
|
||||||
|
new_h, new_w = int(h * scale), int(w * scale)
|
||||||
|
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
||||||
|
return frame
|
||||||
|
|
||||||
# video_predictor = build_sam3_video_predictor()
|
def sam3_video_inference_low_memory(
|
||||||
# video_path = "<YOUR_VIDEO_PATH>" # a JPEG folder or an MP4 video file
|
video_path,
|
||||||
# # Start a session
|
text_prompt="", # 空提示词:让SAM3检测所有目标(最易检出)
|
||||||
# response = video_predictor.handle_request(
|
save_dir="./sam3_video_results",
|
||||||
# request=dict(
|
max_frames=5,
|
||||||
# type="start_session",
|
max_frame_side=640
|
||||||
# resource_path=video_path,
|
):
|
||||||
# )
|
"""
|
||||||
# )
|
修复end_session + 提升检出率:适配你的SAM3版本
|
||||||
# response = video_predictor.handle_request(
|
"""
|
||||||
# request=dict(
|
# 1. 初始化SAM3视频预测器
|
||||||
# type="add_prompt",
|
video_predictor = build_sam3_video_predictor()
|
||||||
# session_id=response["session_id"],
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
# frame_index=0, # Arbitrary frame index
|
|
||||||
# text="<YOUR_TEXT_PROMPT>",
|
# 2. 启动SAM3会话(移除自定义model_config,避免和SAM3内部冲突)
|
||||||
# )
|
session_response = video_predictor.handle_request(
|
||||||
# )
|
request=dict(
|
||||||
# output = response["outputs"]
|
type="start_session",
|
||||||
|
resource_path=video_path
|
||||||
|
# 移除model_config:你的SAM3版本会强制覆盖为max_num_objects=10000
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session_id = session_response["session_id"]
|
||||||
|
print(f"[低显存模式] 会话启动成功,ID: {session_id}")
|
||||||
|
|
||||||
|
# 3. 读取视频帧信息
|
||||||
|
frame_paths, total_frames = get_video_frames(video_path)
|
||||||
|
total_frames = min(total_frames, max_frames)
|
||||||
|
print(f"[低显存模式] 视频总帧数:{total_frames},本次处理前{total_frames}帧")
|
||||||
|
|
||||||
|
# 4. 首帧推理(核心优化:降低阈值+空提示词,提升检出率)
|
||||||
|
first_frame_idx = 0
|
||||||
|
prompt_response = video_predictor.handle_request(
|
||||||
|
request=dict(
|
||||||
|
type="add_prompt",
|
||||||
|
session_id=session_id,
|
||||||
|
frame_index=first_frame_idx,
|
||||||
|
text=text_prompt,
|
||||||
|
# 极低阈值:强制检出所有可能目标(从0.3→0.1)
|
||||||
|
prompt_config=dict(
|
||||||
|
box_threshold=0.1,
|
||||||
|
mask_threshold=0.1
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
first_output = prompt_response["outputs"]
|
||||||
|
first_boxes = first_output.get("boxes", [])
|
||||||
|
first_scores = first_output.get("scores", [])
|
||||||
|
print(f"[低显存模式] 首帧({first_frame_idx})检出框数:{len(first_boxes)},分数:{first_scores}")
|
||||||
|
|
||||||
|
# 5. 处理帧(只处理首帧)
|
||||||
|
frame_save_paths = []
|
||||||
|
cap = cv2.VideoCapture(video_path) if os.path.isfile(video_path) else None
|
||||||
|
|
||||||
|
frame_idx = first_frame_idx
|
||||||
|
if cap is not None:
|
||||||
|
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
|
||||||
|
ret, frame = cap.read()
|
||||||
|
if not ret:
|
||||||
|
print(f"[警告] 无法读取帧{frame_idx}")
|
||||||
|
cap.release()
|
||||||
|
return session_id, first_boxes
|
||||||
|
else:
|
||||||
|
frame = cv2.imread(frame_paths[frame_idx])
|
||||||
|
|
||||||
|
# 缩小帧分辨率
|
||||||
|
frame = resize_frame(frame, max_frame_side)
|
||||||
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||||
|
frame_pil = Image.fromarray(frame_rgb)
|
||||||
|
|
||||||
|
# 画框(即使框数为0也保存,方便排查)
|
||||||
|
fig, ax = plt.subplots(1, 1, figsize=(8, 6))
|
||||||
|
ax.imshow(frame_pil)
|
||||||
|
if len(first_boxes) > 0:
|
||||||
|
for box, score in zip(first_boxes, first_scores):
|
||||||
|
box_abs = normalize_bbox(box, frame_pil.width, frame_pil.height)
|
||||||
|
draw_box_on_image(
|
||||||
|
ax, box_abs,
|
||||||
|
label=f"SAM3 Score: {score:.2f}",
|
||||||
|
color="red", thickness=2
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
ax.text(0.5, 0.5, "未检出目标", ha="center", va="center", fontsize=20, color="red")
|
||||||
|
ax.axis("off")
|
||||||
|
|
||||||
|
# 保存带框帧
|
||||||
|
frame_save_path = os.path.join(save_dir, f"sam3_frame_{frame_idx:04d}.jpg")
|
||||||
|
plt.savefig(frame_save_path, dpi=100, bbox_inches="tight")
|
||||||
|
plt.close(fig)
|
||||||
|
frame_save_paths.append(frame_save_path)
|
||||||
|
print(f"[低显存模式] 帧{frame_idx}处理完成,保存至:{frame_save_path}")
|
||||||
|
|
||||||
|
# 6. 释放资源(移除end_session:你的SAM3版本不支持)
|
||||||
|
if cap is not None:
|
||||||
|
cap.release()
|
||||||
|
# 直接删除预测器+清理显存(替代end_session)
|
||||||
|
del video_predictor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("[低显存模式] 推理完成,资源已释放(无需end_session)")
|
||||||
|
return session_id, first_boxes
|
||||||
|
|
||||||
|
# ==================== 主函数 ====================
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 1. 图像推理
|
||||||
|
image_path = "/home/quant/data/dev/sam3/assets/images/groceries.jpg"
|
||||||
|
image_model = build_sam3_image_model()
|
||||||
|
image_processor = Sam3Processor(image_model)
|
||||||
|
image = Image.open(image_path).convert("RGB")
|
||||||
|
image = image.resize((640, 480), Image.Resampling.LANCZOS)
|
||||||
|
inference_state = image_processor.set_image(image)
|
||||||
|
image_output = image_processor.set_text_prompt(state=inference_state, prompt="food")
|
||||||
|
plot_results(image, inference_state)
|
||||||
|
plt.savefig("./sam3_image_food_result.jpg", dpi=100, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
del image_model, image_processor
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
print("✅ 图像推理完成(低显存模式)")
|
||||||
|
|
||||||
|
# 2. 视频推理:修复end_session + 提升检出率
|
||||||
|
video_path = "/home/quant/data/dev/sam3/assets/videos/bedroom.mp4"
|
||||||
|
sam3_video_inference_low_memory(
|
||||||
|
video_path=video_path,
|
||||||
|
text_prompt="", # 空提示词:检测所有目标(优先保证检出)
|
||||||
|
max_frames=5,
|
||||||
|
max_frame_side=640
|
||||||
|
)
|
||||||
43
test1.py
Normal file
43
test1.py
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import torch
|
||||||
|
import matplotlib.pyplot as plt # 新增:导入matplotlib用于保存图片
|
||||||
|
#################################### For Image ####################################
|
||||||
|
from PIL import Image
|
||||||
|
from sam3.model_builder import build_sam3_image_model
|
||||||
|
from sam3.model.sam3_image_processor import Sam3Processor
|
||||||
|
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
|
||||||
|
|
||||||
|
# Load the model
|
||||||
|
model = build_sam3_image_model()
|
||||||
|
processor = Sam3Processor(model)
|
||||||
|
|
||||||
|
# Load an image - 保留之前的RGB转换修复
|
||||||
|
image = Image.open("/home/quant/data/dev/sam3/assets/images/groceries.jpg").convert("RGB")
|
||||||
|
|
||||||
|
# 可选:打印图像信息,验证通道数
|
||||||
|
print(f"图像模式: {image.mode}, 尺寸: {image.size}")
|
||||||
|
|
||||||
|
# 处理图像
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
|
||||||
|
# 文本提示推理
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt="food")
|
||||||
|
|
||||||
|
# 获取推理结果
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
|
||||||
|
# 可视化并保存图片(核心修改部分)
|
||||||
|
# 1. 生成可视化结果
|
||||||
|
plot_results(image, inference_state)
|
||||||
|
# 2. 保存图片到当前目录,格式可选jpg/png,这里用jpg示例
|
||||||
|
plt.savefig("./sam3_food_detection_result.jpg", # 保存路径:当前目录,文件名自定义
|
||||||
|
dpi=150, # 图片分辨率,可选
|
||||||
|
bbox_inches='tight') # 去除图片周围空白
|
||||||
|
# 3. 关闭plt画布,避免内存占用
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
# 可选:打印输出信息
|
||||||
|
print(f"检测到的mask数量: {len(masks)}")
|
||||||
|
print(f"检测到的box数量: {len(boxes)}")
|
||||||
|
print(f"置信度分数: {scores}")
|
||||||
|
print("图片已保存到当前目录:./sam3_food_detection_result.jpg")
|
||||||
|
|
||||||
Reference in New Issue
Block a user