Coverage for src/models/resnet.py: 96%

73 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-16 12:50 +0000

1import torch 

2from torch import nn 

3import pytorch_lightning as pl 

4from torchmetrics import Accuracy 

5import torchvision.models as models 

6 

7 

8class ResNetClassifier(pl.LightningModule): 

9 def __init__(self, num_classes=2, learning_rate=1e-3, weight_decay=0, transform=None, freeze=False): 

10 super(ResNetClassifier, self).__init__() 

11 self.save_hyperparameters() 

12 

13 self.transform = transform 

14 self.learning_rate = learning_rate 

15 self.weight_decay = weight_decay 

16 self.model = models.resnet18(weights='DEFAULT') 

17 

18 # Freeze pre-trained layers 

19 if freeze: 

20 for param in self.model.parameters(): 

21 param.requires_grad = False 

22 

23 in_features = self.model.fc.in_features 

24 self.model.fc = nn.Linear(in_features, num_classes) 

25 

26 # Define a loss function and metric 

27 self.criterion = nn.CrossEntropyLoss() 

28 if num_classes == 2: 

29 self.accuracy = Accuracy(task="binary") 

30 else: 

31 self.accuracy = Accuracy(task="multiclass", num_classes=num_classes) 

32 

33 # Container for predictions 

34 self.predictions = None 

35 # Container for targets 

36 self.targets = None 

37 

38 def forward(self, x): 

39 return self.model(x) 

40 

41 def on_after_batch_transfer(self, batch, dataloader_idx): 

42 x, y = batch 

43 if self.transform: 

44 x = self.transform(x) 

45 return x, y 

46 

47 def training_step(self, batch, batch_idx): 

48 images, labels = batch 

49 labels = labels.long() 

50 outputs = self(images) 

51 loss = self.criterion(outputs, labels) 

52 

53 # Calculate and log accuracy 

54 predicted_classes = torch.argmax(outputs, dim=1) 

55 acc = self.accuracy(predicted_classes, labels) 

56 self.log("train_loss", loss, on_epoch=True, prog_bar=True) 

57 self.log("train_acc", acc, prog_bar=True) 

58 current_lr = self.trainer.optimizers[0].param_groups[0]['lr'] 

59 self.log('learning_rate', current_lr, prog_bar=True) # log learning rate for testing purposes 

60 return loss 

61 

62 def validation_step(self, batch, batch_idx): 

63 images, labels = batch 

64 labels = labels.long() 

65 outputs = self(images) 

66 loss = self.criterion(outputs, labels) 

67 

68 predicted_classes = torch.argmax(outputs, dim=1) 

69 acc = self.accuracy(predicted_classes, labels) 

70 self.log("val_loss", loss) 

71 self.log("val_acc", acc, prog_bar=True) 

72 

73 return loss 

74 

75 def test_step(self, batch, batch_idx): 

76 images, labels = batch 

77 labels = labels.long() 

78 outputs = self(images) 

79 loss = self.criterion(outputs, labels) 

80 

81 predicted_classes = torch.argmax(outputs, dim=1) 

82 acc = self.accuracy(predicted_classes, labels) 

83 self.log("test_loss", loss, on_epoch=True, prog_bar=True) 

84 self.log("test_acc", acc, prog_bar=True) 

85 

86 predicted_probs = torch.softmax(outputs, dim=1) 

87 

88 # Save predictions and targets for later use 

89 if self.predictions is None: 89 ↛ 93line 89 didn't jump to line 93 because the condition on line 89 was always true

90 self.predictions = predicted_probs 

91 self.targets = labels 

92 else: 

93 self.predictions = torch.cat((self.predictions, predicted_probs), dim=0) 

94 self.targets = torch.cat((self.targets, labels), dim=0) 

95 

96 return loss 

97 

98 def configure_optimizers(self): 

99 optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) 

100 

101 # Decay LR by a factor of 0.1 every 1 epoch 

102 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1) # TODO: tune parameters 

103 

104 return { 

105 'optimizer': optimizer, 

106 'lr_scheduler': { 

107 'scheduler': scheduler 

108 } 

109 }