Improved train module

This commit is contained in:
mbsantiago 2025-04-03 16:49:58 +01:00
parent 7689580a24
commit e383a33cbf
8 changed files with 222 additions and 125 deletions

View File

@ -0,0 +1,48 @@
from batdetect2.train.augmentations import (
AugmentationsConfig,
add_echo,
augment_example,
load_agumentation_config,
mask_frequency,
mask_time,
mix_examples,
scale_volume,
select_subclip,
warp_spectrogram,
)
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import (
LabeledDataset,
SubclipConfig,
TrainExample,
)
from batdetect2.train.labels import LabelConfig, load_label_config
from batdetect2.train.preprocess import preprocess_annotations
from batdetect2.train.targets import TargetConfig, load_target_config
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
__all__ = [
"AugmentationsConfig",
"LabelConfig",
"LabeledDataset",
"SubclipConfig",
"TargetConfig",
"TrainExample",
"TrainerConfig",
"TrainingConfig",
"add_echo",
"augment_example",
"load_agumentation_config",
"load_label_config",
"load_target_config",
"load_train_config",
"load_trainer_config",
"mask_frequency",
"mask_time",
"mix_examples",
"preprocess_annotations",
"scale_volume",
"select_subclip",
"train",
"warp_spectrogram",
]

View File

@ -3,16 +3,30 @@ from typing import Callable, Optional, Union
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from pydantic import Field from pydantic import Field
from soundevent import arrays from soundevent import arrays, data
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig, load_config
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
from batdetect2.preprocess.arrays import adjust_width from batdetect2.preprocess.arrays import adjust_width
Augmentation = Callable[[xr.Dataset], xr.Dataset] Augmentation = Callable[[xr.Dataset], xr.Dataset]
class AugmentationConfig(BaseConfig): __all__ = [
"AugmentationsConfig",
"load_agumentation_config",
"select_subclip",
"mix_examples",
"add_echo",
"scale_volume",
"warp_spectrogram",
"mask_time",
"mask_frequency",
"augment_example",
]
class BaseAugmentationConfig(BaseConfig):
enable: bool = True enable: bool = True
probability: float = 0.2 probability: float = 0.2
@ -63,7 +77,7 @@ def select_subclip(
) )
class MixAugmentationConfig(AugmentationConfig): class MixAugmentationConfig(BaseAugmentationConfig):
min_weight: float = 0.3 min_weight: float = 0.3
max_weight: float = 0.7 max_weight: float = 0.7
@ -133,7 +147,7 @@ def mix_examples(
) )
class EchoAugmentationConfig(AugmentationConfig): class EchoAugmentationConfig(BaseAugmentationConfig):
max_delay: float = 0.005 max_delay: float = 0.005
min_weight: float = 0.0 min_weight: float = 0.0
max_weight: float = 1.0 max_weight: float = 1.0
@ -188,7 +202,7 @@ def add_echo(
) )
class VolumeAugmentationConfig(AugmentationConfig): class VolumeAugmentationConfig(BaseAugmentationConfig):
min_scaling: float = 0.0 min_scaling: float = 0.0
max_scaling: float = 2.0 max_scaling: float = 2.0
@ -206,7 +220,7 @@ def scale_volume(
return example.assign(spectrogram=example["spectrogram"] * factor) return example.assign(spectrogram=example["spectrogram"] * factor)
class WarpAugmentationConfig(AugmentationConfig): class WarpAugmentationConfig(BaseAugmentationConfig):
delta: float = 0.04 delta: float = 0.04
@ -294,7 +308,7 @@ def mask_axis(
return array.where(condition, other=mask_value) return array.where(condition, other=mask_value)
class TimeMaskAugmentationConfig(AugmentationConfig): class TimeMaskAugmentationConfig(BaseAugmentationConfig):
max_perc: float = 0.05 max_perc: float = 0.05
max_masks: int = 3 max_masks: int = 3
@ -318,7 +332,7 @@ def mask_time(
return example.assign(spectrogram=spectrogram) return example.assign(spectrogram=spectrogram)
class FrequencyMaskAugmentationConfig(AugmentationConfig): class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
max_perc: float = 0.10 max_perc: float = 0.10
max_masks: int = 3 max_masks: int = 3
@ -361,7 +375,13 @@ class AugmentationsConfig(BaseConfig):
) )
def should_apply(config: AugmentationConfig) -> bool: def load_agumentation_config(
path: data.PathLike, field: Optional[str] = None
) -> AugmentationsConfig:
return load_config(path, schema=AugmentationsConfig, field=field)
def should_apply(config: BaseAugmentationConfig) -> bool:
if not config.enable: if not config.enable:
return False return False

View File

@ -43,28 +43,23 @@ class SubclipConfig(BaseConfig):
class DatasetConfig(BaseConfig): class DatasetConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig) subclip: SubclipConfig = Field(default_factory=SubclipConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
augmentation: AugmentationsConfig = Field( augmentation: AugmentationsConfig = Field(
default_factory=AugmentationsConfig default_factory=AugmentationsConfig
) )
class LabeledDataset(Dataset): class LabeledDataset(Dataset):
config: DatasetConfig
def __init__( def __init__(
self, self,
filenames: Sequence[PathLike], filenames: Sequence[PathLike],
augment: bool = False, subclip: Optional[SubclipConfig] = None,
subclip: bool = False, augmentation: Optional[AugmentationsConfig] = None,
config: Optional[DatasetConfig] = None, preprocessing: Optional[PreprocessingConfig] = None,
): ):
self.filenames = filenames self.filenames = filenames
self.augment = augment
self.subclip = subclip self.subclip = subclip
self.config = config or DatasetConfig() self.augmentation = augmentation
self.preprocessing = preprocessing or PreprocessingConfig()
def __len__(self): def __len__(self):
return len(self.filenames) return len(self.filenames)
@ -75,16 +70,16 @@ class LabeledDataset(Dataset):
if self.subclip: if self.subclip:
dataset = select_subclip( dataset = select_subclip(
dataset, dataset,
duration=self.config.subclip.duration, duration=self.subclip.duration,
width=self.config.subclip.width, width=self.subclip.width,
random=self.config.subclip.random, random=self.subclip.random,
) )
if self.augment: if self.augmentation:
dataset = augment_example( dataset = augment_example(
dataset, dataset,
self.config.augmentation, self.augmentation,
preprocessing_config=self.config.preprocessing, preprocessing_config=self.preprocessing,
others=self.get_random_example, others=self.get_random_example,
) )
@ -101,15 +96,15 @@ class LabeledDataset(Dataset):
cls, cls,
directory: PathLike, directory: PathLike,
extension: str = ".nc", extension: str = ".nc",
config: Optional[DatasetConfig] = None, subclip: Optional[SubclipConfig] = None,
augment: bool = False, augmentation: Optional[AugmentationsConfig] = None,
subclip: bool = False, preprocessing: Optional[PreprocessingConfig] = None,
): ):
return cls( return cls(
get_files(directory, extension), get_files(directory, extension),
config=config,
augment=augment,
subclip=subclip, subclip=subclip,
augmentation=augmentation,
preprocessing=preprocessing,
) )
def get_random_example(self) -> xr.Dataset: def get_random_example(self) -> xr.Dataset:
@ -119,9 +114,9 @@ class LabeledDataset(Dataset):
if self.subclip: if self.subclip:
dataset = select_subclip( dataset = select_subclip(
dataset, dataset,
duration=self.config.subclip.duration, duration=self.subclip.duration,
width=self.config.subclip.width, width=self.subclip.width,
random=self.config.subclip.random, random=self.subclip.random,
) )
return dataset return dataset
@ -144,7 +139,7 @@ class LabeledDataset(Dataset):
if not self.subclip: if not self.subclip:
return tensor return tensor
width = self.config.subclip.width width = self.subclip.width
return adjust_width(tensor, width) return adjust_width(tensor, width)

View File

@ -3,13 +3,19 @@ from typing import Callable, List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from pydantic import Field
from scipy.ndimage import gaussian_filter from scipy.ndimage import gaussian_filter
from soundevent import arrays, data, geometry from soundevent import arrays, data, geometry
from soundevent.geometry.operations import Positions from soundevent.geometry.operations import Positions
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig, load_config
__all__ = ["generate_heatmaps"] __all__ = [
"HeatmapsConfig",
"LabelConfig",
"generate_heatmaps",
"load_label_config",
]
class HeatmapsConfig(BaseConfig): class HeatmapsConfig(BaseConfig):
@ -19,6 +25,10 @@ class HeatmapsConfig(BaseConfig):
frequency_scale: float = 1 / 859.375 frequency_scale: float = 1 / 859.375
class LabelConfig(BaseConfig):
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
def generate_heatmaps( def generate_heatmaps(
sound_events: Sequence[data.SoundEventAnnotation], sound_events: Sequence[data.SoundEventAnnotation],
spec: xr.DataArray, spec: xr.DataArray,
@ -132,3 +142,9 @@ def generate_heatmaps(
).fillna(0.0) ).fillna(0.0)
return detection_heatmap, class_heatmap, size_heatmap return detection_heatmap, class_heatmap, size_heatmap
def load_label_config(
path: data.PathLike, field: Optional[str] = None
) -> LabelConfig:
return load_config(path, schema=LabelConfig, field=field)

View File

@ -6,9 +6,15 @@ from pydantic import Field
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.typing import ModelOutput from batdetect2.models.typing import ModelOutput
from batdetect2.plot import detection
from batdetect2.train.dataset import TrainExample from batdetect2.train.dataset import TrainExample
__all__ = [
"bbox_size_loss",
"compute_loss",
"focal_loss",
"mse_loss",
]
class SizeLossConfig(BaseConfig): class SizeLossConfig(BaseConfig):
weight: float = 0.1 weight: float = 0.1

View File

@ -17,7 +17,7 @@ from batdetect2.preprocess import (
compute_spectrogram, compute_spectrogram,
load_clip_audio, load_clip_audio,
) )
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps from batdetect2.train.labels import LabelConfig, generate_heatmaps
from batdetect2.train.targets import ( from batdetect2.train.targets import (
TargetConfig, TargetConfig,
build_encoder, build_encoder,
@ -30,6 +30,9 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
__all__ = [ __all__ = [
"preprocess_annotations", "preprocess_annotations",
"preprocess_single_annotation",
"generate_train_example",
"TrainPreprocessingConfig",
] ]
@ -38,15 +41,21 @@ class TrainPreprocessingConfig(BaseConfig):
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
target: TargetConfig = Field(default_factory=TargetConfig) target: TargetConfig = Field(default_factory=TargetConfig)
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig) labels: LabelConfig = Field(default_factory=LabelConfig)
def generate_train_example( def generate_train_example(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
config: Optional[TrainPreprocessingConfig] = None, preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
) -> xr.Dataset: ) -> xr.Dataset:
"""Generate a training example.""" """Generate a training example."""
config = config or TrainPreprocessingConfig() config = TrainPreprocessingConfig(
preprocessing=preprocessing_config or PreprocessingConfig(),
target=target_config or TargetConfig(),
labels=label_config or LabelConfig(),
)
wave = load_clip_audio( wave = load_clip_audio(
clip_annotation.clip, clip_annotation.clip,
@ -78,10 +87,10 @@ def generate_train_example(
spectrogram, spectrogram,
class_names, class_names,
encoder, encoder,
target_sigma=config.heatmaps.sigma, target_sigma=config.labels.heatmaps.sigma,
position=config.heatmaps.position, position=config.labels.heatmaps.position,
time_scale=config.heatmaps.time_scale, time_scale=config.labels.heatmaps.time_scale,
frequency_scale=config.heatmaps.frequency_scale, frequency_scale=config.labels.heatmaps.frequency_scale,
) )
dataset = xr.Dataset( dataset = xr.Dataset(
@ -133,14 +142,14 @@ def preprocess_annotations(
output_dir: PathLike, output_dir: PathLike,
filename_fn: FilenameFn = _get_filename, filename_fn: FilenameFn = _get_filename,
replace: bool = False, replace: bool = False,
config: Optional[TrainPreprocessingConfig] = None, preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
max_workers: Optional[int] = None, max_workers: Optional[int] = None,
) -> None: ) -> None:
"""Preprocess annotations and save to disk.""" """Preprocess annotations and save to disk."""
output_dir = Path(output_dir) output_dir = Path(output_dir)
config = config or TrainPreprocessingConfig()
if not output_dir.is_dir(): if not output_dir.is_dir():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
@ -151,9 +160,11 @@ def preprocess_annotations(
partial( partial(
preprocess_single_annotation, preprocess_single_annotation,
output_dir=output_dir, output_dir=output_dir,
config=config,
filename_fn=filename_fn, filename_fn=filename_fn,
replace=replace, replace=replace,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
), ),
clip_annotations, clip_annotations,
), ),
@ -165,7 +176,9 @@ def preprocess_annotations(
def preprocess_single_annotation( def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
output_dir: PathLike, output_dir: PathLike,
config: TrainPreprocessingConfig, preprocessing_config: Optional[PreprocessingConfig] = None,
target_config: Optional[TargetConfig] = None,
label_config: Optional[LabelConfig] = None,
filename_fn: FilenameFn = _get_filename, filename_fn: FilenameFn = _get_filename,
replace: bool = False, replace: bool = False,
) -> None: ) -> None:
@ -181,7 +194,12 @@ def preprocess_single_annotation(
path.unlink() path.unlink()
try: try:
sample = generate_train_example(clip_annotation, config=config) sample = generate_train_example(
clip_annotation,
preprocessing_config=preprocessing_config,
target_config=target_config,
label_config=label_config,
)
except Exception as error: except Exception as error:
raise RuntimeError( raise RuntimeError(
f"Failed to process annotation: {clip_annotation.uuid}" f"Failed to process annotation: {clip_annotation.uuid}"

View File

@ -9,6 +9,14 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.terms import TagInfo, get_tag_from_info from batdetect2.terms import TagInfo, get_tag_from_info
__all__ = [
"TargetConfig",
"load_target_config",
"build_encoder",
"build_decoder",
"filter_sound_event",
]
class ReplaceConfig(BaseConfig): class ReplaceConfig(BaseConfig):
"""Configuration for replacing tags.""" """Configuration for replacing tags."""

View File

@ -1,82 +1,68 @@
from typing import Callable, NamedTuple, Optional from typing import Optional, Union
import torch from lightning import LightningModule
from soundevent import data from lightning.pytorch import Trainer
from torch.optim import Adam from soundevent.data import PathLike
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.data.datasets import ClipAnnotationDataset from batdetect2.configs import BaseConfig, load_config
from batdetect2.models.typing import DetectionModel from batdetect2.train.dataset import LabeledDataset
__all__ = [
"train",
"TrainerConfig",
"load_trainer_config",
]
class TrainInputs(NamedTuple): class TrainerConfig(BaseConfig):
spec: torch.Tensor accelerator: str = "auto"
detection_heatmap: torch.Tensor accumulate_grad_batches: int = 1
class_heatmap: torch.Tensor deterministic: bool = True
size_heatmap: torch.Tensor check_val_every_n_epoch: int = 1
devices: Union[str, int] = "auto"
enable_checkpointing: bool = True
gradient_clip_val: Optional[float] = None
limit_train_batches: Optional[Union[int, float]] = None
limit_test_batches: Optional[Union[int, float]] = None
limit_val_batches: Optional[Union[int, float]] = None
log_every_n_steps: Optional[int] = None
max_epochs: Optional[int] = None
min_epochs: Optional[int] = 100
max_steps: Optional[int] = None
min_steps: Optional[int] = None
max_time: Optional[str] = None
precision: Optional[str] = None
reload_dataloaders_every_n_epochs: Optional[int] = None
val_check_interval: Optional[Union[int, float]] = None
def train_loop( def load_trainer_config(path: PathLike, field: Optional[str] = None):
model: DetectionModel, return load_config(path, schema=TrainerConfig, field=field)
train_dataset: ClipAnnotationDataset[TrainInputs],
validation_dataset: ClipAnnotationDataset[TrainInputs],
device: Optional[torch.device] = None, def train(
num_epochs: int = 100, module: LightningModule,
learning_rate: float = 1e-4, train_dataset: LabeledDataset,
trainer_config: Optional[TrainerConfig] = None,
dev_run: bool = False,
overfit_batches: bool = False,
profiler: Optional[str] = None,
): ):
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) trainer_config = trainer_config or TrainerConfig()
validation_loader = DataLoader(validation_dataset, batch_size=32) trainer = Trainer(
**trainer_config.model_dump(
model.to(device) exclude_unset=True,
exclude_none=True,
optimizer = Adam(model.parameters(), lr=learning_rate) ),
scheduler = CosineAnnealingLR( fast_dev_run=dev_run,
optimizer, overfit_batches=overfit_batches,
num_epochs * len(train_loader), profiler=profiler,
) )
train_loader = DataLoader(
for epoch in range(num_epochs): train_dataset,
train_loss = train_single_epoch( batch_size=module.config.train.batch_size,
model, shuffle=True,
train_loader, num_workers=7,
optimizer, )
device, trainer.fit(module, train_dataloaders=train_loader)
scheduler,
)
def train_single_epoch(
model: DetectionModel,
train_loader: DataLoader,
optimizer: Adam,
device: torch.device,
scheduler: CosineAnnealingLR,
):
model.train()
train_loss = tu.AverageMeter()
for batch in train_loader:
optimizer.zero_grad()
spec = batch.spec.to(device)
detection_heatmap = batch.detection_heatmap.to(device)
class_heatmap = batch.class_heatmap.to(device)
size_heatmap = batch.size_heatmap.to(device)
outputs = model(spec)
loss = loss_fun(
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
class_inv_freq,
)
train_loss.update(loss.item(), data.shape[0])
loss.backward()
optimizer.step()
scheduler.step()