Coverage for src/dataset.py: 87%

122 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-04-08 16:37 +0000

1import os 

2import yaml 

3import psutil 

4import numpy as np 

5from PIL import Image 

6from math import floor 

7import pytorch_lightning as pl 

8import torchvision.transforms as transforms 

9from torch.utils.data import Dataset, DataLoader 

10import random 

11 

12 

13with open("src/config.yaml", "r") as c: 

14 config = yaml.safe_load(c) 

15 

16 

17def get_train_transform(): 

18 augmentation_cfg = config["training"].get("augmentation", {}) 

19 if not augmentation_cfg.get("enabled", True): 19 ↛ 20line 19 didn't jump to line 20 because the condition on line 19 was never true

20 return None 

21 

22 return transforms.Compose( 

23 [ 

24 transforms.Resize((224, 224)), 

25 transforms.RandomHorizontalFlip(p=augmentation_cfg.get("horizontal_flip_p", 0.5)), 

26 transforms.RandomVerticalFlip(p=augmentation_cfg.get("vertical_flip_p", 0.2)), 

27 transforms.RandomRotation(degrees=augmentation_cfg.get("rotation_deg", 15)), 

28 transforms.ColorJitter( 

29 brightness=augmentation_cfg.get("brightness", 0.15), 

30 contrast=augmentation_cfg.get("contrast", 0.15), 

31 saturation=augmentation_cfg.get("saturation", 0.1), 

32 hue=augmentation_cfg.get("hue", 0.02), 

33 ), 

34 transforms.ToTensor(), 

35 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 

36 ] 

37 ) 

38 

39 

40def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8): 

41 """ 

42 Function calculates the number of workers and prefetch factor 

43 for DataLoader based on the available RAM. 

44 

45 Input: 

46 batch_size: int - the batch size used in DataLoader 

47 img_size: int - the size of the image 

48 image_channels: int - the number of channels in the image 

49 precision: int - the precision of the weights 

50 ram_fraction: float - the fraction of RAM to use 

51 Output: 

52 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers 

53 num_workers: int - the number of workers 

54 prefetch_factor: int - the prefetch factor 

55 pin_memory: bool - whether to use pin_memory 

56 persistent_workers: bool - whether to use persistent workers 

57 """ 

58 if config["training"]["dataloader"]["auto"]: 58 ↛ 79line 58 didn't jump to line 79 because the condition on line 58 was always true

59 total_ram = psutil.virtual_memory().available * ram_fraction 

60 img_memory = np.prod(img_size) * image_channels * (precision / 8) 

61 batch_memory = batch_size * img_memory 

62 

63 if batch_memory > total_ram: 63 ↛ 64line 63 didn't jump to line 64 because the condition on line 63 was never true

64 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.") 

65 

66 max_batches_in_ram = floor(total_ram / batch_memory) 

67 

68 prefetch_factor = min(max_batches_in_ram, 16) 

69 num_workers = min(floor(prefetch_factor / 2), os.cpu_count()) 

70 

71 params = { 

72 "num_workers": num_workers, 

73 "prefetch_factor": prefetch_factor, 

74 "pin_memory": str(config.get("device", "")).startswith("cuda") or config.get("device") == "gpu", 

75 "persistent_workers": True, 

76 } 

77 

78 else: 

79 params = { 

80 "num_workers": config["training"]["dataloader"]["num_workers"], 

81 "prefetch_factor": config["training"]["dataloader"]["prefetch_factor"], 

82 "pin_memory": config["training"]["dataloader"]["pin_memory"], 

83 "persistent_workers": config["training"]["dataloader"]["persistent_workers"], 

84 } 

85 

86 return params 

87 

88 

89class ForestDataset(Dataset): 

90 def __init__(self, image_paths, labels, transform=None): 

91 self.image_paths = image_paths 

92 self.labels = labels 

93 

94 if transform is None: 

95 self.transform = transforms.Compose( 

96 [ 

97 transforms.Resize((224, 224)), 

98 transforms.ToTensor(), 

99 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 

100 ] 

101 ) 

102 else: 

103 self.transform = transform 

104 

105 def __len__(self): 

106 return len(self.image_paths) 

107 

108 def __getitem__(self, idx): 

109 image_path = self.image_paths[idx] 

110 label = self.labels[idx] 

111 

112 with Image.open(image_path) as img: 

113 # Remove "near-infrared" channel if present (4-channel RGBA) 

114 if img.mode == "RGBA": 114 ↛ 116line 114 didn't jump to line 116 because the condition on line 114 was never true

115 # Convert RGBA to RGB (drops alpha/near-infrared channel) 

116 image = img.convert("RGB") 

117 else: 

118 # Keep as PIL Image for transforms 

119 image = img.copy() 

120 

121 # Apply transformations (expects PIL Image) 

122 if self.transform: 122 ↛ 125line 122 didn't jump to line 125 because the condition on line 122 was always true

123 image = self.transform(image) 

124 

125 return image, label 

126 

127 

128class UndersampledDataset(ForestDataset): 

129 def __init__(self, image_paths, labels, transform=None, target_size=None): 

130 super().__init__(image_paths, labels, transform) 

131 

132 class_indices = {} 

133 for idx, label in enumerate(labels): 

134 class_indices.setdefault(label, []).append(idx) 

135 

136 # Find the minimum number of samples in a class 

137 min_count = min(len(indices) for indices in class_indices.values()) 

138 

139 # If the target_size is not provided, set it to the minimum count 

140 target_size = target_size if target_size else min_count 

141 

142 self.sampled_indices = [] 

143 for indices in class_indices.values(): 

144 # Limit the number of images per class to target_size (if it exceeds the target_size) 

145 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices)))) 

146 

147 def __len__(self): 

148 return len(self.sampled_indices) 

149 

150 def __getitem__(self, idx): 

151 return super().__getitem__(self.sampled_indices[idx]) 

152 

153 

154class OversampledDataset(ForestDataset): 

155 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2, oversample_threshold=200): 

156 super().__init__(image_paths, labels, transform) 

157 self.minority_transform = minority_transform 

158 

159 class_indices = {} 

160 for idx, label in enumerate(labels): 

161 class_indices.setdefault(label, []).append(idx) 

162 

163 self.to_transform = set() 

164 self.sampled_indices = [] 

165 for label, indices in class_indices.items(): 

166 if len(indices) <= oversample_threshold: 

167 self.to_transform.add(label) 

168 # Sampling the minority class with replacement 

169 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices)))) 

170 else: 

171 self.sampled_indices.extend(indices) 

172 

173 def __len__(self): 

174 return len(self.sampled_indices) 

175 

176 def __getitem__(self, idx): 

177 image, label = super().__getitem__(self.sampled_indices[idx]) 

178 if label in self.to_transform and self.minority_transform: 178 ↛ 179line 178 didn't jump to line 179 because the condition on line 178 was never true

179 image = self.minority_transform(image) 

180 return image, label 

181 

182 

183class CurriculumLearningDataset(ForestDataset): 

184 def __init__(self, image_paths, labels, indices, transform=None): 

185 super().__init__(image_paths, labels, transform) 

186 self.indices = indices 

187 

188 def __len__(self): 

189 return len(self.indices) 

190 

191 def __getitem__(self, idx): 

192 return super().__getitem__(self.indices[idx]) 

193 

194 

195class ForestDataModule(pl.LightningDataModule): 

196 def __init__(self, train_data, val_data, test_data, dataset, dataset_args=None, batch_size=32): 

197 if dataset_args is None: 197 ↛ 200line 197 didn't jump to line 200 because the condition on line 197 was always true

198 dataset_args = {} 

199 

200 super().__init__() 

201 self.test_dataset = None 

202 self.train_dataset = None 

203 self.val_dataset = None 

204 self.train_data = train_data 

205 self.val_data = val_data 

206 self.test_data = test_data 

207 self.dataset = dataset 

208 self.dataset_args = dataset_args 

209 self.batch_size = batch_size 

210 self.params = calculate_dataloader_params(batch_size) 

211 

212 def setup(self, stage=None): 

213 train_transform = get_train_transform() 

214 self.train_dataset = self.dataset(image_paths=self.train_data["paths"], labels=self.train_data["labels"], transform=train_transform, **self.dataset_args) 

215 self.val_dataset = ForestDataset(image_paths=self.val_data["paths"], labels=self.val_data["labels"]) 

216 self.test_dataset = ForestDataset(image_paths=self.test_data["paths"], labels=self.test_data["labels"]) 

217 

218 def train_dataloader(self): 

219 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params) 

220 

221 def val_dataloader(self): 

222 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params) 

223 

224 def test_dataloader(self): 

225 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params) 

226 

227 

228if __name__ == "__main__": 228 ↛ 229line 228 didn't jump to line 229 because the condition on line 228 was never true

229 params = calculate_dataloader_params(32) 

230 print(params)