#!/usr/bin/env python3
"""
compare_rfdetr_forklift.py — Per-job pre-label accuracy report for RF-DETR forklift detection.

Compares:
  PREDICT : forklift-bbox_datumaro_predict_part3.json  (model output — bbox only)
  GT      : datumaro_GT_part3.json                     (QA-finalized ground truth — bbox + keypoints)

1000 frames split into 10 jobs of 100 frames each (Job 1 = frames 0-99, ...).

────────────────────────────────────────────────────────────────────────────────
Annotation type flags
────────────────────────────────────────────────────────────────────────────────
  --forklift-bbox    forklift-with-roll + forklift-no-roll  (bbox,     IoU >= 0.85)
  --pr-keypoints     roll-keypoints                         (skeleton, 5 px centroid)
  --clamp-keypoints  clamp-2-arm + clamp-3-arm              (skeleton, 5 px centroid)

  If no flag is given, all three types are shown.

────────────────────────────────────────────────────────────────────────────────
Output columns (per job row)
────────────────────────────────────────────────────────────────────────────────
  Job | No. Model | Error BBox/Keypoints | Missed* | Over-det** | Accuracy (%)

  No. Model   — total boxes/keypoints predicted by the model
  Error       — model predictions with no matching GT  (IoU < threshold, or unmatched skeleton)
  Missed*     — GT annotations the model did not predict  (annotator had to add manually)
  Over-det**  — model predicted more than GT  (annotator had to delete the extra)
  Accuracy    — Precision: (No. Model - Error) / No. Model × 100
                Measures what fraction of model output was correct.
                Does NOT penalise for missed GT annotations.

────────────────────────────────────────────────────────────────────────────────
Usage
────────────────────────────────────────────────────────────────────────────────
  python scripts/compare_rfdetr_forklift.py
  python scripts/compare_rfdetr_forklift.py --forklift-bbox
  python scripts/compare_rfdetr_forklift.py --pr-keypoints
  python scripts/compare_rfdetr_forklift.py --clamp-keypoints
  python scripts/compare_rfdetr_forklift.py --forklift-bbox --job 1
  python scripts/compare_rfdetr_forklift.py --forklift-bbox --threshold 10
  python scripts/compare_rfdetr_forklift.py --forklift-bbox --csv out.csv
  python scripts/compare_rfdetr_forklift.py --predict other_pred.json --gt other_gt.json
"""

import argparse
import csv
import json
import math
from collections import defaultdict
from pathlib import Path

import numpy as np

# ── Default file paths ────────────────────────────────────────────────────────
_BASE = Path(__file__).parent.parent / "data" / "rfdter" / "object-detection" / "forklift" / "run01"
DEFAULT_PRED = str(_BASE / "forklift-bbox_datumaro_predict_part3.json")
DEFAULT_GT   = str(_BASE / "datumaro_GT_part3.json")

# ── Keypoint visibility encoding ─────────────────────────────────────────────
# Datumaro skeleton format: points = [x0, y0, vis0, x1, y1, vis1, ...]
# vis == 0 means the keypoint is absent / not annotated
VIS_ABSENT = 0

# ── Annotation type groups ────────────────────────────────────────────────────
# Maps each report section to the GT label names it covers
FORKLIFT_BBOX_LABELS  = {"forklift-with-roll", "forklift-no-roll"}
PR_KP_LABELS          = {"roll-keypoints"}
CLAMP_KP_LABELS       = {"clamp-2-arm", "clamp-3-arm"}

# Datumaro annotation type strings used for filtering
FORKLIFT_BBOX_TYPES   = {"bbox"}
KP_TYPES              = {"skeleton"}

FRAMES_PER_JOB = 100


# ── I/O ───────────────────────────────────────────────────────────────────────

def load_json(path: str) -> dict:
    """Read a Datumaro-format JSON export file."""
    with open(path, "r", encoding="utf-8") as f:
        return json.load(f)


def build_label_map(data: dict) -> dict[int, str]:
    """Return {label_id: label_name} from a Datumaro dataset's categories block."""
    labels = data.get("categories", {}).get("label", {}).get("labels", [])
    return {i: lbl["name"] for i, lbl in enumerate(labels)}


# ── BBox IoU ──────────────────────────────────────────────────────────────────

def bbox_iou(b1: list, b2: list) -> float:
    x1, y1, w1, h1 = b1
    x2, y2, w2, h2 = b2
    xi1, yi1 = max(x1, x2), max(y1, y2)
    xi2, yi2 = min(x1 + w1, x2 + w2), min(y1 + h1, y2 + h2)
    inter = max(0.0, xi2 - xi1) * max(0.0, yi2 - yi1)
    union = w1 * h1 + w2 * h2 - inter
    return inter / union if union > 0 else 0.0


def match_boxes_by_iou(
    pred: list[dict],
    gt: list[dict],
    iou_thresh: float = 0.85,
    pred_labels: list[str] | None = None,
    gt_labels: list[str] | None = None,
) -> tuple[list[tuple], list[dict], list[dict]]:
    """
    Greedy highest-IoU-first matching between predicted and GT bounding boxes.

    Returns (matched_pairs, unmatched_pred, unmatched_gt).
    A pair is formed only when IoU >= iou_thresh.
    If pred_labels and gt_labels are provided, a pair is only eligible when
    pred_labels[i] == gt_labels[j]  (class-label check enabled).
    """
    if not pred or not gt:
        return [], pred[:], gt[:]

    n_p, n_g = len(pred), len(gt)
    mat = np.zeros((n_p, n_g))
    for i, pa in enumerate(pred):
        for j, ga in enumerate(gt):
            # Zero out IoU for label mismatches when class check is enabled
            if pred_labels is not None and gt_labels is not None:
                if pred_labels[i] != gt_labels[j]:
                    continue
            mat[i, j] = bbox_iou(pa["bbox"], ga["bbox"])

    matched, used_p, used_g = [], set(), set()
    while True:
        idx = np.unravel_index(np.argmax(mat), mat.shape)
        if mat[idx] < iou_thresh:
            break
        i, j = int(idx[0]), int(idx[1])
        matched.append((pred[i], gt[j]))
        used_p.add(i)
        used_g.add(j)
        mat[i, :] = -1.0
        mat[:, j] = -1.0

    return (
        matched,
        [pred[i] for i in range(n_p) if i not in used_p],
        [gt[j]   for j in range(n_g) if j not in used_g],
    )


# ── Skeleton matching & keypoint counting ────────────────────────────────────

def _centroid(ann: dict) -> tuple[float, float]:
    """Return the mean (x, y) of all visible keypoints in a skeleton annotation."""
    pts = ann.get("points", [])
    n   = len(pts) // 3
    xs  = [pts[i * 3]     for i in range(n) if pts[i * 3 + 2] > VIS_ABSENT]
    ys  = [pts[i * 3 + 1] for i in range(n) if pts[i * 3 + 2] > VIS_ABSENT]
    if not xs:
        return 0.0, 0.0
    return float(sum(xs) / len(xs)), float(sum(ys) / len(ys))


def match_skeletons(
    pred: list[dict],
    gt: list[dict],
    dist_thresh: float = 300.0,
) -> tuple[list[tuple], list[dict], list[dict]]:
    """
    Greedy nearest-centroid matching between predicted and GT skeleton annotations.

    Returns (matched_pairs, unmatched_pred, unmatched_gt).
    A pair is formed only when centroid distance <= dist_thresh pixels.
    Matching is done per label class (called separately for each label).
    """
    if not pred or not gt:
        return [], pred[:], gt[:]

    n_p, n_g = len(pred), len(gt)
    mat = np.full((n_p, n_g), np.inf)
    pc  = [_centroid(a) for a in pred]
    gc  = [_centroid(a) for a in gt]
    for i in range(n_p):
        for j in range(n_g):
            mat[i, j] = math.hypot(pc[i][0] - gc[j][0], pc[i][1] - gc[j][1])

    matched, used_p, used_g = [], set(), set()
    while True:
        idx = np.unravel_index(np.argmin(mat), mat.shape)
        if mat[idx] > dist_thresh:
            break
        i, j = int(idx[0]), int(idx[1])
        matched.append((pred[i], gt[j]))
        used_p.add(i)
        used_g.add(j)
        mat[i, :] = np.inf
        mat[:, j] = np.inf

    return (
        matched,
        [pred[i] for i in range(n_p) if i not in used_p],
        [gt[j]   for j in range(n_g) if j not in used_g],
    )


def kp_present(ann: dict) -> int:
    """Count all keypoints in a skeleton annotation regardless of visibility flag."""
    pts = ann.get("points", [])
    return len(pts) // 3


def kp_errors(pred_ann: dict, gt_ann: dict, threshold_px: float) -> int:
    """
    Count keypoints where the predicted position is too far from the GT position.

    Compares ALL N corresponding keypoints regardless of visibility flag.
    Each point where distance > threshold_px counts as one error.
    """
    pred_pts = pred_ann.get("points", [])
    gt_pts   = gt_ann.get("points",  [])
    n        = min(len(pred_pts), len(gt_pts)) // 3
    return sum(
        1 for i in range(n)
        if math.hypot(pred_pts[i*3] - gt_pts[i*3],
                      pred_pts[i*3+1] - gt_pts[i*3+1]) > threshold_px
    )


def kp_manual(pred_ann: dict, gt_ann: dict) -> int:
    """
    Count keypoints that are visible in GT but absent in the prediction.

    These are keypoints the annotator had to add manually during QA review.
    """
    pred_pts = pred_ann.get("points", [])
    gt_pts   = gt_ann.get("points",  [])
    n        = min(len(pred_pts), len(gt_pts)) // 3
    return sum(
        1 for i in range(n)
        if gt_pts[i*3+2] != VIS_ABSENT and pred_pts[i*3+2] == VIS_ABSENT
    )


# ── Per-job stats accumulator ────────────────────────────────────────────────

def empty_stats() -> dict:
    """
    Initialise a zero-filled stats dict for one (job, annotation-type) pair.

    Fields
    ------
    no_model    : total boxes / keypoints predicted by the model
    error_model : model predictions with no matching GT  (wrong position or unmatched)
    no_missed   : GT annotations the model did not predict  (annotator added manually)
    no_over     : model predicted more than GT  (annotator had to delete the extras)
    """
    return {
        "no_model":    0,
        "error_model": 0,
        "no_missed":   0,   # GT count > model count  — annotator had to add
        "no_over":     0,   # model count > GT count  — model predicted too many
    }


def accumulate_bbox(
    pred_anns: list[dict],
    gt_anns: list[dict],
    target_labels: set[str],
    pred_lmap: dict,
    gt_lmap: dict,
    stats: dict,
    label_match: bool = False,
):
    p = [a for a in pred_anns if pred_lmap.get(a["label_id"]) in target_labels and a["type"] == "bbox"]
    g = [a for a in gt_anns   if gt_lmap.get(a["label_id"])   in target_labels and a["type"] == "bbox"]

    # no_model    = total predicted boxes in this frame
    # error_model = predicted boxes with no IoU match in GT  (wrong/spurious position)
    # no_missed   = GT has more boxes → delta positive  (annotator had to draw them)
    # no_over     = model has more boxes → delta negative  (annotator had to delete them)
    pred_labels = [pred_lmap.get(a["label_id"]) for a in p] if label_match else None
    gt_labels   = [gt_lmap.get(a["label_id"])   for a in g] if label_match else None
    _, unmatched_p, _ = match_boxes_by_iou(p, g, pred_labels=pred_labels, gt_labels=gt_labels)
    delta = len(g) - len(p)   # positive = missed, negative = over-detected
    stats["no_model"]    += len(p)
    stats["no_missed"]   += max(0,  delta)
    stats["no_over"]     += max(0, -delta)
    stats["error_model"] += len(unmatched_p)


def accumulate_kp(
    pred_anns: list[dict],
    gt_anns: list[dict],
    target_labels: set[str],
    pred_lmap: dict,
    gt_lmap: dict,
    stats: dict,
    threshold_px: float,
):
    # Each skeleton label is matched independently.
    # A skeleton paired to the wrong class counts as an error.
    for label in target_labels:
        p = [a for a in pred_anns if pred_lmap.get(a["label_id"]) == label and a["type"] == "skeleton"]
        g = [a for a in gt_anns   if gt_lmap.get(a["label_id"])   == label and a["type"] == "skeleton"]

        matched, unmatched_p, unmatched_g = match_skeletons(p, g)

        # No. Model = all KP slots in every pred skeleton (vis ignored)
        for pa in p:
            stats["no_model"] += kp_present(pa)

        # Error = per-point position error in matched skeleton pairs (vis ignored)
        for pa, ga in matched:
            stats["error_model"] += kp_errors(pa, ga, threshold_px)

        # Over-det = all points of pred skeletons with no matching GT skeleton
        for pa in unmatched_p:
            stats["no_over"] += kp_present(pa)

        # Missed = all points of GT skeletons the model did not predict
        for ga in unmatched_g:
            stats["no_missed"] += kp_present(ga)


# ── Main compare ──────────────────────────────────────────────────────────────

def run_compare(
    pred_path: str,
    gt_path: str,
    show_bbox: bool,
    show_pr_kp: bool,
    show_clamp_kp: bool,
    threshold_px: float,
    job_filter: int | None,
    label_match: bool = False,
) -> dict:
    """
    Load predict and GT files, iterate over all frames, and accumulate per-job stats.

    Returns
    -------
    {
      job_id (1-10): {
        "forklift_bbox":   {"no_model": int, "error_model": int,
                            "no_missed": int, "no_over": int},
        "pr_keypoints":    {...},
        "clamp_keypoints": {...},
      },
      ...
    }
    """
    pred_data = load_json(pred_path)
    gt_data   = load_json(gt_path)

    pred_lmap = build_label_map(pred_data)
    gt_lmap   = build_label_map(gt_data)

    pred_by_id = {it["id"]: it for it in pred_data["items"]}

    # Sort GT items by sequential frame index so job assignment is deterministic
    gt_sorted = sorted(gt_data["items"], key=lambda x: x["attr"].get("frame", 0))

    jobs: dict[int, dict] = {}

    for seq_idx, gt_item in enumerate(gt_sorted):
        job_id = seq_idx // FRAMES_PER_JOB + 1
        if job_filter is not None and job_id != job_filter:
            continue

        if job_id not in jobs:
            jobs[job_id] = {
                "forklift_bbox":   empty_stats(),
                "pr_keypoints":    empty_stats(),
                "clamp_keypoints": empty_stats(),
            }

        pred_item = pred_by_id.get(gt_item["id"], {"annotations": []})
        pred_anns = pred_item.get("annotations", [])
        gt_anns   = gt_item.get("annotations", [])

        if show_bbox:
            accumulate_bbox(pred_anns, gt_anns, FORKLIFT_BBOX_LABELS, pred_lmap, gt_lmap,
                            jobs[job_id]["forklift_bbox"], label_match=label_match)

        if show_pr_kp:
            accumulate_kp(pred_anns, gt_anns, PR_KP_LABELS, pred_lmap, gt_lmap,
                          jobs[job_id]["pr_keypoints"], threshold_px)

        if show_clamp_kp:
            accumulate_kp(pred_anns, gt_anns, CLAMP_KP_LABELS, pred_lmap, gt_lmap,
                          jobs[job_id]["clamp_keypoints"], threshold_px)

    return jobs


# ── Display helpers ──────────────────────────────────────────────────────────

def _acc(no_model: int, error: int) -> str:
    """
    Compute precision-based pre-label accuracy.

    Formula: (no_model - error) / no_model * 100
    Measures what fraction of model predictions were correct.
    Returns 'N/A' when the model produced no predictions.
    """
    if no_model == 0:
        return "N/A"
    return f"{(no_model - error) / no_model * 100:.1f}%"


TYPE_LABELS = {
    "forklift_bbox":   "Forklift BBox (forklift-with-roll + forklift-no-roll)",
    "pr_keypoints":    "PR Keypoints  (roll-keypoints)",
    "clamp_keypoints": "Clamp KP      (clamp-2-arm + clamp-3-arm)",
}

TYPE_KEYS = ["forklift_bbox", "pr_keypoints", "clamp_keypoints"]
SHOW_FLAGS = {
    "forklift_bbox":   "show_bbox",
    "pr_keypoints":    "show_pr_kp",
    "clamp_keypoints": "show_clamp_kp",
}


def print_report(
    jobs: dict,
    pred_file: str,
    gt_file: str,
    threshold_px: float,
    show_bbox: bool,
    show_pr_kp: bool,
    show_clamp_kp: bool,
    label_match: bool = False,
):
    W = 85
    show_map = {
        "forklift_bbox":   show_bbox,
        "pr_keypoints":    show_pr_kp,
        "clamp_keypoints": show_clamp_kp,
    }
    active_types = [k for k in TYPE_KEYS if show_map[k]]

    print("\n" + "=" * W)
    print("  RF-DETR FORKLIFT — PRE-LABEL ACCURACY REPORT (per job)")
    print("=" * W)
    print(f"  Predicted : {pred_file}")
    print(f"  GT        : {gt_file}")
    print(f"  Threshold : {threshold_px} px  (keypoint position error)")
    print(f"  BBox match: IoU >= 0.85{'  +  same class label' if label_match else '  (position only, no class check)'}")
    print(f"  Jobs      : {min(jobs) if jobs else '—'} – {max(jobs) if jobs else '—'}  ({FRAMES_PER_JOB} frames each)")
    print("=" * W)

    # Column widths for the per-job table
    # Job | No. Model | Error | Missed* | Over-det** | Accuracy (%)
    C_JOB, C1, C2, C3, C4, C5 = 7, 12, 9, 10, 12, 14
    SEP = C_JOB + C1 + C2 + C3 + C4 + C5

    for type_key in active_types:
        print(f"\n  ── {TYPE_LABELS[type_key]} ──")
        print(
            f"  {'Job':>{C_JOB}}"
            f"{'No. Model':>{C1}}"
            f"{'Error':>{C2}}"
            f"{'Missed*':>{C3}}"
            f"{'Over-det**':>{C4}}"
            f"{'Accuracy (%)':>{C5}}"
        )
        print("  * Missed    = GT > Model  (model missed — annotator drew these manually)")
        print("  ** Over-det = Model > GT  (model over-detected — annotator deleted these)")
        print("  " + "─" * SEP)

        total = empty_stats()
        for job_id in sorted(jobs):
            s = jobs[job_id][type_key]
            print(
                f"  {job_id:>{C_JOB}}"
                f"{s['no_model']:>{C1}}"
                f"{s['error_model']:>{C2}}"
                f"{s['no_missed']:>{C3}}"
                f"{s['no_over']:>{C4}}"
                f"{_acc(s['no_model'], s['error_model']):>{C5}}"
            )
            for k in total:
                total[k] += s[k]

        # Total row
        print("  " + "─" * SEP)
        print(
            f"  {'TOTAL':>{C_JOB}}"
            f"{total['no_model']:>{C1}}"
            f"{total['error_model']:>{C2}}"
            f"{total['no_missed']:>{C3}}"
            f"{total['no_over']:>{C4}}"
            f"{_acc(total['no_model'], total['error_model']):>{C5}}"
        )

    print("\n" + "=" * W + "\n")


def export_csv(
    jobs: dict,
    csv_path: str,
    show_bbox: bool,
    show_pr_kp: bool,
    show_clamp_kp: bool,
):
    show_map = {
        "forklift_bbox":   show_bbox,
        "pr_keypoints":    show_pr_kp,
        "clamp_keypoints": show_clamp_kp,
    }
    active_types = [k for k in TYPE_KEYS if show_map[k]]

    FIELDS = ["job", "type", "no_model", "error_model", "no_missed", "no_over", "accuracy_pct"]
    HEADERS = {
        "job":          "Job",
        "type":         "Annotation Type",
        "no_model":     "No. BBox/KP (Model)",
        "error_model":  "Error BBox/KP (Model)",
        "no_missed":    "Missed — GT > Model (Manual Add)",
        "no_over":      "Over-detect — Model > GT (Spurious)",
        "accuracy_pct": "Pre-label Accuracy (%)",
    }

    rows = []
    for type_key in active_types:
        total = empty_stats()
        for job_id in sorted(jobs):
            s = jobs[job_id][type_key]
            no_m = s["no_model"]
            err  = s["error_model"]
            rows.append({
                "job":          job_id,
                "type":         TYPE_LABELS[type_key].split("(")[0].strip(),
                "no_model":     no_m,
                "error_model":  err,
                "no_missed":    s["no_missed"],
                "no_over":      s["no_over"],
                "accuracy_pct": f"{(no_m - err) / no_m * 100:.1f}" if no_m > 0 else "",
            })
            for k in total:
                total[k] += s[k]
        # Total row per type
        rows.append({
            "job":          "TOTAL",
            "type":         TYPE_LABELS[type_key].split("(")[0].strip(),
            "no_model":     total["no_model"],
            "error_model":  total["error_model"],
            "no_missed":    total["no_missed"],
            "no_over":      total["no_over"],
            "accuracy_pct": _acc(total["no_model"], total["error_model"]).replace("%", ""),
        })
        rows.append(dict.fromkeys(FIELDS, ""))  # blank separator

    with open(csv_path, "w", newline="", encoding="utf-8") as f:
        writer = csv.DictWriter(f, fieldnames=FIELDS)
        writer.writerow({k: HEADERS[k] for k in FIELDS})
        writer.writerows(rows)

    print(f"  CSV exported → {csv_path}")


# ── CLI ───────────────────────────────────────────────────────────────────────

def main():
    parser = argparse.ArgumentParser(
        description="Per-job pre-label accuracy report for RF-DETR forklift part3.",
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )

    # Type selection flags
    parser.add_argument("--forklift-bbox",   action="store_true",
                        help="Show Forklift BBox stats (forklift-with-roll + forklift-no-roll)")
    parser.add_argument("--pr-keypoints",    action="store_true",
                        help="Show Paper Roll Keypoints stats (roll-keypoints)")
    parser.add_argument("--clamp-keypoints", action="store_true",
                        help="Show Clamp Keypoints stats (clamp-2-arm + clamp-3-arm)")

    parser.add_argument("--label-match", action="store_true",
                        help="BBox: require pred and GT to have the same class label to match "
                             "(default: off — position-only IoU matching).")

    # Options
    parser.add_argument("--job",       type=int, metavar="N",
                        help="Show only job N (1–10). Default: all jobs.")
    parser.add_argument("--threshold", type=float, default=5.0,
                        help="Keypoint error threshold in pixels (default: 5).")
    parser.add_argument("--csv",       dest="csv_path",
                        help="Export results to a CSV file.")
    parser.add_argument("--predict",   default=DEFAULT_PRED,
                        help=f"Path to predict JSON (default: {DEFAULT_PRED})")
    parser.add_argument("--gt",        default=DEFAULT_GT,
                        help=f"Path to GT JSON (default: {DEFAULT_GT})")

    args = parser.parse_args()

    # If no type flag given → show all
    show_bbox     = args.forklift_bbox
    show_pr_kp    = args.pr_keypoints
    show_clamp_kp = args.clamp_keypoints
    if not show_bbox and not show_pr_kp and not show_clamp_kp:
        show_bbox = show_pr_kp = show_clamp_kp = True

    if not Path(args.predict).exists():
        print(f"Error: predict file not found: {args.predict}")
        return
    if not Path(args.gt).exists():
        print(f"Error: GT file not found: {args.gt}")
        return

    jobs = run_compare(
        pred_path=args.predict,
        gt_path=args.gt,
        show_bbox=show_bbox,
        show_pr_kp=show_pr_kp,
        show_clamp_kp=show_clamp_kp,
        threshold_px=args.threshold,
        job_filter=args.job,
        label_match=args.label_match,
    )

    print_report(
        jobs,
        pred_file=Path(args.predict).name,
        gt_file=Path(args.gt).name,
        threshold_px=args.threshold,
        show_bbox=show_bbox,
        show_pr_kp=show_pr_kp,
        show_clamp_kp=show_clamp_kp,
        label_match=args.label_match,
    )

    if args.csv_path:
        export_csv(jobs, args.csv_path, show_bbox, show_pr_kp, show_clamp_kp)


if __name__ == "__main__":
    main()
