Coverage for src/callbacks.py: 48%

38 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 17:37 +0000

1from pytorch_lightning.callbacks import Callback 

2from dataset import CurriculumLearningDataset 

3 

4 

5class PrintMetricsCallback(Callback): 

6 def __init__(self): 

7 self.train_metrics = {"loss": [], "acc": []} 

8 self.val_metrics = {"loss": [], "acc": []} 

9 

10 def on_train_epoch_end(self, trainer, pl_module): 

11 train_loss = trainer.callback_metrics["train_loss"].item() 

12 train_acc = trainer.callback_metrics["train_acc"].item() 

13 print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}") 

14 

15 self.train_metrics["loss"].append(train_loss) 

16 self.train_metrics["acc"].append(train_acc) 

17 

18 def on_validation_epoch_end(self, trainer, pl_module): 

19 val_loss = trainer.callback_metrics["val_loss"].item() 

20 val_acc = trainer.callback_metrics["val_acc"].item() 

21 print(f"Epoch: {trainer.current_epoch}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}", end=" || ") 

22 

23 self.val_metrics["loss"].append(val_loss) 

24 self.val_metrics["acc"].append(val_acc) 

25 

26 

27class CurriculumLearningCallback(Callback): 

28 def __init__(self, initial_ratio, step_size, class_order, labels): 

29 self.initial_ratio = initial_ratio 

30 self.step_size = step_size 

31 self.class_order = class_order 

32 

33 self.class_indices = {} 

34 for idx, label in enumerate(labels): 

35 self.class_indices.setdefault(label, []).append(idx) 

36 

37 def on_train_epoch_start(self, trainer, pl_module): 

38 current_epoch = trainer.current_epoch 

39 datamodule = trainer.datamodule 

40 

41 current_step = int(current_epoch / self.step_size) 

42 

43 indices = [] 

44 labels = self.class_order[: (current_step + 1) * self.initial_ratio] 

45 for label in labels: 

46 indices.extend(self.class_indices[label]) 

47 

48 if datamodule.dataset != CurriculumLearningDataset: 

49 raise Exception(f"Curriculum learning callback is being used, but the dataset in the datamodule is of type: {type(datamodule.dataset)}") 

50 

51 datamodule.dataset_args["indices"] = indices 

52 datamodule.setup()