Coverage for src/dataset.py: 88%

116 statements  

« prev     ^ index     » next       coverage.py v7.11.0, created at 2026-01-06 01:30 +0000

1import os 

2import yaml 

3import psutil 

4import numpy as np 

5from PIL import Image 

6from math import floor 

7import pytorch_lightning as pl 

8import torchvision.transforms as transforms 

9from torch.utils.data import Dataset, DataLoader 

10import random 

11 

12 

13with open("src/config.yaml", "r") as c: 

14 config = yaml.safe_load(c) 

15 

16 

17def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8): 

18 """ 

19 Function calculates the number of workers and prefetch factor 

20 for DataLoader based on the available RAM. 

21 

22 Input: 

23 batch_size: int - the batch size used in DataLoader 

24 img_size: int - the size of the image 

25 image_channels: int - the number of channels in the image 

26 precision: int - the precision of the weights 

27 ram_fraction: float - the fraction of RAM to use 

28 Output: 

29 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers 

30 num_workers: int - the number of workers 

31 prefetch_factor: int - the prefetch factor 

32 pin_memory: bool - whether to use pin_memory 

33 persistent_workers: bool - whether to use persistent workers 

34 """ 

35 if config["training"]["dataloader"]["auto"]: 35 ↛ 51line 35 didn't jump to line 51 because the condition on line 35 was always true

36 total_ram = psutil.virtual_memory().available * ram_fraction 

37 img_memory = np.prod(img_size) * image_channels * (precision / 8) 

38 batch_memory = batch_size * img_memory 

39 

40 if batch_memory > total_ram: 40 ↛ 41line 40 didn't jump to line 41 because the condition on line 40 was never true

41 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.") 

42 

43 max_batches_in_ram = floor(total_ram / batch_memory) 

44 

45 prefetch_factor = min(max_batches_in_ram, 16) 

46 num_workers = min(floor(prefetch_factor / 2), os.cpu_count()) 

47 

48 params = {"num_workers": num_workers, "prefetch_factor": prefetch_factor, "pin_memory": config["device"] == "gpu", "persistent_workers": True} 

49 

50 else: 

51 params = { 

52 "num_workers": config["training"]["dataloader"]["num_workers"], 

53 "prefetch_factor": config["training"]["dataloader"]["prefetch_factor"], 

54 "pin_memory": config["training"]["dataloader"]["pin_memory"], 

55 "persistent_workers": config["training"]["dataloader"]["persistent_workers"], 

56 } 

57 

58 return params 

59 

60 

61class ForestDataset(Dataset): 

62 def __init__(self, image_paths, labels, transform=None): 

63 self.image_paths = image_paths 

64 self.labels = labels 

65 

66 if transform is None: 

67 self.transform = transforms.Compose( 

68 [ 

69 transforms.Resize((224, 224)), 

70 transforms.ToTensor(), 

71 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 

72 ] 

73 ) 

74 else: 

75 self.transform = transform 

76 

77 def __len__(self): 

78 return len(self.image_paths) 

79 

80 def __getitem__(self, idx): 

81 image_path = self.image_paths[idx] 

82 label = self.labels[idx] 

83 

84 with Image.open(image_path) as img: 

85 # Remove "near-infrared" channel if present (4-channel RGBA) 

86 if img.mode == "RGBA": 86 ↛ 88line 86 didn't jump to line 88 because the condition on line 86 was never true

87 # Convert RGBA to RGB (drops alpha/near-infrared channel) 

88 image = img.convert("RGB") 

89 else: 

90 # Keep as PIL Image for transforms 

91 image = img.copy() 

92 

93 # Apply transformations (expects PIL Image) 

94 if self.transform: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true

95 image = self.transform(image) 

96 

97 return image, label 

98 

99 

100class UndersampledDataset(ForestDataset): 

101 def __init__(self, image_paths, labels, transform=None, target_size=None): 

102 super().__init__(image_paths, labels, transform) 

103 

104 class_indices = {} 

105 for idx, label in enumerate(labels): 

106 class_indices.setdefault(label, []).append(idx) 

107 

108 # Find the minimum number of samples in a class 

109 min_count = min(len(indices) for indices in class_indices.values()) 

110 

111 # If the target_size is not provided, set it to the minimum count 

112 target_size = target_size if target_size else min_count 

113 

114 self.sampled_indices = [] 

115 for indices in class_indices.values(): 

116 # Limit the number of images per class to target_size (if it exceeds the target_size) 

117 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices)))) 

118 

119 def __len__(self): 

120 return len(self.sampled_indices) 

121 

122 def __getitem__(self, idx): 

123 return super().__getitem__(self.sampled_indices[idx]) 

124 

125 

126class OversampledDataset(ForestDataset): 

127 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2, oversample_threshold=200): 

128 super().__init__(image_paths, labels, transform) 

129 self.minority_transform = minority_transform 

130 

131 class_indices = {} 

132 for idx, label in enumerate(labels): 

133 class_indices.setdefault(label, []).append(idx) 

134 

135 self.to_transform = set() 

136 self.sampled_indices = [] 

137 for label, indices in class_indices.items(): 

138 if len(indices) <= oversample_threshold: 

139 self.to_transform.add(label) 

140 # Sampling the minority class with replacement 

141 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices)))) 

142 else: 

143 self.sampled_indices.extend(indices) 

144 

145 def __len__(self): 

146 return len(self.sampled_indices) 

147 

148 def __getitem__(self, idx): 

149 image, label = super().__getitem__(self.sampled_indices[idx]) 

150 if label in self.to_transform and self.minority_transform: 150 ↛ 151line 150 didn't jump to line 151 because the condition on line 150 was never true

151 image = self.minority_transform(image) 

152 return image, label 

153 

154 

155class CurriculumLearningDataset(ForestDataset): 

156 def __init__(self, image_paths, labels, indices, transform=None): 

157 super().__init__(image_paths, labels, transform) 

158 self.indices = indices 

159 

160 def __len__(self): 

161 return len(self.indices) 

162 

163 def __getitem__(self, idx): 

164 return super().__getitem__(self.indices[idx]) 

165 

166 

167class ForestDataModule(pl.LightningDataModule): 

168 def __init__(self, train_data, val_data, test_data, dataset, dataset_args=None, batch_size=32): 

169 if dataset_args is None: 169 ↛ 172line 169 didn't jump to line 172 because the condition on line 169 was always true

170 dataset_args = {} 

171 

172 super().__init__() 

173 self.test_dataset = None 

174 self.train_dataset = None 

175 self.val_dataset = None 

176 self.train_data = train_data 

177 self.val_data = val_data 

178 self.test_data = test_data 

179 self.dataset = dataset 

180 self.dataset_args = dataset_args 

181 self.batch_size = batch_size 

182 self.params = calculate_dataloader_params(batch_size) 

183 

184 def setup(self, stage=None): 

185 self.train_dataset = self.dataset(image_paths=self.train_data["paths"], labels=self.train_data["labels"], **self.dataset_args) 

186 self.val_dataset = ForestDataset(image_paths=self.val_data["paths"], labels=self.val_data["labels"]) 

187 self.test_dataset = ForestDataset(image_paths=self.test_data["paths"], labels=self.test_data["labels"]) 

188 

189 def train_dataloader(self): 

190 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params) 

191 

192 def val_dataloader(self): 

193 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params) 

194 

195 def test_dataloader(self): 

196 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params) 

197 

198 

199if __name__ == "__main__": 199 ↛ 200line 199 didn't jump to line 200 because the condition on line 199 was never true

200 params = calculate_dataloader_params(32) 

201 print(params)