@@ -299,7 +299,9 @@ def evaluate_detection_batch(
299
299
if len (targets ) == 0 :
300
300
# No ground truth - all detections are FP
301
301
class_id_idx = 4
302
- detection_classes = np .array (detection_batch_filtered [:, class_id_idx ], dtype = np .int16 )
302
+ detection_classes = np .array (
303
+ detection_batch_filtered [:, class_id_idx ], dtype = np .int16
304
+ )
303
305
for det_class in detection_classes :
304
306
result_matrix [num_classes , det_class ] += 1
305
307
return result_matrix
@@ -316,9 +318,11 @@ def evaluate_detection_batch(
316
318
# print("Debug IoU calculations:")
317
319
# print(f"GT boxes: {true_boxes}")
318
320
# print(f"Detection boxes: {detection_boxes}")
319
-
321
+
320
322
# Calculate IoU matrix
321
- iou_batch = box_iou_batch (boxes_true = true_boxes , boxes_detection = detection_boxes )
323
+ iou_batch = box_iou_batch (
324
+ boxes_true = true_boxes , boxes_detection = detection_boxes
325
+ )
322
326
# print(f"IoU matrix:\n{iou_batch}")
323
327
324
328
# Find all valid matches (IoU > threshold, regardless of class)
@@ -329,28 +333,28 @@ def evaluate_detection_batch(
329
333
if iou > iou_threshold :
330
334
gt_class = true_classes [gt_idx ]
331
335
det_class = detection_classes [det_idx ]
332
- class_match = ( gt_class == det_class )
336
+ class_match = gt_class == det_class
333
337
valid_matches .append ((gt_idx , det_idx , iou , class_match ))
334
- # print(f"Valid match: GT[{gt_idx}] class={gt_class} vs
335
- # Det[{det_idx}] class={det_class}, IoU={iou:.3f},
338
+ # print(f"Valid match: GT[{gt_idx}] class={gt_class} vs
339
+ # Det[{det_idx}] class={det_class}, IoU={iou:.3f},
336
340
# class_match={class_match}")
337
341
338
342
# Sort matches by class match first (True before False), then by IoU descending
339
343
# This prioritizes correct class predictions over higher IoU with wrong class
340
344
valid_matches .sort (key = lambda x : (x [3 ], x [2 ]), reverse = True )
341
345
# print(f"Sorted matches: {valid_matches}")
342
346
343
- # Greedily assign matches, ensuring each GT
347
+ # Greedily assign matches, ensuring each GT
344
348
# and detection is matched at most once
345
349
matched_gt_idx = set ()
346
350
matched_det_idx = set ()
347
-
351
+
348
352
for gt_idx , det_idx , iou , class_match in valid_matches :
349
353
if gt_idx not in matched_gt_idx and det_idx not in matched_det_idx :
350
354
# Valid spatial match - record the class prediction
351
355
gt_class = true_classes [gt_idx ]
352
356
det_class = detection_classes [det_idx ]
353
- # print(f"Assigning match: GT[{gt_idx}] class={gt_class} ->
357
+ # print(f"Assigning match: GT[{gt_idx}] class={gt_class} ->
354
358
# Det[{det_idx}] class={det_class}")
355
359
# This handles both correct classification (TP) and misclassification
356
360
result_matrix [gt_class , det_class ] += 1
0 commit comments