mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Remove preprocessing modules
This commit is contained in:
parent
b056d7d28d
commit
d8d2e5a2c2
@ -1,142 +0,0 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.train.preprocess import (
|
||||
TrainPreprocessConfig,
|
||||
load_train_preprocessing_config,
|
||||
preprocess_dataset,
|
||||
)
|
||||
|
||||
__all__ = ["preprocess"]
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.argument(
|
||||
"output",
|
||||
type=click.Path(),
|
||||
)
|
||||
@click.option(
|
||||
"--dataset-field",
|
||||
type=str,
|
||||
help=(
|
||||
"Specifies the key to access the dataset information within the "
|
||||
"dataset configuration file, if the information is nested inside a "
|
||||
"dictionary. If the dataset information is at the top level of the "
|
||||
"config file, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"The main directory where your audio recordings and annotation "
|
||||
"files are stored. This helps the program find your data, "
|
||||
"especially if the paths in your dataset configuration file "
|
||||
"are relative."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
"within the preprocessing configuration file, specify the key "
|
||||
"here to access them. If the preprocessing settings are at the "
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--num-workers",
|
||||
type=int,
|
||||
help=(
|
||||
"The maximum number of computer cores to use when processing "
|
||||
"your audio data. Using more cores can speed up the preprocessing, "
|
||||
"but don't use more than your computer has available. By default, "
|
||||
"the program will use all available cores."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def preprocess(
|
||||
dataset_config: Path,
|
||||
output: Path,
|
||||
base_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
dataset_field: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Starting preprocessing.")
|
||||
|
||||
output = Path(output)
|
||||
logger.info("Will save outputs to {output}", output=output)
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||
|
||||
if config:
|
||||
logger.info(
|
||||
"Loading preprocessing config from: {config}", config=config
|
||||
)
|
||||
|
||||
conf = (
|
||||
load_train_preprocessing_config(config, field=config_field)
|
||||
if config is not None
|
||||
else TrainPreprocessConfig()
|
||||
)
|
||||
logger.debug(
|
||||
"Preprocessing config:\n{conf}",
|
||||
conf=yaml.dump(conf.model_dump()),
|
||||
)
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
field=dataset_field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||
num_examples=len(dataset),
|
||||
)
|
||||
|
||||
preprocess_dataset(
|
||||
dataset,
|
||||
conf,
|
||||
output=output,
|
||||
max_workers=num_workers,
|
||||
)
|
||||
@ -10,16 +10,8 @@ from batdetect2.typing.targets import TargetProtocol
|
||||
def iterate_over_sound_events(
|
||||
dataset: Dataset,
|
||||
targets: TargetProtocol,
|
||||
apply_filter: bool = True,
|
||||
apply_transform: bool = True,
|
||||
exclude_generic: bool = True,
|
||||
) -> Generator[Tuple[Optional[str], data.SoundEventAnnotation], None, None]:
|
||||
"""Iterate over sound events in a dataset, applying filtering and
|
||||
transformations.
|
||||
|
||||
This generator function processes sound event annotations from a given
|
||||
dataset, allowing for optional filtering, transformation, and exclusion of
|
||||
unclassifiable (generic) events based on the provided target definitions.
|
||||
"""Iterate over sound events in a dataset.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -29,18 +21,6 @@ def iterate_over_sound_events(
|
||||
targets : TargetProtocol
|
||||
An object implementing the `TargetProtocol`, which provides methods
|
||||
for filtering, transforming, and encoding sound events.
|
||||
apply_filter : bool, optional
|
||||
If True, sound events will be filtered using `targets.filter()`.
|
||||
Only events for which `targets.filter()` returns True will be yielded.
|
||||
Defaults to True.
|
||||
apply_transform : bool, optional
|
||||
If True, sound events will be transformed using `targets.transform()`
|
||||
before being yielded. Defaults to True.
|
||||
exclude_generic : bool, optional
|
||||
If True, sound events that result in a `None` class name after
|
||||
`targets.encode()` will be excluded. This is typically used to
|
||||
filter out events that cannot be mapped to a specific target class.
|
||||
Defaults to True.
|
||||
|
||||
Yields
|
||||
------
|
||||
@ -63,17 +43,9 @@ def iterate_over_sound_events(
|
||||
"""
|
||||
for clip_annotation in dataset:
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
if apply_filter:
|
||||
if not targets.filter(sound_event_annotation):
|
||||
continue
|
||||
|
||||
if apply_transform:
|
||||
sound_event_annotation = targets.transform(
|
||||
sound_event_annotation
|
||||
)
|
||||
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
if class_name is None and exclude_generic:
|
||||
if not targets.filter(sound_event_annotation):
|
||||
continue
|
||||
|
||||
class_name = targets.encode_class(sound_event_annotation)
|
||||
|
||||
yield class_name, sound_event_annotation
|
||||
|
||||
@ -33,10 +33,6 @@ from batdetect2.train.losses import (
|
||||
SizeLossConfig,
|
||||
build_loss,
|
||||
)
|
||||
from batdetect2.train.preprocess import (
|
||||
generate_train_example,
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.train.train import (
|
||||
build_train_dataset,
|
||||
build_train_loader,
|
||||
@ -74,14 +70,12 @@ __all__ = [
|
||||
"build_trainer",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"generate_train_example",
|
||||
"load_full_training_config",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_audio",
|
||||
"preprocess_annotations",
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
|
||||
@ -1,243 +0,0 @@
|
||||
"""Preprocesses datasets for BatDetect2 model training."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Sequence, TypedDict
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from tqdm import tqdm
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.preprocess.audio import build_audio_loader
|
||||
from batdetect2.targets import TargetConfig, build_targets
|
||||
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||
from batdetect2.typing import ClipLabeller, PreprocessorProtocol
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
from batdetect2.typing.train import PreprocessedExample
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"generate_train_example",
|
||||
"preprocess_dataset",
|
||||
"TrainPreprocessConfig",
|
||||
"load_train_preprocessing_config",
|
||||
"save_preprocessed_example",
|
||||
"load_preprocessed_example",
|
||||
]
|
||||
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
"""Type alias for a function that generates an output filename."""
|
||||
|
||||
|
||||
class TrainPreprocessConfig(BaseConfig):
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
def load_train_preprocessing_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TrainPreprocessConfig:
|
||||
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
|
||||
|
||||
|
||||
def preprocess_dataset(
|
||||
dataset: Dataset,
|
||||
config: TrainPreprocessConfig,
|
||||
output: Path,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
targets = build_targets(config=config.targets)
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
labeller = build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.labels,
|
||||
)
|
||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||
|
||||
if not output.exists():
|
||||
logger.debug("Creating directory {directory}", directory=output)
|
||||
output.mkdir(parents=True)
|
||||
|
||||
preprocess_annotations(
|
||||
dataset,
|
||||
output_dir=output,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
max_workers=max_workers,
|
||||
)
|
||||
|
||||
|
||||
class Example(TypedDict):
|
||||
audio: torch.Tensor
|
||||
spectrogram: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
) -> PreprocessedExample:
|
||||
"""Generate a complete training example for one annotation."""
|
||||
wave = torch.tensor(
|
||||
audio_loader.load_clip(clip_annotation.clip)
|
||||
).unsqueeze(0)
|
||||
spectrogram = preprocessor(wave.unsqueeze(0)).squeeze(0)
|
||||
heatmaps = labeller(clip_annotation, spectrogram)
|
||||
return PreprocessedExample(
|
||||
audio=wave,
|
||||
spectrogram=spectrogram,
|
||||
detection_heatmap=heatmaps.detection,
|
||||
class_heatmap=heatmaps.classes,
|
||||
size_heatmap=heatmaps.size,
|
||||
)
|
||||
|
||||
|
||||
class PreprocessingDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
clips: Dataset,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn,
|
||||
output_dir: Path,
|
||||
force: bool = False,
|
||||
):
|
||||
self.clips = clips
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.labeller = labeller
|
||||
self.filename_fn = filename_fn
|
||||
self.output_dir = output_dir
|
||||
self.force = force
|
||||
|
||||
def __getitem__(self, idx) -> int:
|
||||
clip_annotation = self.clips[idx]
|
||||
|
||||
filename = self.filename_fn(clip_annotation)
|
||||
|
||||
path = self.output_dir / filename
|
||||
|
||||
if path.exists() and not self.force:
|
||||
return idx
|
||||
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir()
|
||||
|
||||
example = generate_train_example(
|
||||
clip_annotation,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
labeller=self.labeller,
|
||||
)
|
||||
|
||||
save_preprocessed_example(example, clip_annotation, path)
|
||||
|
||||
return idx
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
|
||||
|
||||
def save_preprocessed_example(
|
||||
example: PreprocessedExample,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
path: data.PathLike,
|
||||
) -> None:
|
||||
np.savez_compressed(
|
||||
path,
|
||||
audio=example.audio.numpy(),
|
||||
spectrogram=example.spectrogram.numpy(),
|
||||
detection_heatmap=example.detection_heatmap.numpy(),
|
||||
class_heatmap=example.class_heatmap.numpy(),
|
||||
size_heatmap=example.size_heatmap.numpy(),
|
||||
clip_annotation=clip_annotation,
|
||||
)
|
||||
|
||||
|
||||
def load_preprocessed_example(path: data.PathLike) -> PreprocessedExample:
|
||||
item = np.load(path, mmap_mode="r+")
|
||||
return PreprocessedExample(
|
||||
audio=torch.tensor(item["audio"]),
|
||||
spectrogram=torch.tensor(item["spectrogram"]),
|
||||
size_heatmap=torch.tensor(item["size_heatmap"]),
|
||||
detection_heatmap=torch.tensor(item["detection_heatmap"]),
|
||||
class_heatmap=torch.tensor(item["class_heatmap"]),
|
||||
)
|
||||
|
||||
|
||||
def list_preprocessed_files(
|
||||
directory: data.PathLike, extension: str = ".npz"
|
||||
) -> List[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
"""Generate a default output filename based on the annotation UUID."""
|
||||
return f"{clip_annotation.uuid}"
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: data.PathLike,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess a sequence of ClipAnnotations and save results to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
logger.info(
|
||||
"Creating output directory: {output_dir}", output_dir=output_dir
|
||||
)
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
logger.info(
|
||||
"Starting preprocessing of {num_annotations} annotations with {max_workers} workers.",
|
||||
num_annotations=len(clip_annotations),
|
||||
max_workers=max_workers or "all available",
|
||||
)
|
||||
|
||||
if max_workers is None:
|
||||
max_workers = os.cpu_count() or 0
|
||||
|
||||
dataset = PreprocessingDataset(
|
||||
clips=list(clip_annotations),
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
labeller=labeller,
|
||||
output_dir=Path(output_dir),
|
||||
filename_fn=filename_fn,
|
||||
)
|
||||
|
||||
loader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=max_workers,
|
||||
prefetch_factor=16,
|
||||
)
|
||||
|
||||
for _ in tqdm(loader, total=len(dataset)):
|
||||
pass
|
||||
16
tests/test_targets/test_terms.py
Normal file
16
tests/test_targets/test_terms.py
Normal file
@ -0,0 +1,16 @@
|
||||
import pytest
|
||||
|
||||
from batdetect2.targets import terms
|
||||
|
||||
|
||||
def test_tag_info_and_get_tag_from_info():
|
||||
tag_info = TagInfo(value="Myotis myotis", key="event")
|
||||
tag = terms.get_tag_from_info(tag_info)
|
||||
assert tag.value == "Myotis myotis"
|
||||
assert tag.term == terms.call_type
|
||||
|
||||
|
||||
def test_get_tag_from_info_key_not_found():
|
||||
tag_info = TagInfo(value="test", key="non_existent_key")
|
||||
with pytest.raises(KeyError):
|
||||
terms.get_tag_from_info(tag_info)
|
||||
Loading…
Reference in New Issue
Block a user