mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Fix clip missalignment in validation dataset
This commit is contained in:
parent
74c419f674
commit
4fd2e84773
@ -126,7 +126,7 @@ class ValidationMetrics(Callback):
|
|||||||
dataset = self.get_dataset(trainer)
|
dataset = self.get_dataset(trainer)
|
||||||
|
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
dataset.clip_annotations[int(example_idx)]
|
dataset.get_clip_annotation(int(example_idx))
|
||||||
for example_idx in batch.idx
|
for example_idx in batch.idx
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -158,8 +158,8 @@ class PaddedClipConfig(BaseConfig):
|
|||||||
|
|
||||||
@registry.register(PaddedClipConfig)
|
@registry.register(PaddedClipConfig)
|
||||||
class PaddedClip:
|
class PaddedClip:
|
||||||
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
|
def __init__(self, chunk_size: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||||
self.duration = duration
|
self.chunk_size = chunk_size
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -168,7 +168,9 @@ class PaddedClip:
|
|||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
duration = clip.duration
|
duration = clip.duration
|
||||||
|
|
||||||
target_duration = self.duration * np.ceil(duration / self.duration)
|
target_duration = float(
|
||||||
|
self.chunk_size * np.ceil(duration / self.chunk_size)
|
||||||
|
)
|
||||||
clip = clip.model_copy(
|
clip = clip.model_copy(
|
||||||
update=dict(
|
update=dict(
|
||||||
end_time=clip.start_time + target_duration,
|
end_time=clip.start_time + target_duration,
|
||||||
@ -178,7 +180,7 @@ class PaddedClip:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: PaddedClipConfig):
|
def from_config(cls, config: PaddedClipConfig):
|
||||||
return cls(duration=config.chunk_size)
|
return cls(chunk_size=config.chunk_size)
|
||||||
|
|
||||||
|
|
||||||
ClipConfig = Annotated[
|
ClipConfig = Annotated[
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -101,20 +101,10 @@ class ValidationDataset(Dataset):
|
|||||||
return len(self.clip_annotations)
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
def __getitem__(self, idx) -> TrainExample:
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
clip_annotation = self.clip_annotations[idx]
|
wav, clip_annotation = self.load_audio(idx)
|
||||||
|
|
||||||
if self.clipper is not None:
|
|
||||||
clip_annotation = self.clipper(clip_annotation)
|
|
||||||
|
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
wav = self.audio_loader.load_clip(
|
|
||||||
clip_annotation.clip,
|
|
||||||
audio_dir=self.audio_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
spectrogram = self.preprocessor(wav)
|
||||||
|
|
||||||
spectrogram = self.preprocessor(wav_tensor)
|
|
||||||
|
|
||||||
heatmaps = self.labeller(clip_annotation, spectrogram)
|
heatmaps = self.labeller(clip_annotation, spectrogram)
|
||||||
|
|
||||||
@ -127,3 +117,17 @@ class ValidationDataset(Dataset):
|
|||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
|
||||||
|
clip_annotation = self.clip_annotations[idx]
|
||||||
|
|
||||||
|
if self.clipper is not None:
|
||||||
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
|
|
||||||
|
return clip_annotation
|
||||||
|
|
||||||
|
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||||
|
clip_annotation = self.get_clip_annotation(idx)
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
|
return torch.tensor(wav).unsqueeze(0), clip_annotation
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user