Files
sam3_local/examples/saco_veval_eval_example.ipynb
facebook-github-bot a13e358df4 Initial commit
fbshipit-source-id: da6be2f26e3a1202f4bffde8cb980e2dcb851294
2025-11-18 23:07:54 -08:00

138 lines
4.3 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"id": "0e0d2e74",
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"import os\n",
"\n",
"from sam3.eval.saco_veval_eval import VEvalEvaluator"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b31ab5d3",
"metadata": {},
"outputs": [],
"source": [
"DATASETS_TO_EVAL = [\n",
" \"saco_veval_sav_test\",\n",
" \"saco_veval_yt1b_test\",\n",
" \"saco_veval_smartglasses_test\",\n",
"]\n",
"# Update to the directory where the GT annotation and PRED files exist\n",
"GT_DIR = None # PUT YOUR ANNOTATION PATH HERE\n",
"PRED_DIR = None # PUT YOUR IMAGE PATH HERE"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3a602fef",
"metadata": {},
"outputs": [],
"source": [
"all_eval_res = {}\n",
"for dataset_name in DATASETS_TO_EVAL:\n",
" gt_annot_file = os.path.join(GT_DIR, dataset_name + \".json\")\n",
" pred_file = os.path.join(PRED_DIR, dataset_name + \"_preds.json\")\n",
" eval_res_file = os.path.join(PRED_DIR, dataset_name + \"_eval_res.json\")\n",
"\n",
" if os.path.exists(eval_res_file):\n",
" with open(eval_res_file, \"r\") as f:\n",
" eval_res = json.load(f)\n",
" else:\n",
" # Alternatively, we can run the evaluator offline first\n",
" # by leveraging sam3/eval/saco_veval_eval.py\n",
" print(f\"=== Running evaluation for Pred {pred_file} vs GT {gt_annot_file} ===\")\n",
" veval_evaluator = VEvalEvaluator(\n",
" gt_annot_file=gt_annot_file, eval_res_file=eval_res_file\n",
" )\n",
" eval_res = veval_evaluator.run_eval(pred_file=pred_file)\n",
" print(f\"=== Results saved to {eval_res_file} ===\")\n",
"\n",
" all_eval_res[dataset_name] = eval_res"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a6dbec47",
"metadata": {},
"outputs": [],
"source": [
"REPORT_METRICS = {\n",
" \"video_mask_demo_cgf1_micro_50_95\": \"cgf1\",\n",
" \"video_mask_all_phrase_HOTA\": \"pHOTA\",\n",
"}"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cc28d29f",
"metadata": {},
"outputs": [],
"source": [
"res_to_print = []\n",
"for dataset_name in DATASETS_TO_EVAL:\n",
" eval_res = all_eval_res[dataset_name]\n",
" row = [dataset_name]\n",
" for metric_k, metric_v in REPORT_METRICS.items():\n",
" row.append(eval_res[\"dataset_results\"][metric_k])\n",
" res_to_print.append(row)\n",
"\n",
"# Print dataset header (each dataset spans 2 metrics: 13 + 3 + 13 = 29 chars)\n",
"print(\"| \" + \" | \".join(f\"{ds:^29}\" for ds in DATASETS_TO_EVAL) + \" |\")\n",
"\n",
"# Print metric header\n",
"metrics = list(REPORT_METRICS.values())\n",
"print(\"| \" + \" | \".join(f\"{m:^13}\" for _ in DATASETS_TO_EVAL for m in metrics) + \" |\")\n",
"\n",
"# Print eval results\n",
"values = []\n",
"for row in res_to_print:\n",
" values.extend([f\"{v * 100:^13.1f}\" for v in row[1:]])\n",
"print(\"| \" + \" | \".join(values) + \" |\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9976908b",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"fileHeader": "",
"fileUid": "bdaa3851-85de-435f-9582-efb46951a1d0",
"isAdHoc": false,
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
}
},
"nbformat": 4,
"nbformat_minor": 2
}