diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 9f013b8..d7a8733 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -303,6 +303,7 @@ class BatDetect2API: self, audio_file: data.PathLike, batch_size: int | None = None, + detection_threshold: float | None = None, ) -> ClipDetections: recording = data.Recording.from_file(audio_file, compute_hash=False) @@ -313,6 +314,7 @@ class BatDetect2API: if batch_size is not None else self.inference_config.loader.batch_size ), + detection_threshold=detection_threshold, ) detections = [ detection @@ -333,14 +335,19 @@ class BatDetect2API: def process_audio( self, audio: np.ndarray, + detection_threshold: float | None = None, ) -> list[Detection]: spec = self.generate_spectrogram(audio) - return self.process_spectrogram(spec) + return self.process_spectrogram( + spec, + detection_threshold=detection_threshold, + ) def process_spectrogram( self, spec: torch.Tensor, start_time: float = 0, + detection_threshold: float | None = None, ) -> list[Detection]: if spec.ndim == 4 and spec.shape[0] > 1: raise ValueError("Batched spectrograms not supported.") @@ -352,6 +359,7 @@ class BatDetect2API: detections = self.postprocessor( outputs, + detection_threshold=detection_threshold, )[0] return self.output_transform.to_detections( detections=detections, @@ -361,9 +369,13 @@ class BatDetect2API: def process_directory( self, audio_dir: data.PathLike, + detection_threshold: float | None = None, ) -> list[ClipDetections]: files = list(get_audio_files(audio_dir)) - return self.process_files(files) + return self.process_files( + files, + detection_threshold=detection_threshold, + ) def process_files( self, @@ -373,6 +385,7 @@ class BatDetect2API: audio_config: AudioConfig | None = None, inference_config: InferenceConfig | None = None, output_config: OutputsConfig | None = None, + detection_threshold: float | None = None, ) -> list[ClipDetections]: return process_file_list( self.model, @@ -386,6 +399,7 @@ class BatDetect2API: audio_config=audio_config or self.audio_config, inference_config=inference_config or self.inference_config, output_config=output_config or self.outputs_config, + detection_threshold=detection_threshold, ) def process_clips( @@ -396,6 +410,7 @@ class BatDetect2API: audio_config: AudioConfig | None = None, inference_config: InferenceConfig | None = None, output_config: OutputsConfig | None = None, + detection_threshold: float | None = None, ) -> list[ClipDetections]: return run_batch_inference( self.model, @@ -409,6 +424,7 @@ class BatDetect2API: audio_config=audio_config or self.audio_config, inference_config=inference_config or self.inference_config, output_config=output_config or self.outputs_config, + detection_threshold=detection_threshold, ) def save_predictions( diff --git a/src/batdetect2/cli/inference.py b/src/batdetect2/cli/inference.py index dac9ab5..50d50f0 100644 --- a/src/batdetect2/cli/inference.py +++ b/src/batdetect2/cli/inference.py @@ -87,6 +87,15 @@ def common_predict_options(func): "the default output format is used." ), ) + @click.option( + "--detection-threshold", + type=click.FloatRange(min=0.0, max=1.0), + default=None, + help=( + "Optional detection score threshold override. If omitted, " + "the model default threshold is used." + ), + ) @wraps(func) def wrapped(*args, **kwargs): return func(*args, **kwargs) @@ -147,6 +156,7 @@ def _run_prediction( batch_size: int | None, num_workers: int, format_name: str | None, + detection_threshold: float | None, ) -> None: logger.info("Initiating prediction process...") @@ -167,6 +177,7 @@ def _run_prediction( audio_config=audio_conf, inference_config=inference_conf, output_config=outputs_conf, + detection_threshold=detection_threshold, ) common_path = audio_files[0].parent if audio_files else None @@ -201,6 +212,7 @@ def predict_directory_command( batch_size: int | None, num_workers: int, format_name: str | None, + detection_threshold: float | None, ) -> None: """Predict on all audio files in a directory. @@ -219,6 +231,7 @@ def predict_directory_command( batch_size=batch_size, num_workers=num_workers, format_name=format_name, + detection_threshold=detection_threshold, ) @@ -241,6 +254,7 @@ def predict_file_list_command( batch_size: int | None, num_workers: int, format_name: str | None, + detection_threshold: float | None, ) -> None: """Predict on audio files listed in a text file. @@ -265,6 +279,7 @@ def predict_file_list_command( batch_size=batch_size, num_workers=num_workers, format_name=format_name, + detection_threshold=detection_threshold, ) @@ -287,6 +302,7 @@ def predict_dataset_command( batch_size: int | None, num_workers: int, format_name: str | None, + detection_threshold: float | None, ) -> None: """Predict on recordings referenced in an annotation dataset. @@ -313,4 +329,5 @@ def predict_dataset_command( batch_size=batch_size, num_workers=num_workers, format_name=format_name, + detection_threshold=detection_threshold, ) diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 119b449..7ccf71d 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -31,6 +31,7 @@ def run_batch_inference( output_transform: OutputTransformProtocol | None = None, output_config: OutputsConfig | None = None, inference_config: InferenceConfig | None = None, + detection_threshold: float | None = None, num_workers: int = 1, batch_size: int | None = None, ) -> list[ClipDetections]: @@ -62,6 +63,7 @@ def run_batch_inference( module = InferenceModule( model, output_transform=output_transform, + detection_threshold=detection_threshold, ) trainer = Trainer(enable_checkpointing=False, logger=False) outputs = trainer.predict(module, loader) @@ -82,6 +84,7 @@ def process_file_list( inference_config: InferenceConfig | None = None, output_config: OutputsConfig | None = None, output_transform: OutputTransformProtocol | None = None, + detection_threshold: float | None = None, batch_size: int | None = None, num_workers: int = 0, ) -> list[ClipDetections]: @@ -106,4 +109,5 @@ def process_file_list( audio_config=audio_config, output_transform=output_transform, inference_config=inference_config, + detection_threshold=detection_threshold, ) diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index 2c853a5..7e4b058 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -14,9 +14,11 @@ class InferenceModule(LightningModule): self, model: Model, output_transform: OutputTransformProtocol | None = None, + detection_threshold: float | None = None, ): super().__init__() self.model = model + self.detection_threshold = detection_threshold self.output_transform = output_transform or build_output_transform( targets=model.targets ) @@ -33,7 +35,10 @@ class InferenceModule(LightningModule): outputs = self.model.detector(batch.spec) - clip_detections = self.model.postprocessor(outputs) + clip_detections = self.model.postprocessor( + outputs, + detection_threshold=self.detection_threshold, + ) return [ self.output_transform.to_clip_detections(