mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Use detection threshold use to API and CLI
This commit is contained in:
parent
dfe19c68f0
commit
b253a54cc8
@ -303,6 +303,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
audio_file: data.PathLike,
|
audio_file: data.PathLike,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> ClipDetections:
|
) -> ClipDetections:
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
|
|
||||||
@ -313,6 +314,7 @@ class BatDetect2API:
|
|||||||
if batch_size is not None
|
if batch_size is not None
|
||||||
else self.inference_config.loader.batch_size
|
else self.inference_config.loader.batch_size
|
||||||
),
|
),
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
detections = [
|
detections = [
|
||||||
detection
|
detection
|
||||||
@ -333,14 +335,19 @@ class BatDetect2API:
|
|||||||
def process_audio(
|
def process_audio(
|
||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> list[Detection]:
|
) -> list[Detection]:
|
||||||
spec = self.generate_spectrogram(audio)
|
spec = self.generate_spectrogram(audio)
|
||||||
return self.process_spectrogram(spec)
|
return self.process_spectrogram(
|
||||||
|
spec,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
self,
|
self,
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
start_time: float = 0,
|
start_time: float = 0,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> list[Detection]:
|
) -> list[Detection]:
|
||||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||||
raise ValueError("Batched spectrograms not supported.")
|
raise ValueError("Batched spectrograms not supported.")
|
||||||
@ -352,6 +359,7 @@ class BatDetect2API:
|
|||||||
|
|
||||||
detections = self.postprocessor(
|
detections = self.postprocessor(
|
||||||
outputs,
|
outputs,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)[0]
|
)[0]
|
||||||
return self.output_transform.to_detections(
|
return self.output_transform.to_detections(
|
||||||
detections=detections,
|
detections=detections,
|
||||||
@ -361,9 +369,13 @@ class BatDetect2API:
|
|||||||
def process_directory(
|
def process_directory(
|
||||||
self,
|
self,
|
||||||
audio_dir: data.PathLike,
|
audio_dir: data.PathLike,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
files = list(get_audio_files(audio_dir))
|
files = list(get_audio_files(audio_dir))
|
||||||
return self.process_files(files)
|
return self.process_files(
|
||||||
|
files,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
def process_files(
|
def process_files(
|
||||||
self,
|
self,
|
||||||
@ -373,6 +385,7 @@ class BatDetect2API:
|
|||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
inference_config: InferenceConfig | None = None,
|
inference_config: InferenceConfig | None = None,
|
||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
@ -386,6 +399,7 @@ class BatDetect2API:
|
|||||||
audio_config=audio_config or self.audio_config,
|
audio_config=audio_config or self.audio_config,
|
||||||
inference_config=inference_config or self.inference_config,
|
inference_config=inference_config or self.inference_config,
|
||||||
output_config=output_config or self.outputs_config,
|
output_config=output_config or self.outputs_config,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_clips(
|
def process_clips(
|
||||||
@ -396,6 +410,7 @@ class BatDetect2API:
|
|||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
inference_config: InferenceConfig | None = None,
|
inference_config: InferenceConfig | None = None,
|
||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
@ -409,6 +424,7 @@ class BatDetect2API:
|
|||||||
audio_config=audio_config or self.audio_config,
|
audio_config=audio_config or self.audio_config,
|
||||||
inference_config=inference_config or self.inference_config,
|
inference_config=inference_config or self.inference_config,
|
||||||
output_config=output_config or self.outputs_config,
|
output_config=output_config or self.outputs_config,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
|
|||||||
@ -87,6 +87,15 @@ def common_predict_options(func):
|
|||||||
"the default output format is used."
|
"the default output format is used."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--detection-threshold",
|
||||||
|
type=click.FloatRange(min=0.0, max=1.0),
|
||||||
|
default=None,
|
||||||
|
help=(
|
||||||
|
"Optional detection score threshold override. If omitted, "
|
||||||
|
"the model default threshold is used."
|
||||||
|
),
|
||||||
|
)
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapped(*args, **kwargs):
|
def wrapped(*args, **kwargs):
|
||||||
return func(*args, **kwargs)
|
return func(*args, **kwargs)
|
||||||
@ -147,6 +156,7 @@ def _run_prediction(
|
|||||||
batch_size: int | None,
|
batch_size: int | None,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
|
detection_threshold: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
logger.info("Initiating prediction process...")
|
logger.info("Initiating prediction process...")
|
||||||
|
|
||||||
@ -167,6 +177,7 @@ def _run_prediction(
|
|||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
inference_config=inference_conf,
|
inference_config=inference_conf,
|
||||||
output_config=outputs_conf,
|
output_config=outputs_conf,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
common_path = audio_files[0].parent if audio_files else None
|
common_path = audio_files[0].parent if audio_files else None
|
||||||
@ -201,6 +212,7 @@ def predict_directory_command(
|
|||||||
batch_size: int | None,
|
batch_size: int | None,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
|
detection_threshold: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Predict on all audio files in a directory.
|
"""Predict on all audio files in a directory.
|
||||||
|
|
||||||
@ -219,6 +231,7 @@ def predict_directory_command(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
format_name=format_name,
|
format_name=format_name,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -241,6 +254,7 @@ def predict_file_list_command(
|
|||||||
batch_size: int | None,
|
batch_size: int | None,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
|
detection_threshold: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Predict on audio files listed in a text file.
|
"""Predict on audio files listed in a text file.
|
||||||
|
|
||||||
@ -265,6 +279,7 @@ def predict_file_list_command(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
format_name=format_name,
|
format_name=format_name,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -287,6 +302,7 @@ def predict_dataset_command(
|
|||||||
batch_size: int | None,
|
batch_size: int | None,
|
||||||
num_workers: int,
|
num_workers: int,
|
||||||
format_name: str | None,
|
format_name: str | None,
|
||||||
|
detection_threshold: float | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Predict on recordings referenced in an annotation dataset.
|
"""Predict on recordings referenced in an annotation dataset.
|
||||||
|
|
||||||
@ -313,4 +329,5 @@ def predict_dataset_command(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
format_name=format_name,
|
format_name=format_name,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -31,6 +31,7 @@ def run_batch_inference(
|
|||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
inference_config: InferenceConfig | None = None,
|
inference_config: InferenceConfig | None = None,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
num_workers: int = 1,
|
num_workers: int = 1,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
@ -62,6 +63,7 @@ def run_batch_inference(
|
|||||||
module = InferenceModule(
|
module = InferenceModule(
|
||||||
model,
|
model,
|
||||||
output_transform=output_transform,
|
output_transform=output_transform,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||||
outputs = trainer.predict(module, loader)
|
outputs = trainer.predict(module, loader)
|
||||||
@ -82,6 +84,7 @@ def process_file_list(
|
|||||||
inference_config: InferenceConfig | None = None,
|
inference_config: InferenceConfig | None = None,
|
||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
@ -106,4 +109,5 @@ def process_file_list(
|
|||||||
audio_config=audio_config,
|
audio_config=audio_config,
|
||||||
output_transform=output_transform,
|
output_transform=output_transform,
|
||||||
inference_config=inference_config,
|
inference_config=inference_config,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -14,9 +14,11 @@ class InferenceModule(LightningModule):
|
|||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
|
detection_threshold: float | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.detection_threshold = detection_threshold
|
||||||
self.output_transform = output_transform or build_output_transform(
|
self.output_transform = output_transform or build_output_transform(
|
||||||
targets=model.targets
|
targets=model.targets
|
||||||
)
|
)
|
||||||
@ -33,7 +35,10 @@ class InferenceModule(LightningModule):
|
|||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
|
|
||||||
clip_detections = self.model.postprocessor(outputs)
|
clip_detections = self.model.postprocessor(
|
||||||
|
outputs,
|
||||||
|
detection_threshold=self.detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
self.output_transform.to_clip_detections(
|
self.output_transform.to_clip_detections(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user