Files
sam3_local/test1.py
2026-02-15 13:22:38 +08:00

44 lines
1.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import torch
import matplotlib.pyplot as plt # 新增导入matplotlib用于保存图片
#################################### For Image ####################################
from PIL import Image
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import draw_box_on_image, normalize_bbox, plot_results
# Load the model
model = build_sam3_image_model()
processor = Sam3Processor(model)
# Load an image - 保留之前的RGB转换修复
image = Image.open("/home/quant/data/dev/sam3/assets/images/groceries.jpg").convert("RGB")
# 可选:打印图像信息,验证通道数
print(f"图像模式: {image.mode}, 尺寸: {image.size}")
# 处理图像
inference_state = processor.set_image(image)
# 文本提示推理
output = processor.set_text_prompt(state=inference_state, prompt="food")
# 获取推理结果
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
# 可视化并保存图片(核心修改部分)
# 1. 生成可视化结果
plot_results(image, inference_state)
# 2. 保存图片到当前目录格式可选jpg/png这里用jpg示例
plt.savefig("./sam3_food_detection_result.jpg", # 保存路径:当前目录,文件名自定义
dpi=150, # 图片分辨率,可选
bbox_inches='tight') # 去除图片周围空白
# 3. 关闭plt画布避免内存占用
plt.close()
# 可选:打印输出信息
print(f"检测到的mask数量: {len(masks)}")
print(f"检测到的box数量: {len(boxes)}")
print(f"置信度分数: {scores}")
print("图片已保存到当前目录:./sam3_food_detection_result.jpg")