Compare commits

..

2 Commits

Author SHA1 Message Date
mbsantiago
74c419f674 Update default config 2025-09-10 21:49:57 +01:00
mbsantiago
e65d5a6846 Added more clipping options for validation 2025-09-10 21:09:51 +01:00
5 changed files with 160 additions and 96 deletions

View File

@ -111,15 +111,19 @@ train:
trainer: trainer:
max_epochs: 5 max_epochs: 5
dataloaders: train_loader:
train: batch_size: 8
batch_size: 8 num_workers: 2
num_workers: 2 shuffle: True
shuffle: True clipping_strategy:
name: random_subclip
duration: 0.256
val: val_loader:
batch_size: 1 num_workers: 2
num_workers: 2 clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
loss: loss:
detection: detection:
@ -136,9 +140,7 @@ train:
weight: 0.1 weight: 0.1
logger: logger:
name: mlflow name: csv
tracking_uri: http://10.20.20.211:9000
log_model: true
augmentations: augmentations:
enabled: true enabled: true

View File

@ -1,25 +1,32 @@
from typing import List, Optional from typing import Annotated, List, Literal, Optional, Union
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
class ClipingConfig(BaseConfig): registry: Registry[ClipperProtocol] = Registry("clipper")
class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip"
duration: float = DEFAULT_TRAIN_CLIP_DURATION duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP max_empty: float = DEFAULT_MAX_EMPTY_CLIP
min_sound_event_overlap: float = 0 min_sound_event_overlap: float = 0
class Clipper: @registry.register(RandomClipConfig)
class RandomClip:
def __init__( def __init__(
self, self,
duration: float = 0.5, duration: float = 0.5,
@ -45,6 +52,14 @@ class Clipper:
min_sound_event_overlap=self.min_sound_event_overlap, min_sound_event_overlap=self.min_sound_event_overlap,
) )
@classmethod
def from_config(cls, config: RandomClipConfig):
return cls(
duration=config.duration,
max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap,
)
def get_subclip_annotation( def get_subclip_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
@ -136,17 +151,46 @@ def select_sound_event_annotations(
return selected return selected
def build_clipper( class PaddedClipConfig(BaseConfig):
config: Optional[ClipingConfig] = None, name: Literal["whole_audio_padded"] = "whole_audio_padded"
random: Optional[bool] = None, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
) -> ClipperProtocol:
config = config or ClipingConfig()
@registry.register(PaddedClipConfig)
class PaddedClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
duration = clip.duration
target_duration = self.duration * np.ceil(duration / self.duration)
clip = clip.model_copy(
update=dict(
end_time=clip.start_time + target_duration,
)
)
return clip_annotation.model_copy(update=dict(clip=clip))
@classmethod
def from_config(cls, config: PaddedClipConfig):
return cls(duration=config.chunk_size)
ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
config = config or RandomClipConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building clipper with config: \n{}", "Building clipper with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return Clipper( return registry.build(config)
duration=config.duration,
max_empty=config.max_empty,
random=config.random if random else False,
)

View File

@ -10,7 +10,11 @@ from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG, DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig, AugmentationsConfig,
) )
from batdetect2.train.clips import ClipingConfig from batdetect2.train.clips import (
ClipConfig,
PaddedClipConfig,
RandomClipConfig,
)
from batdetect2.train.labels import LabelConfig from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
@ -44,34 +48,39 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None val_check_interval: Optional[Union[int, float]] = None
class DataLoaderConfig(BaseConfig): class ValLoaderConfig(BaseConfig):
batch_size: int = 8
shuffle: bool = False
num_workers: int = 0 num_workers: int = 0
clipping_strategy: ClipConfig = Field(
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True) default_factory=lambda: RandomClipConfig()
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
class LoadersConfig(BaseConfig):
train: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
) )
val: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
) )
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
learning_rate: float = 1e-3 learning_rate: float = 1e-3
t_max: int = 100 t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
augmentations: AugmentationsConfig = Field( cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig) labels: LabelConfig = Field(default_factory=LabelConfig)

View File

@ -87,6 +87,7 @@ class ValidationDataset(Dataset):
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
labeller: ClipLabeller, labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
): ):
self.clip_annotations = clip_annotations self.clip_annotations = clip_annotations
@ -94,14 +95,18 @@ class ValidationDataset(Dataset):
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.audio_dir = audio_dir self.audio_dir = audio_dir
self.clipper = clipper
def __len__(self): def __len__(self):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx] clip_annotation = self.clip_annotations[idx]
clip = clip_annotation.clip
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip( wav = self.audio_loader.load_clip(
clip_annotation.clip, clip_annotation.clip,
audio_dir=self.audio_dir, audio_dir=self.audio_dir,

View File

@ -24,7 +24,11 @@ from batdetect2.train.augmentations import (
) )
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.config import (
FullTrainingConfig,
TrainLoaderConfig,
ValLoaderConfig,
)
from batdetect2.train.dataset import TrainingDataset, ValidationDataset from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -85,7 +89,7 @@ def train(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train, config=config.train.train_loader,
num_workers=train_workers, num_workers=train_workers,
) )
@ -95,7 +99,7 @@ def train(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train, config=config.train.val_loader,
num_workers=val_workers, num_workers=val_workers,
) )
if val_annotations is not None if val_annotations is not None
@ -225,10 +229,17 @@ def build_train_loader(
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
config = config or TrainingConfig() config = config or TrainLoaderConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
train_dataset = build_train_dataset( train_dataset = build_train_dataset(
clip_annotations, clip_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
@ -237,17 +248,11 @@ def build_train_loader(
config=config, config=config,
) )
logger.info("Building training data loader...") num_workers = num_workers or config.num_workers
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
train_dataset, train_dataset,
batch_size=loader_conf.batch_size, batch_size=config.batch_size,
shuffle=loader_conf.shuffle, shuffle=config.shuffle,
num_workers=num_workers, num_workers=num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
) )
@ -258,11 +263,15 @@ def build_val_loader(
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
): ):
logger.info("Building validation data loader...") logger.info("Building validation data loader...")
config = config or TrainingConfig() config = config or ValLoaderConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
val_dataset = build_val_dataset( val_dataset = build_val_dataset(
clip_annotations, clip_annotations,
@ -271,56 +280,28 @@ def build_val_loader(
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, config=config,
) )
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug( num_workers = num_workers or config.num_workers
"Validation data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
val_dataset, val_dataset,
batch_size=1, batch_size=1,
shuffle=loader_conf.shuffle, shuffle=False,
num_workers=num_workers, num_workers=num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
) )
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)
def build_train_dataset( def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainLoaderConfig] = None,
) -> TrainingDataset: ) -> TrainingDataset:
logger.info("Building training dataset...") logger.info("Building training dataset...")
config = config or TrainingConfig() config = config or TrainLoaderConfig()
clipper = build_clipper( clipper = build_clipper(config=config.clipping_strategy)
config=config.cliping,
random=True,
)
random_example_source = RandomAudioSource( random_example_source = RandomAudioSource(
clip_annotations, clip_annotations,
@ -354,14 +335,37 @@ def build_val_dataset(
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[ValLoaderConfig] = None,
) -> ValidationDataset: ) -> ValidationDataset:
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or TrainingConfig() config = config or ValLoaderConfig()
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset( return ValidationDataset(
clip_annotations, clip_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
) )