mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Remove train preprocessing
This commit is contained in:
parent
1cec332dd5
commit
40f6b64611
@ -17,7 +17,7 @@ dependencies = [
|
|||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.7.0",
|
"soundevent[audio,geometry,plot]>=2.8.0",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
|
|||||||
@ -6,19 +6,19 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
FullTrainingConfig,
|
FullTrainingConfig,
|
||||||
load_full_training_config,
|
load_full_training_config,
|
||||||
train,
|
train,
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import list_preprocessed_files
|
|
||||||
|
|
||||||
__all__ = ["train_command"]
|
__all__ = ["train_command"]
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="train")
|
@cli.command(name="train")
|
||||||
@click.argument("train_dir", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--val-dir", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@ -31,8 +31,8 @@ __all__ = ["train_command"]
|
|||||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
)
|
)
|
||||||
def train_command(
|
def train_command(
|
||||||
train_dir: Path,
|
train_dataset: Path,
|
||||||
val_dir: Optional[Path] = None,
|
val_dataset: Optional[Path] = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
@ -58,29 +58,27 @@ def train_command(
|
|||||||
else FullTrainingConfig()
|
else FullTrainingConfig()
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Scanning for training and validation data...")
|
logger.info("Loading training dataset...")
|
||||||
train_examples = list_preprocessed_files(train_dir)
|
train_annotations = load_dataset_from_config(train_dataset)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Found {num_files} training examples in {path}",
|
"Loaded {num_annotations} training examples",
|
||||||
num_files=len(train_examples),
|
num_annotations=len(train_annotations),
|
||||||
path=train_dir,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
val_examples = None
|
val_annotations = None
|
||||||
if val_dir is not None:
|
if val_dataset is not None:
|
||||||
val_examples = list_preprocessed_files(val_dir)
|
val_annotations = load_dataset_from_config(val_dataset)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Found {num_files} validation examples in {path}",
|
"Loaded {num_annotations} validation examples",
|
||||||
num_files=len(val_examples),
|
num_files=len(val_annotations),
|
||||||
path=val_dir,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("No validation directory provided.")
|
logger.debug("No validation directory provided.")
|
||||||
|
|
||||||
logger.info("Configuration and data loaded. Starting training...")
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
train(
|
train(
|
||||||
train_examples=train_examples,
|
train_annotations=train_annotations,
|
||||||
val_examples=val_examples,
|
val_annotations=val_annotations,
|
||||||
config=conf,
|
config=conf,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
|
|||||||
@ -38,6 +38,8 @@ def plot_spectrogram(
|
|||||||
if isinstance(spec, torch.Tensor):
|
if isinstance(spec, torch.Tensor):
|
||||||
spec = spec.numpy()
|
spec = spec.numpy()
|
||||||
|
|
||||||
|
spec = spec.squeeze()
|
||||||
|
|
||||||
ax = create_ax(ax=ax, figsize=figsize)
|
ax = create_ax(ax=ax, figsize=figsize)
|
||||||
|
|
||||||
if start_time is None:
|
if start_time is None:
|
||||||
|
|||||||
@ -25,6 +25,8 @@ def plot_detection_heatmap(
|
|||||||
if isinstance(heatmap, torch.Tensor):
|
if isinstance(heatmap, torch.Tensor):
|
||||||
heatmap = heatmap.numpy()
|
heatmap = heatmap.numpy()
|
||||||
|
|
||||||
|
heatmap = heatmap.squeeze()
|
||||||
|
|
||||||
if threshold is not None:
|
if threshold is not None:
|
||||||
heatmap = np.ma.masked_where(
|
heatmap = np.ma.masked_where(
|
||||||
heatmap < threshold,
|
heatmap < threshold,
|
||||||
|
|||||||
@ -2,7 +2,7 @@ from batdetect2.train.augmentations import (
|
|||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
EchoAugmentationConfig,
|
EchoAugmentationConfig,
|
||||||
FrequencyMaskAugmentationConfig,
|
FrequencyMaskAugmentationConfig,
|
||||||
RandomExampleSource,
|
RandomAudioSource,
|
||||||
TimeMaskAugmentationConfig,
|
TimeMaskAugmentationConfig,
|
||||||
VolumeAugmentationConfig,
|
VolumeAugmentationConfig,
|
||||||
WarpAugmentationConfig,
|
WarpAugmentationConfig,
|
||||||
@ -10,7 +10,7 @@ from batdetect2.train.augmentations import (
|
|||||||
build_augmentations,
|
build_augmentations,
|
||||||
mask_frequency,
|
mask_frequency,
|
||||||
mask_time,
|
mask_time,
|
||||||
mix_examples,
|
mix_audio,
|
||||||
scale_volume,
|
scale_volume,
|
||||||
warp_spectrogram,
|
warp_spectrogram,
|
||||||
)
|
)
|
||||||
@ -22,10 +22,7 @@ from batdetect2.train.config import (
|
|||||||
load_full_training_config,
|
load_full_training_config,
|
||||||
load_train_config,
|
load_train_config,
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import TrainingDataset
|
||||||
LabeledDataset,
|
|
||||||
list_preprocessed_files,
|
|
||||||
)
|
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.losses import (
|
from batdetect2.train.losses import (
|
||||||
@ -56,11 +53,11 @@ __all__ = [
|
|||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
"FullTrainingConfig",
|
"FullTrainingConfig",
|
||||||
"LabeledDataset",
|
"TrainingDataset",
|
||||||
"LossConfig",
|
"LossConfig",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
"PLTrainerConfig",
|
"PLTrainerConfig",
|
||||||
"RandomExampleSource",
|
"RandomAudioSource",
|
||||||
"SizeLossConfig",
|
"SizeLossConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
@ -78,13 +75,12 @@ __all__ = [
|
|||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
"build_val_loader",
|
"build_val_loader",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"list_preprocessed_files",
|
|
||||||
"load_full_training_config",
|
"load_full_training_config",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_examples",
|
"mix_audio",
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
|
|||||||
@ -9,14 +9,12 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.geometry import scale_geometry, shift_geometry
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.train.preprocess import (
|
from batdetect2.train.clips import get_subclip_annotation
|
||||||
list_preprocessed_files,
|
from batdetect2.typing import Augmentation
|
||||||
load_preprocessed_example,
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
)
|
|
||||||
from batdetect2.typing import Augmentation, PreprocessorProtocol
|
|
||||||
from batdetect2.typing.train import ClipperProtocol, PreprocessedExample
|
|
||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -24,7 +22,7 @@ __all__ = [
|
|||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
"DEFAULT_AUGMENTATION_CONFIG",
|
"DEFAULT_AUGMENTATION_CONFIG",
|
||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"ExampleSource",
|
"AudioSource",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
"MixAugmentationConfig",
|
"MixAugmentationConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
@ -35,365 +33,12 @@ __all__ = [
|
|||||||
"load_augmentation_config",
|
"load_augmentation_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_examples",
|
"mix_audio",
|
||||||
"scale_volume",
|
"scale_volume",
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
|
||||||
ExampleSource = Callable[[], PreprocessedExample]
|
AudioSource = Callable[[float], tuple[torch.Tensor, data.ClipAnnotation]]
|
||||||
"""Type alias for a function that returns a training example"""
|
|
||||||
|
|
||||||
|
|
||||||
def mix_examples(
|
|
||||||
example: PreprocessedExample,
|
|
||||||
other: PreprocessedExample,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
weight: float,
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
"""Combine two training examples."""
|
|
||||||
audio1 = example.audio
|
|
||||||
audio2 = adjust_width(other.audio, audio1.shape[-1])
|
|
||||||
|
|
||||||
combined = weight * audio1 + (1 - weight) * audio2
|
|
||||||
|
|
||||||
spectrogram = preprocessor(combined)
|
|
||||||
|
|
||||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
|
||||||
# spectrogram computed from the subclip's audio. This is due to a
|
|
||||||
# simplification in the subclip process: It doesn't account for the
|
|
||||||
# spectrogram parameters to precisely determine the corresponding audio
|
|
||||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
|
||||||
# as needed.
|
|
||||||
previous_width = example.spectrogram.shape[-1]
|
|
||||||
spectrogram = adjust_width(spectrogram, previous_width)
|
|
||||||
|
|
||||||
detection_heatmap = torch.maximum(
|
|
||||||
example.detection_heatmap,
|
|
||||||
adjust_width(other.detection_heatmap, previous_width),
|
|
||||||
)
|
|
||||||
|
|
||||||
class_heatmap = torch.maximum(
|
|
||||||
example.class_heatmap,
|
|
||||||
adjust_width(other.class_heatmap, previous_width),
|
|
||||||
)
|
|
||||||
|
|
||||||
size_heatmap = torch.maximum(
|
|
||||||
example.size_heatmap,
|
|
||||||
adjust_width(other.size_heatmap, previous_width),
|
|
||||||
)
|
|
||||||
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=combined,
|
|
||||||
spectrogram=spectrogram,
|
|
||||||
detection_heatmap=detection_heatmap,
|
|
||||||
class_heatmap=class_heatmap,
|
|
||||||
size_heatmap=size_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class EchoAugmentationConfig(BaseConfig):
|
|
||||||
"""Configuration for adding synthetic echo/reverb."""
|
|
||||||
|
|
||||||
augmentation_type: Literal["add_echo"] = "add_echo"
|
|
||||||
|
|
||||||
probability: float = 0.2
|
|
||||||
"""Probability of applying this augmentation."""
|
|
||||||
|
|
||||||
max_delay: float = 0.005
|
|
||||||
min_weight: float = 0.0
|
|
||||||
max_weight: float = 1.0
|
|
||||||
|
|
||||||
|
|
||||||
class AddEcho(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
min_weight: float = 0.1,
|
|
||||||
max_weight: float = 1.0,
|
|
||||||
max_delay: float = 0.005,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.min_weight = min_weight
|
|
||||||
self.max_weight = max_weight
|
|
||||||
self.max_delay = max_delay
|
|
||||||
|
|
||||||
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
|
||||||
delay = np.random.uniform(0, self.max_delay)
|
|
||||||
weight = np.random.uniform(self.min_weight, self.max_weight)
|
|
||||||
return add_echo(
|
|
||||||
example,
|
|
||||||
preprocessor=self.preprocessor,
|
|
||||||
delay=delay,
|
|
||||||
weight=weight,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def add_echo(
|
|
||||||
example: PreprocessedExample,
|
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
delay: float,
|
|
||||||
weight: float,
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
"""Add a synthetic echo to the audio waveform."""
|
|
||||||
|
|
||||||
audio = example.audio
|
|
||||||
delay_steps = int(preprocessor.input_samplerate * delay)
|
|
||||||
|
|
||||||
slices = [slice(None)] * audio.ndim
|
|
||||||
slices[-1] = slice(None, -delay_steps)
|
|
||||||
audio_delay = adjust_width(audio[tuple(slices)], audio.shape[-1]).roll(
|
|
||||||
delay_steps, dims=-1
|
|
||||||
)
|
|
||||||
|
|
||||||
audio = audio + weight * audio_delay
|
|
||||||
spectrogram = preprocessor(audio)
|
|
||||||
|
|
||||||
# NOTE: The subclip's spectrogram might be slightly longer than the
|
|
||||||
# spectrogram computed from the subclip's audio. This is due to a
|
|
||||||
# simplification in the subclip process: It doesn't account for the
|
|
||||||
# spectrogram parameters to precisely determine the corresponding audio
|
|
||||||
# samples. To work around this, we pad the computed spectrogram with zeros
|
|
||||||
# as needed.
|
|
||||||
spectrogram = adjust_width(
|
|
||||||
spectrogram,
|
|
||||||
example.spectrogram.shape[-1],
|
|
||||||
)
|
|
||||||
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=audio,
|
|
||||||
spectrogram=spectrogram,
|
|
||||||
detection_heatmap=example.detection_heatmap,
|
|
||||||
class_heatmap=example.class_heatmap,
|
|
||||||
size_heatmap=example.size_heatmap,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class VolumeAugmentationConfig(BaseConfig):
|
|
||||||
"""Configuration for random volume scaling of the spectrogram."""
|
|
||||||
|
|
||||||
augmentation_type: Literal["scale_volume"] = "scale_volume"
|
|
||||||
probability: float = 0.2
|
|
||||||
min_scaling: float = 0.0
|
|
||||||
max_scaling: float = 2.0
|
|
||||||
|
|
||||||
|
|
||||||
class ScaleVolume(torch.nn.Module):
|
|
||||||
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
|
|
||||||
super().__init__()
|
|
||||||
self.min_scaling = min_scaling
|
|
||||||
self.max_scaling = max_scaling
|
|
||||||
|
|
||||||
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
|
||||||
factor = np.random.uniform(self.min_scaling, self.max_scaling)
|
|
||||||
return scale_volume(example, factor=factor)
|
|
||||||
|
|
||||||
|
|
||||||
def scale_volume(
|
|
||||||
example: PreprocessedExample,
|
|
||||||
factor: Optional[float] = None,
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
"""Scale the amplitude of the spectrogram by a random factor."""
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=example.audio,
|
|
||||||
size_heatmap=example.size_heatmap,
|
|
||||||
class_heatmap=example.class_heatmap,
|
|
||||||
detection_heatmap=example.detection_heatmap,
|
|
||||||
spectrogram=example.spectrogram * factor,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class WarpAugmentationConfig(BaseConfig):
|
|
||||||
augmentation_type: Literal["warp"] = "warp"
|
|
||||||
probability: float = 0.2
|
|
||||||
delta: float = 0.04
|
|
||||||
|
|
||||||
|
|
||||||
class WarpSpectrogram(torch.nn.Module):
|
|
||||||
def __init__(self, delta: float = 0.04) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.delta = delta
|
|
||||||
|
|
||||||
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
|
||||||
factor = np.random.uniform(1 - self.delta, 1 + self.delta)
|
|
||||||
return warp_spectrogram(example, factor=factor)
|
|
||||||
|
|
||||||
|
|
||||||
def warp_spectrogram(
|
|
||||||
example: PreprocessedExample, factor: float
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
"""Apply time warping by resampling the time axis."""
|
|
||||||
width = example.spectrogram.shape[-1]
|
|
||||||
height = example.spectrogram.shape[-2]
|
|
||||||
target_shape = [height, width]
|
|
||||||
new_width = int(target_shape[-1] * factor)
|
|
||||||
|
|
||||||
spectrogram = torch.nn.functional.interpolate(
|
|
||||||
adjust_width(example.spectrogram, new_width).unsqueeze(0),
|
|
||||||
size=target_shape,
|
|
||||||
mode="bilinear",
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
detection = torch.nn.functional.interpolate(
|
|
||||||
adjust_width(example.detection_heatmap, new_width).unsqueeze(0),
|
|
||||||
size=target_shape,
|
|
||||||
mode="nearest",
|
|
||||||
).squeeze(0)
|
|
||||||
|
|
||||||
classification = torch.nn.functional.interpolate(
|
|
||||||
adjust_width(example.class_heatmap, new_width).unsqueeze(1),
|
|
||||||
size=target_shape,
|
|
||||||
mode="nearest",
|
|
||||||
).squeeze(1)
|
|
||||||
|
|
||||||
size = torch.nn.functional.interpolate(
|
|
||||||
adjust_width(example.size_heatmap, new_width).unsqueeze(1),
|
|
||||||
size=target_shape,
|
|
||||||
mode="nearest",
|
|
||||||
).squeeze(1)
|
|
||||||
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=example.audio,
|
|
||||||
size_heatmap=size,
|
|
||||||
class_heatmap=classification,
|
|
||||||
detection_heatmap=detection,
|
|
||||||
spectrogram=spectrogram,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TimeMaskAugmentationConfig(BaseConfig):
|
|
||||||
augmentation_type: Literal["mask_time"] = "mask_time"
|
|
||||||
probability: float = 0.2
|
|
||||||
max_perc: float = 0.05
|
|
||||||
max_masks: int = 3
|
|
||||||
|
|
||||||
|
|
||||||
class MaskTime(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_perc: float = 0.05,
|
|
||||||
max_masks: int = 3,
|
|
||||||
mask_heatmaps: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.max_perc = max_perc
|
|
||||||
self.max_masks = max_masks
|
|
||||||
self.mask_heatmaps = mask_heatmaps
|
|
||||||
|
|
||||||
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
|
||||||
num_masks = np.random.randint(1, self.max_masks + 1)
|
|
||||||
width = example.spectrogram.shape[-1]
|
|
||||||
|
|
||||||
mask_size = np.random.randint(
|
|
||||||
low=0,
|
|
||||||
high=int(self.max_perc * width),
|
|
||||||
size=num_masks,
|
|
||||||
)
|
|
||||||
mask_start = np.random.randint(
|
|
||||||
low=0,
|
|
||||||
high=width - mask_size,
|
|
||||||
size=num_masks,
|
|
||||||
)
|
|
||||||
masks = [
|
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
|
||||||
]
|
|
||||||
return mask_time(example, masks, mask_heatmaps=self.mask_heatmaps)
|
|
||||||
|
|
||||||
|
|
||||||
def mask_time(
|
|
||||||
example: PreprocessedExample,
|
|
||||||
masks: List[Tuple[int, int]],
|
|
||||||
mask_heatmaps: bool = False,
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
"""Apply time masking to the spectrogram."""
|
|
||||||
|
|
||||||
for start, end in masks:
|
|
||||||
slices = [slice(None)] * example.spectrogram.ndim
|
|
||||||
slices[-1] = slice(start, end)
|
|
||||||
|
|
||||||
example.spectrogram[tuple(slices)] = 0
|
|
||||||
|
|
||||||
if not mask_heatmaps:
|
|
||||||
continue
|
|
||||||
|
|
||||||
example.class_heatmap[tuple(slices)] = 0
|
|
||||||
example.size_heatmap[tuple(slices)] = 0
|
|
||||||
example.detection_heatmap[tuple(slices)] = 0
|
|
||||||
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=example.audio,
|
|
||||||
size_heatmap=example.size_heatmap,
|
|
||||||
class_heatmap=example.class_heatmap,
|
|
||||||
detection_heatmap=example.detection_heatmap,
|
|
||||||
spectrogram=example.spectrogram,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyMaskAugmentationConfig(BaseConfig):
|
|
||||||
augmentation_type: Literal["mask_freq"] = "mask_freq"
|
|
||||||
probability: float = 0.2
|
|
||||||
max_perc: float = 0.10
|
|
||||||
max_masks: int = 3
|
|
||||||
mask_heatmaps: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class MaskFrequency(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
max_perc: float = 0.10,
|
|
||||||
max_masks: int = 3,
|
|
||||||
mask_heatmaps: bool = False,
|
|
||||||
) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.max_perc = max_perc
|
|
||||||
self.max_masks = max_masks
|
|
||||||
self.mask_heatmaps = mask_heatmaps
|
|
||||||
|
|
||||||
def forward(self, example: PreprocessedExample) -> PreprocessedExample:
|
|
||||||
num_masks = np.random.randint(1, self.max_masks + 1)
|
|
||||||
height = example.spectrogram.shape[-2]
|
|
||||||
|
|
||||||
mask_size = np.random.randint(
|
|
||||||
low=0,
|
|
||||||
high=int(self.max_perc * height),
|
|
||||||
size=num_masks,
|
|
||||||
)
|
|
||||||
mask_start = np.random.randint(
|
|
||||||
low=0,
|
|
||||||
high=height - mask_size,
|
|
||||||
size=num_masks,
|
|
||||||
)
|
|
||||||
masks = [
|
|
||||||
(start, start + size) for start, size in zip(mask_start, mask_size)
|
|
||||||
]
|
|
||||||
return mask_frequency(example, masks, mask_heatmaps=self.mask_heatmaps)
|
|
||||||
|
|
||||||
|
|
||||||
def mask_frequency(
|
|
||||||
example: PreprocessedExample,
|
|
||||||
masks: List[Tuple[int, int]],
|
|
||||||
mask_heatmaps: bool = False,
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
"""Apply frequency masking to the spectrogram."""
|
|
||||||
for start, end in masks:
|
|
||||||
slices = [slice(None)] * example.spectrogram.ndim
|
|
||||||
slices[-2] = slice(start, end)
|
|
||||||
example.spectrogram[tuple(slices)] = 0
|
|
||||||
|
|
||||||
if not mask_heatmaps:
|
|
||||||
continue
|
|
||||||
|
|
||||||
example.class_heatmap[tuple(slices)] = 0
|
|
||||||
example.size_heatmap[tuple(slices)] = 0
|
|
||||||
example.detection_heatmap[tuple(slices)] = 0
|
|
||||||
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=example.audio,
|
|
||||||
size_heatmap=example.size_heatmap,
|
|
||||||
class_heatmap=example.class_heatmap,
|
|
||||||
detection_heatmap=example.detection_heatmap,
|
|
||||||
spectrogram=example.spectrogram,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MixAugmentationConfig(BaseConfig):
|
class MixAugmentationConfig(BaseConfig):
|
||||||
@ -416,8 +61,7 @@ class MixAudio(torch.nn.Module):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
example_source: ExampleSource,
|
example_source: AudioSource,
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
min_weight: float = 0.3,
|
min_weight: float = 0.3,
|
||||||
max_weight: float = 0.7,
|
max_weight: float = 0.7,
|
||||||
):
|
):
|
||||||
@ -426,20 +70,364 @@ class MixAudio(torch.nn.Module):
|
|||||||
self.min_weight = min_weight
|
self.min_weight = min_weight
|
||||||
self.example_source = example_source
|
self.example_source = example_source
|
||||||
self.max_weight = max_weight
|
self.max_weight = max_weight
|
||||||
self.preprocessor = preprocessor
|
|
||||||
|
|
||||||
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
|
def __call__(
|
||||||
|
self,
|
||||||
|
wav: torch.Tensor,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
"""Fetch another example and perform mixup."""
|
"""Fetch another example and perform mixup."""
|
||||||
other = self.example_source()
|
other_wav, other_clip_annotation = self.example_source(
|
||||||
|
clip_annotation.clip.duration
|
||||||
|
)
|
||||||
weight = np.random.uniform(self.min_weight, self.max_weight)
|
weight = np.random.uniform(self.min_weight, self.max_weight)
|
||||||
return mix_examples(
|
mixed_audio = mix_audio(wav, other_wav, weight=weight)
|
||||||
example,
|
mixed_annotations = combine_clip_annotations(
|
||||||
other,
|
clip_annotation,
|
||||||
self.preprocessor,
|
other_clip_annotation,
|
||||||
weight=weight,
|
)
|
||||||
|
return mixed_audio, mixed_annotations
|
||||||
|
|
||||||
|
|
||||||
|
def mix_audio(
|
||||||
|
wav1: torch.Tensor,
|
||||||
|
wav2: torch.Tensor,
|
||||||
|
weight: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Combine two training examples."""
|
||||||
|
wav2 = adjust_width(wav2, wav1.shape[-1])
|
||||||
|
return weight * wav1 + (1 - weight) * wav2
|
||||||
|
|
||||||
|
|
||||||
|
def shift_sound_event_annotation(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
time: float,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
sound_event = sound_event_annotation.sound_event
|
||||||
|
geometry = sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
sound_event = sound_event.model_copy(
|
||||||
|
update=dict(geometry=shift_geometry(geometry, time=time))
|
||||||
|
)
|
||||||
|
return sound_event_annotation.model_copy(
|
||||||
|
update=dict(sound_event=sound_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def combine_clip_annotations(
|
||||||
|
clip_annotation1: data.ClipAnnotation,
|
||||||
|
clip_annotation2: data.ClipAnnotation,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
time_shift = (
|
||||||
|
clip_annotation1.clip.start_time - clip_annotation2.clip.start_time
|
||||||
|
)
|
||||||
|
return clip_annotation1.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
*clip_annotation1.sound_events,
|
||||||
|
*[
|
||||||
|
shift_sound_event_annotation(sound_event, time=time_shift)
|
||||||
|
for sound_event in clip_annotation2.sound_events
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class EchoAugmentationConfig(BaseConfig):
|
||||||
|
"""Configuration for adding synthetic echo/reverb."""
|
||||||
|
|
||||||
|
augmentation_type: Literal["add_echo"] = "add_echo"
|
||||||
|
probability: float = 0.2
|
||||||
|
max_delay: float = 0.005
|
||||||
|
min_weight: float = 0.0
|
||||||
|
max_weight: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
class AddEcho(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
min_weight: float = 0.1,
|
||||||
|
max_weight: float = 1.0,
|
||||||
|
max_delay: int = 2560,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.min_weight = min_weight
|
||||||
|
self.max_weight = max_weight
|
||||||
|
self.max_delay = max_delay
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
wav: torch.Tensor,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
|
delay = np.random.randint(0, self.max_delay)
|
||||||
|
weight = np.random.uniform(self.min_weight, self.max_weight)
|
||||||
|
return add_echo(wav, delay=delay, weight=weight), clip_annotation
|
||||||
|
|
||||||
|
|
||||||
|
def add_echo(
|
||||||
|
wav: torch.Tensor,
|
||||||
|
delay: int,
|
||||||
|
weight: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Add a synthetic echo to the audio waveform."""
|
||||||
|
|
||||||
|
slices = [slice(None)] * wav.ndim
|
||||||
|
slices[-1] = slice(None, -delay)
|
||||||
|
audio_delay = adjust_width(wav[tuple(slices)], wav.shape[-1]).roll(
|
||||||
|
delay, dims=-1
|
||||||
|
)
|
||||||
|
return mix_audio(wav, audio_delay, weight)
|
||||||
|
|
||||||
|
|
||||||
|
class VolumeAugmentationConfig(BaseConfig):
|
||||||
|
"""Configuration for random volume scaling of the spectrogram."""
|
||||||
|
|
||||||
|
augmentation_type: Literal["scale_volume"] = "scale_volume"
|
||||||
|
probability: float = 0.2
|
||||||
|
min_scaling: float = 0.0
|
||||||
|
max_scaling: float = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleVolume(torch.nn.Module):
|
||||||
|
def __init__(self, min_scaling: float = 0.0, max_scaling: float = 2.0):
|
||||||
|
super().__init__()
|
||||||
|
self.min_scaling = min_scaling
|
||||||
|
self.max_scaling = max_scaling
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
|
factor = np.random.uniform(self.min_scaling, self.max_scaling)
|
||||||
|
return scale_volume(spec, factor=factor), clip_annotation
|
||||||
|
|
||||||
|
|
||||||
|
def scale_volume(spec: torch.Tensor, factor: float) -> torch.Tensor:
|
||||||
|
"""Scale the amplitude of the spectrogram by a factor."""
|
||||||
|
return spec * factor
|
||||||
|
|
||||||
|
|
||||||
|
class WarpAugmentationConfig(BaseConfig):
|
||||||
|
augmentation_type: Literal["warp"] = "warp"
|
||||||
|
probability: float = 0.2
|
||||||
|
delta: float = 0.04
|
||||||
|
|
||||||
|
|
||||||
|
class WarpSpectrogram(torch.nn.Module):
|
||||||
|
def __init__(self, delta: float = 0.04) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.delta = delta
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
|
factor = np.random.uniform(1 - self.delta, 1 + self.delta)
|
||||||
|
return (
|
||||||
|
warp_spectrogram(spec, factor=factor),
|
||||||
|
warp_clip_annotation(clip_annotation, factor=factor),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def warp_sound_event_annotation(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
factor: float,
|
||||||
|
anchor: float,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
sound_event = sound_event_annotation.sound_event
|
||||||
|
geometry = sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
sound_event = sound_event.model_copy(
|
||||||
|
update=dict(
|
||||||
|
geometry=scale_geometry(
|
||||||
|
geometry,
|
||||||
|
time=1 / factor,
|
||||||
|
time_anchor=anchor,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return sound_event_annotation.model_copy(
|
||||||
|
update=dict(sound_event=sound_event)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def warp_clip_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
factor: float,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
warp_sound_event_annotation(
|
||||||
|
sound_event,
|
||||||
|
factor=factor,
|
||||||
|
anchor=clip_annotation.clip.start_time,
|
||||||
|
)
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def warp_spectrogram(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
factor: float,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply time warping by resampling the time axis."""
|
||||||
|
width = spec.shape[-1]
|
||||||
|
height = spec.shape[-2]
|
||||||
|
target_shape = [height, width]
|
||||||
|
new_width = int(target_shape[-1] * factor)
|
||||||
|
return torch.nn.functional.interpolate(
|
||||||
|
adjust_width(spec, new_width).unsqueeze(0),
|
||||||
|
size=target_shape,
|
||||||
|
mode="bilinear",
|
||||||
|
).squeeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class TimeMaskAugmentationConfig(BaseConfig):
|
||||||
|
augmentation_type: Literal["mask_time"] = "mask_time"
|
||||||
|
probability: float = 0.2
|
||||||
|
max_perc: float = 0.05
|
||||||
|
max_masks: int = 3
|
||||||
|
|
||||||
|
|
||||||
|
class MaskTime(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_perc: float = 0.05,
|
||||||
|
max_masks: int = 3,
|
||||||
|
mask_heatmaps: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.max_perc = max_perc
|
||||||
|
self.max_masks = max_masks
|
||||||
|
self.mask_heatmaps = mask_heatmaps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
|
num_masks = np.random.randint(1, self.max_masks + 1)
|
||||||
|
width = spec.shape[-1]
|
||||||
|
|
||||||
|
mask_size = np.random.randint(
|
||||||
|
low=0,
|
||||||
|
high=int(self.max_perc * width),
|
||||||
|
size=num_masks,
|
||||||
|
)
|
||||||
|
mask_start = np.random.randint(
|
||||||
|
low=0,
|
||||||
|
high=width - mask_size,
|
||||||
|
size=num_masks,
|
||||||
|
)
|
||||||
|
masks = [
|
||||||
|
(start, start + size) for start, size in zip(mask_start, mask_size)
|
||||||
|
]
|
||||||
|
return mask_time(spec, masks), clip_annotation
|
||||||
|
|
||||||
|
|
||||||
|
def mask_time(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
masks: List[Tuple[int, int]],
|
||||||
|
value: float = 0,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply time masking to the spectrogram."""
|
||||||
|
for start, end in masks:
|
||||||
|
slices = [slice(None)] * spec.ndim
|
||||||
|
slices[-1] = slice(start, end)
|
||||||
|
spec[tuple(slices)] = value
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyMaskAugmentationConfig(BaseConfig):
|
||||||
|
augmentation_type: Literal["mask_freq"] = "mask_freq"
|
||||||
|
probability: float = 0.2
|
||||||
|
max_perc: float = 0.10
|
||||||
|
max_masks: int = 3
|
||||||
|
mask_heatmaps: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class MaskFrequency(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_perc: float = 0.10,
|
||||||
|
max_masks: int = 3,
|
||||||
|
mask_heatmaps: bool = False,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.max_perc = max_perc
|
||||||
|
self.max_masks = max_masks
|
||||||
|
self.mask_heatmaps = mask_heatmaps
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
spec: torch.Tensor,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
|
num_masks = np.random.randint(1, self.max_masks + 1)
|
||||||
|
height = spec.shape[-2]
|
||||||
|
|
||||||
|
mask_size = np.random.randint(
|
||||||
|
low=0,
|
||||||
|
high=int(self.max_perc * height),
|
||||||
|
size=num_masks,
|
||||||
|
)
|
||||||
|
mask_start = np.random.randint(
|
||||||
|
low=0,
|
||||||
|
high=height - mask_size,
|
||||||
|
size=num_masks,
|
||||||
|
)
|
||||||
|
masks = [
|
||||||
|
(start, start + size) for start, size in zip(mask_start, mask_size)
|
||||||
|
]
|
||||||
|
return mask_frequency(spec, masks), clip_annotation
|
||||||
|
|
||||||
|
|
||||||
|
def mask_frequency(
|
||||||
|
spec: torch.Tensor,
|
||||||
|
masks: List[Tuple[int, int]],
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Apply frequency masking to the spectrogram."""
|
||||||
|
for start, end in masks:
|
||||||
|
slices = [slice(None)] * spec.ndim
|
||||||
|
slices[-2] = slice(start, end)
|
||||||
|
spec[tuple(slices)] = 0
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
AudioAugmentationConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
MixAugmentationConfig,
|
||||||
|
EchoAugmentationConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="augmentation_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
SpectrogramAugmentationConfig = Annotated[
|
||||||
|
Union[
|
||||||
|
VolumeAugmentationConfig,
|
||||||
|
WarpAugmentationConfig,
|
||||||
|
FrequencyMaskAugmentationConfig,
|
||||||
|
TimeMaskAugmentationConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="augmentation_type"),
|
||||||
|
]
|
||||||
|
|
||||||
AugmentationConfig = Annotated[
|
AugmentationConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
MixAugmentationConfig,
|
MixAugmentationConfig,
|
||||||
@ -459,7 +447,11 @@ class AugmentationsConfig(BaseConfig):
|
|||||||
|
|
||||||
enabled: bool = True
|
enabled: bool = True
|
||||||
|
|
||||||
steps: List[AugmentationConfig] = Field(default_factory=list)
|
audio: List[AudioAugmentationConfig] = Field(default_factory=list)
|
||||||
|
|
||||||
|
spectrogram: List[SpectrogramAugmentationConfig] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class MaybeApply(torch.nn.Module):
|
class MaybeApply(torch.nn.Module):
|
||||||
@ -470,46 +462,31 @@ class MaybeApply(torch.nn.Module):
|
|||||||
augmentation: Augmentation,
|
augmentation: Augmentation,
|
||||||
probability: float = 0.2,
|
probability: float = 0.2,
|
||||||
):
|
):
|
||||||
"""Initialize the wrapper.
|
"""Initialize the wrapper."""
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
augmentation : Augmentation (Callable[[xr.Dataset], xr.Dataset])
|
|
||||||
The augmentation function to potentially apply.
|
|
||||||
probability : float, default=0.5
|
|
||||||
The probability (0.0 to 1.0) of applying the augmentation.
|
|
||||||
"""
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.augmentation = augmentation
|
self.augmentation = augmentation
|
||||||
self.probability = probability
|
self.probability = probability
|
||||||
|
|
||||||
def __call__(self, example: PreprocessedExample) -> PreprocessedExample:
|
def __call__(
|
||||||
"""Apply the wrapped augmentation with configured probability.
|
self,
|
||||||
|
tensor: torch.Tensor,
|
||||||
Parameters
|
clip_annotation: data.ClipAnnotation,
|
||||||
----------
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
example : xr.Dataset
|
"""Apply the wrapped augmentation with configured probability."""
|
||||||
The input training example.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
xr.Dataset
|
|
||||||
The potentially augmented training example.
|
|
||||||
"""
|
|
||||||
if np.random.random() > self.probability:
|
if np.random.random() > self.probability:
|
||||||
return example
|
return tensor, clip_annotation
|
||||||
|
|
||||||
return self.augmentation(example)
|
return self.augmentation(tensor, clip_annotation)
|
||||||
|
|
||||||
|
|
||||||
def build_augmentation_from_config(
|
def build_augmentation_from_config(
|
||||||
config: AugmentationConfig,
|
config: AugmentationConfig,
|
||||||
preprocessor: PreprocessorProtocol,
|
samplerate: int,
|
||||||
example_source: Optional[ExampleSource] = None,
|
audio_source: Optional[AudioSource] = None,
|
||||||
) -> Optional[Augmentation]:
|
) -> Optional[Augmentation]:
|
||||||
"""Factory function to build a single augmentation from its config."""
|
"""Factory function to build a single augmentation from its config."""
|
||||||
if config.augmentation_type == "mix_audio":
|
if config.augmentation_type == "mix_audio":
|
||||||
if example_source is None:
|
if audio_source is None:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Mix audio augmentation ('mix_audio') requires an "
|
"Mix audio augmentation ('mix_audio') requires an "
|
||||||
"'example_source' callable to be provided.",
|
"'example_source' callable to be provided.",
|
||||||
@ -518,16 +495,14 @@ def build_augmentation_from_config(
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
return MixAudio(
|
return MixAudio(
|
||||||
example_source=example_source,
|
example_source=audio_source,
|
||||||
preprocessor=preprocessor,
|
|
||||||
min_weight=config.min_weight,
|
min_weight=config.min_weight,
|
||||||
max_weight=config.max_weight,
|
max_weight=config.max_weight,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.augmentation_type == "add_echo":
|
if config.augmentation_type == "add_echo":
|
||||||
return AddEcho(
|
return AddEcho(
|
||||||
preprocessor=preprocessor,
|
max_delay=int(config.max_delay * samplerate),
|
||||||
max_delay=config.max_delay,
|
|
||||||
min_weight=config.min_weight,
|
min_weight=config.min_weight,
|
||||||
max_weight=config.max_weight,
|
max_weight=config.max_weight,
|
||||||
)
|
)
|
||||||
@ -562,37 +537,35 @@ def build_augmentation_from_config(
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
DEFAULT_AUGMENTATION_CONFIG: AugmentationsConfig = AugmentationsConfig(
|
||||||
steps=[
|
enabled=True,
|
||||||
|
audio=[
|
||||||
MixAugmentationConfig(),
|
MixAugmentationConfig(),
|
||||||
EchoAugmentationConfig(),
|
EchoAugmentationConfig(),
|
||||||
|
],
|
||||||
|
spectrogram=[
|
||||||
VolumeAugmentationConfig(),
|
VolumeAugmentationConfig(),
|
||||||
WarpAugmentationConfig(),
|
WarpAugmentationConfig(),
|
||||||
TimeMaskAugmentationConfig(),
|
TimeMaskAugmentationConfig(),
|
||||||
FrequencyMaskAugmentationConfig(),
|
FrequencyMaskAugmentationConfig(),
|
||||||
]
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_augmentations(
|
def build_augmentation_sequence(
|
||||||
preprocessor: PreprocessorProtocol,
|
samplerate: int,
|
||||||
config: Optional[AugmentationsConfig] = None,
|
steps: Optional[Sequence[AugmentationConfig]] = None,
|
||||||
example_source: Optional[ExampleSource] = None,
|
audio_source: Optional[AudioSource] = None,
|
||||||
) -> Augmentation:
|
) -> Optional[Augmentation]:
|
||||||
"""Build a composite augmentation pipeline function from configuration."""
|
if not steps:
|
||||||
config = config or DEFAULT_AUGMENTATION_CONFIG
|
return None
|
||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Building augmentations with config: \n{}",
|
|
||||||
lambda: config.to_yaml_string(),
|
|
||||||
)
|
|
||||||
|
|
||||||
augmentations = []
|
augmentations = []
|
||||||
|
|
||||||
for step_config in config.steps:
|
for step_config in steps:
|
||||||
augmentation = build_augmentation_from_config(
|
augmentation = build_augmentation_from_config(
|
||||||
step_config,
|
step_config,
|
||||||
preprocessor=preprocessor,
|
samplerate=samplerate,
|
||||||
example_source=example_source,
|
audio_source=audio_source,
|
||||||
)
|
)
|
||||||
|
|
||||||
if augmentation is None:
|
if augmentation is None:
|
||||||
@ -608,6 +581,33 @@ def build_augmentations(
|
|||||||
return torch.nn.Sequential(*augmentations)
|
return torch.nn.Sequential(*augmentations)
|
||||||
|
|
||||||
|
|
||||||
|
def build_augmentations(
|
||||||
|
samplerate: int,
|
||||||
|
config: Optional[AugmentationsConfig] = None,
|
||||||
|
audio_source: Optional[AudioSource] = None,
|
||||||
|
) -> Tuple[Optional[Augmentation], Optional[Augmentation]]:
|
||||||
|
"""Build a composite augmentation pipeline function from configuration."""
|
||||||
|
config = config or DEFAULT_AUGMENTATION_CONFIG
|
||||||
|
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building augmentations with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
|
||||||
|
audio_augmentation = build_augmentation_sequence(
|
||||||
|
samplerate,
|
||||||
|
steps=config.audio,
|
||||||
|
audio_source=audio_source,
|
||||||
|
)
|
||||||
|
spectrogram_augmentation = build_augmentation_sequence(
|
||||||
|
samplerate,
|
||||||
|
steps=config.audio,
|
||||||
|
audio_source=audio_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
return audio_augmentation, spectrogram_augmentation
|
||||||
|
|
||||||
|
|
||||||
def load_augmentation_config(
|
def load_augmentation_config(
|
||||||
path: data.PathLike, field: Optional[str] = None
|
path: data.PathLike, field: Optional[str] = None
|
||||||
) -> AugmentationsConfig:
|
) -> AugmentationsConfig:
|
||||||
@ -615,23 +615,24 @@ def load_augmentation_config(
|
|||||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
class RandomExampleSource:
|
class RandomAudioSource:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filenames: Sequence[data.PathLike],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
clipper: ClipperProtocol,
|
audio_loader: AudioLoader,
|
||||||
):
|
):
|
||||||
self.filenames = filenames
|
self.audio_loader = audio_loader
|
||||||
self.clipper = clipper
|
self.clip_annotations = clip_annotations
|
||||||
|
|
||||||
def __call__(self) -> PreprocessedExample:
|
def __call__(
|
||||||
index = int(np.random.randint(len(self.filenames)))
|
self,
|
||||||
filename = self.filenames[index]
|
duration: float,
|
||||||
example = load_preprocessed_example(filename)
|
) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
example, _, _ = self.clipper(example)
|
index = int(np.random.randint(len(self.clip_annotations)))
|
||||||
return example
|
clip_annotation = get_subclip_annotation(
|
||||||
|
self.clip_annotations[index],
|
||||||
@classmethod
|
duration=duration,
|
||||||
def from_directory(cls, path: data.PathLike, clipper: ClipperProtocol):
|
max_empty=0,
|
||||||
filenames = list_preprocessed_files(path)
|
)
|
||||||
return cls(filenames, clipper=clipper)
|
wav = self.audio_loader.load_clip(clip_annotation.clip)
|
||||||
|
return torch.from_numpy(wav).unsqueeze(0), clip_annotation
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from batdetect2.evaluate.match import (
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
from batdetect2.postprocess import get_sound_event_predictions
|
from batdetect2.postprocess import get_sound_event_predictions
|
||||||
from batdetect2.train.dataset import LabeledDataset
|
from batdetect2.train.dataset import TrainingDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
@ -49,11 +49,11 @@ class ValidationMetrics(Callback):
|
|||||||
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
|
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
|
||||||
] = []
|
] = []
|
||||||
|
|
||||||
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
def get_dataset(self, trainer: Trainer) -> TrainingDataset:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
dataset = dataloaders.dataset
|
dataset = dataloaders.dataset
|
||||||
assert isinstance(dataset, LabeledDataset)
|
assert isinstance(dataset, TrainingDataset)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def plot_examples(
|
def plot_examples(
|
||||||
@ -136,12 +136,12 @@ class ValidationMetrics(Callback):
|
|||||||
def _get_batch_clips_and_predictions(
|
def _get_batch_clips_and_predictions(
|
||||||
batch: TrainExample,
|
batch: TrainExample,
|
||||||
outputs: ModelOutput,
|
outputs: ModelOutput,
|
||||||
dataset: LabeledDataset,
|
dataset: TrainingDataset,
|
||||||
model: Model,
|
model: Model,
|
||||||
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
) -> List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]]:
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
_get_subclip(
|
_get_subclip(
|
||||||
dataset.get_clip_annotation(example_id),
|
dataset.clip_annotations[int(example_id)],
|
||||||
start_time=start_time.item(),
|
start_time=start_time.item(),
|
||||||
end_time=end_time.item(),
|
end_time=end_time.item(),
|
||||||
targets=model.targets,
|
targets=model.targets,
|
||||||
|
|||||||
@ -1,14 +1,12 @@
|
|||||||
from typing import Optional, Tuple
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.typing import ClipperProtocol
|
from batdetect2.typing import ClipperProtocol
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
|
||||||
from batdetect2.typing.train import PreprocessedExample
|
|
||||||
from batdetect2.utils.arrays import adjust_width, slice_tensor
|
|
||||||
|
|
||||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||||
@ -18,50 +16,127 @@ class ClipingConfig(BaseConfig):
|
|||||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||||
random: bool = True
|
random: bool = True
|
||||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||||
|
min_sound_event_overlap: float = 0
|
||||||
|
|
||||||
|
|
||||||
class Clipper(torch.nn.Module):
|
class Clipper:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
duration: float = 0.5,
|
duration: float = 0.5,
|
||||||
max_empty: float = 0.2,
|
max_empty: float = 0.2,
|
||||||
random: bool = True,
|
random: bool = True,
|
||||||
|
min_sound_event_overlap: float = 0,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.preprocessor = preprocessor
|
|
||||||
self.duration = duration
|
self.duration = duration
|
||||||
self.random = random
|
self.random = random
|
||||||
self.max_empty = max_empty
|
self.max_empty = max_empty
|
||||||
|
self.min_sound_event_overlap = min_sound_event_overlap
|
||||||
|
|
||||||
def forward(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
example: PreprocessedExample,
|
clip_annotation: data.ClipAnnotation,
|
||||||
) -> Tuple[PreprocessedExample, float, float]:
|
) -> data.ClipAnnotation:
|
||||||
start_time = 0
|
return get_subclip_annotation(
|
||||||
duration = example.audio.shape[-1] / self.preprocessor.input_samplerate
|
clip_annotation,
|
||||||
|
random=self.random,
|
||||||
if self.random:
|
duration=self.duration,
|
||||||
start_time = np.random.uniform(
|
max_empty=self.max_empty,
|
||||||
-self.max_empty,
|
min_sound_event_overlap=self.min_sound_event_overlap,
|
||||||
duration - self.duration + self.max_empty,
|
|
||||||
)
|
|
||||||
|
|
||||||
return (
|
|
||||||
select_subclip(
|
|
||||||
example,
|
|
||||||
start=start_time,
|
|
||||||
duration=self.duration,
|
|
||||||
input_samplerate=self.preprocessor.input_samplerate,
|
|
||||||
output_samplerate=self.preprocessor.output_samplerate,
|
|
||||||
),
|
|
||||||
start_time,
|
|
||||||
start_time + self.duration,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_subclip_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
random: bool = True,
|
||||||
|
duration: float = 0.5,
|
||||||
|
max_empty: float = 0.2,
|
||||||
|
min_sound_event_overlap: float = 0,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
subclip = select_subclip(
|
||||||
|
clip,
|
||||||
|
random=random,
|
||||||
|
duration=duration,
|
||||||
|
max_empty=max_empty,
|
||||||
|
)
|
||||||
|
|
||||||
|
sound_events = select_sound_event_annotations(
|
||||||
|
clip_annotation,
|
||||||
|
subclip,
|
||||||
|
min_overlap=min_sound_event_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
clip=subclip,
|
||||||
|
sound_events=sound_events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def select_subclip(
|
||||||
|
clip: data.Clip,
|
||||||
|
random: bool = True,
|
||||||
|
duration: float = 0.5,
|
||||||
|
max_empty: float = 0.2,
|
||||||
|
) -> data.Clip:
|
||||||
|
start_time = clip.start_time
|
||||||
|
end_time = clip.end_time
|
||||||
|
|
||||||
|
if duration > clip.duration + max_empty or not random:
|
||||||
|
return clip.model_copy(
|
||||||
|
update=dict(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=start_time + duration,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
random_start_time = np.random.uniform(
|
||||||
|
low=start_time,
|
||||||
|
high=end_time + max_empty - duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
return clip.model_copy(
|
||||||
|
update=dict(
|
||||||
|
start_time=random_start_time,
|
||||||
|
end_time=random_start_time + duration,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def select_sound_event_annotations(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
subclip: data.Clip,
|
||||||
|
min_overlap: float = 0,
|
||||||
|
) -> List[data.SoundEventAnnotation]:
|
||||||
|
selected = []
|
||||||
|
|
||||||
|
start_time = subclip.start_time
|
||||||
|
end_time = subclip.end_time
|
||||||
|
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
geom_start_time, _, geom_end_time, _ = compute_bounds(geometry)
|
||||||
|
|
||||||
|
if not intervals_overlap(
|
||||||
|
(start_time, end_time),
|
||||||
|
(geom_start_time, geom_end_time),
|
||||||
|
min_absolute_overlap=min_overlap,
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
selected.append(sound_event_annotation)
|
||||||
|
|
||||||
|
return selected
|
||||||
|
|
||||||
|
|
||||||
def build_clipper(
|
def build_clipper(
|
||||||
preprocessor: PreprocessorProtocol,
|
|
||||||
config: Optional[ClipingConfig] = None,
|
config: Optional[ClipingConfig] = None,
|
||||||
random: Optional[bool] = None,
|
random: Optional[bool] = None,
|
||||||
) -> ClipperProtocol:
|
) -> ClipperProtocol:
|
||||||
@ -71,73 +146,7 @@ def build_clipper(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
return Clipper(
|
return Clipper(
|
||||||
preprocessor=preprocessor,
|
|
||||||
duration=config.duration,
|
duration=config.duration,
|
||||||
max_empty=config.max_empty,
|
max_empty=config.max_empty,
|
||||||
random=config.random if random else False,
|
random=config.random if random else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def select_subclip(
|
|
||||||
example: PreprocessedExample,
|
|
||||||
start: float,
|
|
||||||
duration: float,
|
|
||||||
input_samplerate: float,
|
|
||||||
output_samplerate: float,
|
|
||||||
fill_value: float = 0,
|
|
||||||
) -> PreprocessedExample:
|
|
||||||
audio_width = int(np.floor(duration * input_samplerate))
|
|
||||||
audio_start = int(np.floor(start * input_samplerate))
|
|
||||||
|
|
||||||
audio = adjust_width(
|
|
||||||
slice_tensor(
|
|
||||||
example.audio,
|
|
||||||
start=audio_start,
|
|
||||||
end=audio_start + audio_width,
|
|
||||||
dim=-1,
|
|
||||||
),
|
|
||||||
audio_width,
|
|
||||||
value=fill_value,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec_start = int(np.floor(start * output_samplerate))
|
|
||||||
spec_width = int(np.floor(duration * output_samplerate))
|
|
||||||
return PreprocessedExample(
|
|
||||||
audio=audio,
|
|
||||||
spectrogram=adjust_width(
|
|
||||||
slice_tensor(
|
|
||||||
example.spectrogram,
|
|
||||||
start=spec_start,
|
|
||||||
end=spec_start + spec_width,
|
|
||||||
dim=-1,
|
|
||||||
),
|
|
||||||
spec_width,
|
|
||||||
),
|
|
||||||
class_heatmap=adjust_width(
|
|
||||||
slice_tensor(
|
|
||||||
example.class_heatmap,
|
|
||||||
start=spec_start,
|
|
||||||
end=spec_start + spec_width,
|
|
||||||
dim=-1,
|
|
||||||
),
|
|
||||||
spec_width,
|
|
||||||
),
|
|
||||||
detection_heatmap=adjust_width(
|
|
||||||
slice_tensor(
|
|
||||||
example.detection_heatmap,
|
|
||||||
start=spec_start,
|
|
||||||
end=spec_start + spec_width,
|
|
||||||
dim=-1,
|
|
||||||
),
|
|
||||||
spec_width,
|
|
||||||
),
|
|
||||||
size_heatmap=adjust_width(
|
|
||||||
slice_tensor(
|
|
||||||
example.size_heatmap,
|
|
||||||
start=spec_start,
|
|
||||||
end=spec_start + spec_width,
|
|
||||||
dim=-1,
|
|
||||||
),
|
|
||||||
spec_width,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|||||||
@ -6,11 +6,13 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import ClipingConfig
|
from batdetect2.train.clips import ClipingConfig
|
||||||
|
from batdetect2.train.labels import LabelConfig
|
||||||
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
@ -50,7 +52,7 @@ class DataLoaderConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
|
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
|
||||||
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False)
|
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
|
||||||
|
|
||||||
|
|
||||||
class LoadersConfig(BaseConfig):
|
class LoadersConfig(BaseConfig):
|
||||||
@ -73,6 +75,8 @@ class TrainingConfig(BaseConfig):
|
|||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
|
|||||||
@ -1,78 +1,77 @@
|
|||||||
from typing import Optional, Sequence, Tuple
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from batdetect2.train.augmentations import Augmentation
|
|
||||||
from batdetect2.train.preprocess import (
|
|
||||||
list_preprocessed_files,
|
|
||||||
load_preprocessed_example,
|
|
||||||
)
|
|
||||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||||
from batdetect2.typing.train import PreprocessedExample
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
|
from batdetect2.typing.train import Augmentation, ClipLabeller
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LabeledDataset",
|
"TrainingDataset",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
class TrainingDataset(Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filenames: Sequence[data.PathLike],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
clipper: ClipperProtocol,
|
audio_loader: AudioLoader,
|
||||||
augmentation: Optional[Augmentation] = None,
|
preprocessor: PreprocessorProtocol,
|
||||||
|
labeller: ClipLabeller,
|
||||||
|
clipper: Optional[ClipperProtocol] = None,
|
||||||
|
audio_augmentation: Optional[Augmentation] = None,
|
||||||
|
spectrogram_augmentation: Optional[Augmentation] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
):
|
):
|
||||||
self.filenames = filenames
|
self.clip_annotations = clip_annotations
|
||||||
self.clipper = clipper
|
self.clipper = clipper
|
||||||
self.augmentation = augmentation
|
self.labeller = labeller
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.audio_augmentation = audio_augmentation
|
||||||
|
self.spectrogram_augmentation = spectrogram_augmentation
|
||||||
|
self.audio_dir = audio_dir
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.filenames)
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
example = self.get_example(idx)
|
clip_annotation = self.clip_annotations[idx]
|
||||||
|
|
||||||
example, start_time, end_time = self.clipper(example)
|
if self.clipper is not None:
|
||||||
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
|
|
||||||
if self.augmentation:
|
clip = clip_annotation.clip
|
||||||
example = self.augmentation(example)
|
|
||||||
|
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
|
|
||||||
|
# Add channel dim
|
||||||
|
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||||
|
|
||||||
|
if self.audio_augmentation is not None:
|
||||||
|
wav_tensor, clip_annotation = self.audio_augmentation(
|
||||||
|
wav_tensor,
|
||||||
|
clip_annotation,
|
||||||
|
)
|
||||||
|
|
||||||
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
|
|
||||||
|
if self.spectrogram_augmentation is not None:
|
||||||
|
spectrogram, clip_annotation = self.spectrogram_augmentation(
|
||||||
|
spectrogram,
|
||||||
|
clip_annotation,
|
||||||
|
)
|
||||||
|
|
||||||
|
heatmaps = self.labeller(clip_annotation, spectrogram)
|
||||||
|
|
||||||
return TrainExample(
|
return TrainExample(
|
||||||
spec=example.spectrogram,
|
spec=spectrogram,
|
||||||
detection_heatmap=example.detection_heatmap,
|
detection_heatmap=heatmaps.detection,
|
||||||
class_heatmap=example.class_heatmap,
|
class_heatmap=heatmaps.classes,
|
||||||
size_heatmap=example.size_heatmap,
|
size_heatmap=heatmaps.size,
|
||||||
idx=torch.tensor(idx),
|
idx=torch.tensor(idx),
|
||||||
start_time=torch.tensor(start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_directory(
|
|
||||||
cls,
|
|
||||||
directory: data.PathLike,
|
|
||||||
clipper: ClipperProtocol,
|
|
||||||
extension: str = ".npz",
|
|
||||||
augmentation: Optional[Augmentation] = None,
|
|
||||||
):
|
|
||||||
return cls(
|
|
||||||
filenames=list_preprocessed_files(directory, extension),
|
|
||||||
clipper=clipper,
|
|
||||||
augmentation=augmentation,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_random_example(self) -> Tuple[PreprocessedExample, float, float]:
|
|
||||||
idx = np.random.randint(0, len(self))
|
|
||||||
dataset = self.get_example(idx)
|
|
||||||
dataset, start_time, end_time = self.clipper(dataset)
|
|
||||||
return dataset, start_time, end_time
|
|
||||||
|
|
||||||
def get_example(self, idx) -> PreprocessedExample:
|
|
||||||
return load_preprocessed_example(self.filenames[idx])
|
|
||||||
|
|
||||||
def get_clip_annotation(self, idx) -> data.ClipAnnotation:
|
|
||||||
item = np.load(self.filenames[idx], allow_pickle=True, mmap_mode="r+")
|
|
||||||
return item["clip_annotation"].tolist()
|
|
||||||
|
|||||||
@ -15,16 +15,18 @@ from batdetect2.evaluate.metrics import (
|
|||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
|
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
RandomExampleSource,
|
RandomAudioSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
TrainingDataset,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
@ -33,6 +35,7 @@ from batdetect2.typing import (
|
|||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.train import ClipLabeller
|
||||||
from batdetect2.utils.arrays import adjust_width
|
from batdetect2.utils.arrays import adjust_width
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -46,8 +49,8 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_examples: Sequence[data.PathLike],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_examples: Optional[Sequence[data.PathLike]] = None,
|
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||||
config: Optional[FullTrainingConfig] = None,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
@ -59,8 +62,19 @@ def train(
|
|||||||
|
|
||||||
trainer = build_trainer(config, targets=model.targets)
|
trainer = build_trainer(config, targets=model.targets)
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||||
|
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
model.targets,
|
||||||
|
min_freq=model.preprocessor.min_freq,
|
||||||
|
max_freq=model.preprocessor.max_freq,
|
||||||
|
config=config.train.labels,
|
||||||
|
)
|
||||||
|
|
||||||
train_dataloader = build_train_loader(
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
preprocessor=model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
@ -68,12 +82,14 @@ def train(
|
|||||||
|
|
||||||
val_dataloader = (
|
val_dataloader = (
|
||||||
build_val_loader(
|
build_val_loader(
|
||||||
val_examples,
|
val_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
preprocessor=model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
if val_examples is not None
|
if val_annotations is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -153,19 +169,23 @@ def build_trainer(
|
|||||||
|
|
||||||
|
|
||||||
def build_train_loader(
|
def build_train_loader(
|
||||||
train_examples: Sequence[data.PathLike],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
labeller: ClipLabeller,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
logger.info("Building training data loader...")
|
|
||||||
train_dataset = build_train_dataset(
|
train_dataset = build_train_dataset(
|
||||||
train_examples,
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Building training data loader...")
|
||||||
loader_conf = config.dataloaders.train
|
loader_conf = config.dataloaders.train
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Training data loader config: \n{config}",
|
"Training data loader config: \n{config}",
|
||||||
@ -182,16 +202,20 @@ def build_train_loader(
|
|||||||
|
|
||||||
|
|
||||||
def build_val_loader(
|
def build_val_loader(
|
||||||
val_examples: Sequence[data.PathLike],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
labeller: ClipLabeller,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
logger.info("Building validation data loader...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
logger.info("Building validation data loader...")
|
|
||||||
val_dataset = build_val_dataset(
|
val_dataset = build_val_dataset(
|
||||||
val_examples,
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
@ -203,7 +227,7 @@ def build_val_loader(
|
|||||||
num_workers = num_workers or loader_conf.num_workers
|
num_workers = num_workers or loader_conf.num_workers
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=loader_conf.batch_size,
|
batch_size=1,
|
||||||
shuffle=loader_conf.shuffle,
|
shuffle=loader_conf.shuffle,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=_collate_fn,
|
collate_fn=_collate_fn,
|
||||||
@ -232,52 +256,60 @@ def _collate_fn(batch: List[TrainExample]) -> TrainExample:
|
|||||||
|
|
||||||
|
|
||||||
def build_train_dataset(
|
def build_train_dataset(
|
||||||
examples: Sequence[data.PathLike],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
labeller: ClipLabeller,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
) -> LabeledDataset:
|
) -> TrainingDataset:
|
||||||
logger.info("Building training dataset...")
|
logger.info("Building training dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
clipper = build_clipper(
|
clipper = build_clipper(
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=config.cliping,
|
config=config.cliping,
|
||||||
random=True,
|
random=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
random_example_source = RandomExampleSource(
|
random_example_source = RandomAudioSource(
|
||||||
list(examples),
|
clip_annotations,
|
||||||
clipper=clipper,
|
audio_loader=audio_loader,
|
||||||
)
|
)
|
||||||
|
|
||||||
if config.augmentations.enabled and config.augmentations.steps:
|
if config.augmentations.enabled:
|
||||||
augmentations = build_augmentations(
|
audio_augmentation, spectrogram_augmentation = build_augmentations(
|
||||||
preprocessor,
|
samplerate=preprocessor.input_samplerate,
|
||||||
config=config.augmentations,
|
config=config.augmentations,
|
||||||
example_source=random_example_source,
|
audio_source=random_example_source,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.debug("No augmentations configured for training dataset.")
|
logger.debug("No augmentations configured for training dataset.")
|
||||||
augmentations = None
|
audio_augmentation = None
|
||||||
|
spectrogram_augmentation = None
|
||||||
|
|
||||||
return LabeledDataset(
|
return TrainingDataset(
|
||||||
examples,
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
augmentation=augmentations,
|
preprocessor=preprocessor,
|
||||||
|
audio_augmentation=audio_augmentation,
|
||||||
|
spectrogram_augmentation=spectrogram_augmentation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_val_dataset(
|
def build_val_dataset(
|
||||||
examples: Sequence[data.PathLike],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
labeller: ClipLabeller,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
train: bool = True,
|
) -> TrainingDataset:
|
||||||
) -> LabeledDataset:
|
|
||||||
logger.info("Building validation dataset...")
|
logger.info("Building validation dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
clipper = build_clipper(
|
|
||||||
|
return TrainingDataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.cliping,
|
|
||||||
random=train,
|
|
||||||
)
|
)
|
||||||
return LabeledDataset(examples, clipper=clipper)
|
|
||||||
|
|||||||
@ -49,7 +49,11 @@ spectrogram, applies all configured filtering, transformation, and encoding
|
|||||||
steps, and returns the final `Heatmaps` used for model training.
|
steps, and returns the final `Heatmaps` used for model training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
Augmentation = Callable[[PreprocessedExample], PreprocessedExample]
|
|
||||||
|
Augmentation = Callable[
|
||||||
|
[torch.Tensor, data.ClipAnnotation],
|
||||||
|
Tuple[torch.Tensor, data.ClipAnnotation],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TrainExample(NamedTuple):
|
class TrainExample(NamedTuple):
|
||||||
@ -97,5 +101,6 @@ class LossProtocol(Protocol):
|
|||||||
|
|
||||||
class ClipperProtocol(Protocol):
|
class ClipperProtocol(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, example: PreprocessedExample
|
self,
|
||||||
) -> Tuple[PreprocessedExample, float, float]: ...
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> data.ClipAnnotation: ...
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
add_echo,
|
add_echo,
|
||||||
mix_examples,
|
mix_audio,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import select_subclip
|
from batdetect2.train.clips import select_subclip
|
||||||
from batdetect2.train.preprocess import generate_train_example
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
@ -41,7 +41,7 @@ def test_mix_examples(
|
|||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed = mix_examples(
|
mixed = mix_audio(
|
||||||
example1,
|
example1,
|
||||||
example2,
|
example2,
|
||||||
weight=0.3,
|
weight=0.3,
|
||||||
@ -86,7 +86,7 @@ def test_mix_examples_of_different_durations(
|
|||||||
labeller=sample_labeller,
|
labeller=sample_labeller,
|
||||||
)
|
)
|
||||||
|
|
||||||
mixed = mix_examples(
|
mixed = mix_audio(
|
||||||
example1,
|
example1,
|
||||||
example2,
|
example2,
|
||||||
weight=0.3,
|
weight=0.3,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user