Compare commits

..

No commits in common. "74c419f67404a5b072b73ee00347191d65d27bd6" and "615c811bb43fd54baf71d6ac7d758ccf2cc351f6" have entirely different histories.

5 changed files with 95 additions and 159 deletions

View File

@ -111,19 +111,15 @@ train:
trainer:
max_epochs: 5
train_loader:
dataloaders:
train:
batch_size: 8
num_workers: 2
shuffle: True
clipping_strategy:
name: random_subclip
duration: 0.256
val_loader:
val:
batch_size: 1
num_workers: 2
clipping_strategy:
name: whole_audio_padded
chunk_size: 0.256
loss:
detection:
@ -140,7 +136,9 @@ train:
weight: 0.1
logger:
name: csv
name: mlflow
tracking_uri: http://10.20.20.211:9000
log_model: true
augmentations:
enabled: true

View File

@ -1,32 +1,25 @@
from typing import Annotated, List, Literal, Optional, Union
from typing import List, Optional
import numpy as np
from loguru import logger
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds, intervals_overlap
from batdetect2.configs import BaseConfig
from batdetect2.data._core import Registry
from batdetect2.typing import ClipperProtocol
DEFAULT_TRAIN_CLIP_DURATION = 0.256
DEFAULT_MAX_EMPTY_CLIP = 0.1
registry: Registry[ClipperProtocol] = Registry("clipper")
class RandomClipConfig(BaseConfig):
name: Literal["random_subclip"] = "random_subclip"
class ClipingConfig(BaseConfig):
duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
min_sound_event_overlap: float = 0
@registry.register(RandomClipConfig)
class RandomClip:
class Clipper:
def __init__(
self,
duration: float = 0.5,
@ -52,14 +45,6 @@ class RandomClip:
min_sound_event_overlap=self.min_sound_event_overlap,
)
@classmethod
def from_config(cls, config: RandomClipConfig):
return cls(
duration=config.duration,
max_empty=config.max_empty,
min_sound_event_overlap=config.min_sound_event_overlap,
)
def get_subclip_annotation(
clip_annotation: data.ClipAnnotation,
@ -151,46 +136,17 @@ def select_sound_event_annotations(
return selected
class PaddedClipConfig(BaseConfig):
name: Literal["whole_audio_padded"] = "whole_audio_padded"
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
@registry.register(PaddedClipConfig)
class PaddedClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
duration = clip.duration
target_duration = self.duration * np.ceil(duration / self.duration)
clip = clip.model_copy(
update=dict(
end_time=clip.start_time + target_duration,
)
)
return clip_annotation.model_copy(update=dict(clip=clip))
@classmethod
def from_config(cls, config: PaddedClipConfig):
return cls(duration=config.chunk_size)
ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
]
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
config = config or RandomClipConfig()
def build_clipper(
config: Optional[ClipingConfig] = None,
random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig()
logger.opt(lazy=True).debug(
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return registry.build(config)
return Clipper(
duration=config.duration,
max_empty=config.max_empty,
random=config.random if random else False,
)

View File

@ -10,11 +10,7 @@ from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import (
ClipConfig,
PaddedClipConfig,
RandomClipConfig,
)
from batdetect2.train.clips import ClipingConfig
from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig
@ -48,39 +44,34 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: RandomClipConfig()
)
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
class DataLoaderConfig(BaseConfig):
batch_size: int = 8
shuffle: bool = False
num_workers: int = 0
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
class LoadersConfig(BaseConfig):
train: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
val: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
)
class TrainingConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)

View File

@ -87,7 +87,6 @@ class ValidationDataset(Dataset):
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
labeller: ClipLabeller,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = clip_annotations
@ -95,18 +94,14 @@ class ValidationDataset(Dataset):
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
self.clipper = clipper
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(
clip_annotation.clip,
audio_dir=self.audio_dir,

View File

@ -24,11 +24,7 @@ from batdetect2.train.augmentations import (
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import (
FullTrainingConfig,
TrainLoaderConfig,
ValLoaderConfig,
)
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
@ -89,7 +85,7 @@ def train(
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train.train_loader,
config=config.train,
num_workers=train_workers,
)
@ -99,7 +95,7 @@ def train(
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train.val_loader,
config=config.train,
num_workers=val_workers,
)
if val_annotations is not None
@ -229,17 +225,10 @@ def build_train_loader(
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainLoaderConfig] = None,
config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainLoaderConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
config = config or TrainingConfig()
train_dataset = build_train_dataset(
clip_annotations,
audio_loader=audio_loader,
@ -248,11 +237,17 @@ def build_train_loader(
config=config,
)
num_workers = num_workers or config.num_workers
logger.info("Building training data loader...")
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=config.shuffle,
batch_size=loader_conf.batch_size,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
@ -263,15 +258,11 @@ def build_val_loader(
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[ValLoaderConfig] = None,
config: Optional[TrainingConfig] = None,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
config = config or ValLoaderConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
config = config or TrainingConfig()
val_dataset = build_val_dataset(
clip_annotations,
@ -280,28 +271,56 @@ def build_val_loader(
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
)
num_workers = num_workers or loader_conf.num_workers
return DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
shuffle=loader_conf.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)
def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[TrainLoaderConfig] = None,
config: Optional[TrainingConfig] = None,
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainLoaderConfig()
config = config or TrainingConfig()
clipper = build_clipper(config=config.clipping_strategy)
clipper = build_clipper(
config=config.cliping,
random=True,
)
random_example_source = RandomAudioSource(
clip_annotations,
@ -335,37 +354,14 @@ def build_val_dataset(
audio_loader: AudioLoader,
labeller: ClipLabeller,
preprocessor: PreprocessorProtocol,
config: Optional[ValLoaderConfig] = None,
config: Optional[TrainingConfig] = None,
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or ValLoaderConfig()
config = config or TrainingConfig()
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)