fork download
  1. #!/usr/bin/env python3
  2. """
  3. compare_rfdetr_forklift.py — Per-job pre-label accuracy report for RF-DETR forklift detection.
  4.  
  5. Compares:
  6. PREDICT : forklift-bbox_datumaro_predict_part3.json (model output — bbox only)
  7. GT : datumaro_GT_part3.json (QA-finalized ground truth — bbox + keypoints)
  8.  
  9. 1000 frames split into 10 jobs of 100 frames each (Job 1 = frames 0-99, ...).
  10.  
  11. ────────────────────────────────────────────────────────────────────────────────
  12. Annotation type flags
  13. ────────────────────────────────────────────────────────────────────────────────
  14. --forklift-bbox forklift-with-roll + forklift-no-roll (bbox, IoU >= 0.85)
  15. --pr-keypoints roll-keypoints (skeleton, 5 px centroid)
  16. --clamp-keypoints clamp-2-arm + clamp-3-arm (skeleton, 5 px centroid)
  17.  
  18. If no flag is given, all three types are shown.
  19.  
  20. ────────────────────────────────────────────────────────────────────────────────
  21. Output columns (per job row)
  22. ────────────────────────────────────────────────────────────────────────────────
  23. Job | No. Model | Error BBox/Keypoints | Missed* | Over-det** | Accuracy (%)
  24.  
  25. No. Model — total boxes/keypoints predicted by the model
  26. Error — model predictions with no matching GT (IoU < threshold, or unmatched skeleton)
  27. Missed* — GT annotations the model did not predict (annotator had to add manually)
  28. Over-det** — model predicted more than GT (annotator had to delete the extra)
  29. Accuracy — Precision: (No. Model - Error) / No. Model × 100
  30. Measures what fraction of model output was correct.
  31. Does NOT penalise for missed GT annotations.
  32.  
  33. ────────────────────────────────────────────────────────────────────────────────
  34. Usage
  35. ────────────────────────────────────────────────────────────────────────────────
  36. python scripts/compare_rfdetr_forklift.py
  37. python scripts/compare_rfdetr_forklift.py --forklift-bbox
  38. python scripts/compare_rfdetr_forklift.py --pr-keypoints
  39. python scripts/compare_rfdetr_forklift.py --clamp-keypoints
  40. python scripts/compare_rfdetr_forklift.py --forklift-bbox --job 1
  41. python scripts/compare_rfdetr_forklift.py --forklift-bbox --threshold 10
  42. python scripts/compare_rfdetr_forklift.py --forklift-bbox --csv out.csv
  43. python scripts/compare_rfdetr_forklift.py --predict other_pred.json --gt other_gt.json
  44. """
  45.  
  46. import argparse
  47. import csv
  48. import json
  49. import math
  50. from collections import defaultdict
  51. from pathlib import Path
  52.  
  53. import numpy as np
  54.  
  55. # ── Default file paths ────────────────────────────────────────────────────────
  56. _BASE = Path(__file__).parent.parent / "data" / "rfdter" / "object-detection" / "forklift" / "run01"
  57. DEFAULT_PRED = str(_BASE / "forklift-bbox_datumaro_predict_part3.json")
  58. DEFAULT_GT = str(_BASE / "datumaro_GT_part3.json")
  59.  
  60. # ── Keypoint visibility encoding ─────────────────────────────────────────────
  61. # Datumaro skeleton format: points = [x0, y0, vis0, x1, y1, vis1, ...]
  62. # vis == 0 means the keypoint is absent / not annotated
  63. VIS_ABSENT = 0
  64.  
  65. # ── Annotation type groups ────────────────────────────────────────────────────
  66. # Maps each report section to the GT label names it covers
  67. FORKLIFT_BBOX_LABELS = {"forklift-with-roll", "forklift-no-roll"}
  68. PR_KP_LABELS = {"roll-keypoints"}
  69. CLAMP_KP_LABELS = {"clamp-2-arm", "clamp-3-arm"}
  70.  
  71. # Datumaro annotation type strings used for filtering
  72. FORKLIFT_BBOX_TYPES = {"bbox"}
  73. KP_TYPES = {"skeleton"}
  74.  
  75. FRAMES_PER_JOB = 100
  76.  
  77.  
  78. # ── I/O ───────────────────────────────────────────────────────────────────────
  79.  
  80. def load_json(path: str) -> dict:
  81. """Read a Datumaro-format JSON export file."""
  82. with open(path, "r", encoding="utf-8") as f:
  83. return json.load(f)
  84.  
  85.  
  86. def build_label_map(data: dict) -> dict[int, str]:
  87. """Return {label_id: label_name} from a Datumaro dataset's categories block."""
  88. labels = data.get("categories", {}).get("label", {}).get("labels", [])
  89. return {i: lbl["name"] for i, lbl in enumerate(labels)}
  90.  
  91.  
  92. # ── BBox IoU ──────────────────────────────────────────────────────────────────
  93.  
  94. def bbox_iou(b1: list, b2: list) -> float:
  95. x1, y1, w1, h1 = b1
  96. x2, y2, w2, h2 = b2
  97. xi1, yi1 = max(x1, x2), max(y1, y2)
  98. xi2, yi2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
  99. inter = max(0.0, xi2 - xi1) * max(0.0, yi2 - yi1)
  100. union = w1 * h1 + w2 * h2 - inter
  101. return inter / union if union > 0 else 0.0
  102.  
  103.  
  104. def match_boxes_by_iou(
  105. pred: list[dict],
  106. gt: list[dict],
  107. iou_thresh: float = 0.85,
  108. pred_labels: list[str] | None = None,
  109. gt_labels: list[str] | None = None,
  110. ) -> tuple[list[tuple], list[dict], list[dict]]:
  111. """
  112. Greedy highest-IoU-first matching between predicted and GT bounding boxes.
  113.  
  114. Returns (matched_pairs, unmatched_pred, unmatched_gt).
  115. A pair is formed only when IoU >= iou_thresh.
  116. If pred_labels and gt_labels are provided, a pair is only eligible when
  117. pred_labels[i] == gt_labels[j] (class-label check enabled).
  118. """
  119. if not pred or not gt:
  120. return [], pred[:], gt[:]
  121.  
  122. n_p, n_g = len(pred), len(gt)
  123. mat = np.zeros((n_p, n_g))
  124. for i, pa in enumerate(pred):
  125. for j, ga in enumerate(gt):
  126. # Zero out IoU for label mismatches when class check is enabled
  127. if pred_labels is not None and gt_labels is not None:
  128. if pred_labels[i] != gt_labels[j]:
  129. continue
  130. mat[i, j] = bbox_iou(pa["bbox"], ga["bbox"])
  131.  
  132. matched, used_p, used_g = [], set(), set()
  133. while True:
  134. idx = np.unravel_index(np.argmax(mat), mat.shape)
  135. if mat[idx] < iou_thresh:
  136. break
  137. i, j = int(idx[0]), int(idx[1])
  138. matched.append((pred[i], gt[j]))
  139. used_p.add(i)
  140. used_g.add(j)
  141. mat[i, :] = -1.0
  142. mat[:, j] = -1.0
  143.  
  144. return (
  145. matched,
  146. [pred[i] for i in range(n_p) if i not in used_p],
  147. [gt[j] for j in range(n_g) if j not in used_g],
  148. )
  149.  
  150.  
  151. # ── Skeleton matching & keypoint counting ────────────────────────────────────
  152.  
  153. def _centroid(ann: dict) -> tuple[float, float]:
  154. """Return the mean (x, y) of all visible keypoints in a skeleton annotation."""
  155. pts = ann.get("points", [])
  156. n = len(pts) // 3
  157. xs = [pts[i * 3] for i in range(n) if pts[i * 3 + 2] > VIS_ABSENT]
  158. ys = [pts[i * 3 + 1] for i in range(n) if pts[i * 3 + 2] > VIS_ABSENT]
  159. if not xs:
  160. return 0.0, 0.0
  161. return float(sum(xs) / len(xs)), float(sum(ys) / len(ys))
  162.  
  163.  
  164. def match_skeletons(
  165. pred: list[dict],
  166. gt: list[dict],
  167. dist_thresh: float = 300.0,
  168. ) -> tuple[list[tuple], list[dict], list[dict]]:
  169. """
  170. Greedy nearest-centroid matching between predicted and GT skeleton annotations.
  171.  
  172. Returns (matched_pairs, unmatched_pred, unmatched_gt).
  173. A pair is formed only when centroid distance <= dist_thresh pixels.
  174. Matching is done per label class (called separately for each label).
  175. """
  176. if not pred or not gt:
  177. return [], pred[:], gt[:]
  178.  
  179. n_p, n_g = len(pred), len(gt)
  180. mat = np.full((n_p, n_g), np.inf)
  181. pc = [_centroid(a) for a in pred]
  182. gc = [_centroid(a) for a in gt]
  183. for i in range(n_p):
  184. for j in range(n_g):
  185. mat[i, j] = math.hypot(pc[i][0] - gc[j][0], pc[i][1] - gc[j][1])
  186.  
  187. matched, used_p, used_g = [], set(), set()
  188. while True:
  189. idx = np.unravel_index(np.argmin(mat), mat.shape)
  190. if mat[idx] > dist_thresh:
  191. break
  192. i, j = int(idx[0]), int(idx[1])
  193. matched.append((pred[i], gt[j]))
  194. used_p.add(i)
  195. used_g.add(j)
  196. mat[i, :] = np.inf
  197. mat[:, j] = np.inf
  198.  
  199. return (
  200. matched,
  201. [pred[i] for i in range(n_p) if i not in used_p],
  202. [gt[j] for j in range(n_g) if j not in used_g],
  203. )
  204.  
  205.  
  206. def kp_present(ann: dict) -> int:
  207. """Count all keypoints in a skeleton annotation regardless of visibility flag."""
  208. pts = ann.get("points", [])
  209. return len(pts) // 3
  210.  
  211.  
  212. def kp_errors(pred_ann: dict, gt_ann: dict, threshold_px: float) -> int:
  213. """
  214. Count keypoints where the predicted position is too far from the GT position.
  215.  
  216. Compares ALL N corresponding keypoints regardless of visibility flag.
  217. Each point where distance > threshold_px counts as one error.
  218. """
  219. pred_pts = pred_ann.get("points", [])
  220. gt_pts = gt_ann.get("points", [])
  221. n = min(len(pred_pts), len(gt_pts)) // 3
  222. return sum(
  223. 1 for i in range(n)
  224. if math.hypot(pred_pts[i*3] - gt_pts[i*3],
  225. pred_pts[i*3+1] - gt_pts[i*3+1]) > threshold_px
  226. )
  227.  
  228.  
  229. def kp_manual(pred_ann: dict, gt_ann: dict) -> int:
  230. """
  231. Count keypoints that are visible in GT but absent in the prediction.
  232.  
  233. These are keypoints the annotator had to add manually during QA review.
  234. """
  235. pred_pts = pred_ann.get("points", [])
  236. gt_pts = gt_ann.get("points", [])
  237. n = min(len(pred_pts), len(gt_pts)) // 3
  238. return sum(
  239. 1 for i in range(n)
  240. if gt_pts[i*3+2] != VIS_ABSENT and pred_pts[i*3+2] == VIS_ABSENT
  241. )
  242.  
  243.  
  244. # ── Per-job stats accumulator ────────────────────────────────────────────────
  245.  
  246. def empty_stats() -> dict:
  247. """
  248. Initialise a zero-filled stats dict for one (job, annotation-type) pair.
  249.  
  250. Fields
  251. ------
  252. no_model : total boxes / keypoints predicted by the model
  253. error_model : model predictions with no matching GT (wrong position or unmatched)
  254. no_missed : GT annotations the model did not predict (annotator added manually)
  255. no_over : model predicted more than GT (annotator had to delete the extras)
  256. """
  257. return {
  258. "no_model": 0,
  259. "error_model": 0,
  260. "no_missed": 0, # GT count > model count — annotator had to add
  261. "no_over": 0, # model count > GT count — model predicted too many
  262. }
  263.  
  264.  
  265. def accumulate_bbox(
  266. pred_anns: list[dict],
  267. gt_anns: list[dict],
  268. target_labels: set[str],
  269. pred_lmap: dict,
  270. gt_lmap: dict,
  271. stats: dict,
  272. label_match: bool = False,
  273. ):
  274. p = [a for a in pred_anns if pred_lmap.get(a["label_id"]) in target_labels and a["type"] == "bbox"]
  275. g = [a for a in gt_anns if gt_lmap.get(a["label_id"]) in target_labels and a["type"] == "bbox"]
  276.  
  277. # no_model = total predicted boxes in this frame
  278. # error_model = predicted boxes with no IoU match in GT (wrong/spurious position)
  279. # no_missed = GT has more boxes → delta positive (annotator had to draw them)
  280. # no_over = model has more boxes → delta negative (annotator had to delete them)
  281. pred_labels = [pred_lmap.get(a["label_id"]) for a in p] if label_match else None
  282. gt_labels = [gt_lmap.get(a["label_id"]) for a in g] if label_match else None
  283. _, unmatched_p, _ = match_boxes_by_iou(p, g, pred_labels=pred_labels, gt_labels=gt_labels)
  284. delta = len(g) - len(p) # positive = missed, negative = over-detected
  285. stats["no_model"] += len(p)
  286. stats["no_missed"] += max(0, delta)
  287. stats["no_over"] += max(0, -delta)
  288. stats["error_model"] += len(unmatched_p)
  289.  
  290.  
  291. def accumulate_kp(
  292. pred_anns: list[dict],
  293. gt_anns: list[dict],
  294. target_labels: set[str],
  295. pred_lmap: dict,
  296. gt_lmap: dict,
  297. stats: dict,
  298. threshold_px: float,
  299. ):
  300. # Each skeleton label is matched independently.
  301. # A skeleton paired to the wrong class counts as an error.
  302. for label in target_labels:
  303. p = [a for a in pred_anns if pred_lmap.get(a["label_id"]) == label and a["type"] == "skeleton"]
  304. g = [a for a in gt_anns if gt_lmap.get(a["label_id"]) == label and a["type"] == "skeleton"]
  305.  
  306. matched, unmatched_p, unmatched_g = match_skeletons(p, g)
  307.  
  308. # No. Model = all KP slots in every pred skeleton (vis ignored)
  309. for pa in p:
  310. stats["no_model"] += kp_present(pa)
  311.  
  312. # Error = per-point position error in matched skeleton pairs (vis ignored)
  313. for pa, ga in matched:
  314. stats["error_model"] += kp_errors(pa, ga, threshold_px)
  315.  
  316. # Over-det = all points of pred skeletons with no matching GT skeleton
  317. for pa in unmatched_p:
  318. stats["no_over"] += kp_present(pa)
  319.  
  320. # Missed = all points of GT skeletons the model did not predict
  321. for ga in unmatched_g:
  322. stats["no_missed"] += kp_present(ga)
  323.  
  324.  
  325. # ── Main compare ──────────────────────────────────────────────────────────────
  326.  
  327. def run_compare(
  328. pred_path: str,
  329. gt_path: str,
  330. show_bbox: bool,
  331. show_pr_kp: bool,
  332. show_clamp_kp: bool,
  333. threshold_px: float,
  334. job_filter: int | None,
  335. label_match: bool = False,
  336. ) -> dict:
  337. """
  338. Load predict and GT files, iterate over all frames, and accumulate per-job stats.
  339.  
  340. Returns
  341. -------
  342. {
  343. job_id (1-10): {
  344. "forklift_bbox": {"no_model": int, "error_model": int,
  345. "no_missed": int, "no_over": int},
  346. "pr_keypoints": {...},
  347. "clamp_keypoints": {...},
  348. },
  349. ...
  350. }
  351. """
  352. pred_data = load_json(pred_path)
  353. gt_data = load_json(gt_path)
  354.  
  355. pred_lmap = build_label_map(pred_data)
  356. gt_lmap = build_label_map(gt_data)
  357.  
  358. pred_by_id = {it["id"]: it for it in pred_data["items"]}
  359.  
  360. # Sort GT items by sequential frame index so job assignment is deterministic
  361. gt_sorted = sorted(gt_data["items"], key=lambda x: x["attr"].get("frame", 0))
  362.  
  363. jobs: dict[int, dict] = {}
  364.  
  365. for seq_idx, gt_item in enumerate(gt_sorted):
  366. job_id = seq_idx // FRAMES_PER_JOB + 1
  367. if job_filter is not None and job_id != job_filter:
  368. continue
  369.  
  370. if job_id not in jobs:
  371. jobs[job_id] = {
  372. "forklift_bbox": empty_stats(),
  373. "pr_keypoints": empty_stats(),
  374. "clamp_keypoints": empty_stats(),
  375. }
  376.  
  377. pred_item = pred_by_id.get(gt_item["id"], {"annotations": []})
  378. pred_anns = pred_item.get("annotations", [])
  379. gt_anns = gt_item.get("annotations", [])
  380.  
  381. if show_bbox:
  382. accumulate_bbox(pred_anns, gt_anns, FORKLIFT_BBOX_LABELS, pred_lmap, gt_lmap,
  383. jobs[job_id]["forklift_bbox"], label_match=label_match)
  384.  
  385. if show_pr_kp:
  386. accumulate_kp(pred_anns, gt_anns, PR_KP_LABELS, pred_lmap, gt_lmap,
  387. jobs[job_id]["pr_keypoints"], threshold_px)
  388.  
  389. if show_clamp_kp:
  390. accumulate_kp(pred_anns, gt_anns, CLAMP_KP_LABELS, pred_lmap, gt_lmap,
  391. jobs[job_id]["clamp_keypoints"], threshold_px)
  392.  
  393. return jobs
  394.  
  395.  
  396. # ── Display helpers ──────────────────────────────────────────────────────────
  397.  
  398. def _acc(no_model: int, error: int) -> str:
  399. """
  400. Compute precision-based pre-label accuracy.
  401.  
  402. Formula: (no_model - error) / no_model * 100
  403. Measures what fraction of model predictions were correct.
  404. Returns 'N/A' when the model produced no predictions.
  405. """
  406. if no_model == 0:
  407. return "N/A"
  408. return f"{(no_model - error) / no_model * 100:.1f}%"
  409.  
  410.  
  411. TYPE_LABELS = {
  412. "forklift_bbox": "Forklift BBox (forklift-with-roll + forklift-no-roll)",
  413. "pr_keypoints": "PR Keypoints (roll-keypoints)",
  414. "clamp_keypoints": "Clamp KP (clamp-2-arm + clamp-3-arm)",
  415. }
  416.  
  417. TYPE_KEYS = ["forklift_bbox", "pr_keypoints", "clamp_keypoints"]
  418. SHOW_FLAGS = {
  419. "forklift_bbox": "show_bbox",
  420. "pr_keypoints": "show_pr_kp",
  421. "clamp_keypoints": "show_clamp_kp",
  422. }
  423.  
  424.  
  425. def print_report(
  426. jobs: dict,
  427. pred_file: str,
  428. gt_file: str,
  429. threshold_px: float,
  430. show_bbox: bool,
  431. show_pr_kp: bool,
  432. show_clamp_kp: bool,
  433. label_match: bool = False,
  434. ):
  435. W = 85
  436. show_map = {
  437. "forklift_bbox": show_bbox,
  438. "pr_keypoints": show_pr_kp,
  439. "clamp_keypoints": show_clamp_kp,
  440. }
  441. active_types = [k for k in TYPE_KEYS if show_map[k]]
  442.  
  443. print("\n" + "=" * W)
  444. print(" RF-DETR FORKLIFT — PRE-LABEL ACCURACY REPORT (per job)")
  445. print("=" * W)
  446. print(f" Predicted : {pred_file}")
  447. print(f" GT : {gt_file}")
  448. print(f" Threshold : {threshold_px} px (keypoint position error)")
  449. print(f" BBox match: IoU >= 0.85{' + same class label' if label_match else ' (position only, no class check)'}")
  450. print(f" Jobs : {min(jobs) if jobs else '—'} – {max(jobs) if jobs else '—'} ({FRAMES_PER_JOB} frames each)")
  451. print("=" * W)
  452.  
  453. # Column widths for the per-job table
  454. # Job | No. Model | Error | Missed* | Over-det** | Accuracy (%)
  455. C_JOB, C1, C2, C3, C4, C5 = 7, 12, 9, 10, 12, 14
  456. SEP = C_JOB + C1 + C2 + C3 + C4 + C5
  457.  
  458. for type_key in active_types:
  459. print(f"\n ── {TYPE_LABELS[type_key]} ──")
  460. print(
  461. f" {'Job':>{C_JOB}}"
  462. f"{'No. Model':>{C1}}"
  463. f"{'Error':>{C2}}"
  464. f"{'Missed*':>{C3}}"
  465. f"{'Over-det**':>{C4}}"
  466. f"{'Accuracy (%)':>{C5}}"
  467. )
  468. print(" * Missed = GT > Model (model missed — annotator drew these manually)")
  469. print(" ** Over-det = Model > GT (model over-detected — annotator deleted these)")
  470. print(" " + "─" * SEP)
  471.  
  472. total = empty_stats()
  473. for job_id in sorted(jobs):
  474. s = jobs[job_id][type_key]
  475. print(
  476. f" {job_id:>{C_JOB}}"
  477. f"{s['no_model']:>{C1}}"
  478. f"{s['error_model']:>{C2}}"
  479. f"{s['no_missed']:>{C3}}"
  480. f"{s['no_over']:>{C4}}"
  481. f"{_acc(s['no_model'], s['error_model']):>{C5}}"
  482. )
  483. for k in total:
  484. total[k] += s[k]
  485.  
  486. # Total row
  487. print(" " + "─" * SEP)
  488. print(
  489. f" {'TOTAL':>{C_JOB}}"
  490. f"{total['no_model']:>{C1}}"
  491. f"{total['error_model']:>{C2}}"
  492. f"{total['no_missed']:>{C3}}"
  493. f"{total['no_over']:>{C4}}"
  494. f"{_acc(total['no_model'], total['error_model']):>{C5}}"
  495. )
  496.  
  497. print("\n" + "=" * W + "\n")
  498.  
  499.  
  500. def export_csv(
  501. jobs: dict,
  502. csv_path: str,
  503. show_bbox: bool,
  504. show_pr_kp: bool,
  505. show_clamp_kp: bool,
  506. ):
  507. show_map = {
  508. "forklift_bbox": show_bbox,
  509. "pr_keypoints": show_pr_kp,
  510. "clamp_keypoints": show_clamp_kp,
  511. }
  512. active_types = [k for k in TYPE_KEYS if show_map[k]]
  513.  
  514. FIELDS = ["job", "type", "no_model", "error_model", "no_missed", "no_over", "accuracy_pct"]
  515. HEADERS = {
  516. "job": "Job",
  517. "type": "Annotation Type",
  518. "no_model": "No. BBox/KP (Model)",
  519. "error_model": "Error BBox/KP (Model)",
  520. "no_missed": "Missed — GT > Model (Manual Add)",
  521. "no_over": "Over-detect — Model > GT (Spurious)",
  522. "accuracy_pct": "Pre-label Accuracy (%)",
  523. }
  524.  
  525. rows = []
  526. for type_key in active_types:
  527. total = empty_stats()
  528. for job_id in sorted(jobs):
  529. s = jobs[job_id][type_key]
  530. no_m = s["no_model"]
  531. err = s["error_model"]
  532. rows.append({
  533. "job": job_id,
  534. "type": TYPE_LABELS[type_key].split("(")[0].strip(),
  535. "no_model": no_m,
  536. "error_model": err,
  537. "no_missed": s["no_missed"],
  538. "no_over": s["no_over"],
  539. "accuracy_pct": f"{(no_m - err) / no_m * 100:.1f}" if no_m > 0 else "",
  540. })
  541. for k in total:
  542. total[k] += s[k]
  543. # Total row per type
  544. rows.append({
  545. "job": "TOTAL",
  546. "type": TYPE_LABELS[type_key].split("(")[0].strip(),
  547. "no_model": total["no_model"],
  548. "error_model": total["error_model"],
  549. "no_missed": total["no_missed"],
  550. "no_over": total["no_over"],
  551. "accuracy_pct": _acc(total["no_model"], total["error_model"]).replace("%", ""),
  552. })
  553. rows.append(dict.fromkeys(FIELDS, "")) # blank separator
  554.  
  555. with open(csv_path, "w", newline="", encoding="utf-8") as f:
  556. writer = csv.DictWriter(f, fieldnames=FIELDS)
  557. writer.writerow({k: HEADERS[k] for k in FIELDS})
  558. writer.writerows(rows)
  559.  
  560. print(f" CSV exported → {csv_path}")
  561.  
  562.  
  563. # ── CLI ───────────────────────────────────────────────────────────────────────
  564.  
  565. def main():
  566. parser = argparse.ArgumentParser(
  567. description="Per-job pre-label accuracy report for RF-DETR forklift part3.",
  568. formatter_class=argparse.RawDescriptionHelpFormatter,
  569. )
  570.  
  571. # Type selection flags
  572. parser.add_argument("--forklift-bbox", action="store_true",
  573. help="Show Forklift BBox stats (forklift-with-roll + forklift-no-roll)")
  574. parser.add_argument("--pr-keypoints", action="store_true",
  575. help="Show Paper Roll Keypoints stats (roll-keypoints)")
  576. parser.add_argument("--clamp-keypoints", action="store_true",
  577. help="Show Clamp Keypoints stats (clamp-2-arm + clamp-3-arm)")
  578.  
  579. parser.add_argument("--label-match", action="store_true",
  580. help="BBox: require pred and GT to have the same class label to match "
  581. "(default: off — position-only IoU matching).")
  582.  
  583. # Options
  584. parser.add_argument("--job", type=int, metavar="N",
  585. help="Show only job N (1–10). Default: all jobs.")
  586. parser.add_argument("--threshold", type=float, default=5.0,
  587. help="Keypoint error threshold in pixels (default: 5).")
  588. parser.add_argument("--csv", dest="csv_path",
  589. help="Export results to a CSV file.")
  590. parser.add_argument("--predict", default=DEFAULT_PRED,
  591. help=f"Path to predict JSON (default: {DEFAULT_PRED})")
  592. parser.add_argument("--gt", default=DEFAULT_GT,
  593. help=f"Path to GT JSON (default: {DEFAULT_GT})")
  594.  
  595. args = parser.parse_args()
  596.  
  597. # If no type flag given → show all
  598. show_bbox = args.forklift_bbox
  599. show_pr_kp = args.pr_keypoints
  600. show_clamp_kp = args.clamp_keypoints
  601. if not show_bbox and not show_pr_kp and not show_clamp_kp:
  602. show_bbox = show_pr_kp = show_clamp_kp = True
  603.  
  604. if not Path(args.predict).exists():
  605. print(f"Error: predict file not found: {args.predict}")
  606. return
  607. if not Path(args.gt).exists():
  608. print(f"Error: GT file not found: {args.gt}")
  609. return
  610.  
  611. jobs = run_compare(
  612. pred_path=args.predict,
  613. gt_path=args.gt,
  614. show_bbox=show_bbox,
  615. show_pr_kp=show_pr_kp,
  616. show_clamp_kp=show_clamp_kp,
  617. threshold_px=args.threshold,
  618. job_filter=args.job,
  619. label_match=args.label_match,
  620. )
  621.  
  622. print_report(
  623. jobs,
  624. pred_file=Path(args.predict).name,
  625. gt_file=Path(args.gt).name,
  626. threshold_px=args.threshold,
  627. show_bbox=show_bbox,
  628. show_pr_kp=show_pr_kp,
  629. show_clamp_kp=show_clamp_kp,
  630. label_match=args.label_match,
  631. )
  632.  
  633. if args.csv_path:
  634. export_csv(jobs, args.csv_path, show_bbox, show_pr_kp, show_clamp_kp)
  635.  
  636.  
  637. if __name__ == "__main__":
  638. main()
  639.  
Success #stdin #stdout 0.69s 41908KB
stdin
Standard input is empty
stdout
Error: predict file not found: /home/data/rfdter/object-detection/forklift/run01/forklift-bbox_datumaro_predict_part3.json