Coverage for src/inference.py: 0%

248 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-04-08 16:37 +0000

1import os 

2from pathlib import Path 

3 

4import torch 

5import wandb 

6from omegaconf import OmegaConf 

7from tqdm import tqdm 

8from torchvision.utils import save_image 

9import numpy as np 

10from PIL import Image 

11 

12from models.classifier_module import ClassifierModule 

13from models.segmentation_wrapper import SegmentationWrapper 

14from dataset_functions import load_dataset 

15from dataset import ForestDataset 

16 

17# from transforms import Transforms 

18from counting_functions import calculate_metrics_per_class, count_metrics 

19from visualization_functions import get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve 

20import onnx 

21import json 

22import kornia.augmentation as kaug 

23from kornia import image_to_tensor 

24import torch.nn as nn 

25from torch.utils.data import Dataset 

26 

27 

28class InferenceTransform(nn.Module): 

29 def __init__(self, size): 

30 super().__init__() 

31 self.resize = kaug.Resize(size=size, keepdim=True) 

32 

33 @torch.no_grad() 

34 def forward(self, x): 

35 # Accept PIL image (dataset path), numpy array (raw path), or tensor. 

36 if isinstance(x, Image.Image): 

37 x = np.array(x) 

38 

39 if torch.is_tensor(x): 

40 if x.ndim == 3 and x.shape[0] in (1, 3, 4): 

41 x_t = x.unsqueeze(0).float() 

42 else: 

43 x_t = image_to_tensor(x.cpu().numpy(), keepdim=True).float() 

44 else: 

45 x_t = image_to_tensor(x, keepdim=True).float() 

46 

47 x_res = self.resize(x_t) 

48 return x_res.squeeze(0) 

49 

50 

51class RawImageDataset(Dataset): 

52 """Dataset for loading raw images without labels (for ONNX-style inference testing)""" 

53 

54 def __init__(self, image_dir, transform=None, image_extensions=(".png", ".jpg", ".jpeg", ".tiff", ".tif")): 

55 self.image_dir = Path(image_dir) 

56 self.transform = transform 

57 

58 # Get all image files 

59 self.image_paths = [] 

60 for ext in image_extensions: 

61 self.image_paths.extend(sorted(self.image_dir.glob(f"*{ext}"))) 

62 

63 print(f"Found {len(self.image_paths)} images in {image_dir}") 

64 

65 def __len__(self): 

66 return len(self.image_paths) 

67 

68 def __getitem__(self, idx): 

69 image_path = self.image_paths[idx] 

70 

71 # Load image as numpy array (similar to ONNX input) 

72 with Image.open(image_path) as img: 

73 if img.mode == "RGBA": 

74 image = img.convert("RGB") 

75 else: 

76 image = img.copy() 

77 

78 # Convert to numpy array [H, W, C] in [0, 255] range 

79 image_np = np.array(image) 

80 

81 # Apply transform if provided 

82 if self.transform: 

83 image_tensor = self.transform(image_np) 

84 else: 

85 image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float() 

86 

87 return image_tensor, str(image_path.name) 

88 

89 

90def download_checkpoint_from_wandb(artifact_path, project_name="ghost-irim"): 

91 print(f"Downloading checkpoint from W&B: {artifact_path}") 

92 

93 wandb_api_key = os.environ.get("WANDB_API_KEY") 

94 if wandb_api_key: 

95 wandb.login(key=wandb_api_key) 

96 

97 run = wandb.init(project=project_name, job_type="inference") 

98 

99 artifact = run.use_artifact(artifact_path, type="model") 

100 artifact_dir = artifact.download() 

101 

102 artifact_path_obj = Path(artifact_dir) 

103 checkpoint_files = list(artifact_path_obj.glob("*.ckpt")) 

104 

105 if not checkpoint_files: 

106 raise FileNotFoundError(f"No .ckpt file found in artifact directory: {artifact_dir}") 

107 

108 checkpoint_path = checkpoint_files[0] 

109 print(f"Checkpoint downloaded to: {checkpoint_path}") 

110 

111 run.finish() 

112 

113 return checkpoint_path 

114 

115 

116def main(): 

117 # =========================== CONFIG & SETUP ================================== # 

118 config = OmegaConf.load("src/config.yaml") 

119 

120 config_device = config.device 

121 if config_device in ["gpu", "cuda"]: 

122 device = "cuda" if torch.cuda.is_available() else "cpu" 

123 else: 

124 device = "cpu" 

125 print(f"Using device: {device}") 

126 

127 model_name = config.model.name 

128 mask_size = config.inference.get("mask_size", 224) 

129 image_size = 299 if model_name == "inception_v3" else 224 

130 transforms = InferenceTransform(size=(image_size, image_size)) 

131 

132 # =========================== DATA LOADING ===================================== # 

133 use_raw_images = config.inference.get("use_raw_images", False) 

134 print(f"Use raw images for inference: {use_raw_images}") 

135 

136 if use_raw_images: 

137 # Load raw images from specified directory 

138 raw_images_path = config.inference.get("raw_images_path", "test_images_qgis/10122025_193353") 

139 raw_images_dir = Path.cwd() / raw_images_path 

140 

141 if not raw_images_dir.exists(): 

142 raise FileNotFoundError(f"Raw images directory not found: {raw_images_dir}") 

143 

144 print(f"\n{'=' * 60}") 

145 print(f"Running inference on RAW IMAGES (ONNX-style testing)") 

146 print(f"Directory: {raw_images_dir}") 

147 print(f"{'=' * 60}\n") 

148 

149 test_dataset = RawImageDataset(raw_images_dir, transform=transforms) 

150 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) 

151 

152 # Load label map for model initialization (but no ground truth labels available) 

153 dataset_folder = Path.cwd() / config.dataset.folder 

154 dataset_folder.mkdir(exist_ok=True) 

155 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders) 

156 num_classes = len(label_map) 

157 has_labels = False 

158 else: 

159 # Load test dataset with labels 

160 dataset_folder = Path.cwd() / config.dataset.folder 

161 dataset_folder.mkdir(exist_ok=True) 

162 

163 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders) 

164 

165 test_data = dataset["test"] 

166 test_dataset = ForestDataset(test_data["paths"], test_data["labels"], transform=transforms) 

167 

168 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2) 

169 

170 num_classes = len(label_map) 

171 has_labels = True 

172 

173 # =========================== MODEL LOADING ==================================== # 

174 local_checkpoint_path = Path(config.inference.get("checkpoint_path", "checkpoints/model.ckpt")) 

175 wandb_artifact = config.inference.get("wandb_artifact", None) 

176 

177 if local_checkpoint_path.exists(): 

178 checkpoint_path = local_checkpoint_path 

179 elif wandb_artifact: 

180 wandb_project = config.inference.get("wandb_project", "ghost-irim") 

181 checkpoint_path = download_checkpoint_from_wandb(wandb_artifact, wandb_project) 

182 else: 

183 raise FileNotFoundError(f"Checkpoint not found at {local_checkpoint_path}. Set inference.checkpoint_path to a local .ckpt path or set inference.wandb_artifact in config.yaml.") 

184 

185 print(f"Loading model from: {checkpoint_path}") 

186 

187 # Reuse training-time class-index mapping if it exists next to the checkpoint. 

188 label_map_path = checkpoint_path.with_suffix(".label_map.json") 

189 if label_map_path.exists(): 

190 with open(label_map_path, "r") as f: 

191 label_map = json.load(f) 

192 num_classes = len(label_map) 

193 print(f"Loaded label map from: {label_map_path}") 

194 

195 classifier = ClassifierModule.load_from_checkpoint( 

196 checkpoint_path, 

197 model_name=model_name, 

198 num_classes=num_classes, 

199 ) 

200 classifier = classifier.to(device).eval() 

201 

202 mean = [0.485, 0.456, 0.406] 

203 std = [0.229, 0.224, 0.225] 

204 

205 seg_model = SegmentationWrapper( 

206 classifier, 

207 mask_size=mask_size, 

208 mean=mean, # TODO: fix 

209 std=std, # TODO: fix 

210 input_rescale=True, # Expects 0-255 input, scales to 0-1 internally 

211 ).to(device) 

212 seg_model.eval() 

213 

214 # =========================== EXPORT TO ONNX =================================== # 

215 if config.inference.get("export_onnx", False): 

216 dummy_input = torch.randn(1, 3, image_size, image_size, device=device) 

217 onnx_path = Path("segmentation_model.onnx") 

218 torch.onnx.export( 

219 seg_model, 

220 dummy_input, 

221 onnx_path, 

222 input_names=["input"], 

223 output_names=["mask"], 

224 opset_version=17, 

225 dynamic_axes={"input": {0: "batch_size"}, "mask": {0: "batch_size"}}, 

226 do_constant_folding=True, 

227 ) 

228 print(f"Exported model to {onnx_path.resolve()}") 

229 

230 # Add metadata 

231 model_onnx = onnx.load(onnx_path) 

232 

233 class_names = {v: k for k, v in label_map.items()} 

234 

235 def add_meta(key, value): 

236 meta = model_onnx.metadata_props.add() 

237 meta.key = key 

238 meta.value = json.dumps(value) 

239 

240 add_meta("model_type", "Segmentor") 

241 add_meta("class_names", class_names) 

242 add_meta("resolution", 20) 

243 add_meta("tiles_size", image_size) 

244 add_meta("tiles_overlap", 0) 

245 

246 onnx.save(model_onnx, onnx_path) 

247 

248 if wandb.run is not None: 

249 onnx_artifact = wandb.Artifact( 

250 name=f"segmentation-model-{model_name}", 

251 type="model", 

252 description=f"ONNX segmentation model ({model_name}, {num_classes} classes)", 

253 metadata={ 

254 "model_name": model_name, 

255 "num_classes": num_classes, 

256 "image_size": image_size, 

257 "format": "onnx", 

258 "opset_version": 17, 

259 }, 

260 ) 

261 onnx_artifact.add_file(str(onnx_path)) 

262 wandb.log_artifact(onnx_artifact) 

263 print(f"ONNX model uploaded to W&B artifacts as 'segmentation-model-{model_name}'") 

264 else: 

265 print("Warning: W&B run not initialized. ONNX model not uploaded to artifacts.") 

266 

267 # =========================== INFERENCE LOOP =================================== # 

268 print(f"Running inference on {len(test_loader)} samples...") 

269 all_preds = [] 

270 all_targets = [] 

271 all_filenames = [] 

272 

273 # Debug: check first batch input values 

274 debug_first_batch = True 

275 

276 with torch.no_grad(): 

277 for batch in tqdm(test_loader): 

278 if has_labels: 

279 imgs, labels = batch 

280 imgs = imgs.to(device) 

281 labels = labels.to(device) 

282 all_targets.append(labels) 

283 else: 

284 imgs, filenames = batch 

285 imgs = imgs.to(device) 

286 all_filenames.extend(filenames) 

287 

288 # Debug: print input statistics for first batch 

289 if debug_first_batch: 

290 print(f"\n[DEBUG] Input tensor stats:") 

291 print(f" Shape: {imgs.shape}") 

292 print(f" dtype: {imgs.dtype}") 

293 print(f" Min: {imgs.min().item():.4f}, Max: {imgs.max().item():.4f}") 

294 print(f" Mean: {imgs.mean().item():.4f}, Std: {imgs.std().item():.4f}") 

295 

296 # Check what SegmentationWrapper receives after rescale 

297 if seg_model.input_rescale: 

298 rescaled = imgs / 255.0 

299 print(f"\n[DEBUG] After /255 rescale:") 

300 print(f" Min: {rescaled.min().item():.4f}, Max: {rescaled.max().item():.4f}") 

301 

302 normalized = (rescaled - seg_model.mean) / seg_model.std 

303 print(f"\n[DEBUG] After normalization:") 

304 print(f" Min: {normalized.min().item():.4f}, Max: {normalized.max().item():.4f}") 

305 

306 debug_first_batch = False 

307 

308 masks = seg_model(imgs) 

309 

310 probs = masks[:, :, 0, 0] 

311 

312 all_preds.append(probs) 

313 

314 all_preds = torch.cat(all_preds, dim=0) 

315 if has_labels: 

316 all_targets = torch.cat(all_targets, dim=0) 

317 

318 # =========================== METRICS & LOGGING ================================ # 

319 if has_labels: 

320 metrics_per_experiment = count_metrics(all_targets, all_preds) 

321 print("\n" + "=" * 60) 

322 print("LOCAL TEST METRICS") 

323 print("=" * 60) 

324 for key, value in metrics_per_experiment.items(): 

325 print(f"{key}: {value:.6f}") 

326 

327 metrics_per_class = calculate_metrics_per_class(all_targets, all_preds) 

328 print("\nPer-class metrics:") 

329 reverse_label_map = {v: k for k, v in label_map.items()} 

330 for class_id in sorted(metrics_per_class.keys()): 

331 class_name = reverse_label_map.get(class_id, str(class_id)) 

332 class_metrics = metrics_per_class[class_id] 

333 print( 

334 f"- {class_name}: " 

335 f"acc={class_metrics['accuracy']:.4f}, " 

336 f"prec={class_metrics['precision']:.4f}, " 

337 f"rec={class_metrics['recall']:.4f}, " 

338 f"f1={class_metrics['f1']:.4f}, " 

339 f"iou={class_metrics['IoU']:.4f}" 

340 ) 

341 print("=" * 60) 

342 

343 names_and_labels = [[key, value] for key, value in label_map.items()] 

344 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()] 

345 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()] 

346 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()] 

347 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()] 

348 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()] 

349 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)] 

350 

351 plots_dir = Path("src/plots") 

352 plots_dir.mkdir(exist_ok=True, parents=True) 

353 get_confusion_matrix(all_preds, all_targets, class_names=list(label_map.keys())) 

354 get_roc_auc_curve(all_preds, all_targets, class_names=list(label_map.keys())) 

355 get_precision_recall_curve(all_preds, all_targets, class_names=list(label_map.keys())) 

356 

357 if wandb.run is not None: 

358 for key, value in metrics_per_experiment.items(): 

359 wandb.log({key: value}) 

360 

361 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics) 

362 wandb.log({"Classes": training_table}) 

363 

364 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"] 

365 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"] 

366 for filename, title in zip(filenames, titles, strict=False): 

367 wandb.log({title: wandb.Image(f"src/plots/{filename}")}) 

368 else: 

369 print("W&B run not active. Logged metrics locally only.") 

370 else: 

371 # For raw images, just output predictions 

372 print("\n" + "=" * 60) 

373 print("RAW IMAGE INFERENCE RESULTS") 

374 print("=" * 60) 

375 

376 # Get predicted classes 

377 pred_classes = torch.argmax(all_preds, dim=1) 

378 

379 # Create reverse label map 

380 reverse_label_map = {v: k for k, v in label_map.items()} 

381 

382 # Save predictions to file 

383 output_dir = Path("src/plots/raw_predictions") 

384 output_dir.mkdir(exist_ok=True, parents=True) 

385 

386 results_file = output_dir / "predictions.txt" 

387 with open(results_file, "w") as f: 

388 f.write("Filename\tPredicted_Class\tPredicted_Label\tConfidences\n") 

389 for i, (filename, pred_class, probs) in enumerate(zip(all_filenames, pred_classes, all_preds)): 

390 pred_label = reverse_label_map[pred_class.item()] 

391 confidence_str = "\t".join([f"{reverse_label_map[j]}:{probs[j].item():.4f}" for j in range(num_classes)]) 

392 f.write(f"{filename}\t{pred_class.item()}\t{pred_label}\t{confidence_str}\n") 

393 

394 # Print first 10 predictions 

395 if i < 10: 

396 print(f"{filename}: {pred_label} (confidence: {probs[pred_class].item():.4f})") 

397 

398 print(f"\n...") 

399 print(f"\nFull predictions saved to: {results_file}") 

400 print(f"Total images processed: {len(all_filenames)}") 

401 print("=" * 60) 

402 

403 

404if __name__ == "__main__": 

405 main()