diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index a4cef4d..75c2c93 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -66,7 +66,7 @@ def evaluate( predictions = get_raw_predictions( outputs, - clips=[ + start_times=[ clip_annotation.clip for clip_annotation in clip_annotations ], targets=targets, diff --git a/src/batdetect2/plotting/evaluation.py b/src/batdetect2/plotting/evaluation.py index 3e01b08..6e07f1e 100644 --- a/src/batdetect2/plotting/evaluation.py +++ b/src/batdetect2/plotting/evaluation.py @@ -100,7 +100,7 @@ def plot_class_examples( preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError, RuntimeError): + except (ValueError, AssertionError, RuntimeError, FileNotFoundError): continue for index, match in enumerate(false_positives[:n_examples]): @@ -112,7 +112,7 @@ def plot_class_examples( preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError, RuntimeError): + except (ValueError, AssertionError, RuntimeError, FileNotFoundError): continue for index, match in enumerate(false_negatives[:n_examples]): @@ -124,7 +124,7 @@ def plot_class_examples( preprocessor=preprocessor, duration=duration, ) - except (ValueError, AssertionError, RuntimeError): + except (ValueError, AssertionError, RuntimeError, FileNotFoundError): continue for index, match in enumerate(cross_triggers[:n_examples]): diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index fec9c91..e1be16b 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -191,7 +191,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): def get_detections( self, output: ModelOutput, - clips: Optional[List[data.Clip]] = None, + start_times: Optional[List[float]] = None, ) -> List[DetectionsTensor]: width = output.detection_probs.shape[-1] duration = width / self.samplerate @@ -203,7 +203,7 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): threshold=self.detection_threshold, ) - if clips is None: + if start_times is None: return detections width = output.detection_probs.shape[-1] @@ -211,12 +211,12 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol): return [ map_detection_to_clip( detection, - start_time=clip.start_time, - end_time=clip.start_time + duration, + start_time=start_time, + end_time=start_time + duration, min_freq=self.min_freq, max_freq=self.max_freq, ) - for detection, clip in zip(detections, clips) + for detection, start_time in zip(detections, start_times) ] @@ -224,28 +224,10 @@ def get_raw_predictions( output: ModelOutput, targets: TargetProtocol, postprocessor: PostprocessorProtocol, - clips: Optional[List[data.Clip]] = None, + start_times: Optional[List[float]] = None, ) -> List[List[RawPrediction]]: - """Extract intermediate RawPrediction objects for a batch. - - Processes raw model output through remapping, NMS, detection, data - extraction, and geometry recovery via the configured - `targets.recover_roi`. - - Parameters - ---------- - output : ModelOutput - Raw output from the neural network model for a batch. - clips : List[data.Clip] - List of `soundevent.data.Clip` objects corresponding to the batch. - - Returns - ------- - List[List[RawPrediction]] - List of lists (one inner list per input clip). Each inner list - contains `RawPrediction` objects for detections in that clip. - """ - detections = postprocessor.get_detections(output, clips) + """Extract intermediate RawPrediction objects for a batch.""" + detections = postprocessor.get_detections(output, start_times) return [ to_raw_predictions(detection.numpy(), targets=targets) for detection in detections @@ -254,16 +236,16 @@ def get_raw_predictions( def get_sound_event_predictions( output: ModelOutput, - clips: List[data.Clip], targets: TargetProtocol, postprocessor: PostprocessorProtocol, + clips: List[data.Clip], classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, ) -> List[List[BatDetect2Prediction]]: raw_predictions = get_raw_predictions( output, targets=targets, postprocessor=postprocessor, - clips=clips, + start_times=[clip.start_time for clip in clips], ) return [ [ @@ -312,7 +294,7 @@ def get_predictions( output, targets=targets, postprocessor=postprocessor, - clips=clips, + start_times=[clip.start_time for clip in clips], ) return [ convert_raw_predictions_to_clip_prediction( diff --git a/src/batdetect2/train/callbacks.py b/src/batdetect2/train/callbacks.py index bb182b7..c45bf23 100644 --- a/src/batdetect2/train/callbacks.py +++ b/src/batdetect2/train/callbacks.py @@ -126,14 +126,15 @@ class ValidationMetrics(Callback): dataset = self.get_dataset(trainer) clip_annotations = [ - dataset.get_clip_annotation(int(example_idx)) + dataset.clip_annotations[int(example_idx)] for example_idx in batch.idx ] predictions = get_raw_predictions( outputs, - clips=[ - clip_annotation.clip for clip_annotation in clip_annotations + start_times=[ + clip_annotation.clip.start_time + for clip_annotation in clip_annotations ], targets=targets, postprocessor=postprocessor, diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 16e76ad..e72f75d 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -101,8 +101,15 @@ class ValidationDataset(Dataset): return len(self.clip_annotations) def __getitem__(self, idx) -> TrainExample: - wav, clip_annotation = self.load_audio(idx) + clip_annotation = self.clip_annotations[idx] + + if self.clipper is not None: + clip_annotation = self.clipper(clip_annotation) + clip = clip_annotation.clip + wav = torch.tensor( + self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) + ).unsqueeze(0) spectrogram = self.preprocessor(wav) @@ -117,17 +124,3 @@ class ValidationDataset(Dataset): start_time=torch.tensor(clip.start_time), end_time=torch.tensor(clip.end_time), ) - - def get_clip_annotation(self, idx: int) -> data.ClipAnnotation: - clip_annotation = self.clip_annotations[idx] - - if self.clipper is not None: - clip_annotation = self.clipper(clip_annotation) - - return clip_annotation - - def load_audio(self, idx: int) -> Tuple[torch.Tensor, data.ClipAnnotation]: - clip_annotation = self.get_clip_annotation(idx) - clip = clip_annotation.clip - wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) - return torch.tensor(wav).unsqueeze(0), clip_annotation diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index 0654853..5f2addf 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -97,5 +97,5 @@ class PostprocessorProtocol(Protocol): def get_detections( self, output: ModelOutput, - clips: Optional[List[data.Clip]] = None, + start_times: Optional[List[float]] = None, ) -> List[DetectionsTensor]: ...