Coverage for src/main.py: 0%

105 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 17:37 +0000

1import os 

2from pathlib import Path 

3 

4import kornia.augmentation as kaug 

5import torch 

6import wandb 

7from omegaconf import OmegaConf 

8from pytorch_lightning import Trainer 

9from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint 

10from pytorch_lightning.loggers import WandbLogger 

11from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset, CurriculumLearningDataset 

12from callbacks import PrintMetricsCallback, CurriculumLearningCallback 

13from models.classifier_module import ClassifierModule 

14from dataset_functions import load_dataset 

15from git_functions import get_git_branch, generate_short_hash 

16from counting_functions import calculate_metrics_per_class, count_metrics 

17from visualization_functions import show_n_samples, plot_metrics, get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve 

18 

19import torchvision 

20import math 

21 

22CONFIG_PATH = "src/config.yaml" 

23 

24 

25def main(): 

26 # Load configuration file 

27 config = OmegaConf.load(CONFIG_PATH) 

28 

29 # Create a dedicated folder for the PureForest dataset to keep each tree species 

30 # organized, avoiding multiple directories in the main content folder. 

31 dataset_folder = Path.cwd() / config.dataset.folder 

32 dataset_folder.mkdir(exist_ok=True) 

33 

34 # =========================== DATA LOADING AND PREPROCESSING ================================== # 

35 

36 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders) 

37 show_n_samples(dataset, config.dataset.species_folders) 

38 

39 # =========================== INITIALIZING DATA AND MODEL ================================== # 

40 num_classes = len(label_map) 

41 class_weights = config.training.get("class_weights", None) 

42 

43 if config.training.get("class_weights", None) and (config.training.get("oversample", None) or config.training.get("undersample", None)): 

44 raise ValueError("Can't use class weights and resampling at the same time.") 

45 image_size = 299 if config.model.name == "inception_v3" else 224 

46 transforms = kaug.Resize(size=(image_size, image_size)) 

47 

48 dataset_module = ForestDataset 

49 dataset_args = {} 

50 

51 if config.training.get("oversample", None): 

52 dataset_module = OversampledDataset 

53 dataset_args = { 

54 "minority_transform": torchvision.transforms.Compose( 

55 [ 

56 torchvision.transforms.RandomHorizontalFlip(), 

57 torchvision.transforms.RandomVerticalFlip(), 

58 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10), 

59 ] 

60 ), 

61 "oversample_factor": config.training.oversample.oversample_factor, 

62 "oversample_threshold": config.training.oversample.oversample_threshold, 

63 } 

64 elif config.training.get("undersample", None): 

65 dataset_module = UndersampledDataset 

66 dataset_args = {"target_size": config.training.undersample.target_size} 

67 elif config.training.get("curriculum_learning", None): 

68 dataset_module = CurriculumLearningDataset 

69 dataset_args = { 

70 # The list cannot be empty, since the dataloder doesn't accept empty dataset 

71 "indices": [0] 

72 } 

73 

74 datamodule = ForestDataModule( 

75 dataset["train"], 

76 dataset["val"], 

77 dataset["test"], 

78 dataset=dataset_module, 

79 dataset_args=dataset_args, 

80 batch_size=config.training.batch_size, 

81 ) 

82 

83 model = ClassifierModule( 

84 model_name=config.model.name, 

85 num_classes=num_classes, 

86 step_size=config["training"]["step_size"], 

87 gamma=config["training"]["gamma"], 

88 freeze=config.training.freeze, 

89 transform=transforms, 

90 weight=torch.tensor(class_weights, dtype=torch.float) if class_weights is not None else None, 

91 learning_rate=config.training.learning_rate, 

92 weight_decay=config.training.weight_decay, 

93 ) 

94 

95 # ====================================== TRAINING ========================================== # 

96 device = config.device if torch.cuda.is_available() else "cpu" 

97 callbacks = [PrintMetricsCallback()] 

98 

99 if config.training.early_stopping.apply: 

100 callbacks.append(EarlyStopping(monitor=config.training.early_stopping.monitor, patience=config.training.early_stopping.patience, mode=config.training.early_stopping.mode)) 

101 callbacks.append(ModelCheckpoint(monitor="val_loss", mode="min", save_top_k=1, save_last=False, dirpath=config.training.get("checkpoint_dir", "checkpoints/"))) 

102 

103 if config.training.get("curriculum_learning", None): 

104 callbacks.append( 

105 CurriculumLearningCallback( 

106 config.training.curriculum_learning.get("initial_ratio", None), 

107 config.training.curriculum_learning.get("step_size", None), 

108 config.training.curriculum_learning.get("class_order", None), 

109 dataset["train"]["labels"], 

110 ) 

111 ) 

112 min_epochs = math.ceil(num_classes / config.training.curriculum_learning.initial_ratio) * config.training.step_size 

113 else: 

114 min_epochs = None 

115 step_size = 0 

116 

117 branch_name = get_git_branch() 

118 short_hash = generate_short_hash() 

119 run_name = f"{branch_name}-{short_hash}" 

120 

121 wandb_api_key = os.environ.get("WANDB_API_KEY") 

122 wandb.login(key=wandb_api_key) 

123 wandb.init(project="ghost-irim", name=run_name) 

124 

125 # Log config.yaml to wandb 

126 wandb.save("src/config.yaml") 

127 

128 wandb_logger = WandbLogger(name=run_name, project="ghost-irim", log_model=True) 

129 

130 torch.backends.cudnn.benchmark = True 

131 torch.backends.cudnn.enabled = True 

132 

133 trainer = Trainer( 

134 logger=wandb_logger, min_epochs=min_epochs, max_epochs=config.training.max_epochs, accelerator=device, devices=1, callbacks=callbacks, reload_dataloaders_every_n_epochs=step_size 

135 ) 

136 

137 trainer.fit(model, datamodule) 

138 

139 # ====================================== TESTING ========================================== # 

140 # Retrieve the best checkpoint path from the ModelCheckpoint callback 

141 best_ckpt_path = None 

142 for callback in callbacks: 

143 if isinstance(callback, ModelCheckpoint): 

144 best_ckpt_path = callback.best_model_path 

145 break 

146 

147 if not best_ckpt_path: 

148 raise ValueError("No ModelCheckpoint callback found or no best checkpoint available.") 

149 

150 trainer.test(model, datamodule=datamodule, ckpt_path=best_ckpt_path) 

151 # Callbacks' service 

152 for callback in callbacks: 

153 if isinstance(callback, PrintMetricsCallback): 

154 train_metrics = callback.train_metrics 

155 val_metrics = callback.val_metrics 

156 plot_metrics(train_metrics, val_metrics) 

157 wandb.log({"Accuracy and Loss Curves": wandb.Image("src/plots/acc_loss_curves.png")}) 

158 

159 # Logging plots 

160 preds = model.predictions 

161 targets = model.targets 

162 

163 # Log metrics 

164 metrics_per_experiment = count_metrics(targets, preds) 

165 for key, value in metrics_per_experiment.items(): 

166 wandb.log({key: value}) 

167 

168 # Log metrics per class and classnames 

169 metrics_per_class = calculate_metrics_per_class(targets, preds) 

170 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()] 

171 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()] 

172 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()] 

173 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()] 

174 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()] 

175 names_and_labels = [[key, value] for key, value in label_map.items()] 

176 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)] 

177 

178 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics) 

179 wandb.log({"Classes": training_table}) 

180 

181 # Log confusion matrix, precision-recall curve and roc-auc curve 

182 get_confusion_matrix(preds, targets, class_names=list(label_map.keys())) 

183 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys())) 

184 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys())) 

185 

186 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"] 

187 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"] 

188 for filename, title in zip(filenames, titles, strict=False): 

189 wandb.log({title: wandb.Image(f"src/plots/{filename}")}) 

190 

191 wandb.finish() 

192 

193 

194if __name__ == "__main__": 

195 main()