Coverage for src/main.py: 0%

94 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-16 12:50 +0000

1import os 

2from pathlib import Path 

3 

4import kornia.augmentation as kaug 

5import torch 

6import wandb 

7import yaml 

8from pytorch_lightning import Trainer 

9from pytorch_lightning.callbacks import EarlyStopping 

10from pytorch_lightning.loggers import WandbLogger 

11 

12from models.classifier_module import ClassifierModule 

13from dataset import ForestDataModule, ForestDataset, OversampledDataset, UndersampledDataset 

14from callbacks import PrintMetricsCallback 

15from dataset_functions import download_data, load_dataset 

16from git_functions import get_git_branch, generate_short_hash 

17from counting_functions import calculate_metrics_per_class, count_metrics 

18from visualization_functions import (show_n_samples, plot_metrics, 

19 get_confusion_matrix, 

20 get_precision_recall_curve, 

21 get_roc_auc_curve) 

22 

23import torchvision 

24 

25 

26def main(): 

27 # Load configuration file 

28 with open("src/config.yaml", "r") as c: 

29 config = yaml.safe_load(c) 

30 

31 # Create a dedicated folder for the PureForest dataset to keep each tree species 

32 # organized, avoiding multiple directories in the main content folder. 

33 dataset_folder = Path.cwd() / config["dataset"]["folder"] 

34 dataset_folder.mkdir(exist_ok=True) 

35 

36 species_folders = config["dataset"]["species_folders"] 

37 main_subfolders = config["dataset"]["main_subfolders"] 

38 

39 # =========================== DATA LOADING AND PREPROCESSING ================================== # 

40 

41 download_data(species_folders, main_subfolders, dataset_folder) 

42 dataset, label_map = load_dataset(dataset_folder, species_folders) 

43 show_n_samples(dataset, species_folders) 

44 

45 # =========================== INITIALIZING DATA AND MODEL ================================== # 

46 batch_size = config["training"]["batch_size"] 

47 num_classes = len(label_map) 

48 learning_rate = config["training"]["learning_rate"] 

49 freeze = config["training"]["freeze"] 

50 weight_decay = config["training"]["weight_decay"] 

51 model_name = config["model"]["name"] 

52 image_size = 299 if model_name == "inception_v3" else 224 

53 transforms = kaug.Resize(size=(image_size, image_size)) 

54 

55 dataset_module = ForestDataset 

56 dataset_args = {} 

57 

58 if "oversample" in config["training"]: 

59 dataset_module = OversampledDataset 

60 dataset_args = { 

61 "minority_transform": torchvision.transforms.Compose([ 

62 torchvision.transforms.RandomHorizontalFlip(), 

63 torchvision.transforms.RandomVerticalFlip(), 

64 torchvision.transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(1, 1.2), shear=10), 

65 ]), 

66 "oversample_factor": config["training"]["oversample"]["oversample_factor"], 

67 "oversample_threshold": config["training"]["oversample"]["oversample_threshold"] 

68 } 

69 elif "undersample" in config["training"]: 

70 dataset_module = UndersampledDataset 

71 dataset_args = { 

72 "target_size": config["training"]["undersample"]["target_size"] 

73 } 

74 

75 datamodule = ForestDataModule( 

76 dataset['train'], 

77 dataset['val'], 

78 dataset['test'], 

79 dataset=dataset_module, 

80 dataset_args=dataset_args, 

81 batch_size=batch_size 

82 ) 

83 

84 model = ClassifierModule( 

85 model_name=model_name, 

86 num_classes=num_classes, 

87 freeze=freeze, 

88 transform=transforms, 

89 learning_rate=learning_rate, 

90 weight_decay=weight_decay 

91 ) 

92 

93 # ====================================== TRAINING ========================================== # 

94 max_epochs = config["training"]["max_epochs"] 

95 device = config["device"] if torch.cuda.is_available() else "cpu" 

96 callbacks = [PrintMetricsCallback()] 

97 

98 if config["training"]["early_stopping"]['apply']: 

99 callbacks.append(EarlyStopping(monitor=config["training"]["early_stopping"]['monitor'], 

100 patience=config["training"]["early_stopping"]['patience'], 

101 mode=config["training"]["early_stopping"]['mode'])) 

102 

103 branch_name = get_git_branch() 

104 short_hash = generate_short_hash() 

105 run_name = f'{branch_name}-{short_hash}' 

106 

107 wandb_api_key = os.environ.get('WANDB_API_KEY') 

108 wandb.login(key=wandb_api_key) 

109 wandb.init(project="ghost-irim", name=run_name) 

110 

111 # Log config.yaml to wandb 

112 wandb.save("src/config.yaml") 

113 

114 wandb_logger = WandbLogger( 

115 name=run_name, 

116 project='ghost-irim', 

117 log_model=True 

118 ) 

119 

120 torch.backends.cudnn.benchmark = True 

121 torch.backends.cudnn.enabled = True 

122 

123 trainer = Trainer( 

124 logger=wandb_logger, 

125 max_epochs=max_epochs, 

126 accelerator=device, 

127 devices=1, 

128 callbacks=callbacks 

129 ) 

130 

131 trainer.fit(model, datamodule) 

132 

133 # ====================================== TESTING ========================================== # 

134 trainer.test(model, datamodule=datamodule) 

135 

136 # Callbacks' service 

137 for callback in callbacks: 

138 if isinstance(callback, PrintMetricsCallback): 

139 train_metrics = callback.train_metrics 

140 val_metrics = callback.val_metrics 

141 plot_metrics(train_metrics, val_metrics) 

142 wandb.log({'Accuracy and Loss Curves': wandb.Image('src/plots/acc_loss_curves.png')}) 

143 

144 # Logging plots 

145 preds = model.predictions 

146 targets = model.targets 

147 

148 # Log metrics 

149 metrics_per_experiment = count_metrics(targets, preds) 

150 for key, value in metrics_per_experiment.items(): 

151 wandb.log({key: value}) 

152 

153 # Log metrics per class and classnames 

154 metrics_per_class = calculate_metrics_per_class(targets, preds) 

155 accs = [metrics_per_class[key]['accuracy'] for key in metrics_per_class.keys()] 

156 precs = [metrics_per_class[key]['precision'] for key in metrics_per_class.keys()] 

157 recs = [metrics_per_class[key]['recall'] for key in metrics_per_class.keys()] 

158 f1s = [metrics_per_class[key]['f1'] for key in metrics_per_class.keys()] 

159 ious = [metrics_per_class[key]['IoU'] for key in metrics_per_class.keys()] 

160 names_and_labels = [[key, value] for key, value in label_map.items()] 

161 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious)] 

162 

163 training_table = wandb.Table(columns=['Class name', 'Label', 'Accuracy', 'Precision', 'Recall', 'F1-score', 'IoU'], data=logged_metrics) 

164 wandb.log({'Classes': training_table}) 

165 

166 # Log confusion matrix, precision-recall curve and roc-auc curve 

167 get_confusion_matrix(preds, targets, class_names=list(label_map.keys())) 

168 get_roc_auc_curve(preds, targets, class_names=list(label_map.keys())) 

169 get_precision_recall_curve(preds, targets, class_names=list(label_map.keys())) 

170 

171 filenames = ['confusion_matrix.png', 'precision_recall_curve.png', 'roc_auc_curve.png'] 

172 titles = ['Confusion Matrix', 'Precision-Recall Curve', 'ROC AUC Curve'] 

173 for filename, title in zip(filenames, titles): 

174 wandb.log({title: wandb.Image(f'src/plots/{filename}')}) 

175 

176 wandb.finish() 

177 

178 

179if __name__ == "__main__": 

180 main()