Coverage for src/main.py: 0%
94 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
1import os
2from pathlib import Path
4import kornia.augmentation as kaug
5import torch
6import wandb
7import yaml
8from pytorch_lightning import Trainer
9from pytorch_lightning.callbacks import EarlyStopping
10from pytorch_lightning.loggers import WandbLogger
12from models.classifier_module import ClassifierModule
13from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset
14from callbacks import PrintMetricsCallback
15from dataset_functions import download_data, load_dataset
16from git_functions import get_git_branch, generate_short_hash
17from counting_functions import calculate_metrics_per_class, count_metrics
18from visualization_functions import (show_n_samples, plot_metrics,
19 get_confusion_matrix,
20 get_precision_recall_curve,
21 get_roc_auc_curve)
23import torchvision
26def main():
27 # Load configuration file
28 with open("src/config.yaml", "r") as c:
29 config = yaml.safe_load(c)
31 # Create a dedicated folder for the PureForest dataset to keep each tree species
32 # organized, avoiding multiple directories in the main content folder.
33 dataset_folder = Path.cwd() / config["dataset"]["folder"]
34 dataset_folder.mkdir(exist_ok=True)
36 species_folders = config["dataset"]["species_folders"]
37 main_subfolders = config["dataset"]["main_subfolders"]
39 # =========================== DATA LOADING AND PREPROCESSING ================================== #
41 download_data(species_folders, main_subfolders, dataset_folder)
42 dataset, label_map = load_dataset(dataset_folder, species_folders)
43 show_n_samples(dataset, species_folders)
45 # =========================== INITIALIZING DATA AND MODEL ================================== #
46 batch_size = config["training"]["batch_size"]
47 num_classes = len(label_map)
48 learning_rate = config["training"]["learning_rate"]
49 freeze = config["training"]["freeze"]
50 weight_decay = config["training"]["weight_decay"]
51 model_name = config["model"]["name"]
52 image_size = 299 if model_name == "inception_v3" else 224
53 transforms = kaug.Resize(size=(image_size, image_size))
55 dataset_module = ForestDataset
56 dataset_args = {}
58 if "oversample" in config["training"]:
59 dataset_module = OversampledDataset
60 dataset_args = {
61 "minority_transform": torchvision.transforms.Compose([
62 torchvision.transforms.RandomHorizontalFlip(),
63 torchvision.transforms.RandomVerticalFlip(),
64 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10),
65 ]),
66 "oversample_factor": config["training"]["oversample"]["oversample_factor"],
67 "oversample_threshold": config["training"]["oversample"]["oversample_threshold"]
68 }
69 elif "undersample" in config["training"]:
70 dataset_module = UndersampledDataset
71 dataset_args = {
72 "target_size": config["training"]["undersample"]["target_size"]
73 }
75 datamodule = ForestDataModule(
76 dataset['train'],
77 dataset['val'],
78 dataset['test'],
79 dataset=dataset_module,
80 dataset_args=dataset_args,
81 batch_size=batch_size
82 )
84 model = ClassifierModule(
85 model_name=model_name,
86 num_classes=num_classes,
87 freeze=freeze,
88 transform=transforms,
89 learning_rate=learning_rate,
90 weight_decay=weight_decay
91 )
93 # ====================================== TRAINING ========================================== #
94 max_epochs = config["training"]["max_epochs"]
95 device = config["device"] if torch.cuda.is_available() else "cpu"
96 callbacks = [PrintMetricsCallback()]
98 if config["training"]["early_stopping"]['apply']:
99 callbacks.append(EarlyStopping(monitor=config["training"]["early_stopping"]['monitor'],
100 patience=config["training"]["early_stopping"]['patience'],
101 mode=config["training"]["early_stopping"]['mode']))
103 branch_name = get_git_branch()
104 short_hash = generate_short_hash()
105 run_name = f'{branch_name}-{short_hash}'
107 wandb_api_key = os.environ.get('WANDB_API_KEY')
108 wandb.login(key=wandb_api_key)
109 wandb.init(project="ghost-irim", name=run_name)
111 # Log config.yaml to wandb
112 wandb.save("src/config.yaml")
114 wandb_logger = WandbLogger(
115 name=run_name,
116 project='ghost-irim',
117 log_model=True
118 )
120 torch.backends.cudnn.benchmark = True
121 torch.backends.cudnn.enabled = True
123 trainer = Trainer(
124 logger=wandb_logger,
125 max_epochs=max_epochs,
126 accelerator=device,
127 devices=1,
128 callbacks=callbacks
129 )
131 trainer.fit(model, datamodule)
133 # ====================================== TESTING ========================================== #
134 trainer.test(model, datamodule=datamodule)
136 # Callbacks' service
137 for callback in callbacks:
138 if isinstance(callback, PrintMetricsCallback):
139 train_metrics = callback.train_metrics
140 val_metrics = callback.val_metrics
141 plot_metrics(train_metrics, val_metrics)
142 wandb.log({'Accuracy and Loss Curves': wandb.Image('src/plots/acc_loss_curves.png')})
144 # Logging plots
145 preds = model.predictions
146 targets = model.targets
148 # Log metrics
149 metrics_per_experiment = count_metrics(targets, preds)
150 for key, value in metrics_per_experiment.items():
151 wandb.log({key: value})
153 # Log metrics per class and classnames
154 metrics_per_class = calculate_metrics_per_class(targets, preds)
155 accs = [metrics_per_class[key]['accuracy'] for key in metrics_per_class.keys()]
156 precs = [metrics_per_class[key]['precision'] for key in metrics_per_class.keys()]
157 recs = [metrics_per_class[key]['recall'] for key in metrics_per_class.keys()]
158 f1s = [metrics_per_class[key]['f1'] for key in metrics_per_class.keys()]
159 ious = [metrics_per_class[key]['IoU'] for key in metrics_per_class.keys()]
160 names_and_labels = [[key, value] for key, value in label_map.items()]
161 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious)]
163 training_table = wandb.Table(columns=['Class name', 'Label', 'Accuracy', 'Precision', 'Recall', 'F1-score', 'IoU'], data=logged_metrics)
164 wandb.log({'Classes': training_table})
166 # Log confusion matrix, precision-recall curve and roc-auc curve
167 get_confusion_matrix(preds, targets, class_names=list(label_map.keys()))
168 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys()))
169 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys()))
171 filenames = ['confusion_matrix.png', 'precision_recall_curve.png', 'roc_auc_curve.png']
172 titles = ['Confusion Matrix', 'Precision-Recall Curve', 'ROC AUC Curve']
173 for filename, title in zip(filenames, titles):
174 wandb.log({title: wandb.Image(f'src/plots/{filename}')})
176 wandb.finish()
179if __name__ == "__main__":
180 main()