Coverage for src/main.py: 0%
146 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-08 16:37 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-08 16:37 +0000
1import os
2import json
3from pathlib import Path
4from collections import Counter
6import kornia.augmentation as kaug
7import torch
8import wandb
9import numpy as np
10from omegaconf import OmegaConf
11from pytorch_lightning import Trainer
12from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
13from pytorch_lightning.loggers import WandbLogger
14from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset, CurriculumLearningDataset
15from callbacks import PrintMetricsCallback, CurriculumLearningCallback
16from models.classifier_module import ClassifierModule
17from dataset_functions import load_dataset
18from git_functions import get_git_branch, generate_short_hash
19from counting_functions import calculate_metrics_per_class, count_metrics
20from visualization_functions import show_n_samples, plot_metrics, get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve
22import torchvision
23import math
25CONFIG_PATH = "src/config.yaml"
28def _wandb_safe_metadata_value(value):
29 if isinstance(value, float) and not math.isfinite(value):
30 return None
31 if isinstance(value, torch.Tensor):
32 if value.numel() == 1:
33 scalar = value.item()
34 return None if isinstance(scalar, float) and not math.isfinite(scalar) else scalar
35 return value.tolist()
36 return value
39def print_split_summary(dataset, label_map):
40 inverse_label_map = {idx: name for name, idx in label_map.items()}
41 all_class_ids = set(label_map.values())
43 print("\n=== Dataset split summary ===")
44 for split in ["train", "val", "test"]:
45 labels = dataset[split]["labels"]
46 counts = Counter(labels)
47 print(f"{split}: {len(labels)} samples")
48 for class_id in sorted(counts):
49 class_name = inverse_label_map.get(class_id, str(class_id))
50 print(f" - {class_name}: {counts[class_id]}")
52 missing = sorted(all_class_ids - set(counts.keys()))
53 if missing:
54 missing_names = [inverse_label_map.get(class_id, str(class_id)) for class_id in missing]
55 print(f" WARNING: missing classes in {split}: {missing_names}")
58def main():
59 # Load configuration file
60 config = OmegaConf.load(CONFIG_PATH)
62 # Create a dedicated folder for the PureForest dataset to keep each tree species
63 # organized, avoiding multiple directories in the main content folder.
64 dataset_folder = Path.cwd() / config.dataset.folder
65 dataset_folder.mkdir(exist_ok=True)
67 # =========================== DATA LOADING AND PREPROCESSING ================================== #
69 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders)
70 print_split_summary(dataset, label_map)
71 show_n_samples(dataset, config.dataset.species_folders)
73 # =========================== INITIALIZING DATA AND MODEL ================================== #
74 num_classes = len(label_map)
75 class_weights = config.training.get("class_weights", None)
77 if config.training.get("class_weights", None) and (config.training.get("oversample", None) or config.training.get("undersample", None)):
78 raise ValueError("Can't use class weights and resampling at the same time.")
80 dataset_module = ForestDataset
81 dataset_args = {}
83 if config.training.get("oversample", None):
84 dataset_module = OversampledDataset
85 dataset_args = {
86 "minority_transform": torchvision.transforms.Compose(
87 [
88 torchvision.transforms.RandomHorizontalFlip(),
89 torchvision.transforms.RandomVerticalFlip(),
90 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10),
91 ]
92 ),
93 "oversample_factor": config.training.oversample.oversample_factor,
94 "oversample_threshold": config.training.oversample.oversample_threshold,
95 }
96 elif config.training.get("undersample", None):
97 dataset_module = UndersampledDataset
98 dataset_args = {"target_size": config.training.undersample.target_size}
99 elif config.training.get("curriculum_learning", None):
100 dataset_module = CurriculumLearningDataset
101 dataset_args = {
102 # The list cannot be empty, since the dataloder doesn't accept empty dataset
103 "indices": [0]
104 }
106 datamodule = ForestDataModule(
107 dataset["train"],
108 dataset["val"],
109 dataset["test"],
110 dataset=dataset_module,
111 dataset_args=dataset_args,
112 batch_size=config.training.batch_size,
113 )
115 model = ClassifierModule(
116 model_name=config.model.name,
117 num_classes=num_classes,
118 step_size=config["training"]["step_size"],
119 gamma=config["training"]["gamma"],
120 freeze=config.training.freeze,
121 weight=torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None,
122 learning_rate=config.training.learning_rate,
123 weight_decay=config.training.weight_decay,
124 )
126 # ====================================== TRAINING ========================================== #
127 device = config.device if torch.cuda.is_available() else "cpu"
128 callbacks = [PrintMetricsCallback()]
130 checkpoint_monitor = config.training.early_stopping.get("monitor", "val_loss")
131 checkpoint_mode = config.training.early_stopping.get("mode", "min")
133 if config.training.early_stopping.apply:
134 callbacks.append(EarlyStopping(monitor=checkpoint_monitor, patience=config.training.early_stopping.patience, mode=checkpoint_mode))
136 callbacks.append(
137 ModelCheckpoint(
138 monitor=checkpoint_monitor,
139 mode=checkpoint_mode,
140 save_top_k=1,
141 save_last=False,
142 dirpath=config.training.get("checkpoint_dir", "checkpoints/"),
143 )
144 )
146 if config.training.get("curriculum_learning", None):
147 callbacks.append(
148 CurriculumLearningCallback(
149 config.training.curriculum_learning.get("initial_ratio", None),
150 config.training.curriculum_learning.get("step_size", None),
151 config.training.curriculum_learning.get("class_order", None),
152 dataset["train"]["labels"],
153 )
154 )
155 min_epochs = math.ceil(num_classes / config.training.curriculum_learning.initial_ratio) * config.training.step_size
156 else:
157 min_epochs = None
158 step_size = 0
160 branch_name = get_git_branch()
161 short_hash = generate_short_hash()
162 run_name = f"{branch_name}-{short_hash}"
164 wandb_api_key = os.environ.get("WANDB_API_KEY")
165 wandb.login(key=wandb_api_key)
167 wandb_logger = WandbLogger(name=run_name, project="ghost-irim", log_model=False)
168 wandb_logger.experiment.save("src/config.yaml")
170 torch.backends.cudnn.benchmark = True
171 torch.backends.cudnn.enabled = True
173 trainer = Trainer(
174 logger=wandb_logger, min_epochs=min_epochs, max_epochs=config.training.max_epochs, accelerator=device, devices=1, callbacks=callbacks, reload_dataloaders_every_n_epochs=step_size
175 )
177 trainer.fit(model, datamodule)
179 # ====================================== TESTING ========================================== #
180 # Retrieve the best checkpoint path from the ModelCheckpoint callback
181 best_ckpt_path = None
182 for callback in callbacks:
183 if isinstance(callback, ModelCheckpoint):
184 best_ckpt_path = callback.best_model_path
185 break
187 if not best_ckpt_path:
188 raise ValueError("No ModelCheckpoint callback found or no best checkpoint available.")
190 best_ckpt_score = None
191 for callback in callbacks:
192 if isinstance(callback, ModelCheckpoint):
193 best_ckpt_score = callback.best_model_score
194 break
196 print(f"Best checkpoint selected: {best_ckpt_path}")
197 print(f"Best checkpoint score ({checkpoint_monitor}, {checkpoint_mode}): {best_ckpt_score}")
199 # Persist label map next to checkpoint so inference can reuse exact class indices.
200 label_map_path = Path(best_ckpt_path).with_suffix(".label_map.json")
201 with open(label_map_path, "w") as f:
202 json.dump(label_map, f, indent=2)
204 # Log explicit artifact for the selected best checkpoint.
205 best_model_artifact = wandb.Artifact(
206 name=f"best-model-{run_name}",
207 type="model",
208 description=f"Best checkpoint from run {run_name}",
209 metadata={
210 "best_ckpt_path": str(best_ckpt_path),
211 "monitor": checkpoint_monitor,
212 "mode": checkpoint_mode,
213 "best_model_score": _wandb_safe_metadata_value(float(best_ckpt_score) if best_ckpt_score is not None else None),
214 "num_classes": num_classes,
215 "model_name": config.model.name,
216 },
217 )
218 best_model_artifact.add_file(str(best_ckpt_path))
219 best_model_artifact.add_file(str(label_map_path))
220 wandb_logger.experiment.log_artifact(best_model_artifact, aliases=["best", "latest"])
222 trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)
223 # Callbacks' service
224 for callback in callbacks:
225 if isinstance(callback, PrintMetricsCallback):
226 train_metrics = callback.train_metrics
227 val_metrics = callback.val_metrics
228 plot_metrics(train_metrics, val_metrics)
229 wandb.log({"Accuracy and Loss Curves": wandb.Image("src/plots/acc_loss_curves.png")})
231 # Logging plots
232 preds = model.predictions
233 targets = model.targets
235 # Log metrics
236 metrics_per_experiment = count_metrics(targets, preds)
237 for key, value in metrics_per_experiment.items():
238 wandb.log({key: value})
240 # Log metrics per class and classnames
241 metrics_per_class = calculate_metrics_per_class(targets, preds)
242 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()]
243 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()]
244 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()]
245 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()]
246 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()]
247 names_and_labels = [[key, value] for key, value in label_map.items()]
248 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)]
250 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics)
251 wandb.log({"Classes": training_table})
253 # Log confusion matrix, precision-recall curve and roc-auc curve
254 get_confusion_matrix(preds, targets, class_names=list(label_map.keys()))
255 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys()))
256 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys()))
258 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"]
259 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"]
260 for filename, title in zip(filenames, titles, strict=False):
261 wandb.log({title: wandb.Image(f"src/plots/{filename}")})
263 wandb.finish()
266if __name__ == "__main__":
267 main()