mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Use detection threshold use to API and CLI
This commit is contained in:
parent
dfe19c68f0
commit
b253a54cc8
@ -303,6 +303,7 @@ class BatDetect2API:
|
||||
self,
|
||||
audio_file: data.PathLike,
|
||||
batch_size: int | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> ClipDetections:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
|
||||
@ -313,6 +314,7 @@ class BatDetect2API:
|
||||
if batch_size is not None
|
||||
else self.inference_config.loader.batch_size
|
||||
),
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
detections = [
|
||||
detection
|
||||
@ -333,14 +335,19 @@ class BatDetect2API:
|
||||
def process_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[Detection]:
|
||||
spec = self.generate_spectrogram(audio)
|
||||
return self.process_spectrogram(spec)
|
||||
return self.process_spectrogram(
|
||||
spec,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
def process_spectrogram(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
start_time: float = 0,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[Detection]:
|
||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||
raise ValueError("Batched spectrograms not supported.")
|
||||
@ -352,6 +359,7 @@ class BatDetect2API:
|
||||
|
||||
detections = self.postprocessor(
|
||||
outputs,
|
||||
detection_threshold=detection_threshold,
|
||||
)[0]
|
||||
return self.output_transform.to_detections(
|
||||
detections=detections,
|
||||
@ -361,9 +369,13 @@ class BatDetect2API:
|
||||
def process_directory(
|
||||
self,
|
||||
audio_dir: data.PathLike,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(files)
|
||||
return self.process_files(
|
||||
files,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
def process_files(
|
||||
self,
|
||||
@ -373,6 +385,7 @@ class BatDetect2API:
|
||||
audio_config: AudioConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
@ -386,6 +399,7 @@ class BatDetect2API:
|
||||
audio_config=audio_config or self.audio_config,
|
||||
inference_config=inference_config or self.inference_config,
|
||||
output_config=output_config or self.outputs_config,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
def process_clips(
|
||||
@ -396,6 +410,7 @@ class BatDetect2API:
|
||||
audio_config: AudioConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
@ -409,6 +424,7 @@ class BatDetect2API:
|
||||
audio_config=audio_config or self.audio_config,
|
||||
inference_config=inference_config or self.inference_config,
|
||||
output_config=output_config or self.outputs_config,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
def save_predictions(
|
||||
|
||||
@ -87,6 +87,15 @@ def common_predict_options(func):
|
||||
"the default output format is used."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--detection-threshold",
|
||||
type=click.FloatRange(min=0.0, max=1.0),
|
||||
default=None,
|
||||
help=(
|
||||
"Optional detection score threshold override. If omitted, "
|
||||
"the model default threshold is used."
|
||||
),
|
||||
)
|
||||
@wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
return func(*args, **kwargs)
|
||||
@ -147,6 +156,7 @@ def _run_prediction(
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
detection_threshold: float | None,
|
||||
) -> None:
|
||||
logger.info("Initiating prediction process...")
|
||||
|
||||
@ -167,6 +177,7 @@ def _run_prediction(
|
||||
audio_config=audio_conf,
|
||||
inference_config=inference_conf,
|
||||
output_config=outputs_conf,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
common_path = audio_files[0].parent if audio_files else None
|
||||
@ -201,6 +212,7 @@ def predict_directory_command(
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
detection_threshold: float | None,
|
||||
) -> None:
|
||||
"""Predict on all audio files in a directory.
|
||||
|
||||
@ -219,6 +231,7 @@ def predict_directory_command(
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
format_name=format_name,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
|
||||
@ -241,6 +254,7 @@ def predict_file_list_command(
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
detection_threshold: float | None,
|
||||
) -> None:
|
||||
"""Predict on audio files listed in a text file.
|
||||
|
||||
@ -265,6 +279,7 @@ def predict_file_list_command(
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
format_name=format_name,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
|
||||
@ -287,6 +302,7 @@ def predict_dataset_command(
|
||||
batch_size: int | None,
|
||||
num_workers: int,
|
||||
format_name: str | None,
|
||||
detection_threshold: float | None,
|
||||
) -> None:
|
||||
"""Predict on recordings referenced in an annotation dataset.
|
||||
|
||||
@ -313,4 +329,5 @@ def predict_dataset_command(
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
format_name=format_name,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
@ -31,6 +31,7 @@ def run_batch_inference(
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
num_workers: int = 1,
|
||||
batch_size: int | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
@ -62,6 +63,7 @@ def run_batch_inference(
|
||||
module = InferenceModule(
|
||||
model,
|
||||
output_transform=output_transform,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||
outputs = trainer.predict(module, loader)
|
||||
@ -82,6 +84,7 @@ def process_file_list(
|
||||
inference_config: InferenceConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
batch_size: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> list[ClipDetections]:
|
||||
@ -106,4 +109,5 @@ def process_file_list(
|
||||
audio_config=audio_config,
|
||||
output_transform=output_transform,
|
||||
inference_config=inference_config,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
@ -14,9 +14,11 @@ class InferenceModule(LightningModule):
|
||||
self,
|
||||
model: Model,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.detection_threshold = detection_threshold
|
||||
self.output_transform = output_transform or build_output_transform(
|
||||
targets=model.targets
|
||||
)
|
||||
@ -33,7 +35,10 @@ class InferenceModule(LightningModule):
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
detection_threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
return [
|
||||
self.output_transform.to_clip_detections(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user