#!/usr/bin/env python3
"""
compare_rfdetr_forklift.py — Per-job pre-label accuracy report for RF-DETR forklift detection.
Compares:
PREDICT : forklift-bbox_datumaro_predict_part3.json (model output — bbox only)
GT : datumaro_GT_part3.json (QA-finalized ground truth — bbox + keypoints)
1000 frames split into 10 jobs of 100 frames each (Job 1 = frames 0-99, ...).
────────────────────────────────────────────────────────────────────────────────
Annotation type flags
────────────────────────────────────────────────────────────────────────────────
--forklift-bbox forklift-with-roll + forklift-no-roll (bbox, IoU >= 0.85)
--pr-keypoints roll-keypoints (skeleton, 5 px centroid)
--clamp-keypoints clamp-2-arm + clamp-3-arm (skeleton, 5 px centroid)
If no flag is given, all three types are shown.
────────────────────────────────────────────────────────────────────────────────
Output columns (per job row)
────────────────────────────────────────────────────────────────────────────────
Job | No. Model | Error BBox/Keypoints | Missed* | Over-det** | Accuracy (%)
No. Model — total boxes/keypoints predicted by the model
Error — model predictions with no matching GT (IoU < threshold, or unmatched skeleton)
Missed* — GT annotations the model did not predict (annotator had to add manually)
Over-det** — model predicted more than GT (annotator had to delete the extra)
Accuracy — Precision: (No. Model - Error) / No. Model × 100
Measures what fraction of model output was correct.
Does NOT penalise for missed GT annotations.
────────────────────────────────────────────────────────────────────────────────
Usage
────────────────────────────────────────────────────────────────────────────────
python scripts/compare_rfdetr_forklift.py
python scripts/compare_rfdetr_forklift.py --forklift-bbox
python scripts/compare_rfdetr_forklift.py --pr-keypoints
python scripts/compare_rfdetr_forklift.py --clamp-keypoints
python scripts/compare_rfdetr_forklift.py --forklift-bbox --job 1
python scripts/compare_rfdetr_forklift.py --forklift-bbox --threshold 10
python scripts/compare_rfdetr_forklift.py --forklift-bbox --csv out.csv
python scripts/compare_rfdetr_forklift.py --predict other_pred.json --gt other_gt.json
"""
import argparse
import csv
import json
import math
from collections import defaultdict
from pathlib import Path
import numpy as np
# ── Default file paths ────────────────────────────────────────────────────────
_BASE = Path(__file__).parent.parent / "data" / "rfdter" / "object-detection" / "forklift" / "run01"
DEFAULT_PRED = str(_BASE / "forklift-bbox_datumaro_predict_part3.json")
DEFAULT_GT = str(_BASE / "datumaro_GT_part3.json")
# ── Keypoint visibility encoding ─────────────────────────────────────────────
# Datumaro skeleton format: points = [x0, y0, vis0, x1, y1, vis1, ...]
# vis == 0 means the keypoint is absent / not annotated
VIS_ABSENT = 0
# ── Annotation type groups ────────────────────────────────────────────────────
# Maps each report section to the GT label names it covers
FORKLIFT_BBOX_LABELS = {"forklift-with-roll", "forklift-no-roll"}
PR_KP_LABELS = {"roll-keypoints"}
CLAMP_KP_LABELS = {"clamp-2-arm", "clamp-3-arm"}
# Datumaro annotation type strings used for filtering
FORKLIFT_BBOX_TYPES = {"bbox"}
KP_TYPES = {"skeleton"}
FRAMES_PER_JOB = 100
# ── I/O ───────────────────────────────────────────────────────────────────────
def load_json(path: str) -> dict:
"""Read a Datumaro-format JSON export file."""
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
def build_label_map(data: dict) -> dict[int, str]:
"""Return {label_id: label_name} from a Datumaro dataset's categories block."""
labels = data.get("categories", {}).get("label", {}).get("labels", [])
return {i: lbl["name"] for i, lbl in enumerate(labels)}
# ── BBox IoU ──────────────────────────────────────────────────────────────────
def bbox_iou(b1: list, b2: list) -> float:
x1, y1, w1, h1 = b1
x2, y2, w2, h2 = b2
xi1, yi1 = max(x1, x2), max(y1, y2)
xi2, yi2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
inter = max(0.0, xi2 - xi1) * max(0.0, yi2 - yi1)
union = w1 * h1 + w2 * h2 - inter
return inter / union if union > 0 else 0.0
def match_boxes_by_iou(
pred: list[dict],
gt: list[dict],
iou_thresh: float = 0.85,
pred_labels: list[str] | None = None,
gt_labels: list[str] | None = None,
) -> tuple[list[tuple], list[dict], list[dict]]:
"""
Greedy highest-IoU-first matching between predicted and GT bounding boxes.
Returns (matched_pairs, unmatched_pred, unmatched_gt).
A pair is formed only when IoU >= iou_thresh.
If pred_labels and gt_labels are provided, a pair is only eligible when
pred_labels[i] == gt_labels[j] (class-label check enabled).
"""
if not pred or not gt:
return [], pred[:], gt[:]
n_p, n_g = len(pred), len(gt)
mat = np.zeros((n_p, n_g))
for i, pa in enumerate(pred):
for j, ga in enumerate(gt):
# Zero out IoU for label mismatches when class check is enabled
if pred_labels is not None and gt_labels is not None:
if pred_labels[i] != gt_labels[j]:
continue
mat[i, j] = bbox_iou(pa["bbox"], ga["bbox"])
matched, used_p, used_g = [], set(), set()
while True:
idx = np.unravel_index(np.argmax(mat), mat.shape)
if mat[idx] < iou_thresh:
break
i, j = int(idx[0]), int(idx[1])
matched.append((pred[i], gt[j]))
used_p.add(i)
used_g.add(j)
mat[i, :] = -1.0
mat[:, j] = -1.0
return (
matched,
[pred[i] for i in range(n_p) if i not in used_p],
[gt[j] for j in range(n_g) if j not in used_g],
)
# ── Skeleton matching & keypoint counting ────────────────────────────────────
def _centroid(ann: dict) -> tuple[float, float]:
"""Return the mean (x, y) of all visible keypoints in a skeleton annotation."""
pts = ann.get("points", [])
n = len(pts) // 3
xs = [pts[i * 3] for i in range(n) if pts[i * 3 + 2] > VIS_ABSENT]
ys = [pts[i * 3 + 1] for i in range(n) if pts[i * 3 + 2] > VIS_ABSENT]
if not xs:
return 0.0, 0.0
return float(sum(xs) / len(xs)), float(sum(ys) / len(ys))
def match_skeletons(
pred: list[dict],
gt: list[dict],
dist_thresh: float = 300.0,
) -> tuple[list[tuple], list[dict], list[dict]]:
"""
Greedy nearest-centroid matching between predicted and GT skeleton annotations.
Returns (matched_pairs, unmatched_pred, unmatched_gt).
A pair is formed only when centroid distance <= dist_thresh pixels.
Matching is done per label class (called separately for each label).
"""
if not pred or not gt:
return [], pred[:], gt[:]
n_p, n_g = len(pred), len(gt)
mat = np.full((n_p, n_g), np.inf)
pc = [_centroid(a) for a in pred]
gc = [_centroid(a) for a in gt]
for i in range(n_p):
for j in range(n_g):
mat[i, j] = math.hypot(pc[i][0] - gc[j][0], pc[i][1] - gc[j][1])
matched, used_p, used_g = [], set(), set()
while True:
idx = np.unravel_index(np.argmin(mat), mat.shape)
if mat[idx] > dist_thresh:
break
i, j = int(idx[0]), int(idx[1])
matched.append((pred[i], gt[j]))
used_p.add(i)
used_g.add(j)
mat[i, :] = np.inf
mat[:, j] = np.inf
return (
matched,
[pred[i] for i in range(n_p) if i not in used_p],
[gt[j] for j in range(n_g) if j not in used_g],
)
def kp_present(ann: dict) -> int:
"""Count all keypoints in a skeleton annotation regardless of visibility flag."""
pts = ann.get("points", [])
return len(pts) // 3
def kp_errors(pred_ann: dict, gt_ann: dict, threshold_px: float) -> int:
"""
Count keypoints where the predicted position is too far from the GT position.
Compares ALL N corresponding keypoints regardless of visibility flag.
Each point where distance > threshold_px counts as one error.
"""
pred_pts = pred_ann.get("points", [])
gt_pts = gt_ann.get("points", [])
n = min(len(pred_pts), len(gt_pts)) // 3
return sum(
1 for i in range(n)
if math.hypot(pred_pts[i*3] - gt_pts[i*3],
pred_pts[i*3+1] - gt_pts[i*3+1]) > threshold_px
)
def kp_manual(pred_ann: dict, gt_ann: dict) -> int:
"""
Count keypoints that are visible in GT but absent in the prediction.
These are keypoints the annotator had to add manually during QA review.
"""
pred_pts = pred_ann.get("points", [])
gt_pts = gt_ann.get("points", [])
n = min(len(pred_pts), len(gt_pts)) // 3
return sum(
1 for i in range(n)
if gt_pts[i*3+2] != VIS_ABSENT and pred_pts[i*3+2] == VIS_ABSENT
)
# ── Per-job stats accumulator ────────────────────────────────────────────────
def empty_stats() -> dict:
"""
Initialise a zero-filled stats dict for one (job, annotation-type) pair.
Fields
------
no_model : total boxes / keypoints predicted by the model
error_model : model predictions with no matching GT (wrong position or unmatched)
no_missed : GT annotations the model did not predict (annotator added manually)
no_over : model predicted more than GT (annotator had to delete the extras)
"""
return {
"no_model": 0,
"error_model": 0,
"no_missed": 0, # GT count > model count — annotator had to add
"no_over": 0, # model count > GT count — model predicted too many
}
def accumulate_bbox(
pred_anns: list[dict],
gt_anns: list[dict],
target_labels: set[str],
pred_lmap: dict,
gt_lmap: dict,
stats: dict,
label_match: bool = False,
):
p = [a for a in pred_anns if pred_lmap.get(a["label_id"]) in target_labels and a["type"] == "bbox"]
g = [a for a in gt_anns if gt_lmap.get(a["label_id"]) in target_labels and a["type"] == "bbox"]
# no_model = total predicted boxes in this frame
# error_model = predicted boxes with no IoU match in GT (wrong/spurious position)
# no_missed = GT has more boxes → delta positive (annotator had to draw them)
# no_over = model has more boxes → delta negative (annotator had to delete them)
pred_labels = [pred_lmap.get(a["label_id"]) for a in p] if label_match else None
gt_labels = [gt_lmap.get(a["label_id"]) for a in g] if label_match else None
_, unmatched_p, _ = match_boxes_by_iou(p, g, pred_labels=pred_labels, gt_labels=gt_labels)
delta = len(g) - len(p) # positive = missed, negative = over-detected
stats["no_model"] += len(p)
stats["no_missed"] += max(0, delta)
stats["no_over"] += max(0, -delta)
stats["error_model"] += len(unmatched_p)
def accumulate_kp(
pred_anns: list[dict],
gt_anns: list[dict],
target_labels: set[str],
pred_lmap: dict,
gt_lmap: dict,
stats: dict,
threshold_px: float,
):
# Each skeleton label is matched independently.
# A skeleton paired to the wrong class counts as an error.
for label in target_labels:
p = [a for a in pred_anns if pred_lmap.get(a["label_id"]) == label and a["type"] == "skeleton"]
g = [a for a in gt_anns if gt_lmap.get(a["label_id"]) == label and a["type"] == "skeleton"]
matched, unmatched_p, unmatched_g = match_skeletons(p, g)
# No. Model = all KP slots in every pred skeleton (vis ignored)
for pa in p:
stats["no_model"] += kp_present(pa)
# Error = per-point position error in matched skeleton pairs (vis ignored)
for pa, ga in matched:
stats["error_model"] += kp_errors(pa, ga, threshold_px)
# Over-det = all points of pred skeletons with no matching GT skeleton
for pa in unmatched_p:
stats["no_over"] += kp_present(pa)
# Missed = all points of GT skeletons the model did not predict
for ga in unmatched_g:
stats["no_missed"] += kp_present(ga)
# ── Main compare ──────────────────────────────────────────────────────────────
def run_compare(
pred_path: str,
gt_path: str,
show_bbox: bool,
show_pr_kp: bool,
show_clamp_kp: bool,
threshold_px: float,
job_filter: int | None,
label_match: bool = False,
) -> dict:
"""
Load predict and GT files, iterate over all frames, and accumulate per-job stats.
Returns
-------
{
job_id (1-10): {
"forklift_bbox": {"no_model": int, "error_model": int,
"no_missed": int, "no_over": int},
"pr_keypoints": {...},
"clamp_keypoints": {...},
},
...
}
"""
pred_data = load_json(pred_path)
gt_data = load_json(gt_path)
pred_lmap = build_label_map(pred_data)
gt_lmap = build_label_map(gt_data)
pred_by_id = {it["id"]: it for it in pred_data["items"]}
# Sort GT items by sequential frame index so job assignment is deterministic
gt_sorted = sorted(gt_data["items"], key=lambda x: x["attr"].get("frame", 0))
jobs: dict[int, dict] = {}
for seq_idx, gt_item in enumerate(gt_sorted):
job_id = seq_idx // FRAMES_PER_JOB + 1
if job_filter is not None and job_id != job_filter:
continue
if job_id not in jobs:
jobs[job_id] = {
"forklift_bbox": empty_stats(),
"pr_keypoints": empty_stats(),
"clamp_keypoints": empty_stats(),
}
pred_item = pred_by_id.get(gt_item["id"], {"annotations": []})
pred_anns = pred_item.get("annotations", [])
gt_anns = gt_item.get("annotations", [])
if show_bbox:
accumulate_bbox(pred_anns, gt_anns, FORKLIFT_BBOX_LABELS, pred_lmap, gt_lmap,
jobs[job_id]["forklift_bbox"], label_match=label_match)
if show_pr_kp:
accumulate_kp(pred_anns, gt_anns, PR_KP_LABELS, pred_lmap, gt_lmap,
jobs[job_id]["pr_keypoints"], threshold_px)
if show_clamp_kp:
accumulate_kp(pred_anns, gt_anns, CLAMP_KP_LABELS, pred_lmap, gt_lmap,
jobs[job_id]["clamp_keypoints"], threshold_px)
return jobs
# ── Display helpers ──────────────────────────────────────────────────────────
def _acc(no_model: int, error: int) -> str:
"""
Compute precision-based pre-label accuracy.
Formula: (no_model - error) / no_model * 100
Measures what fraction of model predictions were correct.
Returns 'N/A' when the model produced no predictions.
"""
if no_model == 0:
return "N/A"
return f"{(no_model - error) / no_model * 100:.1f}%"
TYPE_LABELS = {
"forklift_bbox": "Forklift BBox (forklift-with-roll + forklift-no-roll)",
"pr_keypoints": "PR Keypoints (roll-keypoints)",
"clamp_keypoints": "Clamp KP (clamp-2-arm + clamp-3-arm)",
}
TYPE_KEYS = ["forklift_bbox", "pr_keypoints", "clamp_keypoints"]
SHOW_FLAGS = {
"forklift_bbox": "show_bbox",
"pr_keypoints": "show_pr_kp",
"clamp_keypoints": "show_clamp_kp",
}
def print_report(
jobs: dict,
pred_file: str,
gt_file: str,
threshold_px: float,
show_bbox: bool,
show_pr_kp: bool,
show_clamp_kp: bool,
label_match: bool = False,
):
W = 85
show_map = {
"forklift_bbox": show_bbox,
"pr_keypoints": show_pr_kp,
"clamp_keypoints": show_clamp_kp,
}
active_types = [k for k in TYPE_KEYS if show_map[k]]
print("\n" + "=" * W)
print(" RF-DETR FORKLIFT — PRE-LABEL ACCURACY REPORT (per job)")
print("=" * W)
print(f" Predicted : {pred_file}")
print(f" GT : {gt_file}")
print(f" Threshold : {threshold_px} px (keypoint position error)")
print(f" BBox match: IoU >= 0.85{' + same class label' if label_match else ' (position only, no class check)'}")
print(f" Jobs : {min(jobs) if jobs else '—'} – {max(jobs) if jobs else '—'} ({FRAMES_PER_JOB} frames each)")
print("=" * W)
# Column widths for the per-job table
# Job | No. Model | Error | Missed* | Over-det** | Accuracy (%)
C_JOB, C1, C2, C3, C4, C5 = 7, 12, 9, 10, 12, 14
SEP = C_JOB + C1 + C2 + C3 + C4 + C5
for type_key in active_types:
print(f"\n ── {TYPE_LABELS[type_key]} ──")
print(
f" {'Job':>{C_JOB}}"
f"{'No. Model':>{C1}}"
f"{'Error':>{C2}}"
f"{'Missed*':>{C3}}"
f"{'Over-det**':>{C4}}"
f"{'Accuracy (%)':>{C5}}"
)
print(" * Missed = GT > Model (model missed — annotator drew these manually)")
print(" ** Over-det = Model > GT (model over-detected — annotator deleted these)")
print(" " + "─" * SEP)
total = empty_stats()
for job_id in sorted(jobs):
s = jobs[job_id][type_key]
print(
f" {job_id:>{C_JOB}}"
f"{s['no_model']:>{C1}}"
f"{s['error_model']:>{C2}}"
f"{s['no_missed']:>{C3}}"
f"{s['no_over']:>{C4}}"
f"{_acc(s['no_model'], s['error_model']):>{C5}}"
)
for k in total:
total[k] += s[k]
# Total row
print(" " + "─" * SEP)
print(
f" {'TOTAL':>{C_JOB}}"
f"{total['no_model']:>{C1}}"
f"{total['error_model']:>{C2}}"
f"{total['no_missed']:>{C3}}"
f"{total['no_over']:>{C4}}"
f"{_acc(total['no_model'], total['error_model']):>{C5}}"
)
print("\n" + "=" * W + "\n")
def export_csv(
jobs: dict,
csv_path: str,
show_bbox: bool,
show_pr_kp: bool,
show_clamp_kp: bool,
):
show_map = {
"forklift_bbox": show_bbox,
"pr_keypoints": show_pr_kp,
"clamp_keypoints": show_clamp_kp,
}
active_types = [k for k in TYPE_KEYS if show_map[k]]
FIELDS = ["job", "type", "no_model", "error_model", "no_missed", "no_over", "accuracy_pct"]
HEADERS = {
"job": "Job",
"type": "Annotation Type",
"no_model": "No. BBox/KP (Model)",
"error_model": "Error BBox/KP (Model)",
"no_missed": "Missed — GT > Model (Manual Add)",
"no_over": "Over-detect — Model > GT (Spurious)",
"accuracy_pct": "Pre-label Accuracy (%)",
}
rows = []
for type_key in active_types:
total = empty_stats()
for job_id in sorted(jobs):
s = jobs[job_id][type_key]
no_m = s["no_model"]
err = s["error_model"]
rows.append({
"job": job_id,
"type": TYPE_LABELS[type_key].split("(")[0].strip(),
"no_model": no_m,
"error_model": err,
"no_missed": s["no_missed"],
"no_over": s["no_over"],
"accuracy_pct": f"{(no_m - err) / no_m * 100:.1f}" if no_m > 0 else "",
})
for k in total:
total[k] += s[k]
# Total row per type
rows.append({
"job": "TOTAL",
"type": TYPE_LABELS[type_key].split("(")[0].strip(),
"no_model": total["no_model"],
"error_model": total["error_model"],
"no_missed": total["no_missed"],
"no_over": total["no_over"],
"accuracy_pct": _acc(total["no_model"], total["error_model"]).replace("%", ""),
})
rows.append(dict.fromkeys(FIELDS, "")) # blank separator
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=FIELDS)
writer.writerow({k: HEADERS[k] for k in FIELDS})
writer.writerows(rows)
print(f" CSV exported → {csv_path}")
# ── CLI ───────────────────────────────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="Per-job pre-label accuracy report for RF-DETR forklift part3.",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
# Type selection flags
parser.add_argument("--forklift-bbox", action="store_true",
help="Show Forklift BBox stats (forklift-with-roll + forklift-no-roll)")
parser.add_argument("--pr-keypoints", action="store_true",
help="Show Paper Roll Keypoints stats (roll-keypoints)")
parser.add_argument("--clamp-keypoints", action="store_true",
help="Show Clamp Keypoints stats (clamp-2-arm + clamp-3-arm)")
parser.add_argument("--label-match", action="store_true",
help="BBox: require pred and GT to have the same class label to match "
"(default: off — position-only IoU matching).")
# Options
parser.add_argument("--job", type=int, metavar="N",
help="Show only job N (1–10). Default: all jobs.")
parser.add_argument("--threshold", type=float, default=5.0,
help="Keypoint error threshold in pixels (default: 5).")
parser.add_argument("--csv", dest="csv_path",
help="Export results to a CSV file.")
parser.add_argument("--predict", default=DEFAULT_PRED,
help=f"Path to predict JSON (default: {DEFAULT_PRED})")
parser.add_argument("--gt", default=DEFAULT_GT,
help=f"Path to GT JSON (default: {DEFAULT_GT})")
args = parser.parse_args()
# If no type flag given → show all
show_bbox = args.forklift_bbox
show_pr_kp = args.pr_keypoints
show_clamp_kp = args.clamp_keypoints
if not show_bbox and not show_pr_kp and not show_clamp_kp:
show_bbox = show_pr_kp = show_clamp_kp = True
if not Path(args.predict).exists():
print(f"Error: predict file not found: {args.predict}")
return
if not Path(args.gt).exists():
print(f"Error: GT file not found: {args.gt}")
return
jobs = run_compare(
pred_path=args.predict,
gt_path=args.gt,
show_bbox=show_bbox,
show_pr_kp=show_pr_kp,
show_clamp_kp=show_clamp_kp,
threshold_px=args.threshold,
job_filter=args.job,
label_match=args.label_match,
)
print_report(
jobs,
pred_file=Path(args.predict).name,
gt_file=Path(args.gt).name,
threshold_px=args.threshold,
show_bbox=show_bbox,
show_pr_kp=show_pr_kp,
show_clamp_kp=show_clamp_kp,
label_match=args.label_match,
)
if args.csv_path:
export_csv(jobs, args.csv_path, show_bbox, show_pr_kp, show_clamp_kp)
if __name__ == "__main__":
main()