Coverage for src/dataset.py: 87%
122 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-08 16:37 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-04-08 16:37 +0000
1import os
2import yaml
3import psutil
4import numpy as np
5from PIL import Image
6from math import floor
7import pytorch_lightning as pl
8import torchvision.transforms as transforms
9from torch.utils.data import Dataset, DataLoader
10import random
13with open("src/config.yaml", "r") as c:
14 config = yaml.safe_load(c)
17def get_train_transform():
18 augmentation_cfg = config["training"].get("augmentation", {})
19 if not augmentation_cfg.get("enabled", True): 19 ↛ 20line 19 didn't jump to line 20 because the condition on line 19 was never true
20 return None
22 return transforms.Compose(
23 [
24 transforms.Resize((224, 224)),
25 transforms.RandomHorizontalFlip(p=augmentation_cfg.get("horizontal_flip_p", 0.5)),
26 transforms.RandomVerticalFlip(p=augmentation_cfg.get("vertical_flip_p", 0.2)),
27 transforms.RandomRotation(degrees=augmentation_cfg.get("rotation_deg", 15)),
28 transforms.ColorJitter(
29 brightness=augmentation_cfg.get("brightness", 0.15),
30 contrast=augmentation_cfg.get("contrast", 0.15),
31 saturation=augmentation_cfg.get("saturation", 0.1),
32 hue=augmentation_cfg.get("hue", 0.02),
33 ),
34 transforms.ToTensor(),
35 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
36 ]
37 )
40def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8):
41 """
42 Function calculates the number of workers and prefetch factor
43 for DataLoader based on the available RAM.
45 Input:
46 batch_size: int - the batch size used in DataLoader
47 img_size: int - the size of the image
48 image_channels: int - the number of channels in the image
49 precision: int - the precision of the weights
50 ram_fraction: float - the fraction of RAM to use
51 Output:
52 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers
53 num_workers: int - the number of workers
54 prefetch_factor: int - the prefetch factor
55 pin_memory: bool - whether to use pin_memory
56 persistent_workers: bool - whether to use persistent workers
57 """
58 if config["training"]["dataloader"]["auto"]: 58 ↛ 79line 58 didn't jump to line 79 because the condition on line 58 was always true
59 total_ram = psutil.virtual_memory().available * ram_fraction
60 img_memory = np.prod(img_size) * image_channels * (precision / 8)
61 batch_memory = batch_size * img_memory
63 if batch_memory > total_ram: 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true
64 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.")
66 max_batches_in_ram = floor(total_ram / batch_memory)
68 prefetch_factor = min(max_batches_in_ram, 16)
69 num_workers = min(floor(prefetch_factor / 2), os.cpu_count())
71 params = {
72 "num_workers": num_workers,
73 "prefetch_factor": prefetch_factor,
74 "pin_memory": str(config.get("device", "")).startswith("cuda") or config.get("device") == "gpu",
75 "persistent_workers": True,
76 }
78 else:
79 params = {
80 "num_workers": config["training"]["dataloader"]["num_workers"],
81 "prefetch_factor": config["training"]["dataloader"]["prefetch_factor"],
82 "pin_memory": config["training"]["dataloader"]["pin_memory"],
83 "persistent_workers": config["training"]["dataloader"]["persistent_workers"],
84 }
86 return params
89class ForestDataset(Dataset):
90 def __init__(self, image_paths, labels, transform=None):
91 self.image_paths = image_paths
92 self.labels = labels
94 if transform is None:
95 self.transform = transforms.Compose(
96 [
97 transforms.Resize((224, 224)),
98 transforms.ToTensor(),
99 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
100 ]
101 )
102 else:
103 self.transform = transform
105 def __len__(self):
106 return len(self.image_paths)
108 def __getitem__(self, idx):
109 image_path = self.image_paths[idx]
110 label = self.labels[idx]
112 with Image.open(image_path) as img:
113 # Remove "near-infrared" channel if present (4-channel RGBA)
114 if img.mode == "RGBA": 114 ↛ 116line 114 didn't jump to line 116 because the condition on line 114 was never true
115 # Convert RGBA to RGB (drops alpha/near-infrared channel)
116 image = img.convert("RGB")
117 else:
118 # Keep as PIL Image for transforms
119 image = img.copy()
121 # Apply transformations (expects PIL Image)
122 if self.transform: 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true
123 image = self.transform(image)
125 return image, label
128class UndersampledDataset(ForestDataset):
129 def __init__(self, image_paths, labels, transform=None, target_size=None):
130 super().__init__(image_paths, labels, transform)
132 class_indices = {}
133 for idx, label in enumerate(labels):
134 class_indices.setdefault(label, []).append(idx)
136 # Find the minimum number of samples in a class
137 min_count = min(len(indices) for indices in class_indices.values())
139 # If the target_size is not provided, set it to the minimum count
140 target_size = target_size if target_size else min_count
142 self.sampled_indices = []
143 for indices in class_indices.values():
144 # Limit the number of images per class to target_size (if it exceeds the target_size)
145 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices))))
147 def __len__(self):
148 return len(self.sampled_indices)
150 def __getitem__(self, idx):
151 return super().__getitem__(self.sampled_indices[idx])
154class OversampledDataset(ForestDataset):
155 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2, oversample_threshold=200):
156 super().__init__(image_paths, labels, transform)
157 self.minority_transform = minority_transform
159 class_indices = {}
160 for idx, label in enumerate(labels):
161 class_indices.setdefault(label, []).append(idx)
163 self.to_transform = set()
164 self.sampled_indices = []
165 for label, indices in class_indices.items():
166 if len(indices) <= oversample_threshold:
167 self.to_transform.add(label)
168 # Sampling the minority class with replacement
169 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices))))
170 else:
171 self.sampled_indices.extend(indices)
173 def __len__(self):
174 return len(self.sampled_indices)
176 def __getitem__(self, idx):
177 image, label = super().__getitem__(self.sampled_indices[idx])
178 if label in self.to_transform and self.minority_transform: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true
179 image = self.minority_transform(image)
180 return image, label
183class CurriculumLearningDataset(ForestDataset):
184 def __init__(self, image_paths, labels, indices, transform=None):
185 super().__init__(image_paths, labels, transform)
186 self.indices = indices
188 def __len__(self):
189 return len(self.indices)
191 def __getitem__(self, idx):
192 return super().__getitem__(self.indices[idx])
195class ForestDataModule(pl.LightningDataModule):
196 def __init__(self, train_data, val_data, test_data, dataset, dataset_args=None, batch_size=32):
197 if dataset_args is None: 197 ↛ 200line 197 didn't jump to line 200 because the condition on line 197 was always true
198 dataset_args = {}
200 super().__init__()
201 self.test_dataset = None
202 self.train_dataset = None
203 self.val_dataset = None
204 self.train_data = train_data
205 self.val_data = val_data
206 self.test_data = test_data
207 self.dataset = dataset
208 self.dataset_args = dataset_args
209 self.batch_size = batch_size
210 self.params = calculate_dataloader_params(batch_size)
212 def setup(self, stage=None):
213 train_transform = get_train_transform()
214 self.train_dataset = self.dataset(image_paths=self.train_data["paths"], labels=self.train_data["labels"], transform=train_transform, **self.dataset_args)
215 self.val_dataset = ForestDataset(image_paths=self.val_data["paths"], labels=self.val_data["labels"])
216 self.test_dataset = ForestDataset(image_paths=self.test_data["paths"], labels=self.test_data["labels"])
218 def train_dataloader(self):
219 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params)
221 def val_dataloader(self):
222 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params)
224 def test_dataloader(self):
225 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params)
228if __name__ == "__main__": 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true
229 params = calculate_dataloader_params(32)
230 print(params)