Compare commits

..

No commits in common. "74c419f67404a5b072b73ee00347191d65d27bd6" and "615c811bb43fd54baf71d6ac7d758ccf2cc351f6" have entirely different histories.

5 changed files with 95 additions and 159 deletions

View File

@ -111,19 +111,15 @@ train:
trainer: trainer:
max_epochs: 5 max_epochs: 5
train_loader: dataloaders:
batch_size: 8 train:
num_workers: 2 batch_size: 8
shuffle: True num_workers: 2
clipping_strategy: shuffle: True
name: random_subclip
duration: 0.256
val_loader: val:
num_workers: 2 batch_size: 1
clipping_strategy: num_workers: 2
name: whole_audio_padded
chunk_size: 0.256
loss: loss:
detection: detection:
@ -140,7 +136,9 @@ train:
weight: 0.1 weight: 0.1
logger: logger:
name: csv name: mlflow
tracking_uri: http://10.20.20.211:9000
log_model: true
augmentations: augmentations:
enabled: true enabled: true

View File

@ -1,32 +1,25 @@
from typing import Annotated, List, Literal, Optional, Union from typing import List, Optional
import numpy as np import numpy as np
from loguru import logger from loguru import logger
from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing import ClipperProtocol from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256 DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
registry: Registry[ClipperProtocol] = Registry("clipper") class ClipingConfig(BaseConfig):
class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip"
duration: float = DEFAULT_TRAIN_CLIP_DURATION duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP max_empty: float = DEFAULT_MAX_EMPTY_CLIP
min_sound_event_overlap: float = 0 min_sound_event_overlap: float = 0
@registry.register(RandomClipConfig) class Clipper:
class RandomClip:
def __init__( def __init__(
self, self,
duration: float = 0.5, duration: float = 0.5,
@ -52,14 +45,6 @@ class RandomClip:
min_sound_event_overlap=self.min_sound_event_overlap, min_sound_event_overlap=self.min_sound_event_overlap,
) )
@classmethod
def from_config(cls, config: RandomClipConfig):
return cls(
duration=config.duration,
max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap,
)
def get_subclip_annotation( def get_subclip_annotation(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
@ -151,46 +136,17 @@ def select_sound_event_annotations(
return selected return selected
class PaddedClipConfig(BaseConfig): def build_clipper(
name: Literal["whole_audio_padded"] = "whole_audio_padded" config: Optional[ClipingConfig] = None,
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig()
@registry.register(PaddedClipConfig)
class PaddedClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
duration = clip.duration
target_duration = self.duration * np.ceil(duration / self.duration)
clip = clip.model_copy(
update=dict(
end_time=clip.start_time + target_duration,
)
)
return clip_annotation.model_copy(update=dict(clip=clip))
@classmethod
def from_config(cls, config: PaddedClipConfig):
return cls(duration=config.chunk_size)
ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
config = config or RandomClipConfig()
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building clipper with config: \n{}", "Building clipper with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
return registry.build(config) return Clipper(
duration=config.duration,
max_empty=config.max_empty,
random=config.random if random else False,
)

View File

@ -10,11 +10,7 @@ from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG, DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig, AugmentationsConfig,
) )
from batdetect2.train.clips import ( from batdetect2.train.clips import ClipingConfig
ClipConfig,
PaddedClipConfig,
RandomClipConfig,
)
from batdetect2.train.labels import LabelConfig from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
@ -48,39 +44,34 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None val_check_interval: Optional[Union[int, float]] = None
class ValLoaderConfig(BaseConfig): class DataLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: RandomClipConfig()
)
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8 batch_size: int = 8
shuffle: bool = False shuffle: bool = False
num_workers: int = 0
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy() DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
class LoadersConfig(BaseConfig):
train: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
) )
val: DataLoaderConfig = Field(
clipping_strategy: ClipConfig = Field( default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
default_factory=lambda: PaddedClipConfig()
) )
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
learning_rate: float = 1e-3 learning_rate: float = 1e-3
t_max: int = 100 t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig) augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig) labels: LabelConfig = Field(default_factory=LabelConfig)

View File

@ -87,7 +87,6 @@ class ValidationDataset(Dataset):
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
labeller: ClipLabeller, labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None, audio_dir: Optional[data.PathLike] = None,
): ):
self.clip_annotations = clip_annotations self.clip_annotations = clip_annotations
@ -95,18 +94,14 @@ class ValidationDataset(Dataset):
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.audio_dir = audio_dir self.audio_dir = audio_dir
self.clipper = clipper
def __len__(self): def __len__(self):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx] clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip clip = clip_annotation.clip
wav = self.audio_loader.load_clip( wav = self.audio_loader.load_clip(
clip_annotation.clip, clip_annotation.clip,
audio_dir=self.audio_dir, audio_dir=self.audio_dir,

View File

@ -24,11 +24,7 @@ from batdetect2.train.augmentations import (
) )
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import ( from batdetect2.train.config import FullTrainingConfig, TrainingConfig
FullTrainingConfig,
TrainLoaderConfig,
ValLoaderConfig,
)
from batdetect2.train.dataset import TrainingDataset, ValidationDataset from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -89,7 +85,7 @@ def train(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train.train_loader, config=config.train,
num_workers=train_workers, num_workers=train_workers,
) )
@ -99,7 +95,7 @@ def train(
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.train.val_loader, config=config.train,
num_workers=val_workers, num_workers=val_workers,
) )
if val_annotations is not None if val_annotations is not None
@ -229,17 +225,10 @@ def build_train_loader(
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainLoaderConfig] = None, config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
config = config or TrainLoaderConfig() config = config or TrainingConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
train_dataset = build_train_dataset( train_dataset = build_train_dataset(
clip_annotations, clip_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
@ -248,11 +237,17 @@ def build_train_loader(
config=config, config=config,
) )
num_workers = num_workers or config.num_workers logger.info("Building training data loader...")
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=loader_conf.batch_size,
shuffle=config.shuffle, shuffle=loader_conf.shuffle,
num_workers=num_workers, num_workers=num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
) )
@ -263,15 +258,11 @@ def build_val_loader(
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[ValLoaderConfig] = None, config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
): ):
logger.info("Building validation data loader...") logger.info("Building validation data loader...")
config = config or ValLoaderConfig() config = config or TrainingConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
val_dataset = build_val_dataset( val_dataset = build_val_dataset(
clip_annotations, clip_annotations,
@ -280,28 +271,56 @@ def build_val_loader(
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, config=config,
) )
loader_conf = config.dataloaders.val
num_workers = num_workers or config.num_workers logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
val_dataset, val_dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=loader_conf.shuffle,
num_workers=num_workers, num_workers=num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
) )
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)
def build_train_dataset( def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainLoaderConfig] = None, config: Optional[TrainingConfig] = None,
) -> TrainingDataset: ) -> TrainingDataset:
logger.info("Building training dataset...") logger.info("Building training dataset...")
config = config or TrainLoaderConfig() config = config or TrainingConfig()
clipper = build_clipper(config=config.clipping_strategy) clipper = build_clipper(
config=config.cliping,
random=True,
)
random_example_source = RandomAudioSource( random_example_source = RandomAudioSource(
clip_annotations, clip_annotations,
@ -335,37 +354,14 @@ def build_val_dataset(
audio_loader: AudioLoader, audio_loader: AudioLoader,
labeller: ClipLabeller, labeller: ClipLabeller,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[ValLoaderConfig] = None, config: Optional[TrainingConfig] = None,
) -> ValidationDataset: ) -> ValidationDataset:
logger.info("Building validation dataset...") logger.info("Building validation dataset...")
config = config or ValLoaderConfig() config = config or TrainingConfig()
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset( return ValidationDataset(
clip_annotations, clip_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
) )