Improved train module

This commit is contained in:
mbsantiago 2025-04-03 16:49:58 +01:00
parent 7689580a24
commit e383a33cbf
8 changed files with 222 additions and 125 deletions

View File

@ -0,0 +1,48 @@
from batdetect2.train.augmentations import (
AugmentationsConfig,
add_echo,
augment_example,
load_agumentation_config,
mask_frequency,
mask_time,
mix_examples,
scale_volume,
select_subclip,
warp_spectrogram,
)
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import (
LabeledDataset,
SubclipConfig,
TrainExample,
)
from batdetect2.train.labels import LabelConfig, load_label_config
from batdetect2.train.preprocess import preprocess_annotations
from batdetect2.train.targets import TargetConfig, load_target_config
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [
"AugmentationsConfig",
"LabelConfig",
"LabeledDataset",
"SubclipConfig",
"TargetConfig",
"TrainExample",
"TrainerConfig",
"TrainingConfig",
"add_echo",
"augment_example",
"load_agumentation_config",
"load_label_config",
"load_target_config",
"load_train_config",
"load_trainer_config",
"mask_frequency",
"mask_time",
"mix_examples",
"preprocess_annotations",
"scale_volume",
"select_subclip",
"train",
"warp_spectrogram",
]

View File

@ -3,16 +3,30 @@ from typing import Callable, Optional, Union
import numpy as np
import xarray as xr
from pydantic import Field
from soundevent import arrays
from soundevent import arrays, data
from batdetect2.configs import BaseConfig
from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
from batdetect2.preprocess.arrays import adjust_width
Augmentation = Callable[[xr.Dataset], xr.Dataset]
class AugmentationConfig(BaseConfig):
__all__ = [
"AugmentationsConfig",
"load_agumentation_config",
"select_subclip",
"mix_examples",
"add_echo",
"scale_volume",
"warp_spectrogram",
"mask_time",
"mask_frequency",
"augment_example",
]
class BaseAugmentationConfig(BaseConfig):
enable: bool = True
probability: float = 0.2
@ -63,7 +77,7 @@ def select_subclip(
)
class MixAugmentationConfig(AugmentationConfig):
class MixAugmentationConfig(BaseAugmentationConfig):
min_weight: float = 0.3
max_weight: float = 0.7
@ -133,7 +147,7 @@ def mix_examples(
)
class EchoAugmentationConfig(AugmentationConfig):
class EchoAugmentationConfig(BaseAugmentationConfig):
max_delay: float = 0.005
min_weight: float = 0.0
max_weight: float = 1.0
@ -188,7 +202,7 @@ def add_echo(
)
class VolumeAugmentationConfig(AugmentationConfig):
class VolumeAugmentationConfig(BaseAugmentationConfig):
min_scaling: float = 0.0
max_scaling: float = 2.0
@ -206,7 +220,7 @@ def scale_volume(
return example.assign(spectrogram=example["spectrogram"] * factor)
class WarpAugmentationConfig(AugmentationConfig):
class WarpAugmentationConfig(BaseAugmentationConfig):
delta: float = 0.04
@ -294,7 +308,7 @@ def mask_axis(
return array.where(condition, other=mask_value)
class TimeMaskAugmentationConfig(AugmentationConfig):
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
max_perc: float = 0.05
max_masks: int = 3
@ -318,7 +332,7 @@ def mask_time(
return example.assign(spectrogram=spectrogram)
class FrequencyMaskAugmentationConfig(AugmentationConfig):
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
max_perc: float = 0.10
max_masks: int = 3
@ -361,7 +375,13 @@ class AugmentationsConfig(BaseConfig):
)
def should_apply(config: AugmentationConfig) -> bool:
def load_agumentation_config(
path: data.PathLike, field: Optional[str] = None
) -> AugmentationsConfig:
return load_config(path, schema=AugmentationsConfig, field=field)
def should_apply(config: BaseAugmentationConfig) -> bool:
if not config.enable:
return False

View File

@ -43,28 +43,23 @@ class SubclipConfig(BaseConfig):
class DatasetConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
augmentation: AugmentationsConfig = Field(
default_factory=AugmentationsConfig
)
class LabeledDataset(Dataset):
config: DatasetConfig
def __init__(
self,
filenames: Sequence[PathLike],
augment: bool = False,
subclip: bool = False,
config: Optional[DatasetConfig] = None,
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
):
self.filenames = filenames
self.augment = augment
self.subclip = subclip
self.config = config or DatasetConfig()
self.augmentation = augmentation
self.preprocessing = preprocessing or PreprocessingConfig()
def __len__(self):
return len(self.filenames)
@ -75,16 +70,16 @@ class LabeledDataset(Dataset):
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.config.subclip.duration,
width=self.config.subclip.width,
random=self.config.subclip.random,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
if self.augment:
if self.augmentation:
dataset = augment_example(
dataset,
self.config.augmentation,
preprocessing_config=self.config.preprocessing,
self.augmentation,
preprocessing_config=self.preprocessing,
others=self.get_random_example,
)
@ -101,15 +96,15 @@ class LabeledDataset(Dataset):
cls,
directory: PathLike,
extension: str = ".nc",
config: Optional[DatasetConfig] = None,
augment: bool = False,
subclip: bool = False,
subclip: Optional[SubclipConfig] = None,
augmentation: Optional[AugmentationsConfig] = None,
preprocessing: Optional[PreprocessingConfig] = None,
):
return cls(
get_files(directory, extension),
config=config,
augment=augment,
subclip=subclip,
augmentation=augmentation,
preprocessing=preprocessing,
)
def get_random_example(self) -> xr.Dataset:
@ -119,9 +114,9 @@ class LabeledDataset(Dataset):
if self.subclip:
dataset = select_subclip(
dataset,
duration=self.config.subclip.duration,
width=self.config.subclip.width,
random=self.config.subclip.random,
duration=self.subclip.duration,
width=self.subclip.width,
random=self.subclip.random,
)
return dataset
@ -144,7 +139,7 @@ class LabeledDataset(Dataset):
if not self.subclip:
return tensor
width = self.config.subclip.width
width = self.subclip.width
return adjust_width(tensor, width)

View File

@ -3,13 +3,19 @@ from typing import Callable, List, Optional, Sequence, Tuple
import numpy as np
import xarray as xr
from pydantic import Field
from scipy.ndimage import gaussian_filter
from soundevent import arrays, data, geometry
from soundevent.geometry.operations import Positions
from batdetect2.configs import BaseConfig
from batdetect2.configs import BaseConfig, load_config
__all__ = ["generate_heatmaps"]
__all__ = [
"HeatmapsConfig",
"LabelConfig",
"generate_heatmaps",
"load_label_config",
]
class HeatmapsConfig(BaseConfig):
@ -19,6 +25,10 @@ class HeatmapsConfig(BaseConfig):
frequency_scale: float = 1 / 859.375
class LabelConfig(BaseConfig):
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
def generate_heatmaps(
sound_events: Sequence[data.SoundEventAnnotation],
spec: xr.DataArray,
@ -132,3 +142,9 @@ def generate_heatmaps(
).fillna(0.0)
return detection_heatmap, class_heatmap, size_heatmap
def load_label_config(
path: data.PathLike, field: Optional[str] = None
) -> LabelConfig:
return load_config(path, schema=LabelConfig, field=field)

View File

@ -6,9 +6,15 @@ from pydantic import Field
from batdetect2.configs import BaseConfig
from batdetect2.models.typing import ModelOutput
from batdetect2.plot import detection
from batdetect2.train.dataset import TrainExample
__all__ = [
"bbox_size_loss",
"compute_loss",
"focal_loss",
"mse_loss",
]
class SizeLossConfig(BaseConfig):
weight: float = 0.1

View File

@ -17,7 +17,7 @@ from batdetect2.preprocess import (
compute_spectrogram,
load_clip_audio,
)
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
from batdetect2.train.labels import LabelConfig, generate_heatmaps
from batdetect2.train.targets import (
TargetConfig,
build_encoder,
@ -30,6 +30,9 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
__all__ = [
"preprocess_annotations",
"preprocess_single_annotation",
"generate_train_example",
"TrainPreprocessingConfig",
]
@ -38,15 +41,21 @@ class TrainPreprocessingConfig(BaseConfig):
default_factory=PreprocessingConfig
)
target: TargetConfig = Field(default_factory=TargetConfig)
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def generate_train_example(
clip_annotation: data.ClipAnnotation,
config: Optional[TrainPreprocessingConfig] = None,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
) -> xr.Dataset:
"""Generate a training example."""
config = config or TrainPreprocessingConfig()
config = TrainPreprocessingConfig(
preprocessing=preprocessing_config or PreprocessingConfig(),
target=target_config or TargetConfig(),
labels=label_config or LabelConfig(),
)
wave = load_clip_audio(
clip_annotation.clip,
@ -78,10 +87,10 @@ def generate_train_example(
spectrogram,
class_names,
encoder,
target_sigma=config.heatmaps.sigma,
position=config.heatmaps.position,
time_scale=config.heatmaps.time_scale,
frequency_scale=config.heatmaps.frequency_scale,
target_sigma=config.labels.heatmaps.sigma,
position=config.labels.heatmaps.position,
time_scale=config.labels.heatmaps.time_scale,
frequency_scale=config.labels.heatmaps.frequency_scale,
)
dataset = xr.Dataset(
@ -133,14 +142,14 @@ def preprocess_annotations(
output_dir: PathLike,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
config: Optional[TrainPreprocessingConfig] = None,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
max_workers: Optional[int] = None,
) -> None:
"""Preprocess annotations and save to disk."""
output_dir = Path(output_dir)
config = config or TrainPreprocessingConfig()
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
@ -151,9 +160,11 @@ def preprocess_annotations(
partial(
preprocess_single_annotation,
output_dir=output_dir,
config=config,
filename_fn=filename_fn,
replace=replace,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
),
clip_annotations,
),
@ -165,7 +176,9 @@ def preprocess_annotations(
def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation,
output_dir: PathLike,
config: TrainPreprocessingConfig,
preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
) -> None:
@ -181,7 +194,12 @@ def preprocess_single_annotation(
path.unlink()
try:
sample = generate_train_example(clip_annotation, config=config)
sample = generate_train_example(
clip_annotation,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
)
except Exception as error:
raise RuntimeError(
f"Failed to process annotation: {clip_annotation.uuid}"

View File

@ -9,6 +9,14 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.terms import TagInfo, get_tag_from_info
__all__ = [
"TargetConfig",
"load_target_config",
"build_encoder",
"build_decoder",
"filter_sound_event",
]
class ReplaceConfig(BaseConfig):
"""Configuration for replacing tags."""

View File

@ -1,82 +1,68 @@
from typing import Callable, NamedTuple, Optional
from typing import Optional, Union
import torch
from soundevent import data
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from lightning import LightningModule
from lightning.pytorch import Trainer
from soundevent.data import PathLike
from torch.utils.data import DataLoader
from batdetect2.data.datasets import ClipAnnotationDataset
from batdetect2.models.typing import DetectionModel
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.dataset import LabeledDataset
__all__ = [
"train",
"TrainerConfig",
"load_trainer_config",
]
class TrainInputs(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
class TrainerConfig(BaseConfig):
accelerator: str = "auto"
accumulate_grad_batches: int = 1
deterministic: bool = True
check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = None
min_epochs: Optional[int] = 100
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
reload_dataloaders_every_n_epochs: Optional[int] = None
val_check_interval: Optional[Union[int, float]] = None
def train_loop(
model: DetectionModel,
train_dataset: ClipAnnotationDataset[TrainInputs],
validation_dataset: ClipAnnotationDataset[TrainInputs],
device: Optional[torch.device] = None,
num_epochs: int = 100,
learning_rate: float = 1e-4,
def load_trainer_config(path: PathLike, field: Optional[str] = None):
return load_config(path, schema=TrainerConfig, field=field)
def train(
module: LightningModule,
train_dataset: LabeledDataset,
trainer_config: Optional[TrainerConfig] = None,
dev_run: bool = False,
overfit_batches: bool = False,
profiler: Optional[str] = None,
):
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32)
model.to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(
optimizer,
num_epochs * len(train_loader),
trainer_config = trainer_config or TrainerConfig()
trainer = Trainer(
**trainer_config.model_dump(
exclude_unset=True,
exclude_none=True,
),
fast_dev_run=dev_run,
overfit_batches=overfit_batches,
profiler=profiler,
)
for epoch in range(num_epochs):
train_loss = train_single_epoch(
model,
train_loader,
optimizer,
device,
scheduler,
)
def train_single_epoch(
model: DetectionModel,
train_loader: DataLoader,
optimizer: Adam,
device: torch.device,
scheduler: CosineAnnealingLR,
):
model.train()
train_loss = tu.AverageMeter()
for batch in train_loader:
optimizer.zero_grad()
spec = batch.spec.to(device)
detection_heatmap = batch.detection_heatmap.to(device)
class_heatmap = batch.class_heatmap.to(device)
size_heatmap = batch.size_heatmap.to(device)
outputs = model(spec)
loss = loss_fun(
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
class_inv_freq,
)
train_loss.update(loss.item(), data.shape[0])
loss.backward()
optimizer.step()
scheduler.step()
train_loader = DataLoader(
train_dataset,
batch_size=module.config.train.batch_size,
shuffle=True,
num_workers=7,
)
trainer.fit(module, train_dataloaders=train_loader)