Use detection threshold use to API and CLI

This commit is contained in:
mbsantiago 2026-03-28 14:26:10 +00:00
parent dfe19c68f0
commit b253a54cc8
4 changed files with 45 additions and 3 deletions

View File

@ -303,6 +303,7 @@ class BatDetect2API:
self, self,
audio_file: data.PathLike, audio_file: data.PathLike,
batch_size: int | None = None, batch_size: int | None = None,
detection_threshold: float | None = None,
) -> ClipDetections: ) -> ClipDetections:
recording = data.Recording.from_file(audio_file, compute_hash=False) recording = data.Recording.from_file(audio_file, compute_hash=False)
@ -313,6 +314,7 @@ class BatDetect2API:
if batch_size is not None if batch_size is not None
else self.inference_config.loader.batch_size else self.inference_config.loader.batch_size
), ),
detection_threshold=detection_threshold,
) )
detections = [ detections = [
detection detection
@ -333,14 +335,19 @@ class BatDetect2API:
def process_audio( def process_audio(
self, self,
audio: np.ndarray, audio: np.ndarray,
detection_threshold: float | None = None,
) -> list[Detection]: ) -> list[Detection]:
spec = self.generate_spectrogram(audio) spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec) return self.process_spectrogram(
spec,
detection_threshold=detection_threshold,
)
def process_spectrogram( def process_spectrogram(
self, self,
spec: torch.Tensor, spec: torch.Tensor,
start_time: float = 0, start_time: float = 0,
detection_threshold: float | None = None,
) -> list[Detection]: ) -> list[Detection]:
if spec.ndim == 4 and spec.shape[0] > 1: if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.") raise ValueError("Batched spectrograms not supported.")
@ -352,6 +359,7 @@ class BatDetect2API:
detections = self.postprocessor( detections = self.postprocessor(
outputs, outputs,
detection_threshold=detection_threshold,
)[0] )[0]
return self.output_transform.to_detections( return self.output_transform.to_detections(
detections=detections, detections=detections,
@ -361,9 +369,13 @@ class BatDetect2API:
def process_directory( def process_directory(
self, self,
audio_dir: data.PathLike, audio_dir: data.PathLike,
detection_threshold: float | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
files = list(get_audio_files(audio_dir)) files = list(get_audio_files(audio_dir))
return self.process_files(files) return self.process_files(
files,
detection_threshold=detection_threshold,
)
def process_files( def process_files(
self, self,
@ -373,6 +385,7 @@ class BatDetect2API:
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None, inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None, output_config: OutputsConfig | None = None,
detection_threshold: float | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
return process_file_list( return process_file_list(
self.model, self.model,
@ -386,6 +399,7 @@ class BatDetect2API:
audio_config=audio_config or self.audio_config, audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config, inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config, output_config=output_config or self.outputs_config,
detection_threshold=detection_threshold,
) )
def process_clips( def process_clips(
@ -396,6 +410,7 @@ class BatDetect2API:
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
inference_config: InferenceConfig | None = None, inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None, output_config: OutputsConfig | None = None,
detection_threshold: float | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
return run_batch_inference( return run_batch_inference(
self.model, self.model,
@ -409,6 +424,7 @@ class BatDetect2API:
audio_config=audio_config or self.audio_config, audio_config=audio_config or self.audio_config,
inference_config=inference_config or self.inference_config, inference_config=inference_config or self.inference_config,
output_config=output_config or self.outputs_config, output_config=output_config or self.outputs_config,
detection_threshold=detection_threshold,
) )
def save_predictions( def save_predictions(

View File

@ -87,6 +87,15 @@ def common_predict_options(func):
"the default output format is used." "the default output format is used."
), ),
) )
@click.option(
"--detection-threshold",
type=click.FloatRange(min=0.0, max=1.0),
default=None,
help=(
"Optional detection score threshold override. If omitted, "
"the model default threshold is used."
),
)
@wraps(func) @wraps(func)
def wrapped(*args, **kwargs): def wrapped(*args, **kwargs):
return func(*args, **kwargs) return func(*args, **kwargs)
@ -147,6 +156,7 @@ def _run_prediction(
batch_size: int | None, batch_size: int | None,
num_workers: int, num_workers: int,
format_name: str | None, format_name: str | None,
detection_threshold: float | None,
) -> None: ) -> None:
logger.info("Initiating prediction process...") logger.info("Initiating prediction process...")
@ -167,6 +177,7 @@ def _run_prediction(
audio_config=audio_conf, audio_config=audio_conf,
inference_config=inference_conf, inference_config=inference_conf,
output_config=outputs_conf, output_config=outputs_conf,
detection_threshold=detection_threshold,
) )
common_path = audio_files[0].parent if audio_files else None common_path = audio_files[0].parent if audio_files else None
@ -201,6 +212,7 @@ def predict_directory_command(
batch_size: int | None, batch_size: int | None,
num_workers: int, num_workers: int,
format_name: str | None, format_name: str | None,
detection_threshold: float | None,
) -> None: ) -> None:
"""Predict on all audio files in a directory. """Predict on all audio files in a directory.
@ -219,6 +231,7 @@ def predict_directory_command(
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
format_name=format_name, format_name=format_name,
detection_threshold=detection_threshold,
) )
@ -241,6 +254,7 @@ def predict_file_list_command(
batch_size: int | None, batch_size: int | None,
num_workers: int, num_workers: int,
format_name: str | None, format_name: str | None,
detection_threshold: float | None,
) -> None: ) -> None:
"""Predict on audio files listed in a text file. """Predict on audio files listed in a text file.
@ -265,6 +279,7 @@ def predict_file_list_command(
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
format_name=format_name, format_name=format_name,
detection_threshold=detection_threshold,
) )
@ -287,6 +302,7 @@ def predict_dataset_command(
batch_size: int | None, batch_size: int | None,
num_workers: int, num_workers: int,
format_name: str | None, format_name: str | None,
detection_threshold: float | None,
) -> None: ) -> None:
"""Predict on recordings referenced in an annotation dataset. """Predict on recordings referenced in an annotation dataset.
@ -313,4 +329,5 @@ def predict_dataset_command(
batch_size=batch_size, batch_size=batch_size,
num_workers=num_workers, num_workers=num_workers,
format_name=format_name, format_name=format_name,
detection_threshold=detection_threshold,
) )

View File

@ -31,6 +31,7 @@ def run_batch_inference(
output_transform: OutputTransformProtocol | None = None, output_transform: OutputTransformProtocol | None = None,
output_config: OutputsConfig | None = None, output_config: OutputsConfig | None = None,
inference_config: InferenceConfig | None = None, inference_config: InferenceConfig | None = None,
detection_threshold: float | None = None,
num_workers: int = 1, num_workers: int = 1,
batch_size: int | None = None, batch_size: int | None = None,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
@ -62,6 +63,7 @@ def run_batch_inference(
module = InferenceModule( module = InferenceModule(
model, model,
output_transform=output_transform, output_transform=output_transform,
detection_threshold=detection_threshold,
) )
trainer = Trainer(enable_checkpointing=False, logger=False) trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader) outputs = trainer.predict(module, loader)
@ -82,6 +84,7 @@ def process_file_list(
inference_config: InferenceConfig | None = None, inference_config: InferenceConfig | None = None,
output_config: OutputsConfig | None = None, output_config: OutputsConfig | None = None,
output_transform: OutputTransformProtocol | None = None, output_transform: OutputTransformProtocol | None = None,
detection_threshold: float | None = None,
batch_size: int | None = None, batch_size: int | None = None,
num_workers: int = 0, num_workers: int = 0,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
@ -106,4 +109,5 @@ def process_file_list(
audio_config=audio_config, audio_config=audio_config,
output_transform=output_transform, output_transform=output_transform,
inference_config=inference_config, inference_config=inference_config,
detection_threshold=detection_threshold,
) )

View File

@ -14,9 +14,11 @@ class InferenceModule(LightningModule):
self, self,
model: Model, model: Model,
output_transform: OutputTransformProtocol | None = None, output_transform: OutputTransformProtocol | None = None,
detection_threshold: float | None = None,
): ):
super().__init__() super().__init__()
self.model = model self.model = model
self.detection_threshold = detection_threshold
self.output_transform = output_transform or build_output_transform( self.output_transform = output_transform or build_output_transform(
targets=model.targets targets=model.targets
) )
@ -33,7 +35,10 @@ class InferenceModule(LightningModule):
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(outputs) clip_detections = self.model.postprocessor(
outputs,
detection_threshold=self.detection_threshold,
)
return [ return [
self.output_transform.to_clip_detections( self.output_transform.to_clip_detections(