Coverage for src/dataset_functions.py: 0%
47 statements
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
« prev ^ index » next coverage.py v7.9.1, created at 2025-07-01 17:37 +0000
1import numpy as np
3from typing import List, Dict, Optional
6def load_dataset(main_dir: Dict, species_folders: Dict, splits: Optional[List[str]] = None):
7 if splits is None:
8 splits = ["train", "val", "test"]
9 dataset: Dict = {split: {"labels": [], "paths": []} for split in splits} # PLEASE KEEP "paths" KEY!!!!!
11 merged_labels = {
12 "Quercus_petraea": "Deciduous_oak",
13 "Quercus_pubescens": "Deciduous_oak",
14 "Quercus_robur": "Deciduous_oak",
15 "Quercus_rubra": "Deciduous_oak",
16 "Quercus_ilex": "Evergreen_oak",
17 "Fagus_sylvatica": "Beech",
18 "Castanea_sativa": "Chestnut",
19 "Robinia_pseudoacacia": "Black_locust",
20 "Pinus_pinaster": "Maritime_pine",
21 "Pinus_sylvestris": "Scotch_pine",
22 "Pinus_nigra_laricio": "Black_pine",
23 "Pinus_nigra": "Black_pine",
24 "Pinus_halepensis": "Aleppo pine",
25 "Abies_alba": "Fir",
26 "Abies_nordmanniana": "Fir",
27 "Picea_abies": "Spruce",
28 "Larix_decidua": "Larch",
29 "Pseudotsuga_menziesii": "Douglas",
30 }
32 # Filtering merged_labels to present classes in config.yaml
33 available_labels = {key: merged_labels[key] for key in species_folders if key in merged_labels}
35 unique_labels = sorted(set(available_labels.values()))
36 label_map = {label: idx for idx, label in enumerate(unique_labels)}
37 print("Label mapping:", label_map)
39 # base_dirs = list(main_dir.glob("*"))
40 base_dirs = [species_folders[filename].replace("data/imagery-", "").replace(".zip", "") for filename in species_folders]
42 # Load images and create labels
43 for base_dir in base_dirs:
44 original_label = base_dir
45 merged_label = available_labels.get(original_label, None)
46 if merged_label is None:
47 continue
49 label = label_map[merged_label]
51 for split in splits:
52 split_dir = main_dir / base_dir / split
53 if not split_dir.exists():
54 print(f"Warning: {split_dir} does not exist")
55 continue
57 # Get all TIFF files in the directory
58 tiff_files = list(split_dir.glob("*.tiff")) + list(split_dir.glob("*.tif"))
60 print(f"Loading {len(tiff_files)} images from {split_dir}")
62 for tiff_path in tiff_files:
63 dataset[split]["labels"].append(label)
64 dataset[split]["paths"].append(tiff_path)
66 # Convert lists to numpy arrays
67 for split in splits:
68 dataset[split]["labels"] = list(np.array(dataset[split]["labels"]))
70 return dataset, label_map
73def clip_balanced_dataset(dataset: Dict):
74 clipped_dataset = {}
75 for split in dataset.keys():
76 if len(dataset[split]["paths"]) == 0:
77 continue
79 # Identify minimum class count for this split
80 unique_labels, label_counts = np.unique(dataset[split]["labels"], return_counts=True)
81 min_class_count = min(label_counts)
83 # Prepare clipped data
84 labels_clipped = []
85 paths_clipped = []
87 for label in unique_labels:
88 # Find indices of images with the current label
89 indices = np.where(dataset[split]["labels"] == label)[0]
91 # Randomly select min_class_count indices from these
92 selected_indices = np.random.choice(indices, min_class_count, replace=False)
94 # Append selected samples to clipped data lists
95 labels_clipped.extend(dataset[split]["labels"][selected_indices])
96 paths_clipped.extend([dataset[split]["paths"][i] for i in selected_indices])
98 # Convert to numpy arrays
99 clipped_dataset[split] = {
100 "labels": np.array(labels_clipped),
101 "paths": paths_clipped,
102 }
104 return clipped_dataset