Coverage for src/dataset.py: 92%

104 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-16 12:50 +0000

1import os 

2import yaml 

3import psutil 

4import numpy as np 

5from PIL import Image 

6from math import floor 

7import pytorch_lightning as pl 

8from transforms import Preprocess 

9import torchvision.transforms as transforms 

10from torch.utils.data import Dataset, DataLoader 

11import random 

12 

13 

14with open("src/config.yaml", "r") as c: 

15 config = yaml.safe_load(c) 

16 

17 

18def calculate_dataloader_params(batch_size, img_size=(224, 224), image_channels=3, precision=32, ram_fraction=0.8): 

19 """ 

20 Function calculates the number of workers and prefetch factor 

21 for DataLoader based on the available RAM. 

22 

23 Input: 

24 batch_size: int - the batch size used in DataLoader 

25 img_size: int - the size of the image 

26 image_channels: int - the number of channels in the image 

27 precision: int - the precision of the weights 

28 ram_fraction: float - the fraction of RAM to use 

29 Output: 

30 dict of params: num_workers, prefetch_factor, pin_memory, persistent_workers 

31 num_workers: int - the number of workers 

32 prefetch_factor: int - the prefetch factor 

33 pin_memory: bool - whether to use pin_memory 

34 persistent_workers: bool - whether to use persistent workers 

35 """ 

36 if config['training']['dataloader']['auto']: 36 ↛ 55line 36 didn't jump to line 55 because the condition on line 36 was always true

37 total_ram = psutil.virtual_memory().available * ram_fraction 

38 img_memory = np.prod(img_size) * image_channels * (precision/8) 

39 batch_memory = batch_size * img_memory 

40 

41 if batch_memory > total_ram: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true

42 raise ValueError("Batch size too large for available RAM. Reduce the batch size or image dimensions.") 

43 

44 max_batches_in_ram = floor(total_ram / batch_memory) 

45 

46 prefetch_factor = min(max_batches_in_ram, 16) 

47 num_workers = min(floor(prefetch_factor / 2), os.cpu_count()) 

48 

49 params = {"num_workers": num_workers, 

50 "prefetch_factor": prefetch_factor, 

51 "pin_memory": config['device'] == 'gpu', 

52 "persistent_workers": True} 

53 

54 else: 

55 params = {"num_workers": config['training']['dataloader']['num_workers'], 

56 "prefetch_factor": config['training']['dataloader']['prefetch_factor'], 

57 "pin_memory": config['training']['dataloader']['pin_memory'], 

58 "persistent_workers": config['training']['dataloader']['persistent_workers']} 

59 

60 return params 

61 

62 

63class ForestDataset(Dataset): 

64 def __init__(self, image_paths, labels, transform=None): 

65 self.image_paths = image_paths 

66 self.labels = labels 

67 # Define a default transform if none is provided 

68 # TODO: Use transforms suitable for the model 

69 self.transform = transform or transforms.Compose( 

70 [ 

71 transforms.ToTensor(), 

72 transforms.Normalize( 

73 mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5] 

74 ), # Adjust as needed for RGB channels 

75 ] 

76 ) 

77 

78 def __len__(self): 

79 return len(self.image_paths) 

80 

81 def __getitem__(self, idx): 

82 # TODO: Load an image from path here 

83 image_path = self.image_paths[idx] 

84 label = self.labels[idx] 

85 

86 with Image.open(image_path) as img: 

87 # Convert to numpy array 

88 image = np.array(img) 

89 image = image[:, :, 1:] if image.shape[-1] == 4 else image # Removing "near-inferred" channel 

90 # We found out that PIL conversion to RGB 

91 # keeps the "near-inferred" channel which was not desired 

92 

93 # Apply transformations 

94 if self.transform: 94 ↛ 97line 94 didn't jump to line 97 because the condition on line 94 was always true

95 image = self.transform(image) 

96 

97 return image, label 

98 

99 

100class UndersampledDataset(ForestDataset): 

101 def __init__(self, image_paths, labels, transform=None, target_size=None): 

102 super().__init__(image_paths, labels, transform) 

103 

104 class_indices = {} 

105 for idx, label in enumerate(labels): 

106 class_indices.setdefault(label, []).append(idx) 

107 

108 # Find the minimum number of samples in a class 

109 min_count = min(len(indices) for indices in class_indices.values()) 

110 

111 # If the target_size is not provided, set it to the minimum count 

112 target_size = target_size if target_size else min_count 

113 

114 self.sampled_indices = [] 

115 for indices in class_indices.values(): 

116 # Limit the number of images per class to target_size (if it exceeds the target_size) 

117 self.sampled_indices.extend(random.sample(indices, min(target_size, len(indices)))) 

118 

119 def __len__(self): 

120 return len(self.sampled_indices) 

121 

122 def __getitem__(self, idx): 

123 return super().__getitem__(self.sampled_indices[idx]) 

124 

125 

126class OversampledDataset(ForestDataset): 

127 def __init__(self, image_paths, labels, transform=None, minority_transform=None, oversample_factor=2, 

128 oversample_threshold=200): 

129 super().__init__(image_paths, labels, transform) 

130 self.minority_transform = minority_transform 

131 

132 class_indices = {} 

133 for idx, label in enumerate(labels): 

134 class_indices.setdefault(label, []).append(idx) 

135 

136 self.to_transform = set() 

137 self.sampled_indices = [] 

138 for label, indices in class_indices.items(): 

139 if len(indices) <= oversample_threshold: 

140 self.to_transform.add(label) 

141 # Sampling the minority class with replacement 

142 self.sampled_indices.extend(random.choices(indices, k=int(oversample_factor * len(indices)))) 

143 else: 

144 self.sampled_indices.extend(indices) 

145 

146 def __len__(self): 

147 return len(self.sampled_indices) 

148 

149 def __getitem__(self, idx): 

150 image, label = super().__getitem__(self.sampled_indices[idx]) 

151 if label in self.to_transform and self.minority_transform: 151 ↛ 152line 151 didn't jump to line 152 because the condition on line 151 was never true

152 image = self.minority_transform(image) 

153 return image, label 

154 

155 

156class ForestDataModule(pl.LightningDataModule): 

157 def __init__(self, train_data, val_data, test_data, dataset, dataset_args={}, batch_size=32): 

158 super().__init__() 

159 self.test_dataset = None 

160 self.train_dataset = None 

161 self.val_dataset = None 

162 self.train_data = train_data 

163 self.val_data = val_data 

164 self.test_data = test_data 

165 self.dataset = dataset 

166 self.dataset_args = dataset_args 

167 self.batch_size = batch_size 

168 self.params = calculate_dataloader_params(batch_size) 

169 

170 def setup(self, stage=None): 

171 self.train_dataset = self.dataset( 

172 image_paths=self.train_data["paths"], 

173 labels=self.train_data["labels"], 

174 transform=Preprocess(), 

175 **self.dataset_args 

176 ) 

177 self.val_dataset = ForestDataset( 

178 image_paths=self.val_data["paths"], 

179 labels=self.val_data["labels"], 

180 transform=Preprocess() 

181 ) 

182 self.test_dataset = ForestDataset( 

183 image_paths=self.test_data["paths"], 

184 labels=self.test_data["labels"], 

185 transform=Preprocess() 

186 ) 

187 

188 def train_dataloader(self): 

189 return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, **self.params) 

190 

191 def val_dataloader(self): 

192 return DataLoader(self.val_dataset, batch_size=self.batch_size, **self.params) 

193 

194 def test_dataloader(self): 

195 return DataLoader(self.test_dataset, batch_size=self.batch_size, **self.params) 

196 

197 

198if __name__ == '__main__': 198 ↛ 199line 198 didn't jump to line 199 because the condition on line 198 was never true

199 params = calculate_dataloader_params(32) 

200 print(params)