Coverage for src/dataset_functions.py: 0%

47 statements  

« prev     ^ index     » next       coverage.py v7.9.1, created at 2025-07-01 17:37 +0000

1import numpy as np 

2 

3from typing import List, Dict, Optional 

4 

5 

6def load_dataset(main_dir: Dict, species_folders: Dict, splits: Optional[List[str]] = None): 

7 if splits is None: 

8 splits = ["train", "val", "test"] 

9 dataset: Dict = {split: {"labels": [], "paths": []} for split in splits} # PLEASE KEEP "paths" KEY!!!!! 

10 

11 merged_labels = { 

12 "Quercus_petraea": "Deciduous_oak", 

13 "Quercus_pubescens": "Deciduous_oak", 

14 "Quercus_robur": "Deciduous_oak", 

15 "Quercus_rubra": "Deciduous_oak", 

16 "Quercus_ilex": "Evergreen_oak", 

17 "Fagus_sylvatica": "Beech", 

18 "Castanea_sativa": "Chestnut", 

19 "Robinia_pseudoacacia": "Black_locust", 

20 "Pinus_pinaster": "Maritime_pine", 

21 "Pinus_sylvestris": "Scotch_pine", 

22 "Pinus_nigra_laricio": "Black_pine", 

23 "Pinus_nigra": "Black_pine", 

24 "Pinus_halepensis": "Aleppo pine", 

25 "Abies_alba": "Fir", 

26 "Abies_nordmanniana": "Fir", 

27 "Picea_abies": "Spruce", 

28 "Larix_decidua": "Larch", 

29 "Pseudotsuga_menziesii": "Douglas", 

30 } 

31 

32 # Filtering merged_labels to present classes in config.yaml 

33 available_labels = {key: merged_labels[key] for key in species_folders if key in merged_labels} 

34 

35 unique_labels = sorted(set(available_labels.values())) 

36 label_map = {label: idx for idx, label in enumerate(unique_labels)} 

37 print("Label mapping:", label_map) 

38 

39 # base_dirs = list(main_dir.glob("*")) 

40 base_dirs = [species_folders[filename].replace("data/imagery-", "").replace(".zip", "") for filename in species_folders] 

41 

42 # Load images and create labels 

43 for base_dir in base_dirs: 

44 original_label = base_dir 

45 merged_label = available_labels.get(original_label, None) 

46 if merged_label is None: 

47 continue 

48 

49 label = label_map[merged_label] 

50 

51 for split in splits: 

52 split_dir = main_dir / base_dir / split 

53 if not split_dir.exists(): 

54 print(f"Warning: {split_dir} does not exist") 

55 continue 

56 

57 # Get all TIFF files in the directory 

58 tiff_files = list(split_dir.glob("*.tiff")) + list(split_dir.glob("*.tif")) 

59 

60 print(f"Loading {len(tiff_files)} images from {split_dir}") 

61 

62 for tiff_path in tiff_files: 

63 dataset[split]["labels"].append(label) 

64 dataset[split]["paths"].append(tiff_path) 

65 

66 # Convert lists to numpy arrays 

67 for split in splits: 

68 dataset[split]["labels"] = list(np.array(dataset[split]["labels"])) 

69 

70 return dataset, label_map 

71 

72 

73def clip_balanced_dataset(dataset: Dict): 

74 clipped_dataset = {} 

75 for split in dataset.keys(): 

76 if len(dataset[split]["paths"]) == 0: 

77 continue 

78 

79 # Identify minimum class count for this split 

80 unique_labels, label_counts = np.unique(dataset[split]["labels"], return_counts=True) 

81 min_class_count = min(label_counts) 

82 

83 # Prepare clipped data 

84 labels_clipped = [] 

85 paths_clipped = [] 

86 

87 for label in unique_labels: 

88 # Find indices of images with the current label 

89 indices = np.where(dataset[split]["labels"] == label)[0] 

90 

91 # Randomly select min_class_count indices from these 

92 selected_indices = np.random.choice(indices, min_class_count, replace=False) 

93 

94 # Append selected samples to clipped data lists 

95 labels_clipped.extend(dataset[split]["labels"][selected_indices]) 

96 paths_clipped.extend([dataset[split]["paths"][i] for i in selected_indices]) 

97 

98 # Convert to numpy arrays 

99 clipped_dataset[split] = { 

100 "labels": np.array(labels_clipped), 

101 "paths": paths_clipped, 

102 } 

103 

104 return clipped_dataset