mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Fix plotting
This commit is contained in:
parent
ad5293e0d0
commit
d80377981e
@ -66,7 +66,7 @@ def evaluate(
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
clips=[
|
||||
start_times=[
|
||||
clip_annotation.clip for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
|
||||
@ -100,7 +100,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_positives[:n_examples]):
|
||||
@ -112,7 +112,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(false_negatives[:n_examples]):
|
||||
@ -124,7 +124,7 @@ def plot_class_examples(
|
||||
preprocessor=preprocessor,
|
||||
duration=duration,
|
||||
)
|
||||
except (ValueError, AssertionError, RuntimeError):
|
||||
except (ValueError, AssertionError, RuntimeError, FileNotFoundError):
|
||||
continue
|
||||
|
||||
for index, match in enumerate(cross_triggers[:n_examples]):
|
||||
|
||||
@ -191,7 +191,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
@ -203,7 +203,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if clips is None:
|
||||
if start_times is None:
|
||||
return detections
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
@ -211,12 +211,12 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=clip.start_time,
|
||||
end_time=clip.start_time + duration,
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection, clip in zip(detections, clips)
|
||||
for detection, start_time in zip(detections, start_times)
|
||||
]
|
||||
|
||||
|
||||
@ -224,28 +224,10 @@ def get_raw_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch.
|
||||
|
||||
Processes raw model output through remapping, NMS, detection, data
|
||||
extraction, and geometry recovery via the configured
|
||||
`targets.recover_roi`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[List[RawPrediction]]
|
||||
List of lists (one inner list per input clip). Each inner list
|
||||
contains `RawPrediction` objects for detections in that clip.
|
||||
"""
|
||||
detections = postprocessor.get_detections(output, clips)
|
||||
"""Extract intermediate RawPrediction objects for a batch."""
|
||||
detections = postprocessor.get_detections(output, start_times)
|
||||
return [
|
||||
to_raw_predictions(detection.numpy(), targets=targets)
|
||||
for detection in detections
|
||||
@ -254,16 +236,16 @@ def get_raw_predictions(
|
||||
|
||||
def get_sound_event_predictions(
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: List[data.Clip],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
clips=clips,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
[
|
||||
@ -312,7 +294,7 @@ def get_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
clips=clips,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
|
||||
@ -126,14 +126,15 @@ class ValidationMetrics(Callback):
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
dataset.get_clip_annotation(int(example_idx))
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
clips=[
|
||||
clip_annotation.clip for clip_annotation in clip_annotations
|
||||
start_times=[
|
||||
clip_annotation.clip.start_time
|
||||
for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
|
||||
@ -101,8 +101,15 @@ class ValidationDataset(Dataset):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
wav, clip_annotation = self.load_audio(idx)
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
clip = clip_annotation.clip
|
||||
wav = torch.tensor(
|
||||
self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
).unsqueeze(0)
|
||||
|
||||
spectrogram = self.preprocessor(wav)
|
||||
|
||||
@ -117,17 +124,3 @@ class ValidationDataset(Dataset):
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
def get_clip_annotation(self, idx: int) -> data.ClipAnnotation:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
return clip_annotation
|
||||
|
||||
def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]:
|
||||
clip_annotation = self.get_clip_annotation(idx)
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
return torch.tensor(wav).unsqueeze(0), clip_annotation
|
||||
|
||||
@ -97,5 +97,5 @@ class PostprocessorProtocol(Protocol):
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
clips: Optional[List[data.Clip]] = None,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user