Coverage for src/inference.py: 0%
248 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-08 16:37 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-08 16:37 +0000
1import os
2from pathlib import Path
4import torch
5import wandb
6from omegaconf import OmegaConf
7from tqdm import tqdm
8from torchvision.utils import save_image
9import numpy as np
10from PIL import Image
12from models.classifier_module import ClassifierModule
13from models.segmentation_wrapper import SegmentationWrapper
14from dataset_functions import load_dataset
15from dataset import ForestDataset
17# from transforms import Transforms
18from counting_functions import calculate_metrics_per_class, count_metrics
19from visualization_functions import get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve
20import onnx
21import json
22import kornia.augmentation as kaug
23from kornia import image_to_tensor
24import torch.nn as nn
25from torch.utils.data import Dataset
28class InferenceTransform(nn.Module):
29 def __init__(self, size):
30 super().__init__()
31 self.resize = kaug.Resize(size=size, keepdim=True)
33 @torch.no_grad()
34 def forward(self, x):
35 # Accept PIL image (dataset path), numpy array (raw path), or tensor.
36 if isinstance(x, Image.Image):
37 x = np.array(x)
39 if torch.is_tensor(x):
40 if x.ndim == 3 and x.shape[0] in (1, 3, 4):
41 x_t = x.unsqueeze(0).float()
42 else:
43 x_t = image_to_tensor(x.cpu().numpy(), keepdim=True).float()
44 else:
45 x_t = image_to_tensor(x, keepdim=True).float()
47 x_res = self.resize(x_t)
48 return x_res.squeeze(0)
51class RawImageDataset(Dataset):
52 """Dataset for loading raw images without labels (for ONNX-style inference testing)"""
54 def __init__(self, image_dir, transform=None, image_extensions=(".png", ".jpg", ".jpeg", ".tiff", ".tif")):
55 self.image_dir = Path(image_dir)
56 self.transform = transform
58 # Get all image files
59 self.image_paths = []
60 for ext in image_extensions:
61 self.image_paths.extend(sorted(self.image_dir.glob(f"*{ext}")))
63 print(f"Found {len(self.image_paths)} images in {image_dir}")
65 def __len__(self):
66 return len(self.image_paths)
68 def __getitem__(self, idx):
69 image_path = self.image_paths[idx]
71 # Load image as numpy array (similar to ONNX input)
72 with Image.open(image_path) as img:
73 if img.mode == "RGBA":
74 image = img.convert("RGB")
75 else:
76 image = img.copy()
78 # Convert to numpy array [H, W, C] in [0, 255] range
79 image_np = np.array(image)
81 # Apply transform if provided
82 if self.transform:
83 image_tensor = self.transform(image_np)
84 else:
85 image_tensor = torch.from_numpy(image_np).permute(2, 0, 1).float()
87 return image_tensor, str(image_path.name)
90def download_checkpoint_from_wandb(artifact_path, project_name="ghost-irim"):
91 print(f"Downloading checkpoint from W&B: {artifact_path}")
93 wandb_api_key = os.environ.get("WANDB_API_KEY")
94 if wandb_api_key:
95 wandb.login(key=wandb_api_key)
97 run = wandb.init(project=project_name, job_type="inference")
99 artifact = run.use_artifact(artifact_path, type="model")
100 artifact_dir = artifact.download()
102 artifact_path_obj = Path(artifact_dir)
103 checkpoint_files = list(artifact_path_obj.glob("*.ckpt"))
105 if not checkpoint_files:
106 raise FileNotFoundError(f"No .ckpt file found in artifact directory: {artifact_dir}")
108 checkpoint_path = checkpoint_files[0]
109 print(f"Checkpoint downloaded to: {checkpoint_path}")
111 run.finish()
113 return checkpoint_path
116def main():
117 # =========================== CONFIG & SETUP ================================== #
118 config = OmegaConf.load("src/config.yaml")
120 config_device = config.device
121 if config_device in ["gpu", "cuda"]:
122 device = "cuda" if torch.cuda.is_available() else "cpu"
123 else:
124 device = "cpu"
125 print(f"Using device: {device}")
127 model_name = config.model.name
128 mask_size = config.inference.get("mask_size", 224)
129 image_size = 299 if model_name == "inception_v3" else 224
130 transforms = InferenceTransform(size=(image_size, image_size))
132 # =========================== DATA LOADING ===================================== #
133 use_raw_images = config.inference.get("use_raw_images", False)
134 print(f"Use raw images for inference: {use_raw_images}")
136 if use_raw_images:
137 # Load raw images from specified directory
138 raw_images_path = config.inference.get("raw_images_path", "test_images_qgis/10122025_193353")
139 raw_images_dir = Path.cwd() / raw_images_path
141 if not raw_images_dir.exists():
142 raise FileNotFoundError(f"Raw images directory not found: {raw_images_dir}")
144 print(f"\n{'=' * 60}")
145 print(f"Running inference on RAW IMAGES (ONNX-style testing)")
146 print(f"Directory: {raw_images_dir}")
147 print(f"{'=' * 60}\n")
149 test_dataset = RawImageDataset(raw_images_dir, transform=transforms)
150 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)
152 # Load label map for model initialization (but no ground truth labels available)
153 dataset_folder = Path.cwd() / config.dataset.folder
154 dataset_folder.mkdir(exist_ok=True)
155 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders)
156 num_classes = len(label_map)
157 has_labels = False
158 else:
159 # Load test dataset with labels
160 dataset_folder = Path.cwd() / config.dataset.folder
161 dataset_folder.mkdir(exist_ok=True)
163 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders)
165 test_data = dataset["test"]
166 test_dataset = ForestDataset(test_data["paths"], test_data["labels"], transform=transforms)
168 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)
170 num_classes = len(label_map)
171 has_labels = True
173 # =========================== MODEL LOADING ==================================== #
174 local_checkpoint_path = Path(config.inference.get("checkpoint_path", "checkpoints/model.ckpt"))
175 wandb_artifact = config.inference.get("wandb_artifact", None)
177 if local_checkpoint_path.exists():
178 checkpoint_path = local_checkpoint_path
179 elif wandb_artifact:
180 wandb_project = config.inference.get("wandb_project", "ghost-irim")
181 checkpoint_path = download_checkpoint_from_wandb(wandb_artifact, wandb_project)
182 else:
183 raise FileNotFoundError(f"Checkpoint not found at {local_checkpoint_path}. Set inference.checkpoint_path to a local .ckpt path or set inference.wandb_artifact in config.yaml.")
185 print(f"Loading model from: {checkpoint_path}")
187 # Reuse training-time class-index mapping if it exists next to the checkpoint.
188 label_map_path = checkpoint_path.with_suffix(".label_map.json")
189 if label_map_path.exists():
190 with open(label_map_path, "r") as f:
191 label_map = json.load(f)
192 num_classes = len(label_map)
193 print(f"Loaded label map from: {label_map_path}")
195 classifier = ClassifierModule.load_from_checkpoint(
196 checkpoint_path,
197 model_name=model_name,
198 num_classes=num_classes,
199 )
200 classifier = classifier.to(device).eval()
202 mean = [0.485, 0.456, 0.406]
203 std = [0.229, 0.224, 0.225]
205 seg_model = SegmentationWrapper(
206 classifier,
207 mask_size=mask_size,
208 mean=mean, # TODO: fix
209 std=std, # TODO: fix
210 input_rescale=True, # Expects 0-255 input, scales to 0-1 internally
211 ).to(device)
212 seg_model.eval()
214 # =========================== EXPORT TO ONNX =================================== #
215 if config.inference.get("export_onnx", False):
216 dummy_input = torch.randn(1, 3, image_size, image_size, device=device)
217 onnx_path = Path("segmentation_model.onnx")
218 torch.onnx.export(
219 seg_model,
220 dummy_input,
221 onnx_path,
222 input_names=["input"],
223 output_names=["mask"],
224 opset_version=17,
225 dynamic_axes={"input": {0: "batch_size"}, "mask": {0: "batch_size"}},
226 do_constant_folding=True,
227 )
228 print(f"Exported model to {onnx_path.resolve()}")
230 # Add metadata
231 model_onnx = onnx.load(onnx_path)
233 class_names = {v: k for k, v in label_map.items()}
235 def add_meta(key, value):
236 meta = model_onnx.metadata_props.add()
237 meta.key = key
238 meta.value = json.dumps(value)
240 add_meta("model_type", "Segmentor")
241 add_meta("class_names", class_names)
242 add_meta("resolution", 20)
243 add_meta("tiles_size", image_size)
244 add_meta("tiles_overlap", 0)
246 onnx.save(model_onnx, onnx_path)
248 if wandb.run is not None:
249 onnx_artifact = wandb.Artifact(
250 name=f"segmentation-model-{model_name}",
251 type="model",
252 description=f"ONNX segmentation model ({model_name}, {num_classes} classes)",
253 metadata={
254 "model_name": model_name,
255 "num_classes": num_classes,
256 "image_size": image_size,
257 "format": "onnx",
258 "opset_version": 17,
259 },
260 )
261 onnx_artifact.add_file(str(onnx_path))
262 wandb.log_artifact(onnx_artifact)
263 print(f"ONNX model uploaded to W&B artifacts as 'segmentation-model-{model_name}'")
264 else:
265 print("Warning: W&B run not initialized. ONNX model not uploaded to artifacts.")
267 # =========================== INFERENCE LOOP =================================== #
268 print(f"Running inference on {len(test_loader)} samples...")
269 all_preds = []
270 all_targets = []
271 all_filenames = []
273 # Debug: check first batch input values
274 debug_first_batch = True
276 with torch.no_grad():
277 for batch in tqdm(test_loader):
278 if has_labels:
279 imgs, labels = batch
280 imgs = imgs.to(device)
281 labels = labels.to(device)
282 all_targets.append(labels)
283 else:
284 imgs, filenames = batch
285 imgs = imgs.to(device)
286 all_filenames.extend(filenames)
288 # Debug: print input statistics for first batch
289 if debug_first_batch:
290 print(f"\n[DEBUG] Input tensor stats:")
291 print(f" Shape: {imgs.shape}")
292 print(f" dtype: {imgs.dtype}")
293 print(f" Min: {imgs.min().item():.4f}, Max: {imgs.max().item():.4f}")
294 print(f" Mean: {imgs.mean().item():.4f}, Std: {imgs.std().item():.4f}")
296 # Check what SegmentationWrapper receives after rescale
297 if seg_model.input_rescale:
298 rescaled = imgs / 255.0
299 print(f"\n[DEBUG] After /255 rescale:")
300 print(f" Min: {rescaled.min().item():.4f}, Max: {rescaled.max().item():.4f}")
302 normalized = (rescaled - seg_model.mean) / seg_model.std
303 print(f"\n[DEBUG] After normalization:")
304 print(f" Min: {normalized.min().item():.4f}, Max: {normalized.max().item():.4f}")
306 debug_first_batch = False
308 masks = seg_model(imgs)
310 probs = masks[:, :, 0, 0]
312 all_preds.append(probs)
314 all_preds = torch.cat(all_preds, dim=0)
315 if has_labels:
316 all_targets = torch.cat(all_targets, dim=0)
318 # =========================== METRICS & LOGGING ================================ #
319 if has_labels:
320 metrics_per_experiment = count_metrics(all_targets, all_preds)
321 print("\n" + "=" * 60)
322 print("LOCAL TEST METRICS")
323 print("=" * 60)
324 for key, value in metrics_per_experiment.items():
325 print(f"{key}: {value:.6f}")
327 metrics_per_class = calculate_metrics_per_class(all_targets, all_preds)
328 print("\nPer-class metrics:")
329 reverse_label_map = {v: k for k, v in label_map.items()}
330 for class_id in sorted(metrics_per_class.keys()):
331 class_name = reverse_label_map.get(class_id, str(class_id))
332 class_metrics = metrics_per_class[class_id]
333 print(
334 f"- {class_name}: "
335 f"acc={class_metrics['accuracy']:.4f}, "
336 f"prec={class_metrics['precision']:.4f}, "
337 f"rec={class_metrics['recall']:.4f}, "
338 f"f1={class_metrics['f1']:.4f}, "
339 f"iou={class_metrics['IoU']:.4f}"
340 )
341 print("=" * 60)
343 names_and_labels = [[key, value] for key, value in label_map.items()]
344 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()]
345 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()]
346 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()]
347 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()]
348 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()]
349 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)]
351 plots_dir = Path("src/plots")
352 plots_dir.mkdir(exist_ok=True, parents=True)
353 get_confusion_matrix(all_preds, all_targets, class_names=list(label_map.keys()))
354 get_roc_auc_curve(all_preds, all_targets, class_names=list(label_map.keys()))
355 get_precision_recall_curve(all_preds, all_targets, class_names=list(label_map.keys()))
357 if wandb.run is not None:
358 for key, value in metrics_per_experiment.items():
359 wandb.log({key: value})
361 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics)
362 wandb.log({"Classes": training_table})
364 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"]
365 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"]
366 for filename, title in zip(filenames, titles, strict=False):
367 wandb.log({title: wandb.Image(f"src/plots/{filename}")})
368 else:
369 print("W&B run not active. Logged metrics locally only.")
370 else:
371 # For raw images, just output predictions
372 print("\n" + "=" * 60)
373 print("RAW IMAGE INFERENCE RESULTS")
374 print("=" * 60)
376 # Get predicted classes
377 pred_classes = torch.argmax(all_preds, dim=1)
379 # Create reverse label map
380 reverse_label_map = {v: k for k, v in label_map.items()}
382 # Save predictions to file
383 output_dir = Path("src/plots/raw_predictions")
384 output_dir.mkdir(exist_ok=True, parents=True)
386 results_file = output_dir / "predictions.txt"
387 with open(results_file, "w") as f:
388 f.write("Filename\tPredicted_Class\tPredicted_Label\tConfidences\n")
389 for i, (filename, pred_class, probs) in enumerate(zip(all_filenames, pred_classes, all_preds)):
390 pred_label = reverse_label_map[pred_class.item()]
391 confidence_str = "\t".join([f"{reverse_label_map[j]}:{probs[j].item():.4f}" for j in range(num_classes)])
392 f.write(f"{filename}\t{pred_class.item()}\t{pred_label}\t{confidence_str}\n")
394 # Print first 10 predictions
395 if i < 10:
396 print(f"{filename}: {pred_label} (confidence: {probs[pred_class].item():.4f})")
398 print(f"\n...")
399 print(f"\nFull predictions saved to: {results_file}")
400 print(f"Total images processed: {len(all_filenames)}")
401 print("=" * 60)
404if __name__ == "__main__":
405 main()