Coverage for src/main.py: 0%

103 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-01-06 01:30 +0000

1import os 

2from pathlib import Path 

3 

4import kornia.augmentation as kaug 

5import torch 

6import wandb 

7from omegaconf import OmegaConf 

8from pytorch_lightning import Trainer 

9from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 

10from pytorch_lightning.loggers import WandbLogger 

11from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset, CurriculumLearningDataset 

12from callbacks import PrintMetricsCallback, CurriculumLearningCallback 

13from models.classifier_module import ClassifierModule 

14from dataset_functions import load_dataset 

15from git_functions import get_git_branch, generate_short_hash 

16from counting_functions import calculate_metrics_per_class, count_metrics 

17from visualization_functions import show_n_samples, plot_metrics, get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve 

18 

19import torchvision 

20import math 

21 

22CONFIG_PATH = "src/config.yaml" 

23 

24 

25def main(): 

26 # Load configuration file 

27 config = OmegaConf.load(CONFIG_PATH) 

28 

29 # Create a dedicated folder for the PureForest dataset to keep each tree species 

30 # organized, avoiding multiple directories in the main content folder. 

31 dataset_folder = Path.cwd() / config.dataset.folder 

32 dataset_folder.mkdir(exist_ok=True) 

33 

34 # =========================== DATA LOADING AND PREPROCESSING ================================== # 

35 

36 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders) 

37 show_n_samples(dataset, config.dataset.species_folders) 

38 

39 # =========================== INITIALIZING DATA AND MODEL ================================== # 

40 num_classes = len(label_map) 

41 class_weights = config.training.get("class_weights", None) 

42 

43 if config.training.get("class_weights", None) and (config.training.get("oversample", None) or config.training.get("undersample", None)): 

44 raise ValueError("Can't use class weights and resampling at the same time.") 

45 

46 dataset_module = ForestDataset 

47 dataset_args = {} 

48 

49 if config.training.get("oversample", None): 

50 dataset_module = OversampledDataset 

51 dataset_args = { 

52 "minority_transform": torchvision.transforms.Compose( 

53 [ 

54 torchvision.transforms.RandomHorizontalFlip(), 

55 torchvision.transforms.RandomVerticalFlip(), 

56 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10), 

57 ] 

58 ), 

59 "oversample_factor": config.training.oversample.oversample_factor, 

60 "oversample_threshold": config.training.oversample.oversample_threshold, 

61 } 

62 elif config.training.get("undersample", None): 

63 dataset_module = UndersampledDataset 

64 dataset_args = {"target_size": config.training.undersample.target_size} 

65 elif config.training.get("curriculum_learning", None): 

66 dataset_module = CurriculumLearningDataset 

67 dataset_args = { 

68 # The list cannot be empty, since the dataloder doesn't accept empty dataset 

69 "indices": [0] 

70 } 

71 

72 datamodule = ForestDataModule( 

73 dataset["train"], 

74 dataset["val"], 

75 dataset["test"], 

76 dataset=dataset_module, 

77 dataset_args=dataset_args, 

78 batch_size=config.training.batch_size, 

79 ) 

80 

81 model = ClassifierModule( 

82 model_name=config.model.name, 

83 num_classes=num_classes, 

84 step_size=config["training"]["step_size"], 

85 gamma=config["training"]["gamma"], 

86 freeze=config.training.freeze, 

87 weight=torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None, 

88 learning_rate=config.training.learning_rate, 

89 weight_decay=config.training.weight_decay, 

90 ) 

91 

92 # ====================================== TRAINING ========================================== # 

93 device = config.device if torch.cuda.is_available() else "cpu" 

94 callbacks = [PrintMetricsCallback()] 

95 

96 if config.training.early_stopping.apply: 

97 callbacks.append(EarlyStopping(monitor=config.training.early_stopping.monitor, patience=config.training.early_stopping.patience, mode=config.training.early_stopping.mode)) 

98 callbacks.append(ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, save_last=False, dirpath=config.training.get("checkpoint_dir", "checkpoints/"))) 

99 

100 if config.training.get("curriculum_learning", None): 

101 callbacks.append( 

102 CurriculumLearningCallback( 

103 config.training.curriculum_learning.get("initial_ratio", None), 

104 config.training.curriculum_learning.get("step_size", None), 

105 config.training.curriculum_learning.get("class_order", None), 

106 dataset["train"]["labels"], 

107 ) 

108 ) 

109 min_epochs = math.ceil(num_classes / config.training.curriculum_learning.initial_ratio) * config.training.step_size 

110 else: 

111 min_epochs = None 

112 step_size = 0 

113 

114 branch_name = get_git_branch() 

115 short_hash = generate_short_hash() 

116 run_name = f"{branch_name}-{short_hash}" 

117 

118 wandb_api_key = os.environ.get("WANDB_API_KEY") 

119 wandb.login(key=wandb_api_key) 

120 wandb.init(project="ghost-irim", name=run_name) 

121 

122 # Log config.yaml to wandb 

123 wandb.save("src/config.yaml") 

124 

125 wandb_logger = WandbLogger(name=run_name, project="ghost-irim", log_model=True) 

126 

127 torch.backends.cudnn.benchmark = True 

128 torch.backends.cudnn.enabled = True 

129 

130 trainer = Trainer( 

131 logger=wandb_logger, min_epochs=min_epochs, max_epochs=config.training.max_epochs, accelerator=device, devices=1, callbacks=callbacks, reload_dataloaders_every_n_epochs=step_size 

132 ) 

133 

134 trainer.fit(model, datamodule) 

135 

136 # ====================================== TESTING ========================================== # 

137 # Retrieve the best checkpoint path from the ModelCheckpoint callback 

138 best_ckpt_path = None 

139 for callback in callbacks: 

140 if isinstance(callback, ModelCheckpoint): 

141 best_ckpt_path = callback.best_model_path 

142 break 

143 

144 if not best_ckpt_path: 

145 raise ValueError("No ModelCheckpoint callback found or no best checkpoint available.") 

146 

147 trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path) 

148 # Callbacks' service 

149 for callback in callbacks: 

150 if isinstance(callback, PrintMetricsCallback): 

151 train_metrics = callback.train_metrics 

152 val_metrics = callback.val_metrics 

153 plot_metrics(train_metrics, val_metrics) 

154 wandb.log({"Accuracy and Loss Curves": wandb.Image("src/plots/acc_loss_curves.png")}) 

155 

156 # Logging plots 

157 preds = model.predictions 

158 targets = model.targets 

159 

160 # Log metrics 

161 metrics_per_experiment = count_metrics(targets, preds) 

162 for key, value in metrics_per_experiment.items(): 

163 wandb.log({key: value}) 

164 

165 # Log metrics per class and classnames 

166 metrics_per_class = calculate_metrics_per_class(targets, preds) 

167 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()] 

168 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()] 

169 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()] 

170 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()] 

171 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()] 

172 names_and_labels = [[key, value] for key, value in label_map.items()] 

173 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)] 

174 

175 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics) 

176 wandb.log({"Classes": training_table}) 

177 

178 # Log confusion matrix, precision-recall curve and roc-auc curve 

179 get_confusion_matrix(preds, targets, class_names=list(label_map.keys())) 

180 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys())) 

181 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys())) 

182 

183 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"] 

184 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"] 

185 for filename, title in zip(filenames, titles, strict=False): 

186 wandb.log({title: wandb.Image(f"src/plots/{filename}")}) 

187 

188 wandb.finish() 

189 

190 

191if __name__ == "__main__": 

192 main()