Coverage for src/dataset.py: 89%
114 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
1import os
2import yaml
3import psutil
4import numpy as np
5from PIL import Image
6from math import floor
7import pytorch_lightning as pl
8from transforms import Preprocess
9import torchvision.transforms as transforms
10from torch.utils.data import Dataset, DataLoader
11import random
14with open("src/config.yaml", "r") as c:
15 config = yaml.safe_load(c)
18def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8):
19 """
20 Function calculates the number of workers and prefetch factor
21 for DataLoader based on the available RAM.
23 Input:
24 batch_size: int - the batch size used in DataLoader
25 img_size: int - the size of the image
26 image_channels: int - the number of channels in the image
27 precision: int - the precision of the weights
28 ram_fraction: float - the fraction of RAM to use
29 Output:
30 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers
31 num_workers: int - the number of workers
32 prefetch_factor: int - the prefetch factor
33 pin_memory: bool - whether to use pin_memory
34 persistent_workers: bool - whether to use persistent workers
35 """
36 if config["training"]["dataloader"]["auto"]: 36 ↛ 52line 36 didn't jump to line 52 because the condition on line 36 was always true
37 total_ram = psutil.virtual_memory().available * ram_fraction
38 img_memory = np.prod(img_size) * image_channels * (precision / 8)
39 batch_memory = batch_size * img_memory
41 if batch_memory > total_ram: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true
42 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.")
44 max_batches_in_ram = floor(total_ram / batch_memory)
46 prefetch_factor = min(max_batches_in_ram, 16)
47 num_workers = min(floor(prefetch_factor / 2), os.cpu_count())
49 params = {"num_workers": num_workers, "prefetch_factor": prefetch_factor, "pin_memory": config["device"] == "gpu", "persistent_workers": True}
51 else:
52 params = {
53 "num_workers": config["training"]["dataloader"]["num_workers"],
54 "prefetch_factor": config["training"]["dataloader"]["prefetch_factor"],
55 "pin_memory": config["training"]["dataloader"]["pin_memory"],
56 "persistent_workers": config["training"]["dataloader"]["persistent_workers"],
57 }
59 return params
62class ForestDataset(Dataset):
63 def __init__(self, image_paths, labels, transform=None):
64 self.image_paths = image_paths
65 self.labels = labels
66 # Define a default transform if none is provided
67 # TODO: Use transforms suitable for the model
68 self.transform = transform or transforms.Compose(
69 [
70 transforms.ToTensor(),
71 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Adjust as needed for RGB channels
72 ]
73 )
75 def __len__(self):
76 return len(self.image_paths)
78 def __getitem__(self, idx):
79 # TODO: Load an image from path here
80 image_path = self.image_paths[idx]
81 label = self.labels[idx]
83 with Image.open(image_path) as img:
84 # Convert to numpy array
85 image = np.array(img)
86 image = image[:, :, 1:] if image.shape[-1] == 4 else image # Removing "near-inferred" channel
87 # We found out that PIL conversion to RGB
88 # keeps the "near-inferred" channel which was not desired
90 # Apply transformations
91 if self.transform: 91 ↛ 94line 91 didn't jump to line 94 because the condition on line 91 was always true
92 image = self.transform(image)
94 return image, label
97class UndersampledDataset(ForestDataset):
98 def __init__(self, image_paths, labels, transform=None, target_size=None):
99 super().__init__(image_paths, labels, transform)
101 class_indices = {}
102 for idx, label in enumerate(labels):
103 class_indices.setdefault(label, []).append(idx)
105 # Find the minimum number of samples in a class
106 min_count = min(len(indices) for indices in class_indices.values())
108 # If the target_size is not provided, set it to the minimum count
109 target_size = target_size if target_size else min_count
111 self.sampled_indices = []
112 for indices in class_indices.values():
113 # Limit the number of images per class to target_size (if it exceeds the target_size)
114 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices))))
116 def __len__(self):
117 return len(self.sampled_indices)
119 def __getitem__(self, idx):
120 return super().__getitem__(self.sampled_indices[idx])
123class OversampledDataset(ForestDataset):
124 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2, oversample_threshold=200):
125 super().__init__(image_paths, labels, transform)
126 self.minority_transform = minority_transform
128 class_indices = {}
129 for idx, label in enumerate(labels):
130 class_indices.setdefault(label, []).append(idx)
132 self.to_transform = set()
133 self.sampled_indices = []
134 for label, indices in class_indices.items():
135 if len(indices) <= oversample_threshold:
136 self.to_transform.add(label)
137 # Sampling the minority class with replacement
138 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices))))
139 else:
140 self.sampled_indices.extend(indices)
142 def __len__(self):
143 return len(self.sampled_indices)
145 def __getitem__(self, idx):
146 image, label = super().__getitem__(self.sampled_indices[idx])
147 if label in self.to_transform and self.minority_transform: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 image = self.minority_transform(image)
149 return image, label
152class CurriculumLearningDataset(ForestDataset):
153 def __init__(self, image_paths, labels, indices, transform=None):
154 super().__init__(image_paths, labels, transform)
155 self.indices = indices
157 def __len__(self):
158 return len(self.indices)
160 def __getitem__(self, idx):
161 return super().__getitem__(self.indices[idx])
164class ForestDataModule(pl.LightningDataModule):
165 def __init__(self, train_data, val_data, test_data, dataset, dataset_args=None, batch_size=32):
166 if dataset_args is None: 166 ↛ 169line 166 didn't jump to line 169 because the condition on line 166 was always true
167 dataset_args = {}
169 super().__init__()
170 self.test_dataset = None
171 self.train_dataset = None
172 self.val_dataset = None
173 self.train_data = train_data
174 self.val_data = val_data
175 self.test_data = test_data
176 self.dataset = dataset
177 self.dataset_args = dataset_args
178 self.batch_size = batch_size
179 self.params = calculate_dataloader_params(batch_size)
181 def setup(self, stage=None):
182 self.train_dataset = self.dataset(image_paths=self.train_data["paths"], labels=self.train_data["labels"], transform=Preprocess(), **self.dataset_args)
183 self.val_dataset = ForestDataset(image_paths=self.val_data["paths"], labels=self.val_data["labels"], transform=Preprocess())
184 self.test_dataset = ForestDataset(image_paths=self.test_data["paths"], labels=self.test_data["labels"], transform=Preprocess())
186 def train_dataloader(self):
187 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params)
189 def val_dataloader(self):
190 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params)
192 def test_dataloader(self):
193 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params)
196if __name__ == "__main__": 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 params = calculate_dataloader_params(32)
198 print(params)