Coverage for src/models/resnet.py: 96%
73 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
1import torch
2from torch import nn
3import pytorch_lightning as pl
4from torchmetrics import Accuracy
5import torchvision.models as models
8class ResNetClassifier(pl.LightningModule):
9 def __init__(self, num_classes=2, learning_rate=1e-3, weight_decay=0, transform=None, freeze=False):
10 super(ResNetClassifier, self).__init__()
11 self.save_hyperparameters()
13 self.transform = transform
14 self.learning_rate = learning_rate
15 self.weight_decay = weight_decay
16 self.model = models.resnet18(weights='DEFAULT')
18 # Freeze pre-trained layers
19 if freeze:
20 for param in self.model.parameters():
21 param.requires_grad = False
23 in_features = self.model.fc.in_features
24 self.model.fc = nn.Linear(in_features, num_classes)
26 # Define a loss function and metric
27 self.criterion = nn.CrossEntropyLoss()
28 if num_classes == 2:
29 self.accuracy = Accuracy(task="binary")
30 else:
31 self.accuracy = Accuracy(task="multiclass", num_classes=num_classes)
33 # Container for predictions
34 self.predictions = None
35 # Container for targets
36 self.targets = None
38 def forward(self, x):
39 return self.model(x)
41 def on_after_batch_transfer(self, batch, dataloader_idx):
42 x, y = batch
43 if self.transform:
44 x = self.transform(x)
45 return x, y
47 def training_step(self, batch, batch_idx):
48 images, labels = batch
49 labels = labels.long()
50 outputs = self(images)
51 loss = self.criterion(outputs, labels)
53 # Calculate and log accuracy
54 predicted_classes = torch.argmax(outputs, dim=1)
55 acc = self.accuracy(predicted_classes, labels)
56 self.log("train_loss", loss, on_epoch=True, prog_bar=True)
57 self.log("train_acc", acc, prog_bar=True)
58 current_lr = self.trainer.optimizers[0].param_groups[0]['lr']
59 self.log('learning_rate', current_lr, prog_bar=True) # log learning rate for testing purposes
60 return loss
62 def validation_step(self, batch, batch_idx):
63 images, labels = batch
64 labels = labels.long()
65 outputs = self(images)
66 loss = self.criterion(outputs, labels)
68 predicted_classes = torch.argmax(outputs, dim=1)
69 acc = self.accuracy(predicted_classes, labels)
70 self.log("val_loss", loss)
71 self.log("val_acc", acc, prog_bar=True)
73 return loss
75 def test_step(self, batch, batch_idx):
76 images, labels = batch
77 labels = labels.long()
78 outputs = self(images)
79 loss = self.criterion(outputs, labels)
81 predicted_classes = torch.argmax(outputs, dim=1)
82 acc = self.accuracy(predicted_classes, labels)
83 self.log("test_loss", loss, on_epoch=True, prog_bar=True)
84 self.log("test_acc", acc, prog_bar=True)
86 predicted_probs = torch.softmax(outputs, dim=1)
88 # Save predictions and targets for later use
89 if self.predictions is None: 89 ↛ 93line 89 didn't jump to line 93 because the condition on line 89 was always true
90 self.predictions = predicted_probs
91 self.targets = labels
92 else:
93 self.predictions = torch.cat((self.predictions, predicted_probs), dim=0)
94 self.targets = torch.cat((self.targets, labels), dim=0)
96 return loss
98 def configure_optimizers(self):
99 optimizer = torch.optim.Adam(self.model.fc.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
101 # Decay LR by a factor of 0.1 every 1 epoch
102 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.1) # TODO: tune parameters
104 return {
105 'optimizer': optimizer,
106 'lr_scheduler': {
107 'scheduler': scheduler
108 }
109 }