Coverage for src/main.py: 0%

146 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-04-08 16:37 +0000

1import os 

2import json 

3from pathlib import Path 

4from collections import Counter 

5 

6import kornia.augmentation as kaug 

7import torch 

8import wandb 

9import numpy as np 

10from omegaconf import OmegaConf 

11from pytorch_lightning import Trainer 

12from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 

13from pytorch_lightning.loggers import WandbLogger 

14from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset, CurriculumLearningDataset 

15from callbacks import PrintMetricsCallback, CurriculumLearningCallback 

16from models.classifier_module import ClassifierModule 

17from dataset_functions import load_dataset 

18from git_functions import get_git_branch, generate_short_hash 

19from counting_functions import calculate_metrics_per_class, count_metrics 

20from visualization_functions import show_n_samples, plot_metrics, get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve 

21 

22import torchvision 

23import math 

24 

25CONFIG_PATH = "src/config.yaml" 

26 

27 

28def _wandb_safe_metadata_value(value): 

29 if isinstance(value, float) and not math.isfinite(value): 

30 return None 

31 if isinstance(value, torch.Tensor): 

32 if value.numel() == 1: 

33 scalar = value.item() 

34 return None if isinstance(scalar, float) and not math.isfinite(scalar) else scalar 

35 return value.tolist() 

36 return value 

37 

38 

39def print_split_summary(dataset, label_map): 

40 inverse_label_map = {idx: name for name, idx in label_map.items()} 

41 all_class_ids = set(label_map.values()) 

42 

43 print("\n=== Dataset split summary ===") 

44 for split in ["train", "val", "test"]: 

45 labels = dataset[split]["labels"] 

46 counts = Counter(labels) 

47 print(f"{split}: {len(labels)} samples") 

48 for class_id in sorted(counts): 

49 class_name = inverse_label_map.get(class_id, str(class_id)) 

50 print(f" - {class_name}: {counts[class_id]}") 

51 

52 missing = sorted(all_class_ids - set(counts.keys())) 

53 if missing: 

54 missing_names = [inverse_label_map.get(class_id, str(class_id)) for class_id in missing] 

55 print(f" WARNING: missing classes in {split}: {missing_names}") 

56 

57 

58def main(): 

59 # Load configuration file 

60 config = OmegaConf.load(CONFIG_PATH) 

61 

62 # Create a dedicated folder for the PureForest dataset to keep each tree species 

63 # organized, avoiding multiple directories in the main content folder. 

64 dataset_folder = Path.cwd() / config.dataset.folder 

65 dataset_folder.mkdir(exist_ok=True) 

66 

67 # =========================== DATA LOADING AND PREPROCESSING ================================== # 

68 

69 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders) 

70 print_split_summary(dataset, label_map) 

71 show_n_samples(dataset, config.dataset.species_folders) 

72 

73 # =========================== INITIALIZING DATA AND MODEL ================================== # 

74 num_classes = len(label_map) 

75 class_weights = config.training.get("class_weights", None) 

76 

77 if config.training.get("class_weights", None) and (config.training.get("oversample", None) or config.training.get("undersample", None)): 

78 raise ValueError("Can't use class weights and resampling at the same time.") 

79 

80 dataset_module = ForestDataset 

81 dataset_args = {} 

82 

83 if config.training.get("oversample", None): 

84 dataset_module = OversampledDataset 

85 dataset_args = { 

86 "minority_transform": torchvision.transforms.Compose( 

87 [ 

88 torchvision.transforms.RandomHorizontalFlip(), 

89 torchvision.transforms.RandomVerticalFlip(), 

90 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10), 

91 ] 

92 ), 

93 "oversample_factor": config.training.oversample.oversample_factor, 

94 "oversample_threshold": config.training.oversample.oversample_threshold, 

95 } 

96 elif config.training.get("undersample", None): 

97 dataset_module = UndersampledDataset 

98 dataset_args = {"target_size": config.training.undersample.target_size} 

99 elif config.training.get("curriculum_learning", None): 

100 dataset_module = CurriculumLearningDataset 

101 dataset_args = { 

102 # The list cannot be empty, since the dataloder doesn't accept empty dataset 

103 "indices": [0] 

104 } 

105 

106 datamodule = ForestDataModule( 

107 dataset["train"], 

108 dataset["val"], 

109 dataset["test"], 

110 dataset=dataset_module, 

111 dataset_args=dataset_args, 

112 batch_size=config.training.batch_size, 

113 ) 

114 

115 model = ClassifierModule( 

116 model_name=config.model.name, 

117 num_classes=num_classes, 

118 step_size=config["training"]["step_size"], 

119 gamma=config["training"]["gamma"], 

120 freeze=config.training.freeze, 

121 weight=torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None, 

122 learning_rate=config.training.learning_rate, 

123 weight_decay=config.training.weight_decay, 

124 ) 

125 

126 # ====================================== TRAINING ========================================== # 

127 device = config.device if torch.cuda.is_available() else "cpu" 

128 callbacks = [PrintMetricsCallback()] 

129 

130 checkpoint_monitor = config.training.early_stopping.get("monitor", "val_loss") 

131 checkpoint_mode = config.training.early_stopping.get("mode", "min") 

132 

133 if config.training.early_stopping.apply: 

134 callbacks.append(EarlyStopping(monitor=checkpoint_monitor, patience=config.training.early_stopping.patience, mode=checkpoint_mode)) 

135 

136 callbacks.append( 

137 ModelCheckpoint( 

138 monitor=checkpoint_monitor, 

139 mode=checkpoint_mode, 

140 save_top_k=1, 

141 save_last=False, 

142 dirpath=config.training.get("checkpoint_dir", "checkpoints/"), 

143 ) 

144 ) 

145 

146 if config.training.get("curriculum_learning", None): 

147 callbacks.append( 

148 CurriculumLearningCallback( 

149 config.training.curriculum_learning.get("initial_ratio", None), 

150 config.training.curriculum_learning.get("step_size", None), 

151 config.training.curriculum_learning.get("class_order", None), 

152 dataset["train"]["labels"], 

153 ) 

154 ) 

155 min_epochs = math.ceil(num_classes / config.training.curriculum_learning.initial_ratio) * config.training.step_size 

156 else: 

157 min_epochs = None 

158 step_size = 0 

159 

160 branch_name = get_git_branch() 

161 short_hash = generate_short_hash() 

162 run_name = f"{branch_name}-{short_hash}" 

163 

164 wandb_api_key = os.environ.get("WANDB_API_KEY") 

165 wandb.login(key=wandb_api_key) 

166 

167 wandb_logger = WandbLogger(name=run_name, project="ghost-irim", log_model=False) 

168 wandb_logger.experiment.save("src/config.yaml") 

169 

170 torch.backends.cudnn.benchmark = True 

171 torch.backends.cudnn.enabled = True 

172 

173 trainer = Trainer( 

174 logger=wandb_logger, min_epochs=min_epochs, max_epochs=config.training.max_epochs, accelerator=device, devices=1, callbacks=callbacks, reload_dataloaders_every_n_epochs=step_size 

175 ) 

176 

177 trainer.fit(model, datamodule) 

178 

179 # ====================================== TESTING ========================================== # 

180 # Retrieve the best checkpoint path from the ModelCheckpoint callback 

181 best_ckpt_path = None 

182 for callback in callbacks: 

183 if isinstance(callback, ModelCheckpoint): 

184 best_ckpt_path = callback.best_model_path 

185 break 

186 

187 if not best_ckpt_path: 

188 raise ValueError("No ModelCheckpoint callback found or no best checkpoint available.") 

189 

190 best_ckpt_score = None 

191 for callback in callbacks: 

192 if isinstance(callback, ModelCheckpoint): 

193 best_ckpt_score = callback.best_model_score 

194 break 

195 

196 print(f"Best checkpoint selected: {best_ckpt_path}") 

197 print(f"Best checkpoint score ({checkpoint_monitor}, {checkpoint_mode}): {best_ckpt_score}") 

198 

199 # Persist label map next to checkpoint so inference can reuse exact class indices. 

200 label_map_path = Path(best_ckpt_path).with_suffix(".label_map.json") 

201 with open(label_map_path, "w") as f: 

202 json.dump(label_map, f, indent=2) 

203 

204 # Log explicit artifact for the selected best checkpoint. 

205 best_model_artifact = wandb.Artifact( 

206 name=f"best-model-{run_name}", 

207 type="model", 

208 description=f"Best checkpoint from run {run_name}", 

209 metadata={ 

210 "best_ckpt_path": str(best_ckpt_path), 

211 "monitor": checkpoint_monitor, 

212 "mode": checkpoint_mode, 

213 "best_model_score": _wandb_safe_metadata_value(float(best_ckpt_score) if best_ckpt_score is not None else None), 

214 "num_classes": num_classes, 

215 "model_name": config.model.name, 

216 }, 

217 ) 

218 best_model_artifact.add_file(str(best_ckpt_path)) 

219 best_model_artifact.add_file(str(label_map_path)) 

220 wandb_logger.experiment.log_artifact(best_model_artifact, aliases=["best", "latest"]) 

221 

222 trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path) 

223 # Callbacks' service 

224 for callback in callbacks: 

225 if isinstance(callback, PrintMetricsCallback): 

226 train_metrics = callback.train_metrics 

227 val_metrics = callback.val_metrics 

228 plot_metrics(train_metrics, val_metrics) 

229 wandb.log({"Accuracy and Loss Curves": wandb.Image("src/plots/acc_loss_curves.png")}) 

230 

231 # Logging plots 

232 preds = model.predictions 

233 targets = model.targets 

234 

235 # Log metrics 

236 metrics_per_experiment = count_metrics(targets, preds) 

237 for key, value in metrics_per_experiment.items(): 

238 wandb.log({key: value}) 

239 

240 # Log metrics per class and classnames 

241 metrics_per_class = calculate_metrics_per_class(targets, preds) 

242 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()] 

243 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()] 

244 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()] 

245 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()] 

246 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()] 

247 names_and_labels = [[key, value] for key, value in label_map.items()] 

248 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)] 

249 

250 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics) 

251 wandb.log({"Classes": training_table}) 

252 

253 # Log confusion matrix, precision-recall curve and roc-auc curve 

254 get_confusion_matrix(preds, targets, class_names=list(label_map.keys())) 

255 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys())) 

256 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys())) 

257 

258 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"] 

259 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"] 

260 for filename, title in zip(filenames, titles, strict=False): 

261 wandb.log({title: wandb.Image(f"src/plots/{filename}")}) 

262 

263 wandb.finish() 

264 

265 

266if __name__ == "__main__": 

267 main()