Coverage for src/inference.py: 0%

138 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-01-07 20:17 +0000

1import os 

2from pathlib import Path 

3 

4import torch 

5import wandb 

6from omegaconf import OmegaConf 

7from tqdm import tqdm 

8from torchvision.utils import save_image 

9 

10from models.classifier_module import ClassifierModule 

11from models.segmentation_wrapper import SegmentationWrapper 

12from dataset_functions import load_dataset 

13from dataset import ForestDataset 

14from transforms import Transforms 

15from counting_functions import calculate_metrics_per_class, count_metrics 

16from visualization_functions import get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve 

17import onnx 

18import json 

19import kornia.augmentation as kaug 

20from kornia import image_to_tensor 

21import torch.nn as nn 

22 

23 

24class InferenceTransform(nn.Module): 

25 def __init__(self, size): 

26 super().__init__() 

27 self.resize = kaug.Resize(size=size, keepdim=True) 

28 

29 @torch.no_grad() 

30 def forward(self, x): 

31 # x is numpy array (H, W, C) 

32 x_t = image_to_tensor(x, keepdim=True).float() 

33 x_res = self.resize(x_t) 

34 return x_res.squeeze(0) 

35 

36 

37def download_checkpoint_from_wandb(artifact_path, project_name="ghost-irim"): 

38 print(f"Downloading checkpoint from W&B: {artifact_path}") 

39 

40 wandb_api_key = os.environ.get("WANDB_API_KEY") 

41 if wandb_api_key: 

42 wandb.login(key=wandb_api_key) 

43 

44 run = wandb.init(project=project_name, job_type="inference") 

45 

46 artifact = run.use_artifact(artifact_path, type="model") 

47 artifact_dir = artifact.download() 

48 

49 artifact_path_obj = Path(artifact_dir) 

50 checkpoint_files = list(artifact_path_obj.glob("*.ckpt")) 

51 

52 if not checkpoint_files: 

53 raise FileNotFoundError(f"No .ckpt file found in artifact directory: {artifact_dir}") 

54 

55 checkpoint_path = checkpoint_files[0] 

56 print(f"Checkpoint downloaded to: {checkpoint_path}") 

57 

58 return checkpoint_path 

59 

60 

61def main(): 

62 # =========================== CONFIG & SETUP ================================== # 

63 config = OmegaConf.load("src/config.yaml") 

64 

65 config_device = config.device 

66 if config_device in ["gpu", "cuda"]: 

67 device = "cuda" if torch.cuda.is_available() else "cpu" 

68 else: 

69 device = "cpu" 

70 print(f"Using device: {device}") 

71 

72 model_name = config.model.name 

73 mask_size = config.inference.get("mask_size", 224) 

74 image_size = 299 if model_name == "inception_v3" else 224 

75 transforms = InferenceTransform(size=(image_size, image_size)) 

76 

77 # =========================== DATA LOADING ===================================== # 

78 dataset_folder = Path.cwd() / config.dataset.folder 

79 dataset_folder.mkdir(exist_ok=True) 

80 

81 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders) 

82 

83 test_data = dataset["test"] 

84 test_dataset = ForestDataset(test_data["paths"], test_data["labels"], transform=transforms) 

85 

86 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) 

87 

88 num_classes = len(label_map) 

89 

90 # =========================== MODEL LOADING ==================================== # 

91 wandb_artifact = config.inference.get("wandb_artifact", None) 

92 

93 if wandb_artifact: 

94 wandb_project = config.inference.get("wandb_project", "ghost-irim") 

95 checkpoint_path = download_checkpoint_from_wandb(wandb_artifact, wandb_project) 

96 else: 

97 raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please set 'wandb_artifact' in config.yaml to download from W&B, or ensure the local checkpoint exists.") 

98 

99 print(f"Loading model from: {checkpoint_path}") 

100 

101 classifier = ClassifierModule.load_from_checkpoint( 

102 checkpoint_path, 

103 model_name=model_name, 

104 num_classes=num_classes, 

105 ) 

106 classifier = classifier.to(device).eval() 

107 

108 norm_mean = [0.5, 0.5, 0.5] 

109 norm_std = [0.5, 0.5, 0.5] 

110 

111 seg_model = SegmentationWrapper( 

112 classifier, 

113 mask_size=mask_size, 

114 mean=None, # TODO: fix 

115 std=None, # TODO: fix 

116 input_rescale=True, # Expects 0-255 input, scales to 0-1 internally 

117 ).to(device) 

118 seg_model.eval() 

119 

120 # =========================== EXPORT TO ONNX =================================== # 

121 if config.inference.get("export_onnx", False): 

122 dummy_input = torch.randn(1, 3, image_size, image_size, device=device) 

123 onnx_path = Path("segmentation_model.onnx") 

124 torch.onnx.export( 

125 seg_model, 

126 dummy_input, 

127 onnx_path, 

128 input_names=["input"], 

129 output_names=["mask"], 

130 opset_version=17, 

131 dynamic_axes={"input": {0: "batch_size"}, "mask": {0: "batch_size"}}, 

132 do_constant_folding=True, 

133 ) 

134 print(f"Exported model to {onnx_path.resolve()}") 

135 

136 # Add metadata 

137 model_onnx = onnx.load(onnx_path) 

138 

139 class_names = {v: k for k, v in label_map.items()} 

140 

141 def add_meta(key, value): 

142 meta = model_onnx.metadata_props.add() 

143 meta.key = key 

144 meta.value = json.dumps(value) 

145 

146 add_meta("model_type", "Segmentor") 

147 add_meta("class_names", class_names) 

148 add_meta("resolution", 20) 

149 add_meta("tiles_size", image_size) 

150 add_meta("tiles_overlap", 0) 

151 

152 onnx.save(model_onnx, onnx_path) 

153 

154 if wandb.run is not None: 

155 onnx_artifact = wandb.Artifact( 

156 name=f"segmentation-model-{model_name}", 

157 type="model", 

158 description=f"ONNX segmentation model ({model_name}, {num_classes} classes)", 

159 metadata={ 

160 "model_name": model_name, 

161 "num_classes": num_classes, 

162 "image_size": image_size, 

163 "format": "onnx", 

164 "opset_version": 17, 

165 }, 

166 ) 

167 onnx_artifact.add_file(str(onnx_path)) 

168 wandb.log_artifact(onnx_artifact) 

169 print(f"ONNX model uploaded to W&B artifacts as 'segmentation-model-{model_name}'") 

170 else: 

171 print("Warning: W&B run not initialized. ONNX model not uploaded to artifacts.") 

172 

173 # =========================== INFERENCE LOOP =================================== # 

174 print(f"Running inference on {len(test_loader)} samples...") 

175 all_preds = [] 

176 all_targets = [] 

177 

178 with torch.no_grad(): 

179 for batch in tqdm(test_loader): 

180 imgs, labels = batch 

181 imgs = imgs.to(device) 

182 labels = labels.to(device) 

183 

184 masks = seg_model(imgs) 

185 

186 probs = masks[:, :, 0, 0] 

187 

188 all_preds.append(probs) 

189 all_targets.append(labels) 

190 

191 all_preds = torch.cat(all_preds, dim=0) 

192 all_targets = torch.cat(all_targets, dim=0) 

193 

194 # =========================== METRICS & LOGGING ================================ # 

195 if wandb.run is not None: 

196 print("Calculating and logging metrics...") 

197 

198 metrics_per_experiment = count_metrics(all_targets, all_preds) 

199 print(f"Test Metrics: {metrics_per_experiment}") 

200 for key, value in metrics_per_experiment.items(): 

201 wandb.log({key: value}) 

202 

203 metrics_per_class = calculate_metrics_per_class(all_targets, all_preds) 

204 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()] 

205 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()] 

206 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()] 

207 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()] 

208 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()] 

209 names_and_labels = [[key, value] for key, value in label_map.items()] 

210 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)] 

211 

212 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics) 

213 wandb.log({"Classes": training_table}) 

214 

215 plots_dir = Path("src/plots") 

216 plots_dir.mkdir(exist_ok=True, parents=True) 

217 

218 get_confusion_matrix(all_preds, all_targets, class_names=list(label_map.keys())) 

219 get_roc_auc_curve(all_preds, all_targets, class_names=list(label_map.keys())) 

220 get_precision_recall_curve(all_preds, all_targets, class_names=list(label_map.keys())) 

221 

222 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"] 

223 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"] 

224 for filename, title in zip(filenames, titles, strict=False): 

225 wandb.log({title: wandb.Image(f"src/plots/{filename}")}) 

226 else: 

227 print("W&B run not active. Skipping metrics logging.") 

228 

229 

230if __name__ == "__main__": 

231 main()