This commit is contained in:
2026-02-15 13:22:38 +08:00
parent 8bb00ac928
commit 3b5461371c
6 changed files with 365 additions and 32 deletions

43
test1.py Normal file
View File

@@ -0,0 +1,43 @@
import torch
import matplotlib.pyplot as plt # 新增导入matplotlib用于保存图片
#################################### For Image ####################################
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model)
# Load an image - 保留之前的RGB转换修复
image = Image.open("/home/quant/data/dev/sam3/assets/images/groceries.jpg").convert("RGB")
# 可选:打印图像信息,验证通道数
print(f"图像模式: {image.mode}, 尺寸: {image.size}")
# 处理图像
inference_state = processor.set_image(image)
# 文本提示推理
output = processor.set_text_prompt(state=inference_state, prompt="food")
# 获取推理结果
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
# 可视化并保存图片(核心修改部分)
# 1. 生成可视化结果
plot_results(image, inference_state)
# 2. 保存图片到当前目录格式可选jpg/png这里用jpg示例
plt.savefig("./sam3_food_detection_result.jpg", # 保存路径:当前目录,文件名自定义
dpi=150, # 图片分辨率,可选
bbox_inches='tight') # 去除图片周围空白
# 3. 关闭plt画布避免内存占用
plt.close()
# 可选:打印输出信息
print(f"检测到的mask数量: {len(masks)}")
print(f"检测到的box数量: {len(boxes)}")
print(f"置信度分数: {scores}")
print("图片已保存到当前目录:./sam3_food_detection_result.jpg")