mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Added more clipping options for validation
This commit is contained in:
parent
615c811bb4
commit
e65d5a6846
@ -1,25 +1,32 @@
|
||||
from typing import List, Optional
|
||||
from typing import Annotated, List, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds, intervals_overlap
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.data._core import Registry
|
||||
from batdetect2.typing import ClipperProtocol
|
||||
|
||||
DEFAULT_TRAIN_CLIP_DURATION = 0.256
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
class ClipingConfig(BaseConfig):
|
||||
registry: Registry[ClipperProtocol] = Registry("clipper")
|
||||
|
||||
|
||||
class RandomClipConfig(BaseConfig):
|
||||
name: Literal["random_subclip"] = "random_subclip"
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
random: bool = True
|
||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||
min_sound_event_overlap: float = 0
|
||||
|
||||
|
||||
class Clipper:
|
||||
@registry.register(RandomClipConfig)
|
||||
class RandomClip:
|
||||
def __init__(
|
||||
self,
|
||||
duration: float = 0.5,
|
||||
@ -45,6 +52,14 @@ class Clipper:
|
||||
min_sound_event_overlap=self.min_sound_event_overlap,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: RandomClipConfig):
|
||||
return cls(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
min_sound_event_overlap=config.min_sound_event_overlap,
|
||||
)
|
||||
|
||||
|
||||
def get_subclip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
@ -136,17 +151,46 @@ def select_sound_event_annotations(
|
||||
return selected
|
||||
|
||||
|
||||
def build_clipper(
|
||||
config: Optional[ClipingConfig] = None,
|
||||
random: Optional[bool] = None,
|
||||
) -> ClipperProtocol:
|
||||
config = config or ClipingConfig()
|
||||
class PaddedClipConfig(BaseConfig):
|
||||
name: Literal["whole_audio_padded"] = "whole_audio_padded"
|
||||
chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
|
||||
|
||||
@registry.register(PaddedClipConfig)
|
||||
class PaddedClip:
|
||||
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||
self.duration = duration
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
clip = clip_annotation.clip
|
||||
duration = clip.duration
|
||||
|
||||
target_duration = self.duration * np.ceil(duration / self.duration)
|
||||
clip = clip.model_copy(
|
||||
update=dict(
|
||||
end_time=clip.start_time + target_duration,
|
||||
)
|
||||
)
|
||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PaddedClipConfig):
|
||||
return cls(duration=config.chunk_size)
|
||||
|
||||
|
||||
ClipConfig = Annotated[
|
||||
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipConfig] = None) -> ClipperProtocol:
|
||||
config = config or RandomClipConfig()
|
||||
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building clipper with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return Clipper(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
random=config.random if random else False,
|
||||
)
|
||||
return registry.build(config)
|
||||
|
||||
@ -10,7 +10,11 @@ from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
)
|
||||
from batdetect2.train.clips import ClipingConfig
|
||||
from batdetect2.train.clips import (
|
||||
ClipConfig,
|
||||
PaddedClipConfig,
|
||||
RandomClipConfig,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig
|
||||
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
@ -44,34 +48,39 @@ class PLTrainerConfig(BaseConfig):
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
class DataLoaderConfig(BaseConfig):
|
||||
batch_size: int = 8
|
||||
shuffle: bool = False
|
||||
class ValLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
|
||||
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
|
||||
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=1, shuffle=False)
|
||||
|
||||
|
||||
class LoadersConfig(BaseConfig):
|
||||
train: DataLoaderConfig = Field(
|
||||
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: RandomClipConfig()
|
||||
)
|
||||
val: DataLoaderConfig = Field(
|
||||
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
|
||||
|
||||
|
||||
class TrainLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
batch_size: int = 8
|
||||
|
||||
shuffle: bool = False
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||
)
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
|
||||
|
||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||
)
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
@ -87,6 +87,7 @@ class ValidationDataset(Dataset):
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
labeller: ClipLabeller,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
self.clip_annotations = clip_annotations
|
||||
@ -94,14 +95,18 @@ class ValidationDataset(Dataset):
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.audio_dir = audio_dir
|
||||
self.clipper = clipper
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
clip = clip_annotation.clip
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(
|
||||
clip_annotation.clip,
|
||||
audio_dir=self.audio_dir,
|
||||
|
||||
@ -24,7 +24,11 @@ from batdetect2.train.augmentations import (
|
||||
)
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
TrainLoaderConfig,
|
||||
ValLoaderConfig,
|
||||
)
|
||||
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -85,7 +89,7 @@ def train(
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
config=config.train.train_loader,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
|
||||
@ -95,7 +99,7 @@ def train(
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
config=config.train.val_loader,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
if val_annotations is not None
|
||||
@ -225,10 +229,17 @@ def build_train_loader(
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
config: Optional[TrainLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
config = config or TrainingConfig()
|
||||
config = config or TrainLoaderConfig()
|
||||
|
||||
logger.info("Building training data loader...")
|
||||
logger.opt(lazy=True).debug(
|
||||
"Training data loader config: \n{config}",
|
||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
train_dataset = build_train_dataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
@ -237,17 +248,11 @@ def build_train_loader(
|
||||
config=config,
|
||||
)
|
||||
|
||||
logger.info("Building training data loader...")
|
||||
loader_conf = config.dataloaders.train
|
||||
logger.opt(lazy=True).debug(
|
||||
"Training data loader config: \n{config}",
|
||||
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
num_workers = num_workers or loader_conf.num_workers
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=loader_conf.batch_size,
|
||||
shuffle=loader_conf.shuffle,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=config.shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
@ -258,11 +263,15 @@ def build_val_loader(
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
config: Optional[ValLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
logger.info("Building validation data loader...")
|
||||
config = config or TrainingConfig()
|
||||
config = config or ValLoaderConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Validation data loader config: \n{config}",
|
||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
val_dataset = build_val_dataset(
|
||||
clip_annotations,
|
||||
@ -271,56 +280,28 @@ def build_val_loader(
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
loader_conf = config.dataloaders.val
|
||||
logger.opt(lazy=True).debug(
|
||||
"Validation data loader config: \n{config}",
|
||||
config=lambda: loader_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
num_workers = num_workers or loader_conf.num_workers
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
shuffle=loader_conf.shuffle,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return TrainExample(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
detection_heatmap=torch.stack(
|
||||
[adjust_width(item.detection_heatmap, max_width) for item in batch]
|
||||
),
|
||||
size_heatmap=torch.stack(
|
||||
[adjust_width(item.size_heatmap, max_width) for item in batch]
|
||||
),
|
||||
class_heatmap=torch.stack(
|
||||
[adjust_width(item.class_heatmap, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
|
||||
|
||||
def build_train_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
config: Optional[TrainLoaderConfig] = None,
|
||||
) -> TrainingDataset:
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TrainingConfig()
|
||||
config = config or TrainLoaderConfig()
|
||||
|
||||
clipper = build_clipper(
|
||||
config=config.cliping,
|
||||
random=True,
|
||||
)
|
||||
clipper = build_clipper(config=config.clipping_strategy)
|
||||
|
||||
random_example_source = RandomAudioSource(
|
||||
clip_annotations,
|
||||
@ -354,14 +335,37 @@ def build_val_dataset(
|
||||
audio_loader: AudioLoader,
|
||||
labeller: ClipLabeller,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
config: Optional[ValLoaderConfig] = None,
|
||||
) -> ValidationDataset:
|
||||
logger.info("Building validation dataset...")
|
||||
config = config or TrainingConfig()
|
||||
config = config or ValLoaderConfig()
|
||||
|
||||
clipper = build_clipper(config.clipping_strategy)
|
||||
return ValidationDataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
clipper=clipper,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return TrainExample(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
detection_heatmap=torch.stack(
|
||||
[adjust_width(item.detection_heatmap, max_width) for item in batch]
|
||||
),
|
||||
size_heatmap=torch.stack(
|
||||
[adjust_width(item.size_heatmap, max_width) for item in batch]
|
||||
),
|
||||
class_heatmap=torch.stack(
|
||||
[adjust_width(item.class_heatmap, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user