Coverage for src/counting_functions.py: 0%
34 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
1from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
2import numpy as np
3import torch
6def calculate_metrics_per_class(y_true, y_pred):
7 y_pred = torch.argmax(y_pred, dim=1).cpu().detach().numpy()
8 y_true = y_true.cpu().detach().numpy()
10 classes = np.unique(y_true)
11 metrics = {}
13 for cls in classes:
14 y_true_binary = (y_true == cls).astype(int)
15 y_pred_binary = (y_pred == cls).astype(int)
17 tp = np.sum((y_true == cls) & (y_pred == cls))
18 fp = np.sum((y_true != cls) & (y_pred == cls))
19 fn = np.sum((y_true == cls) & (y_pred != cls))
20 union = tp + fp + fn
22 metrics[cls.item()] = {
23 "accuracy": accuracy_score(y_true_binary, y_pred_binary),
24 "precision": precision_score(y_true_binary, y_pred_binary),
25 "recall": recall_score(y_true_binary, y_pred_binary),
26 "f1": f1_score(y_true_binary, y_pred_binary),
27 "IoU": tp / union if union != 0 else 0,
28 }
30 return metrics
33def count_metrics(y_true, y_pred):
34 y_pred = torch.argmax(y_pred, dim=1).cpu().detach().numpy()
35 y_true = y_true.cpu().detach().numpy()
37 accuracy = accuracy_score(y_true, y_pred)
38 precision = precision_score(y_true, y_pred, average="weighted")
39 recall = recall_score(y_true, y_pred, average="weighted")
40 f1 = f1_score(y_true, y_pred, average="weighted")
42 # Count mIoU
43 classes = np.unique(y_true)
44 iou = []
46 for cls in classes:
47 tp = np.sum((y_true == cls) & (y_pred == cls))
48 fp = np.sum((y_true != cls) & (y_pred == cls))
49 fn = np.sum((y_true == cls) & (y_pred != cls))
50 union = tp + fp + fn
51 iou.append(tp / union if union != 0 else 0)
53 mIoU = np.mean(iou)
55 return {"accuracy": accuracy, "precision": precision, "recall": recall, "f1": f1, "mIoU": mIoU}