diff --git a/fastAPI_main.py b/fastAPI_main.py index 88e469c..55d5d2e 100644 --- a/fastAPI_main.py +++ b/fastAPI_main.py @@ -25,48 +25,60 @@ STATIC_DIR = "static" RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) +# ------------------- 核心修改:图片压缩函数 ------------------- +def compress_image(image: Image.Image, max_size: int = 1920, quality: int = 85) -> Image.Image: + """ + 如果图片边长超过 max_size,则按比例压缩。 + :param image: PIL Image 对象 + :param max_size: 图片最大边长 (宽或高) + :param quality: 仅用于保存时的参考,这里主要做尺寸压缩 + :return: 压缩后的 PIL Image 对象 + """ + width, height = image.size + + # 如果图片本身就很小,直接返回 + if width <= max_size and height <= max_size: + return image + + # 计算缩放比例 + if width > height: + new_width = max_size + new_height = int(height * (max_size / width)) + else: + new_height = max_size + new_width = int(width * (max_size / height)) + + # 使用 LANCZOS 滤镜进行高质量下采样 + print(f"压缩图片: {width}x{height} -> {new_width}x{new_height}") + return image.resize((new_width, new_height), Image.Resampling.LANCZOS) + # ------------------- 生命周期管理 ------------------- @asynccontextmanager async def lifespan(app: FastAPI): - """ - FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源 - """ print("正在加载 SAM3 模型到 GPU...") - - # 1. 检测设备 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not torch.cuda.is_available(): print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") - # 2. 加载模型 (全局单例) model = build_sam3_image_model() model = model.to(device) - model.eval() # 切换到评估模式 - - # 3. 初始化 Processor + model.eval() processor = Sam3Processor(model) - # 4. 存入 app.state 供全局访问 app.state.model = model app.state.processor = processor app.state.device = device print(f"模型加载完成,设备: {device}") - - yield # 服务运行中... - - # 清理资源 (如果需要) + yield print("正在清理资源...") # ------------------- FastAPI 初始化 ------------------- app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API") - -# 挂载静态文件目录,用于通过 URL 访问生成的图片 app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") # ------------------- 辅助函数 ------------------- def load_image_from_url(url: str) -> Image.Image: - """从网络 URL 下载图片""" try: headers = {'User-Agent': 'Mozilla/5.0'} response = requests.get(url, headers=headers, stream=True, timeout=10) @@ -77,18 +89,11 @@ def load_image_from_url(url: str) -> Image.Image: raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") def generate_and_save_result(image: Image.Image, inference_state) -> str: - """生成可视化结果图并保存,返回文件名""" - # 生成唯一文件名防止冲突 filename = f"seg_{uuid.uuid4().hex}.jpg" save_path = os.path.join(RESULT_IMAGE_DIR, filename) - - # 绘图 (复用你提供的逻辑) plot_results(image, inference_state) - - # 保存 plt.savefig(save_path, dpi=150, bbox_inches='tight') - plt.close() # 务必关闭,防止内存泄漏 - + plt.close() return filename # ------------------- API 接口 ------------------- @@ -99,44 +104,41 @@ async def segment( file: Optional[UploadFile] = File(None), image_url: Optional[str] = Form(None) ): - """ - 接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。 - """ - - # 1. 校验输入 if not file and not image_url: - raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + raise HTTPException(status_code=400, detail="必须提供 file 或 image_url") - # 2. 获取图片对象 + # 1. 获取图片对象 try: if file: image = Image.open(file.file).convert("RGB") elif image_url: image = load_image_from_url(image_url) + + # ========== 关键修改位置 ========== + # 在送入模型前,强制压缩图片 + image = compress_image(image, max_size=1920) + # =================================== + except Exception as e: raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") - # 3. 获取模型 + # 2. 获取模型 processor = request.app.state.processor - # 4. 执行推理 + # 3. 执行推理 try: - # 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移 inference_state = processor.set_image(image) output = processor.set_text_prompt(state=inference_state, prompt=prompt) - masks, boxes, scores = output["masks"], output["boxes"], output["scores"] except Exception as e: raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") - # 5. 生成可视化并保存 + # 4. 生成可视化并保存 try: filename = generate_and_save_result(image, inference_state) except Exception as e: raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") - # 6. 构建返回 URL - # request.url_for 会自动根据当前域名生成正确的访问链接 file_url = request.url_for("static", path=f"results/{filename}") return JSONResponse(content={ @@ -148,12 +150,10 @@ async def segment( if __name__ == "__main__": import uvicorn - - # 使用 Python 函数参数的方式传递配置 uvicorn.run( - "fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名 + "fastAPI_main:app", host="127.0.0.1", port=55600, - proxy_headers=True, # 对应 --proxy-headers - forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*" + proxy_headers=True, + forwarded_allow_ips="*" ) \ No newline at end of file diff --git a/fastAPI_nocom.py b/fastAPI_nocom.py new file mode 100644 index 0000000..7d7d3ac --- /dev/null +++ b/fastAPI_nocom.py @@ -0,0 +1,159 @@ +import os +import uuid +import requests +from typing import Optional +from contextlib import asynccontextmanager + +import torch +import matplotlib +# 关键:设置非交互式后端,避免服务器环境下报错 +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Request +from fastapi.staticfiles import StaticFiles +from fastapi.responses import JSONResponse +from PIL import Image + +# SAM3 相关导入 (请确保你的环境中已正确安装 sam3) +from sam3.model_builder import build_sam3_image_model +from sam3.model.sam3_image_processor import Sam3Processor +from sam3.visualization_utils import plot_results + +# ------------------- 配置与路径 ------------------- +STATIC_DIR = "static" +RESULT_IMAGE_DIR = os.path.join(STATIC_DIR, "results") +os.makedirs(RESULT_IMAGE_DIR, exist_ok=True) + +# ------------------- 生命周期管理 ------------------- +@asynccontextmanager +async def lifespan(app: FastAPI): + """ + FastAPI 生命周期管理器:在服务启动时加载模型,关闭时清理资源 + """ + print("正在加载 SAM3 模型到 GPU...") + + # 1. 检测设备 + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if not torch.cuda.is_available(): + print("警告: 未检测到 GPU,将使用 CPU,速度会较慢。") + + # 2. 加载模型 (全局单例) + model = build_sam3_image_model() + model = model.to(device) + model.eval() # 切换到评估模式 + + # 3. 初始化 Processor + processor = Sam3Processor(model) + + # 4. 存入 app.state 供全局访问 + app.state.model = model + app.state.processor = processor + app.state.device = device + + print(f"模型加载完成,设备: {device}") + + yield # 服务运行中... + + # 清理资源 (如果需要) + print("正在清理资源...") + +# ------------------- FastAPI 初始化 ------------------- +app = FastAPI(lifespan=lifespan, title="SAM3 Segmentation API") + +# 挂载静态文件目录,用于通过 URL 访问生成的图片 +app.mount("/static", StaticFiles(directory=STATIC_DIR), name="static") + +# ------------------- 辅助函数 ------------------- +def load_image_from_url(url: str) -> Image.Image: + """从网络 URL 下载图片""" + try: + headers = {'User-Agent': 'Mozilla/5.0'} + response = requests.get(url, headers=headers, stream=True, timeout=10) + response.raise_for_status() + image = Image.open(response.raw).convert("RGB") + return image + except Exception as e: + raise HTTPException(status_code=400, detail=f"无法下载图片: {str(e)}") + +def generate_and_save_result(image: Image.Image, inference_state) -> str: + """生成可视化结果图并保存,返回文件名""" + # 生成唯一文件名防止冲突 + filename = f"seg_{uuid.uuid4().hex}.jpg" + save_path = os.path.join(RESULT_IMAGE_DIR, filename) + + # 绘图 (复用你提供的逻辑) + plot_results(image, inference_state) + + # 保存 + plt.savefig(save_path, dpi=150, bbox_inches='tight') + plt.close() # 务必关闭,防止内存泄漏 + + return filename + +# ------------------- API 接口 ------------------- +@app.post("/segment") +async def segment( + request: Request, + prompt: str = Form(...), + file: Optional[UploadFile] = File(None), + image_url: Optional[str] = Form(None) +): + """ + 接收图片 (文件上传 或 URL) 和 文本提示词,返回分割后的图片 URL。 + """ + + # 1. 校验输入 + if not file and not image_url: + raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)") + + # 2. 获取图片对象 + try: + if file: + image = Image.open(file.file).convert("RGB") + elif image_url: + image = load_image_from_url(image_url) + except Exception as e: + raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}") + + # 3. 获取模型 + processor = request.app.state.processor + + # 4. 执行推理 + try: + # 这一步内部应该已经由 Sam3Processor 处理了 GPU 张量转移 + inference_state = processor.set_image(image) + output = processor.set_text_prompt(state=inference_state, prompt=prompt) + + masks, boxes, scores = output["masks"], output["boxes"], output["scores"] + except Exception as e: + raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}") + + # 5. 生成可视化并保存 + try: + filename = generate_and_save_result(image, inference_state) + except Exception as e: + raise HTTPException(status_code=500, detail=f"绘图保存错误: {str(e)}") + + # 6. 构建返回 URL + # request.url_for 会自动根据当前域名生成正确的访问链接 + file_url = request.url_for("static", path=f"results/{filename}") + + return JSONResponse(content={ + "status": "success", + "result_image_url": str(file_url), + "detected_count": len(masks), + "scores": scores.tolist() if torch.is_tensor(scores) else scores + }) + +if __name__ == "__main__": + import uvicorn + + # 使用 Python 函数参数的方式传递配置 + uvicorn.run( + "fastAPI_main:app", # 注意:这里要改成你的文件名:app对象名 + host="127.0.0.1", + port=55600, + proxy_headers=True, # 对应 --proxy-headers + forwarded_allow_ips="*" # 对应 --forwarded-allow-ips="*" + ) diff --git a/sam3_image_food_result.jpg b/sam3_image_food_result.jpg deleted file mode 100644 index 5dda773..0000000 Binary files a/sam3_image_food_result.jpg and /dev/null differ diff --git a/static/results/seg_4da8384c943c49099c0cc06f91b5f5e0.jpg b/static/results/seg_4da8384c943c49099c0cc06f91b5f5e0.jpg new file mode 100644 index 0000000..ac64486 Binary files /dev/null and b/static/results/seg_4da8384c943c49099c0cc06f91b5f5e0.jpg differ diff --git a/static/results/seg_b46b74373a1642fd8c0e9c009632774c.jpg b/static/results/seg_b46b74373a1642fd8c0e9c009632774c.jpg new file mode 100644 index 0000000..29d812e Binary files /dev/null and b/static/results/seg_b46b74373a1642fd8c0e9c009632774c.jpg differ