Coverage for src/dataset_functions.py: 71%

78 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-04-16 12:50 +0000

1import zipfile 

2from pathlib import Path 

3 

4import numpy as np 

5from huggingface_hub import hf_hub_download 

6 

7from typing import List, Dict, Optional 

8 

9 

10def print_extracted_files(extract_dir: Path): 

11 print(f"Successfully extracted to {extract_dir}") 

12 

13 extracted_files = Path(extract_dir).iterdir() 

14 print("Extracted files:") 

15 for extracted_file in list(extracted_files)[:5]: 

16 print(f"- {extracted_file.stem}") 

17 if len(list(extracted_files)) > 5: 17 ↛ 18line 17 didn't jump to line 18 because the condition on line 17 was never true

18 print(f"... and {len(list(extracted_files)) - 5} more files") 

19 

20 

21def extract_files(file_path: str, extract_dir: Path, main_subfolders: Dict): 

22 with zipfile.ZipFile(file_path, "r") as zip_ref: 

23 # Get list of all files in zip 

24 image_file_list = zip_ref.namelist() 

25 

26 # Extract all files, modifying their paths 

27 for image_file in image_file_list: 

28 # Extract file with modified path 

29 source = zip_ref.read(image_file) 

30 

31 # I assumed we are using aeirla imagery data. However, if needed, 

32 # a simple function can be written that chooses either aerial or 

33 # LiDAR data 

34 target_path = extract_dir / Path(image_file).relative_to(main_subfolders["aerial_imagery"]) 

35 

36 # Create directories if they don't exist 

37 target_path.parent.mkdir(parents=True, exist_ok=True) 

38 

39 with open(target_path, "wb") as f: 

40 f.write(source) 

41 

42 print_extracted_files(extract_dir) 

43 

44 

45def download_data(species_folders: Dict, main_subfolders: Dict, dataset_folder: Path): 

46 """ 

47 Function downloads specified data from HF (PureForest dataset) 

48 """ 

49 

50 for filename in species_folders: 

51 print(f"\nProcessing {species_folders[filename]}...") 

52 

53 # Download file 

54 file_path = hf_hub_download( 

55 repo_id="IGNF/PureForest", 

56 filename=species_folders[filename], 

57 repo_type="dataset" 

58 ) 

59 

60 extract_dir = dataset_folder / filename 

61 extract_dir.mkdir(exist_ok=True, parents=True) 

62 

63 try: 

64 extract_files(file_path, extract_dir, main_subfolders) 

65 except zipfile.BadZipFile: 

66 print(f"Error: {filename} is not a valid zip file") 

67 

68 

69def load_dataset(main_dir: Dict, species_folders: Dict, splits: Optional[List[str]] = None): 

70 if splits is None: 70 ↛ 72line 70 didn't jump to line 72 because the condition on line 70 was always true

71 splits = ["train", "val", "test"] 

72 dataset: Dict = {split: {"labels": [], "paths": []} for split in splits} # PLEASE KEEP "paths" KEY!!!!! 

73 

74 merged_labels = { 

75 "Quercus_petraea": "Deciduous_oak", 

76 "Quercus_pubescens": "Deciduous_oak", 

77 "Quercus_robur": "Deciduous_oak", 

78 "Quercus_rubra": "Deciduous_oak", 

79 "Quercus_ilex": "Evergreen_oak", 

80 "Fagus_sylvatica": "Beech", 

81 "Castanea_sativa": "Chestnut", 

82 "Robinia_pseudoacacia": "Black_locust", 

83 "Pinus_pinaster": "Maritime_pine", 

84 "Pinus_sylvestris": "Scotch_pine", 

85 "Pinus_nigra_laricio": "Black_pine", 

86 "Pinus_nigra": "Black_pine", 

87 "Pinus_halepensis": "Aleppo pine", 

88 "Abies_alba": "Fir", 

89 "Abies_nordmanniana": "Fir", 

90 "Picea_abies": "Spruce", 

91 "Larix_decidua": "Larch", 

92 "Pseudotsuga_menziesii": "Douglas" 

93 } 

94 

95 # Filtering merged_labels to present classes in config.yaml 

96 available_labels = {key: merged_labels[key] for key in species_folders if key in merged_labels} 

97 

98 unique_labels = sorted(set(available_labels.values())) 

99 label_map = {label: idx for idx, label in enumerate(unique_labels)} 

100 print("Label mapping:", label_map) 

101 

102 # base_dirs = list(main_dir.glob("*")) 

103 base_dirs = [species_folders[filename]. 

104 replace("data/imagery-", ""). 

105 replace(".zip", "") 

106 for filename in species_folders] 

107 

108 # Load images and create labels 

109 for base_dir in base_dirs: 

110 original_label = base_dir 

111 merged_label = available_labels.get(original_label, None) 

112 if merged_label is None: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true

113 continue 

114 

115 label = label_map[merged_label] 

116 

117 for split in splits: 

118 split_dir = main_dir / base_dir / split 

119 if not split_dir.exists(): 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true

120 print(f"Warning: {split_dir} does not exist") 

121 continue 

122 

123 # Get all TIFF files in the directory 

124 tiff_files = list(split_dir.glob("*.tiff")) + list( 

125 split_dir.glob("*.tif") 

126 ) 

127 

128 print(f"Loading {len(tiff_files)} images from {split_dir}") 

129 

130 for tiff_path in tiff_files: 

131 dataset[split]["labels"].append(label) 

132 dataset[split]["paths"].append(tiff_path) 

133 

134 # Convert lists to numpy arrays 

135 for split in splits: 

136 dataset[split]["labels"] = list(np.array(dataset[split]["labels"])) 

137 

138 return dataset, label_map 

139 

140 

141def clip_balanced_dataset(dataset: Dict): 

142 clipped_dataset = {} 

143 for split in dataset.keys(): 

144 if len(dataset[split]["paths"]) == 0: 

145 continue 

146 

147 # Identify minimum class count for this split 

148 unique_labels, label_counts = np.unique( 

149 dataset[split]["labels"], return_counts=True 

150 ) 

151 min_class_count = min(label_counts) 

152 

153 # Prepare clipped data 

154 labels_clipped = [] 

155 paths_clipped = [] 

156 

157 for label in unique_labels: 

158 # Find indices of images with the current label 

159 indices = np.where(dataset[split]["labels"] == label)[0] 

160 

161 # Randomly select min_class_count indices from these 

162 selected_indices = np.random.choice(indices, min_class_count, replace=False) 

163 

164 # Append selected samples to clipped data lists 

165 labels_clipped.extend(dataset[split]["labels"][selected_indices]) 

166 paths_clipped.extend([dataset[split]["paths"][i] for i in selected_indices]) 

167 

168 # Convert to numpy arrays 

169 clipped_dataset[split] = { 

170 "labels": np.array(labels_clipped), 

171 "paths": paths_clipped, 

172 } 

173 

174 return clipped_dataset