tarot
194
fastAPI_tarot.py
@@ -6,6 +6,8 @@ import numpy as np
|
|||||||
import cv2
|
import cv2
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import dashscope
|
||||||
|
from dashscope import MultiModalConversation
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import matplotlib
|
import matplotlib
|
||||||
@@ -31,6 +33,10 @@ os.makedirs(RESULT_IMAGE_DIR, exist_ok=True)
|
|||||||
VALID_API_KEY = "123quant-speed"
|
VALID_API_KEY = "123quant-speed"
|
||||||
API_KEY_HEADER_NAME = "X-API-Key"
|
API_KEY_HEADER_NAME = "X-API-Key"
|
||||||
|
|
||||||
|
# Dashscope 配置 (Qwen-VL)
|
||||||
|
dashscope.api_key = 'sk-ce2404f55f744a1987d5ece61c6bac58'
|
||||||
|
QWEN_MODEL = 'qwen-vl-max'
|
||||||
|
|
||||||
# 定义 Header 认证
|
# 定义 Header 认证
|
||||||
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
api_key_header = APIKeyHeader(name=API_KEY_HEADER_NAME, auto_error=False)
|
||||||
|
|
||||||
@@ -264,6 +270,82 @@ def generate_and_save_result(image: Image.Image, inference_state, output_dir: st
|
|||||||
plt.close()
|
plt.close()
|
||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
def recognize_card_with_qwen(image_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
调用 Qwen-VL 识别塔罗牌
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 确保路径是绝对路径并加上 file:// 前缀
|
||||||
|
abs_path = os.path.abspath(image_path)
|
||||||
|
file_url = f"file://{abs_path}"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"image": file_url},
|
||||||
|
{"text": "这是一张塔罗牌。请识别它的名字(中文),并判断它是正位还是逆位。请以JSON格式返回,包含 'name' 和 'position' 两个字段。例如:{'name': '愚者', 'position': '正位'}。不要包含Markdown代码块标记。"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
content = response.output.choices[0].message.content[0]['text']
|
||||||
|
# 尝试解析简单的 JSON
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
# 清理可能存在的 markdown 标记
|
||||||
|
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||||
|
result = json.loads(clean_content)
|
||||||
|
return result
|
||||||
|
except:
|
||||||
|
return {"raw_response": content}
|
||||||
|
else:
|
||||||
|
return {"error": f"API Error: {response.code} - {response.message}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"识别失败: {str(e)}"}
|
||||||
|
|
||||||
|
def recognize_spread_with_qwen(image_path: str) -> dict:
|
||||||
|
"""
|
||||||
|
调用 Qwen-VL 识别塔罗牌牌阵
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 确保路径是绝对路径并加上 file:// 前缀
|
||||||
|
abs_path = os.path.abspath(image_path)
|
||||||
|
file_url = f"file://{abs_path}"
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": [
|
||||||
|
{"image": file_url},
|
||||||
|
{"text": "这是一张包含多张塔罗牌的图片。请根据牌的排列方式识别这是什么牌阵(例如:圣三角、凯尔特十字、三张牌等)。如果看不出明显的正规牌阵,请返回“不是正规牌阵”。请以JSON格式返回,包含 'spread_name' 和 'description' 两个字段。例如:{'spread_name': '圣三角', 'description': '常见的时间流占卜法'}。不要包含Markdown代码块标记。"}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = MultiModalConversation.call(model=QWEN_MODEL, messages=messages)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
content = response.output.choices[0].message.content[0]['text']
|
||||||
|
# 尝试解析简单的 JSON
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
# 清理可能存在的 markdown 标记
|
||||||
|
clean_content = content.replace("```json", "").replace("```", "").strip()
|
||||||
|
result = json.loads(clean_content)
|
||||||
|
return result
|
||||||
|
except:
|
||||||
|
return {"raw_response": content, "spread_name": "Unknown"}
|
||||||
|
else:
|
||||||
|
return {"error": f"API Error: {response.code} - {response.message}"}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {"error": f"牌阵识别失败: {str(e)}"}
|
||||||
|
|
||||||
# ------------------- API 接口 (强制依赖验证) -------------------
|
# ------------------- API 接口 (强制依赖验证) -------------------
|
||||||
@app.post("/segment", dependencies=[Depends(verify_api_key)])
|
@app.post("/segment", dependencies=[Depends(verify_api_key)])
|
||||||
async def segment(
|
async def segment(
|
||||||
@@ -398,6 +480,118 @@ async def segment_tarot(
|
|||||||
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||||
})
|
})
|
||||||
|
|
||||||
|
@app.post("/recognize_tarot", dependencies=[Depends(verify_api_key)])
|
||||||
|
async def recognize_tarot(
|
||||||
|
request: Request,
|
||||||
|
file: Optional[UploadFile] = File(None),
|
||||||
|
image_url: Optional[str] = Form(None),
|
||||||
|
expected_count: int = Form(3)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
塔罗牌全流程接口: 分割 + 矫正 + 识别
|
||||||
|
1. 检测是否包含指定数量的塔罗牌 (SAM3)
|
||||||
|
2. 分割并透视矫正
|
||||||
|
3. 调用 Qwen-VL 识别每张牌的名称和正逆位
|
||||||
|
"""
|
||||||
|
if not file and not image_url:
|
||||||
|
raise HTTPException(status_code=400, detail="必须提供 file (图片文件) 或 image_url (图片链接)")
|
||||||
|
|
||||||
|
try:
|
||||||
|
if file:
|
||||||
|
image = Image.open(file.file).convert("RGB")
|
||||||
|
elif image_url:
|
||||||
|
image = load_image_from_url(image_url)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=400, detail=f"图片解析失败: {str(e)}")
|
||||||
|
|
||||||
|
processor = request.app.state.processor
|
||||||
|
|
||||||
|
try:
|
||||||
|
inference_state = processor.set_image(image)
|
||||||
|
# 固定 Prompt 检测塔罗牌
|
||||||
|
output = processor.set_text_prompt(state=inference_state, prompt="tarot card")
|
||||||
|
masks, boxes, scores = output["masks"], output["boxes"], output["scores"]
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"模型推理错误: {str(e)}")
|
||||||
|
|
||||||
|
# 核心逻辑:判断数量
|
||||||
|
detected_count = len(masks)
|
||||||
|
|
||||||
|
# 创建本次请求的独立文件夹
|
||||||
|
request_id = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||||
|
output_dir = os.path.join(RESULT_IMAGE_DIR, request_id)
|
||||||
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 保存整体效果图 (无论是成功还是失败,都先保存一张主图)
|
||||||
|
try:
|
||||||
|
main_filename = generate_and_save_result(image, inference_state, output_dir=output_dir)
|
||||||
|
main_file_path = os.path.join(output_dir, main_filename)
|
||||||
|
main_file_url = str(request.url_for("static", path=f"results/{request_id}/{main_filename}"))
|
||||||
|
except:
|
||||||
|
main_filename = None
|
||||||
|
main_file_path = None
|
||||||
|
main_file_url = None
|
||||||
|
|
||||||
|
# Step 0: 牌阵识别 (在判断数量之前或之后都可以,这里放在前面作为全局判断)
|
||||||
|
spread_info = {"spread_name": "Unknown"}
|
||||||
|
if main_file_path:
|
||||||
|
# 使用带有mask绘制的主图或者原始图?
|
||||||
|
# 使用原始图可能更好,不受mask遮挡干扰,但是main_filename是带mask的。
|
||||||
|
# 我们这里暂时用原始图保存一份临时文件给Qwen看
|
||||||
|
temp_raw_path = os.path.join(output_dir, "raw_for_spread.jpg")
|
||||||
|
image.save(temp_raw_path)
|
||||||
|
spread_info = recognize_spread_with_qwen(temp_raw_path)
|
||||||
|
|
||||||
|
# 如果识别结果明确说是“不是正规牌阵”,是否要继续?
|
||||||
|
# 用户需求:“如果没有正确的牌阵则返回‘不是正规牌阵’”
|
||||||
|
# 我们将其放在返回结果中,由客户端决定是否展示警告
|
||||||
|
|
||||||
|
if detected_count != expected_count:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=400,
|
||||||
|
content={
|
||||||
|
"status": "failed",
|
||||||
|
"message": f"检测到 {detected_count} 个目标,需要严格的 {expected_count} 张塔罗牌。请调整拍摄角度或背景。",
|
||||||
|
"detected_count": detected_count,
|
||||||
|
"spread_info": spread_info,
|
||||||
|
"debug_image_url": str(main_file_url) if main_file_url else None
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 数量正确,执行抠图 + 矫正
|
||||||
|
try:
|
||||||
|
saved_objects = crop_and_save_objects(image, masks, boxes, output_dir=output_dir)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"抠图处理错误: {str(e)}")
|
||||||
|
|
||||||
|
# 遍历每张卡片进行识别
|
||||||
|
tarot_cards = []
|
||||||
|
for obj in saved_objects:
|
||||||
|
fname = obj["filename"]
|
||||||
|
file_path = os.path.join(output_dir, fname)
|
||||||
|
|
||||||
|
# 调用 Qwen-VL 识别
|
||||||
|
# 注意:这里会串行调用,速度可能较慢。
|
||||||
|
recognition_res = recognize_card_with_qwen(file_path)
|
||||||
|
|
||||||
|
file_url = str(request.url_for("static", path=f"results/{request_id}/{fname}"))
|
||||||
|
tarot_cards.append({
|
||||||
|
"url": file_url,
|
||||||
|
"is_rotated": obj["is_rotated_by_algorithm"],
|
||||||
|
"orientation_status": "corrected_to_portrait" if obj["is_rotated_by_algorithm"] else "original_portrait",
|
||||||
|
"recognition": recognition_res,
|
||||||
|
"note": obj["note"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return JSONResponse(content={
|
||||||
|
"status": "success",
|
||||||
|
"message": f"成功识别并分割 {expected_count} 张塔罗牌 (含Qwen识别结果)",
|
||||||
|
"spread_info": spread_info,
|
||||||
|
"tarot_cards": tarot_cards,
|
||||||
|
"full_visualization": main_file_url,
|
||||||
|
"scores": scores.tolist() if torch.is_tensor(scores) else scores
|
||||||
|
})
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
import uvicorn
|
import uvicorn
|
||||||
# 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数
|
# 注意:如果你的文件名不是 fastAPI_tarot.py,请修改下面第一个参数
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
uvicorn
|
uvicorn
|
||||||
python-multipart
|
python-multipart
|
||||||
fastapi
|
fastapi
|
||||||
|
dashscope
|
||||||
BIN
static/results/1771145685_9aee4513/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 197 KiB |
|
After Width: | Height: | Size: 142 KiB |
|
After Width: | Height: | Size: 107 KiB |
|
After Width: | Height: | Size: 92 KiB |
|
After Width: | Height: | Size: 122 KiB |
|
After Width: | Height: | Size: 102 KiB |
|
After Width: | Height: | Size: 116 KiB |
BIN
static/results/1771145893_ff902547/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 300 KiB |
|
After Width: | Height: | Size: 92 KiB |
|
After Width: | Height: | Size: 293 KiB |
|
After Width: | Height: | Size: 272 KiB |
BIN
static/results/1771146096_519396b9/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 205 KiB |
|
After Width: | Height: | Size: 146 KiB |
|
After Width: | Height: | Size: 83 KiB |
|
After Width: | Height: | Size: 52 KiB |
|
After Width: | Height: | Size: 64 KiB |
|
After Width: | Height: | Size: 82 KiB |
|
After Width: | Height: | Size: 74 KiB |
|
After Width: | Height: | Size: 54 KiB |
|
After Width: | Height: | Size: 73 KiB |
|
After Width: | Height: | Size: 51 KiB |
BIN
static/results/1771146680_f72a7212/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 205 KiB |
|
After Width: | Height: | Size: 146 KiB |
|
After Width: | Height: | Size: 54 KiB |
|
After Width: | Height: | Size: 74 KiB |
|
After Width: | Height: | Size: 82 KiB |
|
After Width: | Height: | Size: 83 KiB |
|
After Width: | Height: | Size: 73 KiB |
|
After Width: | Height: | Size: 52 KiB |
|
After Width: | Height: | Size: 64 KiB |
|
After Width: | Height: | Size: 51 KiB |
BIN
static/results/1771147171_7043dd6e/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 212 KiB |
|
After Width: | Height: | Size: 84 KiB |
BIN
static/results/1771147215_299c8af0/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 212 KiB |
|
After Width: | Height: | Size: 84 KiB |
BIN
static/results/1771147221_98dd0e08/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 212 KiB |
|
After Width: | Height: | Size: 84 KiB |
BIN
static/results/1771147306_7b73fc1a/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 232 KiB |
|
After Width: | Height: | Size: 84 KiB |
|
After Width: | Height: | Size: 342 KiB |
|
After Width: | Height: | Size: 266 KiB |
|
After Width: | Height: | Size: 343 KiB |
BIN
static/results/1771147521_e160c2b3/raw_for_spread.jpg
Normal file
|
After Width: | Height: | Size: 232 KiB |
|
After Width: | Height: | Size: 84 KiB |
|
After Width: | Height: | Size: 343 KiB |
|
After Width: | Height: | Size: 342 KiB |
|
After Width: | Height: | Size: 266 KiB |