mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Fix clip missalignment in validation dataset
This commit is contained in:
parent
74c419f674
commit
4fd2e84773
@ -126,7 +126,7 @@ class ValidationMetrics(Callback):
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
dataset.get_clip_annotation(int(example_idx))
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
|
||||
@ -158,8 +158,8 @@ class PaddedClipConfig(BaseConfig):
|
||||
|
||||
@registry.register(PaddedClipConfig)
|
||||
class PaddedClip:
|
||||
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||
self.duration = duration
|
||||
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||
self.chunk_size = chunk_size
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -168,7 +168,9 @@ class PaddedClip:
|
||||
clip = clip_annotation.clip
|
||||
duration = clip.duration
|
||||
|
||||
target_duration = self.duration * np.ceil(duration / self.duration)
|
||||
target_duration = float(
|
||||
self.chunk_size * np.ceil(duration / self.chunk_size)
|
||||
)
|
||||
clip = clip.model_copy(
|
||||
update=dict(
|
||||
end_time=clip.start_time + target_duration,
|
||||
@ -178,7 +180,7 @@ class PaddedClip:
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: PaddedClipConfig):
|
||||
return cls(duration=config.chunk_size)
|
||||
return cls(chunk_size=config.chunk_size)
|
||||
|
||||
|
||||
ClipConfig = Annotated[
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
@ -101,20 +101,10 @@ class ValidationDataset(Dataset):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
wav, clip_annotation = self.load_audio(idx)
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(
|
||||
clip_annotation.clip,
|
||||
audio_dir=self.audio_dir,
|
||||
)
|
||||
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
spectrogram = self.preprocessor(wav)
|
||||
|
||||
heatmaps = self.labeller(clip_annotation, spectrogram)
|
||||
|
||||
@ -127,3 +117,17 @@ class ValidationDataset(Dataset):
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
return clip_annotation
|
||||
|
||||
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||
clip_annotation = self.get_clip_annotation(idx)
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
return torch.tensor(wav).unsqueeze(0), clip_annotation
|
||||
|
||||
Loading…
Reference in New Issue
Block a user