Coverage for src/dataset.py: 92%
104 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
1import os
2import yaml
3import psutil
4import numpy as np
5from PIL import Image
6from math import floor
7import pytorch_lightning as pl
8from transforms import Preprocess
9import torchvision.transforms as transforms
10from torch.utils.data import Dataset, DataLoader
11import random
14with open("src/config.yaml", "r") as c:
15 config = yaml.safe_load(c)
18def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8):
19 """
20 Function calculates the number of workers and prefetch factor
21 for DataLoader based on the available RAM.
23 Input:
24 batch_size: int - the batch size used in DataLoader
25 img_size: int - the size of the image
26 image_channels: int - the number of channels in the image
27 precision: int - the precision of the weights
28 ram_fraction: float - the fraction of RAM to use
29 Output:
30 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers
31 num_workers: int - the number of workers
32 prefetch_factor: int - the prefetch factor
33 pin_memory: bool - whether to use pin_memory
34 persistent_workers: bool - whether to use persistent workers
35 """
36 if config['training']['dataloader']['auto']: 36 ↛ 55line 36 didn't jump to line 55 because the condition on line 36 was always true
37 total_ram = psutil.virtual_memory().available * ram_fraction
38 img_memory = np.prod(img_size) * image_channels * (precision/8)
39 batch_memory = batch_size * img_memory
41 if batch_memory > total_ram: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true
42 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.")
44 max_batches_in_ram = floor(total_ram / batch_memory)
46 prefetch_factor = min(max_batches_in_ram, 16)
47 num_workers = min(floor(prefetch_factor / 2), os.cpu_count())
49 params = {"num_workers": num_workers,
50 "prefetch_factor": prefetch_factor,
51 "pin_memory": config['device'] == 'gpu',
52 "persistent_workers": True}
54 else:
55 params = {"num_workers": config['training']['dataloader']['num_workers'],
56 "prefetch_factor": config['training']['dataloader']['prefetch_factor'],
57 "pin_memory": config['training']['dataloader']['pin_memory'],
58 "persistent_workers": config['training']['dataloader']['persistent_workers']}
60 return params
63class ForestDataset(Dataset):
64 def __init__(self, image_paths, labels, transform=None):
65 self.image_paths = image_paths
66 self.labels = labels
67 # Define a default transform if none is provided
68 # TODO: Use transforms suitable for the model
69 self.transform = transform or transforms.Compose(
70 [
71 transforms.ToTensor(),
72 transforms.Normalize(
73 mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
74 ), # Adjust as needed for RGB channels
75 ]
76 )
78 def __len__(self):
79 return len(self.image_paths)
81 def __getitem__(self, idx):
82 # TODO: Load an image from path here
83 image_path = self.image_paths[idx]
84 label = self.labels[idx]
86 with Image.open(image_path) as img:
87 # Convert to numpy array
88 image = np.array(img)
89 image = image[:, :, 1:] if image.shape[-1] == 4 else image # Removing "near-inferred" channel
90 # We found out that PIL conversion to RGB
91 # keeps the "near-inferred" channel which was not desired
93 # Apply transformations
94 if self.transform: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true
95 image = self.transform(image)
97 return image, label
100class UndersampledDataset(ForestDataset):
101 def __init__(self, image_paths, labels, transform=None, target_size=None):
102 super().__init__(image_paths, labels, transform)
104 class_indices = {}
105 for idx, label in enumerate(labels):
106 class_indices.setdefault(label, []).append(idx)
108 # Find the minimum number of samples in a class
109 min_count = min(len(indices) for indices in class_indices.values())
111 # If the target_size is not provided, set it to the minimum count
112 target_size = target_size if target_size else min_count
114 self.sampled_indices = []
115 for indices in class_indices.values():
116 # Limit the number of images per class to target_size (if it exceeds the target_size)
117 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices))))
119 def __len__(self):
120 return len(self.sampled_indices)
122 def __getitem__(self, idx):
123 return super().__getitem__(self.sampled_indices[idx])
126class OversampledDataset(ForestDataset):
127 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2,
128 oversample_threshold=200):
129 super().__init__(image_paths, labels, transform)
130 self.minority_transform = minority_transform
132 class_indices = {}
133 for idx, label in enumerate(labels):
134 class_indices.setdefault(label, []).append(idx)
136 self.to_transform = set()
137 self.sampled_indices = []
138 for label, indices in class_indices.items():
139 if len(indices) <= oversample_threshold:
140 self.to_transform.add(label)
141 # Sampling the minority class with replacement
142 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices))))
143 else:
144 self.sampled_indices.extend(indices)
146 def __len__(self):
147 return len(self.sampled_indices)
149 def __getitem__(self, idx):
150 image, label = super().__getitem__(self.sampled_indices[idx])
151 if label in self.to_transform and self.minority_transform: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true
152 image = self.minority_transform(image)
153 return image, label
156class ForestDataModule(pl.LightningDataModule):
157 def __init__(self, train_data, val_data, test_data, dataset, dataset_args={}, batch_size=32):
158 super().__init__()
159 self.test_dataset = None
160 self.train_dataset = None
161 self.val_dataset = None
162 self.train_data = train_data
163 self.val_data = val_data
164 self.test_data = test_data
165 self.dataset = dataset
166 self.dataset_args = dataset_args
167 self.batch_size = batch_size
168 self.params = calculate_dataloader_params(batch_size)
170 def setup(self, stage=None):
171 self.train_dataset = self.dataset(
172 image_paths=self.train_data["paths"],
173 labels=self.train_data["labels"],
174 transform=Preprocess(),
175 **self.dataset_args
176 )
177 self.val_dataset = ForestDataset(
178 image_paths=self.val_data["paths"],
179 labels=self.val_data["labels"],
180 transform=Preprocess()
181 )
182 self.test_dataset = ForestDataset(
183 image_paths=self.test_data["paths"],
184 labels=self.test_data["labels"],
185 transform=Preprocess()
186 )
188 def train_dataloader(self):
189 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params)
191 def val_dataloader(self):
192 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params)
194 def test_dataloader(self):
195 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params)
198if __name__ == '__main__': 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true
199 params = calculate_dataloader_params(32)
200 print(params)