Coverage for src/dataset_functions.py: 71%
78 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-04-16 12:50 +0000
1import zipfile
2from pathlib import Path
4import numpy as np
5from huggingface_hub import hf_hub_download
7from typing import List, Dict, Optional
10def print_extracted_files(extract_dir: Path):
11 print(f"Successfully extracted to {extract_dir}")
13 extracted_files = Path(extract_dir).iterdir()
14 print("Extracted files:")
15 for extracted_file in list(extracted_files)[:5]:
16 print(f"- {extracted_file.stem}")
17 if len(list(extracted_files)) > 5: 17 ↛ 18line 17 didn't jump to line 18 because the condition on line 17 was never true
18 print(f"... and {len(list(extracted_files)) - 5} more files")
21def extract_files(file_path: str, extract_dir: Path, main_subfolders: Dict):
22 with zipfile.ZipFile(file_path, "r") as zip_ref:
23 # Get list of all files in zip
24 image_file_list = zip_ref.namelist()
26 # Extract all files, modifying their paths
27 for image_file in image_file_list:
28 # Extract file with modified path
29 source = zip_ref.read(image_file)
31 # I assumed we are using aeirla imagery data. However, if needed,
32 # a simple function can be written that chooses either aerial or
33 # LiDAR data
34 target_path = extract_dir / Path(image_file).relative_to(main_subfolders["aerial_imagery"])
36 # Create directories if they don't exist
37 target_path.parent.mkdir(parents=True, exist_ok=True)
39 with open(target_path, "wb") as f:
40 f.write(source)
42 print_extracted_files(extract_dir)
45def download_data(species_folders: Dict, main_subfolders: Dict, dataset_folder: Path):
46 """
47 Function downloads specified data from HF (PureForest dataset)
48 """
50 for filename in species_folders:
51 print(f"\nProcessing {species_folders[filename]}...")
53 # Download file
54 file_path = hf_hub_download(
55 repo_id="IGNF/PureForest",
56 filename=species_folders[filename],
57 repo_type="dataset"
58 )
60 extract_dir = dataset_folder / filename
61 extract_dir.mkdir(exist_ok=True, parents=True)
63 try:
64 extract_files(file_path, extract_dir, main_subfolders)
65 except zipfile.BadZipFile:
66 print(f"Error: {filename} is not a valid zip file")
69def load_dataset(main_dir: Dict, species_folders: Dict, splits: Optional[List[str]] = None):
70 if splits is None: 70 ↛ 72line 70 didn't jump to line 72 because the condition on line 70 was always true
71 splits = ["train", "val", "test"]
72 dataset: Dict = {split: {"labels": [], "paths": []} for split in splits} # PLEASE KEEP "paths" KEY!!!!!
74 merged_labels = {
75 "Quercus_petraea": "Deciduous_oak",
76 "Quercus_pubescens": "Deciduous_oak",
77 "Quercus_robur": "Deciduous_oak",
78 "Quercus_rubra": "Deciduous_oak",
79 "Quercus_ilex": "Evergreen_oak",
80 "Fagus_sylvatica": "Beech",
81 "Castanea_sativa": "Chestnut",
82 "Robinia_pseudoacacia": "Black_locust",
83 "Pinus_pinaster": "Maritime_pine",
84 "Pinus_sylvestris": "Scotch_pine",
85 "Pinus_nigra_laricio": "Black_pine",
86 "Pinus_nigra": "Black_pine",
87 "Pinus_halepensis": "Aleppo pine",
88 "Abies_alba": "Fir",
89 "Abies_nordmanniana": "Fir",
90 "Picea_abies": "Spruce",
91 "Larix_decidua": "Larch",
92 "Pseudotsuga_menziesii": "Douglas"
93 }
95 # Filtering merged_labels to present classes in config.yaml
96 available_labels = {key: merged_labels[key] for key in species_folders if key in merged_labels}
98 unique_labels = sorted(set(available_labels.values()))
99 label_map = {label: idx for idx, label in enumerate(unique_labels)}
100 print("Label mapping:", label_map)
102 # base_dirs = list(main_dir.glob("*"))
103 base_dirs = [species_folders[filename].
104 replace("data/imagery-", "").
105 replace(".zip", "")
106 for filename in species_folders]
108 # Load images and create labels
109 for base_dir in base_dirs:
110 original_label = base_dir
111 merged_label = available_labels.get(original_label, None)
112 if merged_label is None: 112 ↛ 113line 112 didn't jump to line 113 because the condition on line 112 was never true
113 continue
115 label = label_map[merged_label]
117 for split in splits:
118 split_dir = main_dir / base_dir / split
119 if not split_dir.exists(): 119 ↛ 120line 119 didn't jump to line 120 because the condition on line 119 was never true
120 print(f"Warning: {split_dir} does not exist")
121 continue
123 # Get all TIFF files in the directory
124 tiff_files = list(split_dir.glob("*.tiff")) + list(
125 split_dir.glob("*.tif")
126 )
128 print(f"Loading {len(tiff_files)} images from {split_dir}")
130 for tiff_path in tiff_files:
131 dataset[split]["labels"].append(label)
132 dataset[split]["paths"].append(tiff_path)
134 # Convert lists to numpy arrays
135 for split in splits:
136 dataset[split]["labels"] = list(np.array(dataset[split]["labels"]))
138 return dataset, label_map
141def clip_balanced_dataset(dataset: Dict):
142 clipped_dataset = {}
143 for split in dataset.keys():
144 if len(dataset[split]["paths"]) == 0:
145 continue
147 # Identify minimum class count for this split
148 unique_labels, label_counts = np.unique(
149 dataset[split]["labels"], return_counts=True
150 )
151 min_class_count = min(label_counts)
153 # Prepare clipped data
154 labels_clipped = []
155 paths_clipped = []
157 for label in unique_labels:
158 # Find indices of images with the current label
159 indices = np.where(dataset[split]["labels"] == label)[0]
161 # Randomly select min_class_count indices from these
162 selected_indices = np.random.choice(indices, min_class_count, replace=False)
164 # Append selected samples to clipped data lists
165 labels_clipped.extend(dataset[split]["labels"][selected_indices])
166 paths_clipped.extend([dataset[split]["paths"][i] for i in selected_indices])
168 # Convert to numpy arrays
169 clipped_dataset[split] = {
170 "labels": np.array(labels_clipped),
171 "paths": paths_clipped,
172 }
174 return clipped_dataset