Coverage for src/counting_functions.py: 0%

34 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-16 12:50 +0000

1from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score 

2import numpy as np 

3import torch 

4 

5 

6def calculate_metrics_per_class(y_true, y_pred): 

7 y_pred = torch.argmax(y_pred, dim=1).cpu().detach().numpy() 

8 y_true = y_true.cpu().detach().numpy() 

9 

10 classes = np.unique(y_true) 

11 metrics = {} 

12 

13 for cls in classes: 

14 y_true_binary = (y_true == cls).astype(int) 

15 y_pred_binary = (y_pred == cls).astype(int) 

16 

17 tp = np.sum((y_true == cls) & (y_pred == cls)) 

18 fp = np.sum((y_true != cls) & (y_pred == cls)) 

19 fn = np.sum((y_true == cls) & (y_pred != cls)) 

20 union = tp + fp + fn 

21 

22 metrics[cls.item()] = { 

23 'accuracy': accuracy_score(y_true_binary, y_pred_binary), 

24 'precision': precision_score(y_true_binary, y_pred_binary), 

25 'recall': recall_score(y_true_binary, y_pred_binary), 

26 'f1': f1_score(y_true_binary, y_pred_binary), 

27 'IoU': tp/union if union != 0 else 0 

28 } 

29 

30 return metrics 

31 

32 

33def count_metrics(y_true, y_pred): 

34 y_pred = torch.argmax(y_pred, dim=1).cpu().detach().numpy() 

35 y_true = y_true.cpu().detach().numpy() 

36 

37 accuracy = accuracy_score(y_true, y_pred) 

38 precision = precision_score(y_true, y_pred, average='weighted') 

39 recall = recall_score(y_true, y_pred, average='weighted') 

40 f1 = f1_score(y_true, y_pred, average='weighted') 

41 

42 # Count mIoU 

43 classes = np.unique(y_true) 

44 iou = [] 

45 

46 for cls in classes: 

47 tp = np.sum((y_true == cls) & (y_pred == cls)) 

48 fp = np.sum((y_true != cls) & (y_pred == cls)) 

49 fn = np.sum((y_true == cls) & (y_pred != cls)) 

50 union = tp + fp + fn 

51 iou.append(tp/union if union != 0 else 0) 

52 

53 mIoU = np.mean(iou) 

54 

55 return { 

56 'accuracy': accuracy, 

57 'precision': precision, 

58 'recall': recall, 

59 'f1': f1, 

60 'mIoU': mIoU 

61 }