Coverage for src/callbacks.py: 100%

17 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-16 12:50 +0000

1from pytorch_lightning.callbacks import Callback 

2 

3 

4class PrintMetricsCallback(Callback): 

5 def __init__(self): 

6 self.train_metrics = {"loss": [], "acc": []} 

7 self.val_metrics = {"loss": [], "acc": []} 

8 

9 def on_train_epoch_end(self, trainer, pl_module): 

10 train_loss = trainer.callback_metrics['train_loss'].item() 

11 train_acc = trainer.callback_metrics['train_acc'].item() 

12 print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}") 

13 

14 self.train_metrics['loss'].append(train_loss) 

15 self.train_metrics['acc'].append(train_acc) 

16 

17 def on_validation_epoch_end(self, trainer, pl_module): 

18 val_loss = trainer.callback_metrics['val_loss'].item() 

19 val_acc = trainer.callback_metrics['val_acc'].item() 

20 print(f"Epoch: {trainer.current_epoch}, " 

21 f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}", end=' || ') 

22 

23 self.val_metrics['loss'].append(val_loss) 

24 self.val_metrics['acc'].append(val_acc)