Coverage for src/dataset.py: 88%
116 statements
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-06 01:30 +0000
« prev ^ index » next coverage.py v7.11.0, created at 2026-01-06 01:30 +0000
1import os
2import yaml
3import psutil
4import numpy as np
5from PIL import Image
6from math import floor
7import pytorch_lightning as pl
8import torchvision.transforms as transforms
9from torch.utils.data import Dataset, DataLoader
10import random
13with open("src/config.yaml", "r") as c:
14 config = yaml.safe_load(c)
17def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8):
18 """
19 Function calculates the number of workers and prefetch factor
20 for DataLoader based on the available RAM.
22 Input:
23 batch_size: int - the batch size used in DataLoader
24 img_size: int - the size of the image
25 image_channels: int - the number of channels in the image
26 precision: int - the precision of the weights
27 ram_fraction: float - the fraction of RAM to use
28 Output:
29 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers
30 num_workers: int - the number of workers
31 prefetch_factor: int - the prefetch factor
32 pin_memory: bool - whether to use pin_memory
33 persistent_workers: bool - whether to use persistent workers
34 """
35 if config["training"]["dataloader"]["auto"]: 35 ↛ 51line 35 didn't jump to line 51 because the condition on line 35 was always true
36 total_ram = psutil.virtual_memory().available * ram_fraction
37 img_memory = np.prod(img_size) * image_channels * (precision / 8)
38 batch_memory = batch_size * img_memory
40 if batch_memory > total_ram: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true
41 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.")
43 max_batches_in_ram = floor(total_ram / batch_memory)
45 prefetch_factor = min(max_batches_in_ram, 16)
46 num_workers = min(floor(prefetch_factor / 2), os.cpu_count())
48 params = {"num_workers": num_workers, "prefetch_factor": prefetch_factor, "pin_memory": config["device"] == "gpu", "persistent_workers": True}
50 else:
51 params = {
52 "num_workers": config["training"]["dataloader"]["num_workers"],
53 "prefetch_factor": config["training"]["dataloader"]["prefetch_factor"],
54 "pin_memory": config["training"]["dataloader"]["pin_memory"],
55 "persistent_workers": config["training"]["dataloader"]["persistent_workers"],
56 }
58 return params
61class ForestDataset(Dataset):
62 def __init__(self, image_paths, labels, transform=None):
63 self.image_paths = image_paths
64 self.labels = labels
66 if transform is None:
67 self.transform = transforms.Compose(
68 [
69 transforms.Resize((224, 224)),
70 transforms.ToTensor(),
71 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
72 ]
73 )
74 else:
75 self.transform = transform
77 def __len__(self):
78 return len(self.image_paths)
80 def __getitem__(self, idx):
81 image_path = self.image_paths[idx]
82 label = self.labels[idx]
84 with Image.open(image_path) as img:
85 # Remove "near-infrared" channel if present (4-channel RGBA)
86 if img.mode == "RGBA": 86 ↛ 88line 86 didn't jump to line 88 because the condition on line 86 was never true
87 # Convert RGBA to RGB (drops alpha/near-infrared channel)
88 image = img.convert("RGB")
89 else:
90 # Keep as PIL Image for transforms
91 image = img.copy()
93 # Apply transformations (expects PIL Image)
94 if self.transform: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true
95 image = self.transform(image)
97 return image, label
100class UndersampledDataset(ForestDataset):
101 def __init__(self, image_paths, labels, transform=None, target_size=None):
102 super().__init__(image_paths, labels, transform)
104 class_indices = {}
105 for idx, label in enumerate(labels):
106 class_indices.setdefault(label, []).append(idx)
108 # Find the minimum number of samples in a class
109 min_count = min(len(indices) for indices in class_indices.values())
111 # If the target_size is not provided, set it to the minimum count
112 target_size = target_size if target_size else min_count
114 self.sampled_indices = []
115 for indices in class_indices.values():
116 # Limit the number of images per class to target_size (if it exceeds the target_size)
117 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices))))
119 def __len__(self):
120 return len(self.sampled_indices)
122 def __getitem__(self, idx):
123 return super().__getitem__(self.sampled_indices[idx])
126class OversampledDataset(ForestDataset):
127 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2, oversample_threshold=200):
128 super().__init__(image_paths, labels, transform)
129 self.minority_transform = minority_transform
131 class_indices = {}
132 for idx, label in enumerate(labels):
133 class_indices.setdefault(label, []).append(idx)
135 self.to_transform = set()
136 self.sampled_indices = []
137 for label, indices in class_indices.items():
138 if len(indices) <= oversample_threshold:
139 self.to_transform.add(label)
140 # Sampling the minority class with replacement
141 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices))))
142 else:
143 self.sampled_indices.extend(indices)
145 def __len__(self):
146 return len(self.sampled_indices)
148 def __getitem__(self, idx):
149 image, label = super().__getitem__(self.sampled_indices[idx])
150 if label in self.to_transform and self.minority_transform: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true
151 image = self.minority_transform(image)
152 return image, label
155class CurriculumLearningDataset(ForestDataset):
156 def __init__(self, image_paths, labels, indices, transform=None):
157 super().__init__(image_paths, labels, transform)
158 self.indices = indices
160 def __len__(self):
161 return len(self.indices)
163 def __getitem__(self, idx):
164 return super().__getitem__(self.indices[idx])
167class ForestDataModule(pl.LightningDataModule):
168 def __init__(self, train_data, val_data, test_data, dataset, dataset_args=None, batch_size=32):
169 if dataset_args is None: 169 ↛ 172line 169 didn't jump to line 172 because the condition on line 169 was always true
170 dataset_args = {}
172 super().__init__()
173 self.test_dataset = None
174 self.train_dataset = None
175 self.val_dataset = None
176 self.train_data = train_data
177 self.val_data = val_data
178 self.test_data = test_data
179 self.dataset = dataset
180 self.dataset_args = dataset_args
181 self.batch_size = batch_size
182 self.params = calculate_dataloader_params(batch_size)
184 def setup(self, stage=None):
185 self.train_dataset = self.dataset(image_paths=self.train_data["paths"], labels=self.train_data["labels"], **self.dataset_args)
186 self.val_dataset = ForestDataset(image_paths=self.val_data["paths"], labels=self.val_data["labels"])
187 self.test_dataset = ForestDataset(image_paths=self.test_data["paths"], labels=self.test_data["labels"])
189 def train_dataloader(self):
190 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params)
192 def val_dataloader(self):
193 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params)
195 def test_dataloader(self):
196 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params)
199if __name__ == "__main__": 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true
200 params = calculate_dataloader_params(32)
201 print(params)