Compare commits

...

2 Commits

Author SHA1 Message Date
mbsantiago
74c419f674 Update default config 2025-09-10 21:49:57 +01:00
mbsantiago
e65d5a6846 Added more clipping options for validation 2025-09-10 21:09:51 +01:00
5 changed files with 160 additions and 96 deletions

View File

@ -111,15 +111,19 @@ train:
trainer:
max_epochs: 5
dataloaders:
train:
batch_size: 8
num_workers: 2
shuffle: True
train_loader:
batch_size: 8
num_workers: 2
shuffle: True
clipping_strategy:
name: random_subclip
duration: 0.256
val:
batch_size: 1
num_workers: 2
val_loader:
num_workers: 2
clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
loss:
detection:
@ -136,9 +140,7 @@ train:
weight: 0.1
logger:
name: mlflow
tracking_uri: http://10.20.20.211:9000
log_model: true
name: csv
augmentations:
enabled: true

View File

@ -1,25 +1,32 @@
from typing import List, Optional
from typing import Annotated, List, Literal, Optional, Union
import numpy as np
from loguru import logger
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1
class ClipingConfig(BaseConfig):
registry: Registry[ClipperProtocol] = Registry("clipper")
class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip"
duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
min_sound_event_overlap: float = 0
class Clipper:
@registry.register(RandomClipConfig)
class RandomClip:
def __init__(
self,
duration: float = 0.5,
@ -45,6 +52,14 @@ class Clipper:
min_sound_event_overlap=self.min_sound_event_overlap,
)
@classmethod
def from_config(cls, config: RandomClipConfig):
return cls(
duration=config.duration,
max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap,
)
def get_subclip_annotation(
clip_annotation: data.ClipAnnotation,
@ -136,17 +151,46 @@ def select_sound_event_annotations(
return selected
def build_clipper(
config: Optional[ClipingConfig] = None,
random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig()
class PaddedClipConfig(BaseConfig):
name: Literal["whole_audio_padded"] = "whole_audio_padded"
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
@registry.register(PaddedClipConfig)
class PaddedClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
duration = clip.duration
target_duration = self.duration * np.ceil(duration / self.duration)
clip = clip.model_copy(
update=dict(
end_time=clip.start_time + target_duration,
)
)
return clip_annotation.model_copy(update=dict(clip=clip))
@classmethod
def from_config(cls, config: PaddedClipConfig):
return cls(duration=config.chunk_size)
ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
config = config or RandomClipConfig()
logger.opt(lazy=True).debug(
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return Clipper(
duration=config.duration,
max_empty=config.max_empty,
random=config.random if random else False,
)
return registry.build(config)

View File

@ -10,7 +10,11 @@ from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import ClipingConfig
from batdetect2.train.clips import (
ClipConfig,
PaddedClipConfig,
RandomClipConfig,
)
from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig
@ -44,34 +48,39 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None
class DataLoaderConfig(BaseConfig):
batch_size: int = 8
shuffle: bool = False
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
class LoadersConfig(BaseConfig):
train: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
clipping_strategy: ClipConfig = Field(
default_factory=lambda: RandomClipConfig()
)
val: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
class TrainingConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
loss: LossConfig = Field(default_factory=LossConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)

View File

@ -87,6 +87,7 @@ class ValidationDataset(Dataset):
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = clip_annotations
@ -94,14 +95,18 @@ class ValidationDataset(Dataset):
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
self.clipper = clipper
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx]
clip = clip_annotation.clip
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(
clip_annotation.clip,
audio_dir=self.audio_dir,

View File

@ -24,7 +24,11 @@ from batdetect2.train.augmentations import (
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.config import (
FullTrainingConfig,
TrainLoaderConfig,
ValLoaderConfig,
)
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
@ -85,7 +89,7 @@ def train(
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train,
config=config.train.train_loader,
num_workers=train_workers,
)
@ -95,7 +99,7 @@ def train(
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train,
config=config.train.val_loader,
num_workers=val_workers,
)
if val_annotations is not None
@ -225,10 +229,17 @@ def build_train_loader(
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainingConfig()
config = config or TrainLoaderConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
train_dataset = build_train_dataset(
clip_annotations,
audio_loader=audio_loader,
@ -237,17 +248,11 @@ def build_train_loader(
config=config,
)
logger.info("Building training data loader...")
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
num_workers = num_workers or config.num_workers
return DataLoader(
train_dataset,
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
@ -258,11 +263,15 @@ def build_val_loader(
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
config = config or TrainingConfig()
config = config or ValLoaderConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
val_dataset = build_val_dataset(
clip_annotations,
@ -271,56 +280,28 @@ def build_val_loader(
preprocessor=preprocessor,
config=config,
)
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
num_workers = num_workers or config.num_workers
return DataLoader(
val_dataset,
batch_size=1,
shuffle=loader_conf.shuffle,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)
def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
config: Optional[TrainLoaderConfig] = None,
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainingConfig()
config = config or TrainLoaderConfig()
clipper = build_clipper(
config=config.cliping,
random=True,
)
clipper = build_clipper(config=config.clipping_strategy)
random_example_source = RandomAudioSource(
clip_annotations,
@ -354,14 +335,37 @@ def build_val_dataset(
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
config: Optional[ValLoaderConfig] = None,
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig()
config = config or ValLoaderConfig()
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)