Coverage for src/callbacks.py: 48%
38 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
1from pytorch_lightning.callbacks import Callback
2from dataset import CurriculumLearningDataset
5class PrintMetricsCallback(Callback):
6 def __init__(self):
7 self.train_metrics = {"loss": [], "acc": []}
8 self.val_metrics = {"loss": [], "acc": []}
10 def on_train_epoch_end(self, trainer, pl_module):
11 train_loss = trainer.callback_metrics["train_loss"].item()
12 train_acc = trainer.callback_metrics["train_acc"].item()
13 print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
15 self.train_metrics["loss"].append(train_loss)
16 self.train_metrics["acc"].append(train_acc)
18 def on_validation_epoch_end(self, trainer, pl_module):
19 val_loss = trainer.callback_metrics["val_loss"].item()
20 val_acc = trainer.callback_metrics["val_acc"].item()
21 print(f"Epoch: {trainer.current_epoch}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}", end=" || ")
23 self.val_metrics["loss"].append(val_loss)
24 self.val_metrics["acc"].append(val_acc)
27class CurriculumLearningCallback(Callback):
28 def __init__(self, initial_ratio, step_size, class_order, labels):
29 self.initial_ratio = initial_ratio
30 self.step_size = step_size
31 self.class_order = class_order
33 self.class_indices = {}
34 for idx, label in enumerate(labels):
35 self.class_indices.setdefault(label, []).append(idx)
37 def on_train_epoch_start(self, trainer, pl_module):
38 current_epoch = trainer.current_epoch
39 datamodule = trainer.datamodule
41 current_step = int(current_epoch / self.step_size)
43 indices = []
44 labels = self.class_order[: (current_step + 1) * self.initial_ratio]
45 for label in labels:
46 indices.extend(self.class_indices[label])
48 if datamodule.dataset != CurriculumLearningDataset:
49 raise Exception(f"Curriculum learning callback is being used, but the dataset in the datamodule is of type: {type(datamodule.dataset)}")
51 datamodule.dataset_args["indices"] = indices
52 datamodule.setup()