Remove preprocessing modules

This commit is contained in:
mbsantiago 2025-09-08 18:11:58 +01:00
parent b056d7d28d
commit d8d2e5a2c2
5 changed files with 20 additions and 423 deletions

View File

@ -1,142 +0,0 @@
import sys
from pathlib import Path
from typing import Optional
import click
import yaml
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.train.preprocess import (
TrainPreprocessConfig,
load_train_preprocessing_config,
preprocess_dataset,
)
__all__ = ["preprocess"]
@cli.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.argument(
"output",
type=click.Path(),
)
@click.option(
"--dataset-field",
type=str,
help=(
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
)
@click.option(
"--config",
type=click.Path(exists=True),
help=(
"Path to the configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
)
@click.option(
"--config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
)
@click.option(
"--num-workers",
type=int,
help=(
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def preprocess(
dataset_config: Path,
output: Path,
base_dir: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
num_workers: Optional[int] = None,
dataset_field: Optional[str] = None,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Starting preprocessing.")
output = Path(output)
logger.info("Will save outputs to {output}", output=output)
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
if config:
logger.info(
"Loading preprocessing config from: {config}", config=config
)
conf = (
load_train_preprocessing_config(config, field=config_field)
if config is not None
else TrainPreprocessConfig()
)
logger.debug(
"Preprocessing config:\n{conf}",
conf=yaml.dump(conf.model_dump()),
)
dataset = load_dataset_from_config(
dataset_config,
field=dataset_field,
base_dir=base_dir,
)
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
preprocess_dataset(
dataset,
conf,
output=output,
max_workers=num_workers,
)

View File

@ -10,16 +10,8 @@ from batdetect2.typing.targets import TargetProtocol
def iterate_over_sound_events( def iterate_over_sound_events(
dataset: Dataset, dataset: Dataset,
targets: TargetProtocol, targets: TargetProtocol,
apply_filter: bool = True,
apply_transform: bool = True,
exclude_generic: bool = True,
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]: ) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
"""Iterate over sound events in a dataset, applying filtering and """Iterate over sound events in a dataset.
transformations.
This generator function processes sound event annotations from a given
dataset, allowing for optional filtering, transformation, and exclusion of
unclassifiable (generic) events based on the provided target definitions.
Parameters Parameters
---------- ----------
@ -29,18 +21,6 @@ def iterate_over_sound_events(
targets : TargetProtocol targets : TargetProtocol
An object implementing the `TargetProtocol`, which provides methods An object implementing the `TargetProtocol`, which provides methods
for filtering, transforming, and encoding sound events. for filtering, transforming, and encoding sound events.
apply_filter : bool, optional
If True, sound events will be filtered using `targets.filter()`.
Only events for which `targets.filter()` returns True will be yielded.
Defaults to True.
apply_transform : bool, optional
If True, sound events will be transformed using `targets.transform()`
before being yielded. Defaults to True.
exclude_generic : bool, optional
If True, sound events that result in a `None` class name after
`targets.encode()` will be excluded. This is typically used to
filter out events that cannot be mapped to a specific target class.
Defaults to True.
Yields Yields
------ ------
@ -63,17 +43,9 @@ def iterate_over_sound_events(
""" """
for clip_annotation in dataset: for clip_annotation in dataset:
for sound_event_annotation in clip_annotation.sound_events: for sound_event_annotation in clip_annotation.sound_events:
if apply_filter: if not targets.filter(sound_event_annotation):
if not targets.filter(sound_event_annotation):
continue
if apply_transform:
sound_event_annotation = targets.transform(
sound_event_annotation
)
class_name = targets.encode_class(sound_event_annotation)
if class_name is None and exclude_generic:
continue continue
class_name = targets.encode_class(sound_event_annotation)
yield class_name, sound_event_annotation yield class_name, sound_event_annotation

View File

@ -33,10 +33,6 @@ from batdetect2.train.losses import (
SizeLossConfig, SizeLossConfig,
build_loss, build_loss,
) )
from batdetect2.train.preprocess import (
generate_train_example,
preprocess_annotations,
)
from batdetect2.train.train import ( from batdetect2.train.train import (
build_train_dataset, build_train_dataset,
build_train_loader, build_train_loader,
@ -74,14 +70,12 @@ __all__ = [
"build_trainer", "build_trainer",
"build_val_dataset", "build_val_dataset",
"build_val_loader", "build_val_loader",
"generate_train_example",
"load_full_training_config", "load_full_training_config",
"load_label_config", "load_label_config",
"load_train_config", "load_train_config",
"mask_frequency", "mask_frequency",
"mask_time", "mask_time",
"mix_audio", "mix_audio",
"preprocess_annotations",
"scale_volume", "scale_volume",
"select_subclip", "select_subclip",
"train", "train",

View File

@ -1,243 +0,0 @@
"""Preprocesses datasets for BatDetect2 model training."""
import os
from pathlib import Path
from typing import Callable, List, Optional, Sequence, TypedDict
import numpy as np
import torch
import torch.utils.data
from loguru import logger
from pydantic import Field
from soundevent import data
from tqdm import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.audio import build_audio_loader
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.typing.train import PreprocessedExample
__all__ = [
"preprocess_annotations",
"generate_train_example",
"preprocess_dataset",
"TrainPreprocessConfig",
"load_train_preprocessing_config",
"save_preprocessed_example",
"load_preprocessed_example",
]
FilenameFn = Callable[[data.ClipAnnotation], str]
"""Type alias for a function that generates an output filename."""
class TrainPreprocessConfig(BaseConfig):
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def load_train_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TrainPreprocessConfig:
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
def preprocess_dataset(
dataset: Dataset,
config: TrainPreprocessConfig,
output: Path,
max_workers: Optional[int] = None,
) -> None:
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
labeller = build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.labels,
)
audio_loader = build_audio_loader(config=config.preprocess.audio)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
preprocess_annotations(
dataset,
output_dir=output,
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
max_workers=max_workers,
)
class Example(TypedDict):
audio: torch.Tensor
spectrogram: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def generate_train_example(
clip_annotation: data.ClipAnnotation,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
) -> PreprocessedExample:
"""Generate a complete training example for one annotation."""
wave = torch.tensor(
audio_loader.load_clip(clip_annotation.clip)
).unsqueeze(0)
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
heatmaps = labeller(clip_annotation, spectrogram)
return PreprocessedExample(
audio=wave,
spectrogram=spectrogram,
detection_heatmap=heatmaps.detection,
class_heatmap=heatmaps.classes,
size_heatmap=heatmaps.size,
)
class PreprocessingDataset(torch.utils.data.Dataset):
def __init__(
self,
clips: Dataset,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
filename_fn: FilenameFn,
output_dir: Path,
force: bool = False,
):
self.clips = clips
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.labeller = labeller
self.filename_fn = filename_fn
self.output_dir = output_dir
self.force = force
def __getitem__(self, idx) -> int:
clip_annotation = self.clips[idx]
filename = self.filename_fn(clip_annotation)
path = self.output_dir / filename
if path.exists() and not self.force:
return idx
if not path.parent.exists():
path.parent.mkdir()
example = generate_train_example(
clip_annotation,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
labeller=self.labeller,
)
save_preprocessed_example(example, clip_annotation, path)
return idx
def __len__(self) -> int:
return len(self.clips)
def save_preprocessed_example(
example: PreprocessedExample,
clip_annotation: data.ClipAnnotation,
path: data.PathLike,
) -> None:
np.savez_compressed(
path,
audio=example.audio.numpy(),
spectrogram=example.spectrogram.numpy(),
detection_heatmap=example.detection_heatmap.numpy(),
class_heatmap=example.class_heatmap.numpy(),
size_heatmap=example.size_heatmap.numpy(),
clip_annotation=clip_annotation,
)
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
item = np.load(path, mmap_mode="r+")
return PreprocessedExample(
audio=torch.tensor(item["audio"]),
spectrogram=torch.tensor(item["spectrogram"]),
size_heatmap=torch.tensor(item["size_heatmap"]),
detection_heatmap=torch.tensor(item["detection_heatmap"]),
class_heatmap=torch.tensor(item["class_heatmap"]),
)
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".npz"
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
"""Generate a default output filename based on the annotation UUID."""
return f"{clip_annotation.uuid}"
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: data.PathLike,
preprocessor: PreprocessorProtocol,
audio_loader: AudioLoader,
labeller: ClipLabeller,
filename_fn: FilenameFn = _get_filename,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
output_dir = Path(output_dir)
if not output_dir.is_dir():
logger.info(
"Creating output directory: {output_dir}", output_dir=output_dir
)
output_dir.mkdir(parents=True)
logger.info(
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
num_annotations=len(clip_annotations),
max_workers=max_workers or "all available",
)
if max_workers is None:
max_workers = os.cpu_count() or 0
dataset = PreprocessingDataset(
clips=list(clip_annotations),
audio_loader=audio_loader,
preprocessor=preprocessor,
labeller=labeller,
output_dir=Path(output_dir),
filename_fn=filename_fn,
)
loader = torch.utils.data.DataLoader(
dataset,
batch_size=1,
shuffle=False,
num_workers=max_workers,
prefetch_factor=16,
)
for _ in tqdm(loader, total=len(dataset)):
pass

View File

@ -0,0 +1,16 @@
import pytest
from batdetect2.targets import terms
def test_tag_info_and_get_tag_from_info():
tag_info = TagInfo(value="Myotis myotis", key="event")
tag = terms.get_tag_from_info(tag_info)
assert tag.value == "Myotis myotis"
assert tag.term == terms.call_type
def test_get_tag_from_info_key_not_found():
tag_info = TagInfo(value="test", key="non_existent_key")
with pytest.raises(KeyError):
terms.get_tag_from_info(tag_info)