From d8d2e5a2c2d0e9291ea7609108100afa130f67c3 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 8 Sep 2025 18:11:58 +0100 Subject: [PATCH] Remove preprocessing modules --- src/batdetect2/cli/preprocess.py | 142 ----------------- src/batdetect2/data/iterators.py | 36 +---- src/batdetect2/train/__init__.py | 6 - src/batdetect2/train/preprocess.py | 243 ----------------------------- tests/test_targets/test_terms.py | 16 ++ 5 files changed, 20 insertions(+), 423 deletions(-) delete mode 100644 src/batdetect2/cli/preprocess.py delete mode 100644 src/batdetect2/train/preprocess.py create mode 100644 tests/test_targets/test_terms.py diff --git a/src/batdetect2/cli/preprocess.py b/src/batdetect2/cli/preprocess.py deleted file mode 100644 index b539d5d..0000000 --- a/src/batdetect2/cli/preprocess.py +++ /dev/null @@ -1,142 +0,0 @@ -import sys -from pathlib import Path -from typing import Optional - -import click -import yaml -from loguru import logger - -from batdetect2.cli.base import cli -from batdetect2.data import load_dataset_from_config -from batdetect2.train.preprocess import ( - TrainPreprocessConfig, - load_train_preprocessing_config, - preprocess_dataset, -) - -__all__ = ["preprocess"] - - -@cli.command() -@click.argument( - "dataset_config", - type=click.Path(exists=True), -) -@click.argument( - "output", - type=click.Path(), -) -@click.option( - "--dataset-field", - type=str, - help=( - "Specifies the key to access the dataset information within the " - "dataset configuration file, if the information is nested inside a " - "dictionary. If the dataset information is at the top level of the " - "config file, you don't need to specify this." - ), -) -@click.option( - "--base-dir", - type=click.Path(exists=True), - help=( - "The main directory where your audio recordings and annotation " - "files are stored. This helps the program find your data, " - "especially if the paths in your dataset configuration file " - "are relative." - ), -) -@click.option( - "--config", - type=click.Path(exists=True), - help=( - "Path to the configuration file. This file tells " - "the program how to prepare your audio data before training, such " - "as resampling or applying filters." - ), -) -@click.option( - "--config-field", - type=str, - help=( - "If the preprocessing settings are inside a nested dictionary " - "within the preprocessing configuration file, specify the key " - "here to access them. If the preprocessing settings are at the " - "top level, you don't need to specify this." - ), -) -@click.option( - "--num-workers", - type=int, - help=( - "The maximum number of computer cores to use when processing " - "your audio data. Using more cores can speed up the preprocessing, " - "but don't use more than your computer has available. By default, " - "the program will use all available cores." - ), -) -@click.option( - "-v", - "--verbose", - count=True, - help="Increase verbosity. -v for INFO, -vv for DEBUG.", -) -def preprocess( - dataset_config: Path, - output: Path, - base_dir: Optional[Path] = None, - config: Optional[Path] = None, - config_field: Optional[str] = None, - num_workers: Optional[int] = None, - dataset_field: Optional[str] = None, - verbose: int = 0, -): - logger.remove() - if verbose == 0: - log_level = "WARNING" - elif verbose == 1: - log_level = "INFO" - else: - log_level = "DEBUG" - logger.add(sys.stderr, level=log_level) - - logger.info("Starting preprocessing.") - - output = Path(output) - logger.info("Will save outputs to {output}", output=output) - - base_dir = base_dir or Path.cwd() - logger.debug("Current working directory: {base_dir}", base_dir=base_dir) - - if config: - logger.info( - "Loading preprocessing config from: {config}", config=config - ) - - conf = ( - load_train_preprocessing_config(config, field=config_field) - if config is not None - else TrainPreprocessConfig() - ) - logger.debug( - "Preprocessing config:\n{conf}", - conf=yaml.dump(conf.model_dump()), - ) - - dataset = load_dataset_from_config( - dataset_config, - field=dataset_field, - base_dir=base_dir, - ) - - logger.info( - "Loaded {num_examples} annotated clips from the configured dataset", - num_examples=len(dataset), - ) - - preprocess_dataset( - dataset, - conf, - output=output, - max_workers=num_workers, - ) diff --git a/src/batdetect2/data/iterators.py b/src/batdetect2/data/iterators.py index 08c4411..f3f3ff7 100644 --- a/src/batdetect2/data/iterators.py +++ b/src/batdetect2/data/iterators.py @@ -10,16 +10,8 @@ from batdetect2.typing.targets import TargetProtocol def iterate_over_sound_events( dataset: Dataset, targets: TargetProtocol, - apply_filter: bool = True, - apply_transform: bool = True, - exclude_generic: bool = True, ) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]: - """Iterate over sound events in a dataset, applying filtering and - transformations. - - This generator function processes sound event annotations from a given - dataset, allowing for optional filtering, transformation, and exclusion of - unclassifiable (generic) events based on the provided target definitions. + """Iterate over sound events in a dataset. Parameters ---------- @@ -29,18 +21,6 @@ def iterate_over_sound_events( targets : TargetProtocol An object implementing the `TargetProtocol`, which provides methods for filtering, transforming, and encoding sound events. - apply_filter : bool, optional - If True, sound events will be filtered using `targets.filter()`. - Only events for which `targets.filter()` returns True will be yielded. - Defaults to True. - apply_transform : bool, optional - If True, sound events will be transformed using `targets.transform()` - before being yielded. Defaults to True. - exclude_generic : bool, optional - If True, sound events that result in a `None` class name after - `targets.encode()` will be excluded. This is typically used to - filter out events that cannot be mapped to a specific target class. - Defaults to True. Yields ------ @@ -63,17 +43,9 @@ def iterate_over_sound_events( """ for clip_annotation in dataset: for sound_event_annotation in clip_annotation.sound_events: - if apply_filter: - if not targets.filter(sound_event_annotation): - continue - - if apply_transform: - sound_event_annotation = targets.transform( - sound_event_annotation - ) - - class_name = targets.encode_class(sound_event_annotation) - if class_name is None and exclude_generic: + if not targets.filter(sound_event_annotation): continue + class_name = targets.encode_class(sound_event_annotation) + yield class_name, sound_event_annotation diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index bf180e3..faff98b 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -33,10 +33,6 @@ from batdetect2.train.losses import ( SizeLossConfig, build_loss, ) -from batdetect2.train.preprocess import ( - generate_train_example, - preprocess_annotations, -) from batdetect2.train.train import ( build_train_dataset, build_train_loader, @@ -74,14 +70,12 @@ __all__ = [ "build_trainer", "build_val_dataset", "build_val_loader", - "generate_train_example", "load_full_training_config", "load_label_config", "load_train_config", "mask_frequency", "mask_time", "mix_audio", - "preprocess_annotations", "scale_volume", "select_subclip", "train", diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py deleted file mode 100644 index 9d6a660..0000000 --- a/src/batdetect2/train/preprocess.py +++ /dev/null @@ -1,243 +0,0 @@ -"""Preprocesses datasets for BatDetect2 model training.""" - -import os -from pathlib import Path -from typing import Callable, List, Optional, Sequence, TypedDict - -import numpy as np -import torch -import torch.utils.data -from loguru import logger -from pydantic import Field -from soundevent import data -from tqdm import tqdm - -from batdetect2.configs import BaseConfig, load_config -from batdetect2.data.datasets import Dataset -from batdetect2.preprocess import PreprocessingConfig, build_preprocessor -from batdetect2.preprocess.audio import build_audio_loader -from batdetect2.targets import TargetConfig, build_targets -from batdetect2.train.labels import LabelConfig, build_clip_labeler -from batdetect2.typing import ClipLabeller, PreprocessorProtocol -from batdetect2.typing.preprocess import AudioLoader -from batdetect2.typing.train import PreprocessedExample - -__all__ = [ - "preprocess_annotations", - "generate_train_example", - "preprocess_dataset", - "TrainPreprocessConfig", - "load_train_preprocessing_config", - "save_preprocessed_example", - "load_preprocessed_example", -] - -FilenameFn = Callable[[data.ClipAnnotation], str] -"""Type alias for a function that generates an output filename.""" - - -class TrainPreprocessConfig(BaseConfig): - preprocess: PreprocessingConfig = Field( - default_factory=PreprocessingConfig - ) - targets: TargetConfig = Field(default_factory=TargetConfig) - labels: LabelConfig = Field(default_factory=LabelConfig) - - -def load_train_preprocessing_config( - path: data.PathLike, - field: Optional[str] = None, -) -> TrainPreprocessConfig: - return load_config(path=path, schema=TrainPreprocessConfig, field=field) - - -def preprocess_dataset( - dataset: Dataset, - config: TrainPreprocessConfig, - output: Path, - max_workers: Optional[int] = None, -) -> None: - targets = build_targets(config=config.targets) - preprocessor = build_preprocessor(config=config.preprocess) - labeller = build_clip_labeler( - targets, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, - config=config.labels, - ) - audio_loader = build_audio_loader(config=config.preprocess.audio) - - if not output.exists(): - logger.debug("Creating directory {directory}", directory=output) - output.mkdir(parents=True) - - preprocess_annotations( - dataset, - output_dir=output, - audio_loader=audio_loader, - preprocessor=preprocessor, - labeller=labeller, - max_workers=max_workers, - ) - - -class Example(TypedDict): - audio: torch.Tensor - spectrogram: torch.Tensor - detection_heatmap: torch.Tensor - class_heatmap: torch.Tensor - size_heatmap: torch.Tensor - - -def generate_train_example( - clip_annotation: data.ClipAnnotation, - audio_loader: AudioLoader, - preprocessor: PreprocessorProtocol, - labeller: ClipLabeller, -) -> PreprocessedExample: - """Generate a complete training example for one annotation.""" - wave = torch.tensor( - audio_loader.load_clip(clip_annotation.clip) - ).unsqueeze(0) - spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0) - heatmaps = labeller(clip_annotation, spectrogram) - return PreprocessedExample( - audio=wave, - spectrogram=spectrogram, - detection_heatmap=heatmaps.detection, - class_heatmap=heatmaps.classes, - size_heatmap=heatmaps.size, - ) - - -class PreprocessingDataset(torch.utils.data.Dataset): - def __init__( - self, - clips: Dataset, - audio_loader: AudioLoader, - preprocessor: PreprocessorProtocol, - labeller: ClipLabeller, - filename_fn: FilenameFn, - output_dir: Path, - force: bool = False, - ): - self.clips = clips - self.audio_loader = audio_loader - self.preprocessor = preprocessor - self.labeller = labeller - self.filename_fn = filename_fn - self.output_dir = output_dir - self.force = force - - def __getitem__(self, idx) -> int: - clip_annotation = self.clips[idx] - - filename = self.filename_fn(clip_annotation) - - path = self.output_dir / filename - - if path.exists() and not self.force: - return idx - - if not path.parent.exists(): - path.parent.mkdir() - - example = generate_train_example( - clip_annotation, - audio_loader=self.audio_loader, - preprocessor=self.preprocessor, - labeller=self.labeller, - ) - - save_preprocessed_example(example, clip_annotation, path) - - return idx - - def __len__(self) -> int: - return len(self.clips) - - -def save_preprocessed_example( - example: PreprocessedExample, - clip_annotation: data.ClipAnnotation, - path: data.PathLike, -) -> None: - np.savez_compressed( - path, - audio=example.audio.numpy(), - spectrogram=example.spectrogram.numpy(), - detection_heatmap=example.detection_heatmap.numpy(), - class_heatmap=example.class_heatmap.numpy(), - size_heatmap=example.size_heatmap.numpy(), - clip_annotation=clip_annotation, - ) - - -def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample: - item = np.load(path, mmap_mode="r+") - return PreprocessedExample( - audio=torch.tensor(item["audio"]), - spectrogram=torch.tensor(item["spectrogram"]), - size_heatmap=torch.tensor(item["size_heatmap"]), - detection_heatmap=torch.tensor(item["detection_heatmap"]), - class_heatmap=torch.tensor(item["class_heatmap"]), - ) - - -def list_preprocessed_files( - directory: data.PathLike, extension: str = ".npz" -) -> List[Path]: - return list(Path(directory).glob(f"*{extension}")) - - -def _get_filename(clip_annotation: data.ClipAnnotation) -> str: - """Generate a default output filename based on the annotation UUID.""" - return f"{clip_annotation.uuid}" - - -def preprocess_annotations( - clip_annotations: Sequence[data.ClipAnnotation], - output_dir: data.PathLike, - preprocessor: PreprocessorProtocol, - audio_loader: AudioLoader, - labeller: ClipLabeller, - filename_fn: FilenameFn = _get_filename, - max_workers: Optional[int] = None, -) -> None: - """Preprocess a sequence of ClipAnnotations and save results to disk.""" - output_dir = Path(output_dir) - - if not output_dir.is_dir(): - logger.info( - "Creating output directory: {output_dir}", output_dir=output_dir - ) - output_dir.mkdir(parents=True) - - logger.info( - "Starting preprocessing of {num_annotations} annotations with {max_workers} workers.", - num_annotations=len(clip_annotations), - max_workers=max_workers or "all available", - ) - - if max_workers is None: - max_workers = os.cpu_count() or 0 - - dataset = PreprocessingDataset( - clips=list(clip_annotations), - audio_loader=audio_loader, - preprocessor=preprocessor, - labeller=labeller, - output_dir=Path(output_dir), - filename_fn=filename_fn, - ) - - loader = torch.utils.data.DataLoader( - dataset, - batch_size=1, - shuffle=False, - num_workers=max_workers, - prefetch_factor=16, - ) - - for _ in tqdm(loader, total=len(dataset)): - pass diff --git a/tests/test_targets/test_terms.py b/tests/test_targets/test_terms.py new file mode 100644 index 0000000..088a2e6 --- /dev/null +++ b/tests/test_targets/test_terms.py @@ -0,0 +1,16 @@ +import pytest + +from batdetect2.targets import terms + + +def test_tag_info_and_get_tag_from_info(): + tag_info = TagInfo(value="Myotis myotis", key="event") + tag = terms.get_tag_from_info(tag_info) + assert tag.value == "Myotis myotis" + assert tag.term == terms.call_type + + +def test_get_tag_from_info_key_not_found(): + tag_info = TagInfo(value="test", key="non_existent_key") + with pytest.raises(KeyError): + terms.get_tag_from_info(tag_info)