Coverage for src/inference.py: 0%
138 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-07 20:17 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-07 20:17 +0000
1import os
2from pathlib import Path
4import torch
5import wandb
6from omegaconf import OmegaConf
7from tqdm import tqdm
8from torchvision.utils import save_image
10from models.classifier_module import ClassifierModule
11from models.segmentation_wrapper import SegmentationWrapper
12from dataset_functions import load_dataset
13from dataset import ForestDataset
14from transforms import Transforms
15from counting_functions import calculate_metrics_per_class, count_metrics
16from visualization_functions import get_confusion_matrix, get_precision_recall_curve, get_roc_auc_curve
17import onnx
18import json
19import kornia.augmentation as kaug
20from kornia import image_to_tensor
21import torch.nn as nn
24class InferenceTransform(nn.Module):
25 def __init__(self, size):
26 super().__init__()
27 self.resize = kaug.Resize(size=size, keepdim=True)
29 @torch.no_grad()
30 def forward(self, x):
31 # x is numpy array (H, W, C)
32 x_t = image_to_tensor(x, keepdim=True).float()
33 x_res = self.resize(x_t)
34 return x_res.squeeze(0)
37def download_checkpoint_from_wandb(artifact_path, project_name="ghost-irim"):
38 print(f"Downloading checkpoint from W&B: {artifact_path}")
40 wandb_api_key = os.environ.get("WANDB_API_KEY")
41 if wandb_api_key:
42 wandb.login(key=wandb_api_key)
44 run = wandb.init(project=project_name, job_type="inference")
46 artifact = run.use_artifact(artifact_path, type="model")
47 artifact_dir = artifact.download()
49 artifact_path_obj = Path(artifact_dir)
50 checkpoint_files = list(artifact_path_obj.glob("*.ckpt"))
52 if not checkpoint_files:
53 raise FileNotFoundError(f"No .ckpt file found in artifact directory: {artifact_dir}")
55 checkpoint_path = checkpoint_files[0]
56 print(f"Checkpoint downloaded to: {checkpoint_path}")
58 return checkpoint_path
61def main():
62 # =========================== CONFIG & SETUP ================================== #
63 config = OmegaConf.load("src/config.yaml")
65 config_device = config.device
66 if config_device in ["gpu", "cuda"]:
67 device = "cuda" if torch.cuda.is_available() else "cpu"
68 else:
69 device = "cpu"
70 print(f"Using device: {device}")
72 model_name = config.model.name
73 mask_size = config.inference.get("mask_size", 224)
74 image_size = 299 if model_name == "inception_v3" else 224
75 transforms = InferenceTransform(size=(image_size, image_size))
77 # =========================== DATA LOADING ===================================== #
78 dataset_folder = Path.cwd() / config.dataset.folder
79 dataset_folder.mkdir(exist_ok=True)
81 dataset, label_map = load_dataset(dataset_folder, config.dataset.species_folders)
83 test_data = dataset["test"]
84 test_dataset = ForestDataset(test_data["paths"], test_data["labels"], transform=transforms)
86 test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=2)
88 num_classes = len(label_map)
90 # =========================== MODEL LOADING ==================================== #
91 wandb_artifact = config.inference.get("wandb_artifact", None)
93 if wandb_artifact:
94 wandb_project = config.inference.get("wandb_project", "ghost-irim")
95 checkpoint_path = download_checkpoint_from_wandb(wandb_artifact, wandb_project)
96 else:
97 raise FileNotFoundError(f"Checkpoint not found at {checkpoint_path}. Please set 'wandb_artifact' in config.yaml to download from W&B, or ensure the local checkpoint exists.")
99 print(f"Loading model from: {checkpoint_path}")
101 classifier = ClassifierModule.load_from_checkpoint(
102 checkpoint_path,
103 model_name=model_name,
104 num_classes=num_classes,
105 )
106 classifier = classifier.to(device).eval()
108 norm_mean = [0.5, 0.5, 0.5]
109 norm_std = [0.5, 0.5, 0.5]
111 seg_model = SegmentationWrapper(
112 classifier,
113 mask_size=mask_size,
114 mean=None, # TODO: fix
115 std=None, # TODO: fix
116 input_rescale=True, # Expects 0-255 input, scales to 0-1 internally
117 ).to(device)
118 seg_model.eval()
120 # =========================== EXPORT TO ONNX =================================== #
121 if config.inference.get("export_onnx", False):
122 dummy_input = torch.randn(1, 3, image_size, image_size, device=device)
123 onnx_path = Path("segmentation_model.onnx")
124 torch.onnx.export(
125 seg_model,
126 dummy_input,
127 onnx_path,
128 input_names=["input"],
129 output_names=["mask"],
130 opset_version=17,
131 dynamic_axes={"input": {0: "batch_size"}, "mask": {0: "batch_size"}},
132 do_constant_folding=True,
133 )
134 print(f"Exported model to {onnx_path.resolve()}")
136 # Add metadata
137 model_onnx = onnx.load(onnx_path)
139 class_names = {v: k for k, v in label_map.items()}
141 def add_meta(key, value):
142 meta = model_onnx.metadata_props.add()
143 meta.key = key
144 meta.value = json.dumps(value)
146 add_meta("model_type", "Segmentor")
147 add_meta("class_names", class_names)
148 add_meta("resolution", 20)
149 add_meta("tiles_size", image_size)
150 add_meta("tiles_overlap", 0)
152 onnx.save(model_onnx, onnx_path)
154 if wandb.run is not None:
155 onnx_artifact = wandb.Artifact(
156 name=f"segmentation-model-{model_name}",
157 type="model",
158 description=f"ONNX segmentation model ({model_name}, {num_classes} classes)",
159 metadata={
160 "model_name": model_name,
161 "num_classes": num_classes,
162 "image_size": image_size,
163 "format": "onnx",
164 "opset_version": 17,
165 },
166 )
167 onnx_artifact.add_file(str(onnx_path))
168 wandb.log_artifact(onnx_artifact)
169 print(f"ONNX model uploaded to W&B artifacts as 'segmentation-model-{model_name}'")
170 else:
171 print("Warning: W&B run not initialized. ONNX model not uploaded to artifacts.")
173 # =========================== INFERENCE LOOP =================================== #
174 print(f"Running inference on {len(test_loader)} samples...")
175 all_preds = []
176 all_targets = []
178 with torch.no_grad():
179 for batch in tqdm(test_loader):
180 imgs, labels = batch
181 imgs = imgs.to(device)
182 labels = labels.to(device)
184 masks = seg_model(imgs)
186 probs = masks[:, :, 0, 0]
188 all_preds.append(probs)
189 all_targets.append(labels)
191 all_preds = torch.cat(all_preds, dim=0)
192 all_targets = torch.cat(all_targets, dim=0)
194 # =========================== METRICS & LOGGING ================================ #
195 if wandb.run is not None:
196 print("Calculating and logging metrics...")
198 metrics_per_experiment = count_metrics(all_targets, all_preds)
199 print(f"Test Metrics: {metrics_per_experiment}")
200 for key, value in metrics_per_experiment.items():
201 wandb.log({key: value})
203 metrics_per_class = calculate_metrics_per_class(all_targets, all_preds)
204 accs = [metrics_per_class[key]["accuracy"] for key in metrics_per_class.keys()]
205 precs = [metrics_per_class[key]["precision"] for key in metrics_per_class.keys()]
206 recs = [metrics_per_class[key]["recall"] for key in metrics_per_class.keys()]
207 f1s = [metrics_per_class[key]["f1"] for key in metrics_per_class.keys()]
208 ious = [metrics_per_class[key]["IoU"] for key in metrics_per_class.keys()]
209 names_and_labels = [[key, value] for key, value in label_map.items()]
210 logged_metrics = [[name, label, acc, prec, rec, f1, iou] for [name, label], acc, prec, rec, f1, iou in zip(names_and_labels, accs, precs, recs, f1s, ious, strict=False)]
212 training_table = wandb.Table(columns=["Class name", "Label", "Accuracy", "Precision", "Recall", "F1-score", "IoU"], data=logged_metrics)
213 wandb.log({"Classes": training_table})
215 plots_dir = Path("src/plots")
216 plots_dir.mkdir(exist_ok=True, parents=True)
218 get_confusion_matrix(all_preds, all_targets, class_names=list(label_map.keys()))
219 get_roc_auc_curve(all_preds, all_targets, class_names=list(label_map.keys()))
220 get_precision_recall_curve(all_preds, all_targets, class_names=list(label_map.keys()))
222 filenames = ["confusion_matrix.png", "precision_recall_curve.png", "roc_auc_curve.png"]
223 titles = ["Confusion Matrix", "Precision-Recall Curve", "ROC AUC Curve"]
224 for filename, title in zip(filenames, titles, strict=False):
225 wandb.log({title: wandb.Image(f"src/plots/{filename}")})
226 else:
227 print("W&B run not active. Skipping metrics logging.")
230if __name__ == "__main__":
231 main()