Coverage for src/main.py: 0%
105 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
1import os
2from pathlib import Path
4import kornia.augmentation as kaug
5import torch
6import wandb
7from omegaconf import OmegaConf
8from pytorch_lightning import Trainer
9from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
10from pytorch_lightning.loggers import WandbLogger
11from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset, CurriculumLearningDataset
12from callbacks import PrintMetricsCallback, CurriculumLearningCallback
13from models.classifier_module import ClassifierModule
14from dataset_functions import load_dataset
15from git_functions import get_git_branch, generate_short_hash
16from counting_functions import calculate_metrics_per_class, count_metrics
17from visualization_functions import show_n_samples, plot_metrics, get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve
19import torchvision
20import math
22CONFIG_PATH = "src/config.yaml"
25def main():
26 # Load configuration file
27 config = OmegaConf.load(CONFIG_PATH)
29 # Create a dedicated folder for the PureForest dataset to keep each tree species
30 # organized, avoiding multiple directories in the main content folder.
31 dataset_folder = Path.cwd() / config.dataset.folder
32 dataset_folder.mkdir(exist_ok=True)
34 # =========================== DATA LOADING AND PREPROCESSING ================================== #
36 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders)
37 show_n_samples(dataset, config.dataset.species_folders)
39 # =========================== INITIALIZING DATA AND MODEL ================================== #
40 num_classes = len(label_map)
41 class_weights = config.training.get("class_weights", None)
43 if config.training.get("class_weights", None) and (config.training.get("oversample", None) or config.training.get("undersample", None)):
44 raise ValueError("Can't use class weights and resampling at the same time.")
45 image_size = 299 if config.model.name == "inception_v3" else 224
46 transforms = kaug.Resize(size=(image_size, image_size))
48 dataset_module = ForestDataset
49 dataset_args = {}
51 if config.training.get("oversample", None):
52 dataset_module = OversampledDataset
53 dataset_args = {
54 "minority_transform": torchvision.transforms.Compose(
55 [
56 torchvision.transforms.RandomHorizontalFlip(),
57 torchvision.transforms.RandomVerticalFlip(),
58 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10),
59 ]
60 ),
61 "oversample_factor": config.training.oversample.oversample_factor,
62 "oversample_threshold": config.training.oversample.oversample_threshold,
63 }
64 elif config.training.get("undersample", None):
65 dataset_module = UndersampledDataset
66 dataset_args = {"target_size": config.training.undersample.target_size}
67 elif config.training.get("curriculum_learning", None):
68 dataset_module = CurriculumLearningDataset
69 dataset_args = {
70 # The list cannot be empty, since the dataloder doesn't accept empty dataset
71 "indices": [0]
72 }
74 datamodule = ForestDataModule(
75 dataset["train"],
76 dataset["val"],
77 dataset["test"],
78 dataset=dataset_module,
79 dataset_args=dataset_args,
80 batch_size=config.training.batch_size,
81 )
83 model = ClassifierModule(
84 model_name=config.model.name,
85 num_classes=num_classes,
86 step_size=config["training"]["step_size"],
87 gamma=config["training"]["gamma"],
88 freeze=config.training.freeze,
89 transform=transforms,
90 weight=torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None,
91 learning_rate=config.training.learning_rate,
92 weight_decay=config.training.weight_decay,
93 )
95 # ====================================== TRAINING ========================================== #
96 device = config.device if torch.cuda.is_available() else "cpu"
97 callbacks = [PrintMetricsCallback()]
99 if config.training.early_stopping.apply:
100 callbacks.append(EarlyStopping(monitor=config.training.early_stopping.monitor, patience=config.training.early_stopping.patience, mode=config.training.early_stopping.mode))
101 callbacks.append(ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, save_last=False, dirpath=config.training.get("checkpoint_dir", "checkpoints/")))
103 if config.training.get("curriculum_learning", None):
104 callbacks.append(
105 CurriculumLearningCallback(
106 config.training.curriculum_learning.get("initial_ratio", None),
107 config.training.curriculum_learning.get("step_size", None),
108 config.training.curriculum_learning.get("class_order", None),
109 dataset["train"]["labels"],
110 )
111 )
112 min_epochs = math.ceil(num_classes / config.training.curriculum_learning.initial_ratio) * config.training.step_size
113 else:
114 min_epochs = None
115 step_size = 0
117 branch_name = get_git_branch()
118 short_hash = generate_short_hash()
119 run_name = f"{branch_name}-{short_hash}"
121 wandb_api_key = os.environ.get("WANDB_API_KEY")
122 wandb.login(key=wandb_api_key)
123 wandb.init(project="ghost-irim", name=run_name)
125 # Log config.yaml to wandb
126 wandb.save("src/config.yaml")
128 wandb_logger = WandbLogger(name=run_name, project="ghost-irim", log_model=True)
130 torch.backends.cudnn.benchmark = True
131 torch.backends.cudnn.enabled = True
133 trainer = Trainer(
134 logger=wandb_logger, min_epochs=min_epochs, max_epochs=config.training.max_epochs, accelerator=device, devices=1, callbacks=callbacks, reload_dataloaders_every_n_epochs=step_size
135 )
137 trainer.fit(model, datamodule)
139 # ====================================== TESTING ========================================== #
140 # Retrieve the best checkpoint path from the ModelCheckpoint callback
141 best_ckpt_path = None
142 for callback in callbacks:
143 if isinstance(callback, ModelCheckpoint):
144 best_ckpt_path = callback.best_model_path
145 break
147 if not best_ckpt_path:
148 raise ValueError("No ModelCheckpoint callback found or no best checkpoint available.")
150 trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path)
151 # Callbacks' service
152 for callback in callbacks:
153 if isinstance(callback, PrintMetricsCallback):
154 train_metrics = callback.train_metrics
155 val_metrics = callback.val_metrics
156 plot_metrics(train_metrics, val_metrics)
157 wandb.log({"Accuracy and Loss Curves": wandb.Image("src/plots/acc_loss_curves.png")})
159 # Logging plots
160 preds = model.predictions
161 targets = model.targets
163 # Log metrics
164 metrics_per_experiment = count_metrics(targets, preds)
165 for key, value in metrics_per_experiment.items():
166 wandb.log({key: value})
168 # Log metrics per class and classnames
169 metrics_per_class = calculate_metrics_per_class(targets, preds)
170 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()]
171 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()]
172 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()]
173 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()]
174 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()]
175 names_and_labels = [[key, value] for key, value in label_map.items()]
176 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)]
178 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics)
179 wandb.log({"Classes": training_table})
181 # Log confusion matrix, precision-recall curve and roc-auc curve
182 get_confusion_matrix(preds, targets, class_names=list(label_map.keys()))
183 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys()))
184 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys()))
186 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"]
187 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"]
188 for filename, title in zip(filenames, titles, strict=False):
189 wandb.log({title: wandb.Image(f"src/plots/{filename}")})
191 wandb.finish()
194if __name__ == "__main__":
195 main()