Fix clip missalignment in validation dataset

This commit is contained in:
mbsantiago 2025-09-11 11:09:59 +01:00
parent 74c419f674
commit 4fd2e84773
3 changed files with 24 additions and 18 deletions

View File

@ -126,7 +126,7 @@ class ValidationMetrics(Callback):
dataset = self.get_dataset(trainer) dataset = self.get_dataset(trainer)
clip_annotations = [ clip_annotations = [
dataset.clip_annotations[int(example_idx)] dataset.get_clip_annotation(int(example_idx))
for example_idx in batch.idx for example_idx in batch.idx
] ]

View File

@ -158,8 +158,8 @@ class PaddedClipConfig(BaseConfig):
@registry.register(PaddedClipConfig) @registry.register(PaddedClipConfig)
class PaddedClip: class PaddedClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION): def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration self.chunk_size = chunk_size
def __call__( def __call__(
self, self,
@ -168,7 +168,9 @@ class PaddedClip:
clip = clip_annotation.clip clip = clip_annotation.clip
duration = clip.duration duration = clip.duration
target_duration = self.duration * np.ceil(duration / self.duration) target_duration = float(
self.chunk_size * np.ceil(duration / self.chunk_size)
)
clip = clip.model_copy( clip = clip.model_copy(
update=dict( update=dict(
end_time=clip.start_time + target_duration, end_time=clip.start_time + target_duration,
@ -178,7 +180,7 @@ class PaddedClip:
@classmethod @classmethod
def from_config(cls, config: PaddedClipConfig): def from_config(cls, config: PaddedClipConfig):
return cls(duration=config.chunk_size) return cls(chunk_size=config.chunk_size)
ClipConfig = Annotated[ ClipConfig = Annotated[

View File

@ -1,4 +1,4 @@
from typing import Optional, Sequence from typing import Optional, Sequence, Tuple
import torch import torch
from soundevent import data from soundevent import data
@ -101,20 +101,10 @@ class ValidationDataset(Dataset):
return len(self.clip_annotations) return len(self.clip_annotations)
def __getitem__(self, idx) -> TrainExample: def __getitem__(self, idx) -> TrainExample:
clip_annotation = self.clip_annotations[idx] wav, clip_annotation = self.load_audio(idx)
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip clip = clip_annotation.clip
wav = self.audio_loader.load_clip(
clip_annotation.clip,
audio_dir=self.audio_dir,
)
wav_tensor = torch.tensor(wav).unsqueeze(0) spectrogram = self.preprocessor(wav)
spectrogram = self.preprocessor(wav_tensor)
heatmaps = self.labeller(clip_annotation, spectrogram) heatmaps = self.labeller(clip_annotation, spectrogram)
@ -127,3 +117,17 @@ class ValidationDataset(Dataset):
start_time=torch.tensor(clip.start_time), start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time), end_time=torch.tensor(clip.end_time),
) )
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
return clip_annotation
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
clip_annotation = self.get_clip_annotation(idx)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
return torch.tensor(wav).unsqueeze(0), clip_annotation