From 285c6a334763800b3ea7bd43c789fac41e5e17d0 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 22 Apr 2025 09:01:46 +0100 Subject: [PATCH] Updating to new Augmentation object --- batdetect2/cli/train.py | 37 +++++++++++----------------- batdetect2/postprocess/extraction.py | 8 +++--- batdetect2/train/__init__.py | 18 +++++++++++--- batdetect2/train/dataset.py | 13 +++------- batdetect2/train/labels.py | 4 +-- 5 files changed, 39 insertions(+), 41 deletions(-) diff --git a/batdetect2/cli/train.py b/batdetect2/cli/train.py index 60fbc45..a55ea0d 100644 --- a/batdetect2/cli/train.py +++ b/batdetect2/cli/train.py @@ -5,16 +5,10 @@ import click from batdetect2.cli.base import cli from batdetect2.data import load_dataset_from_config -from batdetect2.preprocess import ( - load_preprocessing_config, -) -from batdetect2.targets import ( - load_label_config, - load_target_config, -) -from batdetect2.train import ( - preprocess_annotations, -) +from batdetect2.preprocess import build_preprocessor, load_preprocessing_config +from batdetect2.targets import build_targets, load_target_config +from batdetect2.train import load_label_config, preprocess_annotations +from batdetect2.train.labels import build_clip_labeler __all__ = ["train"] @@ -129,9 +123,9 @@ def train(): ... def preprocess( dataset_config: Path, output: Path, + target_config: Path, base_dir: Optional[Path] = None, preprocess_config: Optional[Path] = None, - target_config: Optional[Path] = None, label_config: Optional[Path] = None, force: bool = False, num_workers: Optional[int] = None, @@ -152,13 +146,9 @@ def preprocess( else None ) - target = ( - load_target_config( - target_config, - field=target_config_field, - ) - if target_config - else None + target = load_target_config( + target_config, + field=target_config_field, ) label = ( @@ -176,15 +166,18 @@ def preprocess( base_dir=base_dir, ) + targets = build_targets(config=target) + preprocessor = build_preprocessor(config=preprocess) + labeller = build_clip_labeler(targets, config=label) + if not output.exists(): output.mkdir(parents=True) preprocess_annotations( - dataset.clip_annotations, + dataset, output_dir=output, + preprocessor=preprocessor, + labeller=labeller, replace=force, - preprocessing_config=preprocess, - label_config=label, - target_config=target, max_workers=num_workers, ) diff --git a/batdetect2/postprocess/extraction.py b/batdetect2/postprocess/extraction.py index 2592e59..84019a2 100644 --- a/batdetect2/postprocess/extraction.py +++ b/batdetect2/postprocess/extraction.py @@ -63,7 +63,7 @@ def extract_values_at_positions( ], Dimensions.time.value: positions.coords[Dimensions.time.value], } - ) + ).T def extract_detection_xr_dataset( @@ -109,9 +109,9 @@ def extract_detection_xr_dataset( DataArrays share the 'detection' dimension and associated time/frequency coordinates. """ - sizes = extract_values_at_positions(sizes, positions).T - classes = extract_values_at_positions(classes, positions).T - features = extract_values_at_positions(features, positions).T + sizes = extract_values_at_positions(sizes, positions) + classes = extract_values_at_positions(classes, positions) + features = extract_values_at_positions(features, positions) return xr.Dataset( { "scores": positions, diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index ab6c627..6caba28 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -1,8 +1,12 @@ from batdetect2.train.augmentations import ( AugmentationsConfig, + EchoAugmentationConfig, + FrequencyMaskAugmentationConfig, + TimeMaskAugmentationConfig, + VolumeAugmentationConfig, + WarpAugmentationConfig, add_echo, - augment_example, - load_agumentation_config, + build_augmentations, mask_frequency, mask_time, mix_examples, @@ -17,6 +21,7 @@ from batdetect2.train.dataset import ( TrainExample, get_preprocessed_files, ) +from batdetect2.train.labels import load_label_config from batdetect2.train.losses import compute_loss from batdetect2.train.preprocess import ( generate_train_example, @@ -26,17 +31,21 @@ from batdetect2.train.train import TrainerConfig, load_trainer_config, train __all__ = [ "AugmentationsConfig", + "EchoAugmentationConfig", + "FrequencyMaskAugmentationConfig", "LabeledDataset", "SubclipConfig", + "TimeMaskAugmentationConfig", "TrainExample", "TrainerConfig", "TrainingConfig", + "VolumeAugmentationConfig", + "WarpAugmentationConfig", "add_echo", - "augment_example", + "build_augmentations", "compute_loss", "generate_train_example", "get_preprocessed_files", - "load_agumentation_config", "load_train_config", "load_trainer_config", "mask_frequency", @@ -47,4 +56,5 @@ __all__ = [ "select_subclip", "train", "warp_spectrogram", + "load_label_config", ] diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index c1be352..cc839dd 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -11,8 +11,8 @@ from torch.utils.data import Dataset from batdetect2.configs import BaseConfig from batdetect2.train.augmentations import ( + Augmentation, AugmentationsConfig, - augment_example, select_subclip, ) from batdetect2.train.preprocess import PreprocessorProtocol @@ -54,7 +54,7 @@ class LabeledDataset(Dataset): preprocessor: PreprocessorProtocol, filenames: Sequence[PathLike], subclip: Optional[SubclipConfig] = None, - augmentation: Optional[AugmentationsConfig] = None, + augmentation: Optional[Augmentation] = None, ): self.preprocessor = preprocessor self.filenames = filenames @@ -76,12 +76,7 @@ class LabeledDataset(Dataset): ) if self.augmentation: - dataset = augment_example( - dataset, - self.augmentation, - preprocessor=self.preprocessor, - others=self.get_random_example, - ) + dataset = self.augmentation(dataset) return TrainExample( spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0), @@ -98,7 +93,7 @@ class LabeledDataset(Dataset): preprocessor: PreprocessorProtocol, extension: str = ".nc", subclip: Optional[SubclipConfig] = None, - augmentation: Optional[AugmentationsConfig] = None, + augmentation: Optional[Augmentation] = None, ): return cls( preprocessor=preprocessor, diff --git a/batdetect2/train/labels.py b/batdetect2/train/labels.py index 4f4454c..11955e4 100644 --- a/batdetect2/train/labels.py +++ b/batdetect2/train/labels.py @@ -71,7 +71,7 @@ class LabelConfig(BaseConfig): def build_clip_labeler( targets: TargetProtocol, - config: LabelConfig, + config: Optional[LabelConfig] = None, ) -> ClipLabeller: """Construct the final clip labelling function. @@ -98,7 +98,7 @@ def build_clip_labeler( return partial( generate_clip_label, targets=targets, - config=config, + config=config or LabelConfig(), )