From 4fd2e84773f7003cdd5a49aec26347fe4e687bbb Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 11 Sep 2025 11:09:59 +0100 Subject: [PATCH] Fix clip missalignment in validation dataset --- src/batdetect2/train/callbacks.py | 2 +- src/batdetect2/train/clips.py | 10 ++++++---- src/batdetect2/train/dataset.py | 30 +++++++++++++++++------------- 3 files changed, 24 insertions(+), 18 deletions(-) diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index 5108469..bb182b7 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -126,7 +126,7 @@ class ValidationMetrics(Callback): dataset = self.get_dataset(trainer) clip_annotations = [ - dataset.clip_annotations[int(example_idx)] + dataset.get_clip_annotation(int(example_idx)) for example_idx in batch.idx ] diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 0ebd203..6333ebb 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -158,8 +158,8 @@ class PaddedClipConfig(BaseConfig): @registry.register(PaddedClipConfig) class PaddedClip: - def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION): - self.duration = duration + def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION): + self.chunk_size = chunk_size def __call__( self, @@ -168,7 +168,9 @@ class PaddedClip: clip = clip_annotation.clip duration = clip.duration - target_duration = self.duration * np.ceil(duration / self.duration) + target_duration = float( + self.chunk_size * np.ceil(duration / self.chunk_size) + ) clip = clip.model_copy( update=dict( end_time=clip.start_time + target_duration, @@ -178,7 +180,7 @@ class PaddedClip: @classmethod def from_config(cls, config: PaddedClipConfig): - return cls(duration=config.chunk_size) + return cls(chunk_size=config.chunk_size) ClipConfig = Annotated[ diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 865c1aa..16e76ad 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence +from typing import Optional, Sequence, Tuple import torch from soundevent import data @@ -101,20 +101,10 @@ class ValidationDataset(Dataset): return len(self.clip_annotations) def __getitem__(self, idx) -> TrainExample: - clip_annotation = self.clip_annotations[idx] - - if self.clipper is not None: - clip_annotation = self.clipper(clip_annotation) - + wav, clip_annotation = self.load_audio(idx) clip = clip_annotation.clip - wav = self.audio_loader.load_clip( - clip_annotation.clip, - audio_dir=self.audio_dir, - ) - wav_tensor = torch.tensor(wav).unsqueeze(0) - - spectrogram = self.preprocessor(wav_tensor) + spectrogram = self.preprocessor(wav) heatmaps = self.labeller(clip_annotation, spectrogram) @@ -127,3 +117,17 @@ class ValidationDataset(Dataset): start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) + + def get_clip_annotation(self, idx: int) -> data.ClipAnnotation: + clip_annotation = self.clip_annotations[idx] + + if self.clipper is not None: + clip_annotation = self.clipper(clip_annotation) + + return clip_annotation + + def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]: + clip_annotation = self.get_clip_annotation(idx) + clip = clip_annotation.clip + wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) + return torch.tensor(wav).unsqueeze(0), clip_annotation