Coverage for src/main.py: 0%
103 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-06 01:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-06 01:30 +0000
1import os
2from pathlib import Path
4import kornia.augmentation as kaug
5import torch
6import wandb
7from omegaconf import OmegaConf
8from pytorch_lightning import Trainer
9from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
10from pytorch_lightning.loggers import WandbLogger
11from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset, CurriculumLearningDataset
12from callbacks import PrintMetricsCallback, CurriculumLearningCallback
13from models.classifier_module import ClassifierModule
14from dataset_functions import load_dataset
15from git_functions import get_git_branch, generate_short_hash
16from counting_functions import calculate_metrics_per_class, count_metrics
17from visualization_functions import show_n_samples, plot_metrics, get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve
19import torchvision
20import math
22CONFIG_PATH = "src/config.yaml"
25def main():
26 # Load configuration file
27 config = OmegaConf.load(CONFIG_PATH)
29 # Create a dedicated folder for the PureForest dataset to keep each tree species
30 # organized, avoiding multiple directories in the main content folder.
31 dataset_folder = Path.cwd() / config.dataset.folder
32 dataset_folder.mkdir(exist_ok=True)
34 # =========================== DATA LOADING AND PREPROCESSING ================================== #
36 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders)
37 show_n_samples(dataset, config.dataset.species_folders)
39 # =========================== INITIALIZING DATA AND MODEL ================================== #
40 num_classes = len(label_map)
41 class_weights = config.training.get("class_weights", None)
43 if config.training.get("class_weights", None) and (config.training.get("oversample", None) or config.training.get("undersample", None)):
44 raise ValueError("Can't use class weights and resampling at the same time.")
46 dataset_module = ForestDataset
47 dataset_args = {}
49 if config.training.get("oversample", None):
50 dataset_module = OversampledDataset
51 dataset_args = {
52 "minority_transform": torchvision.transforms.Compose(
53 [
54 torchvision.transforms.RandomHorizontalFlip(),
55 torchvision.transforms.RandomVerticalFlip(),
56 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10),
57 ]
58 ),
59 "oversample_factor": config.training.oversample.oversample_factor,
60 "oversample_threshold": config.training.oversample.oversample_threshold,
61 }
62 elif config.training.get("undersample", None):
63 dataset_module = UndersampledDataset
64 dataset_args = {"target_size": config.training.undersample.target_size}
65 elif config.training.get("curriculum_learning", None):
66 dataset_module = CurriculumLearningDataset
67 dataset_args = {
68 # The list cannot be empty, since the dataloder doesn't accept empty dataset
69 "indices": [0]
70 }
72 datamodule = ForestDataModule(
73 dataset["train"],
74 dataset["val"],
75 dataset["test"],
76 dataset=dataset_module,
77 dataset_args=dataset_args,
78 batch_size=config.training.batch_size,
79 )
81 model = ClassifierModule(
82 model_name=config.model.name,
83 num_classes=num_classes,
84 step_size=config["training"]["step_size"],
85 gamma=config["training"]["gamma"],
86 freeze=config.training.freeze,
87 weight=torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None,
88 learning_rate=config.training.learning_rate,
89 weight_decay=config.training.weight_decay,
90 )
92 # ====================================== TRAINING ========================================== #
93 device = config.device if torch.cuda.is_available() else "cpu"
94 callbacks = [PrintMetricsCallback()]
96 if config.training.early_stopping.apply:
97 callbacks.append(EarlyStopping(monitor=config.training.early_stopping.monitor, patience=config.training.early_stopping.patience, mode=config.training.early_stopping.mode))
98 callbacks.append(ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, save_last=False, dirpath=config.training.get("checkpoint_dir", "checkpoints/")))
100 if config.training.get("curriculum_learning", None):
101 callbacks.append(
102 CurriculumLearningCallback(
103 config.training.curriculum_learning.get("initial_ratio", None),
104 config.training.curriculum_learning.get("step_size", None),
105 config.training.curriculum_learning.get("class_order", None),
106 dataset["train"]["labels"],
107 )
108 )
109 min_epochs = math.ceil(num_classes / config.training.curriculum_learning.initial_ratio) * config.training.step_size
110 else:
111 min_epochs = None
112 step_size = 0
114 branch_name = get_git_branch()
115 short_hash = generate_short_hash()
116 run_name = f"{branch_name}-{short_hash}"
118 wandb_api_key = os.environ.get("WANDB_API_KEY")
119 wandb.login(key=wandb_api_key)
120 wandb.init(project="ghost-irim", name=run_name)
122 # Log config.yaml to wandb
123 wandb.save("src/config.yaml")
125 wandb_logger = WandbLogger(name=run_name, project="ghost-irim", log_model=True)
127 torch.backends.cudnn.benchmark = True
128 torch.backends.cudnn.enabled = True
130 trainer = Trainer(
131 logger=wandb_logger, min_epochs=min_epochs, max_epochs=config.training.max_epochs, accelerator=device, devices=1, callbacks=callbacks, reload_dataloaders_every_n_epochs=step_size
132 )
134 trainer.fit(model, datamodule)
136 # ====================================== TESTING ========================================== #
137 # Retrieve the best checkpoint path from the ModelCheckpoint callback
138 best_ckpt_path = None
139 for callback in callbacks:
140 if isinstance(callback, ModelCheckpoint):
141 best_ckpt_path = callback.best_model_path
142 break
144 if not best_ckpt_path:
145 raise ValueError("No ModelCheckpoint callback found or no best checkpoint available.")
147 trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)
148 # Callbacks' service
149 for callback in callbacks:
150 if isinstance(callback, PrintMetricsCallback):
151 train_metrics = callback.train_metrics
152 val_metrics = callback.val_metrics
153 plot_metrics(train_metrics, val_metrics)
154 wandb.log({"Accuracy and Loss Curves": wandb.Image("src/plots/acc_loss_curves.png")})
156 # Logging plots
157 preds = model.predictions
158 targets = model.targets
160 # Log metrics
161 metrics_per_experiment = count_metrics(targets, preds)
162 for key, value in metrics_per_experiment.items():
163 wandb.log({key: value})
165 # Log metrics per class and classnames
166 metrics_per_class = calculate_metrics_per_class(targets, preds)
167 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()]
168 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()]
169 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()]
170 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()]
171 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()]
172 names_and_labels = [[key, value] for key, value in label_map.items()]
173 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)]
175 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics)
176 wandb.log({"Classes": training_table})
178 # Log confusion matrix, precision-recall curve and roc-auc curve
179 get_confusion_matrix(preds, targets, class_names=list(label_map.keys()))
180 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys()))
181 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys()))
183 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"]
184 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"]
185 for filename, title in zip(filenames, titles, strict=False):
186 wandb.log({title: wandb.Image(f"src/plots/{filename}")})
188 wandb.finish()
191if __name__ == "__main__":
192 main()