Coverage for src/counting_functions.py: 0%
34 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
1from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
2import numpy as np
3import torch
6def calculate_metrics_per_class(y_true, y_pred):
7 y_pred = torch.argmax(y_pred, dim=1).cpu().detach().numpy()
8 y_true = y_true.cpu().detach().numpy()
10 classes = np.unique(y_true)
11 metrics = {}
13 for cls in classes:
14 y_true_binary = (y_true == cls).astype(int)
15 y_pred_binary = (y_pred == cls).astype(int)
17 tp = np.sum((y_true == cls) & (y_pred == cls))
18 fp = np.sum((y_true != cls) & (y_pred == cls))
19 fn = np.sum((y_true == cls) & (y_pred != cls))
20 union = tp + fp + fn
22 metrics[cls.item()] = {
23 'accuracy': accuracy_score(y_true_binary, y_pred_binary),
24 'precision': precision_score(y_true_binary, y_pred_binary),
25 'recall': recall_score(y_true_binary, y_pred_binary),
26 'f1': f1_score(y_true_binary, y_pred_binary),
27 'IoU': tp/union if union != 0 else 0
28 }
30 return metrics
33def count_metrics(y_true, y_pred):
34 y_pred = torch.argmax(y_pred, dim=1).cpu().detach().numpy()
35 y_true = y_true.cpu().detach().numpy()
37 accuracy = accuracy_score(y_true, y_pred)
38 precision = precision_score(y_true, y_pred, average='weighted')
39 recall = recall_score(y_true, y_pred, average='weighted')
40 f1 = f1_score(y_true, y_pred, average='weighted')
42 # Count mIoU
43 classes = np.unique(y_true)
44 iou = []
46 for cls in classes:
47 tp = np.sum((y_true == cls) & (y_pred == cls))
48 fp = np.sum((y_true != cls) & (y_pred == cls))
49 fn = np.sum((y_true == cls) & (y_pred != cls))
50 union = tp + fp + fn
51 iou.append(tp/union if union != 0 else 0)
53 mIoU = np.mean(iou)
55 return {
56 'accuracy': accuracy,
57 'precision': precision,
58 'recall': recall,
59 'f1': f1,
60 'mIoU': mIoU
61 }