Coverage for src/callbacks.py: 100%
17 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
1from pytorch_lightning.callbacks import Callback
4class PrintMetricsCallback(Callback):
5 def __init__(self):
6 self.train_metrics = {"loss": [], "acc": []}
7 self.val_metrics = {"loss": [], "acc": []}
9 def on_train_epoch_end(self, trainer, pl_module):
10 train_loss = trainer.callback_metrics['train_loss'].item()
11 train_acc = trainer.callback_metrics['train_acc'].item()
12 print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
14 self.train_metrics['loss'].append(train_loss)
15 self.train_metrics['acc'].append(train_acc)
17 def on_validation_epoch_end(self, trainer, pl_module):
18 val_loss = trainer.callback_metrics['val_loss'].item()
19 val_acc = trainer.callback_metrics['val_acc'].item()
20 print(f"Epoch: {trainer.current_epoch}, "
21 f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}", end=' || ')
23 self.val_metrics['loss'].append(val_loss)
24 self.val_metrics['acc'].append(val_acc)