Updating to new Augmentation object

This commit is contained in:
mbsantiago 2025-04-22 09:01:46 +01:00
parent 541be15c9e
commit 285c6a3347
5 changed files with 39 additions and 41 deletions

View File

@ -5,16 +5,10 @@ import click
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import ( from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
load_preprocessing_config, from batdetect2.targets import build_targets, load_target_config
) from batdetect2.train import load_label_config, preprocess_annotations
from batdetect2.targets import ( from batdetect2.train.labels import build_clip_labeler
load_label_config,
load_target_config,
)
from batdetect2.train import (
preprocess_annotations,
)
__all__ = ["train"] __all__ = ["train"]
@ -129,9 +123,9 @@ def train(): ...
def preprocess( def preprocess(
dataset_config: Path, dataset_config: Path,
output: Path, output: Path,
target_config: Path,
base_dir: Optional[Path] = None, base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None, preprocess_config: Optional[Path] = None,
target_config: Optional[Path] = None,
label_config: Optional[Path] = None, label_config: Optional[Path] = None,
force: bool = False, force: bool = False,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
@ -152,13 +146,9 @@ def preprocess(
else None else None
) )
target = ( target = load_target_config(
load_target_config( target_config,
target_config, field=target_config_field,
field=target_config_field,
)
if target_config
else None
) )
label = ( label = (
@ -176,15 +166,18 @@ def preprocess(
base_dir=base_dir, base_dir=base_dir,
) )
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists(): if not output.exists():
output.mkdir(parents=True) output.mkdir(parents=True)
preprocess_annotations( preprocess_annotations(
dataset.clip_annotations, dataset,
output_dir=output, output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force, replace=force,
preprocessing_config=preprocess,
label_config=label,
target_config=target,
max_workers=num_workers, max_workers=num_workers,
) )

View File

@ -63,7 +63,7 @@ def extract_values_at_positions(
], ],
Dimensions.time.value: positions.coords[Dimensions.time.value], Dimensions.time.value: positions.coords[Dimensions.time.value],
} }
) ).T
def extract_detection_xr_dataset( def extract_detection_xr_dataset(
@ -109,9 +109,9 @@ def extract_detection_xr_dataset(
DataArrays share the 'detection' dimension and associated DataArrays share the 'detection' dimension and associated
time/frequency coordinates. time/frequency coordinates.
""" """
sizes = extract_values_at_positions(sizes, positions).T sizes = extract_values_at_positions(sizes, positions)
classes = extract_values_at_positions(classes, positions).T classes = extract_values_at_positions(classes, positions)
features = extract_values_at_positions(features, positions).T features = extract_values_at_positions(features, positions)
return xr.Dataset( return xr.Dataset(
{ {
"scores": positions, "scores": positions,

View File

@ -1,8 +1,12 @@
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
AugmentationsConfig, AugmentationsConfig,
EchoAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
add_echo, add_echo,
augment_example, build_augmentations,
load_agumentation_config,
mask_frequency, mask_frequency,
mask_time, mask_time,
mix_examples, mix_examples,
@ -17,6 +21,7 @@ from batdetect2.train.dataset import (
TrainExample, TrainExample,
get_preprocessed_files, get_preprocessed_files,
) )
from batdetect2.train.labels import load_label_config
from batdetect2.train.losses import compute_loss from batdetect2.train.losses import compute_loss
from batdetect2.train.preprocess import ( from batdetect2.train.preprocess import (
generate_train_example, generate_train_example,
@ -26,17 +31,21 @@ from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [ __all__ = [
"AugmentationsConfig", "AugmentationsConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"LabeledDataset", "LabeledDataset",
"SubclipConfig", "SubclipConfig",
"TimeMaskAugmentationConfig",
"TrainExample", "TrainExample",
"TrainerConfig", "TrainerConfig",
"TrainingConfig", "TrainingConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo", "add_echo",
"augment_example", "build_augmentations",
"compute_loss", "compute_loss",
"generate_train_example", "generate_train_example",
"get_preprocessed_files", "get_preprocessed_files",
"load_agumentation_config",
"load_train_config", "load_train_config",
"load_trainer_config", "load_trainer_config",
"mask_frequency", "mask_frequency",
@ -47,4 +56,5 @@ __all__ = [
"select_subclip", "select_subclip",
"train", "train",
"warp_spectrogram", "warp_spectrogram",
"load_label_config",
] ]

View File

@ -11,8 +11,8 @@ from torch.utils.data import Dataset
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
Augmentation,
AugmentationsConfig, AugmentationsConfig,
augment_example,
select_subclip, select_subclip,
) )
from batdetect2.train.preprocess import PreprocessorProtocol from batdetect2.train.preprocess import PreprocessorProtocol
@ -54,7 +54,7 @@ class LabeledDataset(Dataset):
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
filenames: Sequence[PathLike], filenames: Sequence[PathLike],
subclip: Optional[SubclipConfig] = None, subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None, augmentation: Optional[Augmentation] = None,
): ):
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.filenames = filenames self.filenames = filenames
@ -76,12 +76,7 @@ class LabeledDataset(Dataset):
) )
if self.augmentation: if self.augmentation:
dataset = augment_example( dataset = self.augmentation(dataset)
dataset,
self.augmentation,
preprocessor=self.preprocessor,
others=self.get_random_example,
)
return TrainExample( return TrainExample(
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0), spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
@ -98,7 +93,7 @@ class LabeledDataset(Dataset):
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
extension: str = ".nc", extension: str = ".nc",
subclip: Optional[SubclipConfig] = None, subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None, augmentation: Optional[Augmentation] = None,
): ):
return cls( return cls(
preprocessor=preprocessor, preprocessor=preprocessor,

View File

@ -71,7 +71,7 @@ class LabelConfig(BaseConfig):
def build_clip_labeler( def build_clip_labeler(
targets: TargetProtocol, targets: TargetProtocol,
config: LabelConfig, config: Optional[LabelConfig] = None,
) -> ClipLabeller: ) -> ClipLabeller:
"""Construct the final clip labelling function. """Construct the final clip labelling function.
@ -98,7 +98,7 @@ def build_clip_labeler(
return partial( return partial(
generate_clip_label, generate_clip_label,
targets=targets, targets=targets,
config=config, config=config or LabelConfig(),
) )