Updating to new Augmentation object

This commit is contained in:
mbsantiago 2025-04-22 09:01:46 +01:00
parent 541be15c9e
commit 285c6a3347
5 changed files with 39 additions and 41 deletions

View File

@ -5,16 +5,10 @@ import click
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import (
load_preprocessing_config,
)
from batdetect2.targets import (
load_label_config,
load_target_config,
)
from batdetect2.train import (
preprocess_annotations,
)
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import load_label_config, preprocess_annotations
from batdetect2.train.labels import build_clip_labeler
__all__ = ["train"]
@ -129,9 +123,9 @@ def train(): ...
def preprocess(
dataset_config: Path,
output: Path,
target_config: Path,
base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None,
target_config: Optional[Path] = None,
label_config: Optional[Path] = None,
force: bool = False,
num_workers: Optional[int] = None,
@ -152,13 +146,9 @@ def preprocess(
else None
)
target = (
load_target_config(
target_config,
field=target_config_field,
)
if target_config
else None
target = load_target_config(
target_config,
field=target_config_field,
)
label = (
@ -176,15 +166,18 @@ def preprocess(
base_dir=base_dir,
)
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
output.mkdir(parents=True)
preprocess_annotations(
dataset.clip_annotations,
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
preprocessing_config=preprocess,
label_config=label,
target_config=target,
max_workers=num_workers,
)

View File

@ -63,7 +63,7 @@ def extract_values_at_positions(
],
Dimensions.time.value: positions.coords[Dimensions.time.value],
}
)
).T
def extract_detection_xr_dataset(
@ -109,9 +109,9 @@ def extract_detection_xr_dataset(
DataArrays share the 'detection' dimension and associated
time/frequency coordinates.
"""
sizes = extract_values_at_positions(sizes, positions).T
classes = extract_values_at_positions(classes, positions).T
features = extract_values_at_positions(features, positions).T
sizes = extract_values_at_positions(sizes, positions)
classes = extract_values_at_positions(classes, positions)
features = extract_values_at_positions(features, positions)
return xr.Dataset(
{
"scores": positions,

View File

@ -1,8 +1,12 @@
from batdetect2.train.augmentations import (
AugmentationsConfig,
EchoAugmentationConfig,
FrequencyMaskAugmentationConfig,
TimeMaskAugmentationConfig,
VolumeAugmentationConfig,
WarpAugmentationConfig,
add_echo,
augment_example,
load_agumentation_config,
build_augmentations,
mask_frequency,
mask_time,
mix_examples,
@ -17,6 +21,7 @@ from batdetect2.train.dataset import (
TrainExample,
get_preprocessed_files,
)
from batdetect2.train.labels import load_label_config
from batdetect2.train.losses import compute_loss
from batdetect2.train.preprocess import (
generate_train_example,
@ -26,17 +31,21 @@ from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [
"AugmentationsConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"LabeledDataset",
"SubclipConfig",
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainerConfig",
"TrainingConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo",
"augment_example",
"build_augmentations",
"compute_loss",
"generate_train_example",
"get_preprocessed_files",
"load_agumentation_config",
"load_train_config",
"load_trainer_config",
"mask_frequency",
@ -47,4 +56,5 @@ __all__ = [
"select_subclip",
"train",
"warp_spectrogram",
"load_label_config",
]

View File

@ -11,8 +11,8 @@ from torch.utils.data import Dataset
from batdetect2.configs import BaseConfig
from batdetect2.train.augmentations import (
Augmentation,
AugmentationsConfig,
augment_example,
select_subclip,
)
from batdetect2.train.preprocess import PreprocessorProtocol
@ -54,7 +54,7 @@ class LabeledDataset(Dataset):
preprocessor: PreprocessorProtocol,
filenames: Sequence[PathLike],
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
augmentation: Optional[Augmentation] = None,
):
self.preprocessor = preprocessor
self.filenames = filenames
@ -76,12 +76,7 @@ class LabeledDataset(Dataset):
)
if self.augmentation:
dataset = augment_example(
dataset,
self.augmentation,
preprocessor=self.preprocessor,
others=self.get_random_example,
)
dataset = self.augmentation(dataset)
return TrainExample(
spec=self.to_tensor(dataset["spectrogram"]).unsqueeze(0),
@ -98,7 +93,7 @@ class LabeledDataset(Dataset):
preprocessor: PreprocessorProtocol,
extension: str = ".nc",
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
augmentation: Optional[Augmentation] = None,
):
return cls(
preprocessor=preprocessor,

View File

@ -71,7 +71,7 @@ class LabelConfig(BaseConfig):
def build_clip_labeler(
targets: TargetProtocol,
config: LabelConfig,
config: Optional[LabelConfig] = None,
) -> ClipLabeller:
"""Construct the final clip labelling function.
@ -98,7 +98,7 @@ def build_clip_labeler(
return partial(
generate_clip_label,
targets=targets,
config=config,
config=config or LabelConfig(),
)