Coverage for src/dataset.py: 89%

114 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 17:37 +0000

1import os 

2import yaml 

3import psutil 

4import numpy as np 

5from PIL import Image 

6from math import floor 

7import pytorch_lightning as pl 

8from transforms import Preprocess 

9import torchvision.transforms as transforms 

10from torch.utils.data import Dataset, DataLoader 

11import random 

12 

13 

14with open("src/config.yaml", "r") as c: 

15 config = yaml.safe_load(c) 

16 

17 

18def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8): 

19 """ 

20 Function calculates the number of workers and prefetch factor 

21 for DataLoader based on the available RAM. 

22 

23 Input: 

24 batch_size: int - the batch size used in DataLoader 

25 img_size: int - the size of the image 

26 image_channels: int - the number of channels in the image 

27 precision: int - the precision of the weights 

28 ram_fraction: float - the fraction of RAM to use 

29 Output: 

30 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers 

31 num_workers: int - the number of workers 

32 prefetch_factor: int - the prefetch factor 

33 pin_memory: bool - whether to use pin_memory 

34 persistent_workers: bool - whether to use persistent workers 

35 """ 

36 if config["training"]["dataloader"]["auto"]: 36 ↛ 52line 36 didn't jump to line 52 because the condition on line 36 was always true

37 total_ram = psutil.virtual_memory().available * ram_fraction 

38 img_memory = np.prod(img_size) * image_channels * (precision / 8) 

39 batch_memory = batch_size * img_memory 

40 

41 if batch_memory > total_ram: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true

42 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.") 

43 

44 max_batches_in_ram = floor(total_ram / batch_memory) 

45 

46 prefetch_factor = min(max_batches_in_ram, 16) 

47 num_workers = min(floor(prefetch_factor / 2), os.cpu_count()) 

48 

49 params = {"num_workers": num_workers, "prefetch_factor": prefetch_factor, "pin_memory": config["device"] == "gpu", "persistent_workers": True} 

50 

51 else: 

52 params = { 

53 "num_workers": config["training"]["dataloader"]["num_workers"], 

54 "prefetch_factor": config["training"]["dataloader"]["prefetch_factor"], 

55 "pin_memory": config["training"]["dataloader"]["pin_memory"], 

56 "persistent_workers": config["training"]["dataloader"]["persistent_workers"], 

57 } 

58 

59 return params 

60 

61 

62class ForestDataset(Dataset): 

63 def __init__(self, image_paths, labels, transform=None): 

64 self.image_paths = image_paths 

65 self.labels = labels 

66 # Define a default transform if none is provided 

67 # TODO: Use transforms suitable for the model 

68 self.transform = transform or transforms.Compose( 

69 [ 

70 transforms.ToTensor(), 

71 transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), # Adjust as needed for RGB channels 

72 ] 

73 ) 

74 

75 def __len__(self): 

76 return len(self.image_paths) 

77 

78 def __getitem__(self, idx): 

79 # TODO: Load an image from path here 

80 image_path = self.image_paths[idx] 

81 label = self.labels[idx] 

82 

83 with Image.open(image_path) as img: 

84 # Convert to numpy array 

85 image = np.array(img) 

86 image = image[:, :, 1:] if image.shape[-1] == 4 else image # Removing "near-inferred" channel 

87 # We found out that PIL conversion to RGB 

88 # keeps the "near-inferred" channel which was not desired 

89 

90 # Apply transformations 

91 if self.transform: 91 ↛ 94line 91 didn't jump to line 94 because the condition on line 91 was always true

92 image = self.transform(image) 

93 

94 return image, label 

95 

96 

97class UndersampledDataset(ForestDataset): 

98 def __init__(self, image_paths, labels, transform=None, target_size=None): 

99 super().__init__(image_paths, labels, transform) 

100 

101 class_indices = {} 

102 for idx, label in enumerate(labels): 

103 class_indices.setdefault(label, []).append(idx) 

104 

105 # Find the minimum number of samples in a class 

106 min_count = min(len(indices) for indices in class_indices.values()) 

107 

108 # If the target_size is not provided, set it to the minimum count 

109 target_size = target_size if target_size else min_count 

110 

111 self.sampled_indices = [] 

112 for indices in class_indices.values(): 

113 # Limit the number of images per class to target_size (if it exceeds the target_size) 

114 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices)))) 

115 

116 def __len__(self): 

117 return len(self.sampled_indices) 

118 

119 def __getitem__(self, idx): 

120 return super().__getitem__(self.sampled_indices[idx]) 

121 

122 

123class OversampledDataset(ForestDataset): 

124 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2, oversample_threshold=200): 

125 super().__init__(image_paths, labels, transform) 

126 self.minority_transform = minority_transform 

127 

128 class_indices = {} 

129 for idx, label in enumerate(labels): 

130 class_indices.setdefault(label, []).append(idx) 

131 

132 self.to_transform = set() 

133 self.sampled_indices = [] 

134 for label, indices in class_indices.items(): 

135 if len(indices) <= oversample_threshold: 

136 self.to_transform.add(label) 

137 # Sampling the minority class with replacement 

138 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices)))) 

139 else: 

140 self.sampled_indices.extend(indices) 

141 

142 def __len__(self): 

143 return len(self.sampled_indices) 

144 

145 def __getitem__(self, idx): 

146 image, label = super().__getitem__(self.sampled_indices[idx]) 

147 if label in self.to_transform and self.minority_transform: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 image = self.minority_transform(image) 

149 return image, label 

150 

151 

152class CurriculumLearningDataset(ForestDataset): 

153 def __init__(self, image_paths, labels, indices, transform=None): 

154 super().__init__(image_paths, labels, transform) 

155 self.indices = indices 

156 

157 def __len__(self): 

158 return len(self.indices) 

159 

160 def __getitem__(self, idx): 

161 return super().__getitem__(self.indices[idx]) 

162 

163 

164class ForestDataModule(pl.LightningDataModule): 

165 def __init__(self, train_data, val_data, test_data, dataset, dataset_args=None, batch_size=32): 

166 if dataset_args is None: 166 ↛ 169line 166 didn't jump to line 169 because the condition on line 166 was always true

167 dataset_args = {} 

168 

169 super().__init__() 

170 self.test_dataset = None 

171 self.train_dataset = None 

172 self.val_dataset = None 

173 self.train_data = train_data 

174 self.val_data = val_data 

175 self.test_data = test_data 

176 self.dataset = dataset 

177 self.dataset_args = dataset_args 

178 self.batch_size = batch_size 

179 self.params = calculate_dataloader_params(batch_size) 

180 

181 def setup(self, stage=None): 

182 self.train_dataset = self.dataset(image_paths=self.train_data["paths"], labels=self.train_data["labels"], transform=Preprocess(), **self.dataset_args) 

183 self.val_dataset = ForestDataset(image_paths=self.val_data["paths"], labels=self.val_data["labels"], transform=Preprocess()) 

184 self.test_dataset = ForestDataset(image_paths=self.test_data["paths"], labels=self.test_data["labels"], transform=Preprocess()) 

185 

186 def train_dataloader(self): 

187 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params) 

188 

189 def val_dataloader(self): 

190 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params) 

191 

192 def test_dataloader(self): 

193 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params) 

194 

195 

196if __name__ == "__main__": 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 params = calculate_dataloader_params(32) 

198 print(params)