Moved decoding to outputs

This commit is contained in:
mbsantiago 2026-03-18 01:35:34 +00:00
parent 31d4f92359
commit daff74fdde
17 changed files with 429 additions and 195 deletions

View File

@ -26,7 +26,7 @@ from batdetect2.outputs import (
get_output_formatter,
)
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.types import (
ClipDetections,
Detection,
@ -215,13 +215,8 @@ class BatDetect2API:
detections = self.model.postprocessor(
outputs,
)[0]
raw_predictions = to_raw_predictions(
detections.numpy(),
targets=self.targets,
)
return self.output_transform.transform_detections(
raw_predictions,
return self.output_transform.to_detections(
detections=detections,
start_time=start_time,
)
@ -321,7 +316,8 @@ class BatDetect2API:
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform
config=config.outputs.transform,
targets=targets,
)
return cls(
@ -375,7 +371,8 @@ class BatDetect2API:
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform
config=config.outputs.transform,
targets=targets,
)
return cls(

View File

@ -61,7 +61,10 @@ def run_evaluate(
experiment_name=experiment_name,
run_name=run_name,
)
output_transform = build_output_transform(config=output_config.transform)
output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
module = EvaluationModule(
model,
evaluator,

View File

@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.postprocess.types import ClipDetections
@ -24,7 +23,9 @@ class EvaluationModule(LightningModule):
self.model = model
self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform()
self.output_transform = output_transform or build_output_transform(
targets=evaluator.targets
)
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[ClipDetections] = []
@ -39,18 +40,14 @@ class EvaluationModule(LightningModule):
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(outputs)
predictions = [
ClipDetections(
self.output_transform.to_clip_detections(
detections=clip_dets,
clip=clip_annotation.clip,
detections=to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections, strict=False
)
]
predictions = self.output_transform(predictions)
self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions)

View File

@ -44,6 +44,7 @@ def run_batch_inference(
targets = targets or build_targets()
output_transform = output_transform or build_output_transform(
config=config.outputs.transform,
targets=targets,
)
loader = build_inference_loader(

View File

@ -6,7 +6,6 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.postprocess.types import ClipDetections
@ -18,7 +17,9 @@ class InferenceModule(LightningModule):
):
super().__init__()
self.model = model
self.output_transform = output_transform or build_output_transform()
self.output_transform = output_transform or build_output_transform(
targets=model.targets
)
def predict_step(
self,
@ -34,19 +35,14 @@ class InferenceModule(LightningModule):
clip_detections = self.model.postprocessor(outputs)
predictions = [
ClipDetections(
return [
self.output_transform.to_clip_detections(
detections=clip_dets,
clip=clip,
detections=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),
)
for clip, clip_dets in zip(clips, clip_detections, strict=False)
for clip, clip_dets in zip(clips, clip_detections, strict=True)
]
return self.output_transform(predictions)
def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders
assert isinstance(dataloaders, DataLoader)

View File

@ -11,9 +11,9 @@ from batdetect2.outputs.formats import (
)
from batdetect2.outputs.transforms import (
OutputTransformConfig,
OutputTransformProtocol,
build_output_transform,
)
from batdetect2.outputs.types import OutputTransformProtocol
__all__ = [
"BatDetect2OutputConfig",

View File

@ -1,89 +0,0 @@
from collections.abc import Sequence
from dataclasses import replace
from typing import Protocol
from soundevent.geometry import shift_geometry
from batdetect2.core.configs import BaseConfig
from batdetect2.postprocess.types import ClipDetections, Detection
__all__ = [
"OutputTransform",
"OutputTransformConfig",
"OutputTransformProtocol",
"build_output_transform",
]
class OutputTransformConfig(BaseConfig):
shift_time_to_clip_start: bool = True
class OutputTransformProtocol(Protocol):
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]: ...
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]: ...
def shift_detection_time(detection: Detection, time: float) -> Detection:
geometry = shift_geometry(detection.geometry, time=time)
return replace(detection, geometry=geometry)
class OutputTransform(OutputTransformProtocol):
def __init__(self, shift_time_to_clip_start: bool = True):
self.shift_time_to_clip_start = shift_time_to_clip_start
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]:
return [
self.transform_prediction(prediction) for prediction in predictions
]
def transform_prediction(
self, prediction: ClipDetections
) -> ClipDetections:
if not self.shift_time_to_clip_start:
return prediction
detections = self.transform_detections(
prediction.detections,
start_time=prediction.clip.start_time,
)
return ClipDetections(clip=prediction.clip, detections=detections)
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]:
if not self.shift_time_to_clip_start or start_time == 0:
return list(detections)
return [
shift_detection_time(detection, time=start_time)
for detection in detections
]
def build_output_transform(
config: OutputTransformConfig | dict | None = None,
) -> OutputTransformProtocol:
if config is None:
config = OutputTransformConfig()
if not isinstance(config, OutputTransformConfig):
config = OutputTransformConfig.model_validate(config)
return OutputTransform(
shift_time_to_clip_start=config.shift_time_to_clip_start,
)

View File

@ -0,0 +1,173 @@
from collections.abc import Sequence
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.outputs.transforms.clip_transforms import (
ClipDetectionsTransformConfig,
)
from batdetect2.outputs.transforms.clip_transforms import (
clip_transforms as clip_transform_registry,
)
from batdetect2.outputs.transforms.decoding import to_detections
from batdetect2.outputs.transforms.detection_transforms import (
DetectionTransformConfig,
shift_detections_to_start_time,
)
from batdetect2.outputs.transforms.detection_transforms import (
detection_transforms as detection_transform_registry,
)
from batdetect2.outputs.types import (
ClipDetectionsTransform,
DetectionTransform,
OutputTransformProtocol,
)
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClipDetectionsTransformConfig",
"DetectionTransformConfig",
"OutputTransform",
"OutputTransformConfig",
"build_output_transform",
]
class OutputTransformConfig(BaseConfig):
detection_transforms: list[DetectionTransformConfig] = Field(
default_factory=list
)
clip_transforms: list[ClipDetectionsTransformConfig] = Field(
default_factory=list
)
class OutputTransform(OutputTransformProtocol):
detection_transform_steps: list[DetectionTransform]
clip_transform_steps: list[ClipDetectionsTransform]
def __init__(
self,
targets: TargetProtocol,
detection_transform_steps: Sequence[DetectionTransform] = (),
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
):
self.targets = targets
self.detection_transform_steps = list(detection_transform_steps)
self.clip_transform_steps = list(clip_transform_steps)
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]:
return [
self._transform_prediction(prediction)
for prediction in predictions
]
def _transform_prediction(
self,
prediction: ClipDetections,
) -> ClipDetections:
detections = shift_detections_to_start_time(
prediction.detections,
start_time=prediction.clip.start_time,
)
detections = self.transform_detections(detections)
return self.transform_clip_detections(
ClipDetections(clip=prediction.clip, detections=detections)
)
def to_detections(
self,
detections: ClipDetectionsTensor,
start_time: float = 0,
) -> list[Detection]:
decoded = to_detections(detections.numpy(), targets=self.targets)
shifted = shift_detections_to_start_time(
decoded,
start_time=start_time,
)
return self.transform_detections(shifted)
def to_clip_detections(
self,
detections: ClipDetectionsTensor,
clip: data.Clip,
) -> ClipDetections:
prediction = ClipDetections(
clip=clip,
detections=self.to_detections(
detections,
start_time=clip.start_time,
),
)
return self.transform_clip_detections(prediction)
def transform_detections(
self,
detections: Sequence[Detection],
) -> list[Detection]:
out: list[Detection] = []
for detection in detections:
transformed = self.transform_detection(detection)
if transformed is None:
continue
out.append(transformed)
return []
def transform_detection(
self,
detection: Detection,
) -> Detection | None:
for transform in self.detection_transform_steps:
detection = transform(detection) # type: ignore
if detection is None:
return None
return detection
def transform_clip_detections(
self,
prediction: ClipDetections,
) -> ClipDetections:
for transform in self.clip_transform_steps:
prediction = transform(prediction)
return prediction
def build_output_transform(
config: OutputTransformConfig | dict | None = None,
targets: TargetProtocol | None = None,
) -> OutputTransformProtocol:
from batdetect2.targets import build_targets
if config is None:
config = OutputTransformConfig()
if not isinstance(config, OutputTransformConfig):
config = OutputTransformConfig.model_validate(config)
targets = targets or build_targets()
return OutputTransform(
targets=targets,
detection_transform_steps=[
detection_transform_registry.build(transform_config)
for transform_config in config.detection_transforms
],
clip_transform_steps=[
clip_transform_registry.build(transform_config)
for transform_config in config.clip_transforms
],
)

View File

@ -0,0 +1,31 @@
from typing import Annotated, Literal
from pydantic import Field
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.outputs.types import ClipDetectionsTransform
__all__ = [
"ClipDetectionsTransformConfig",
"clip_transforms",
]
clip_transforms: Registry[ClipDetectionsTransform, []] = Registry(
"clip_detection_transform"
)
@add_import_config(clip_transforms)
class ClipDetectionsTransformImportConfig(ImportConfig):
name: Literal["import"] = "import"
ClipDetectionsTransformConfig = Annotated[
ClipDetectionsTransformImportConfig,
Field(discriminator="name"),
]

View File

@ -1,33 +1,28 @@
"""Decodes extracted detection data into standard soundevent predictions."""
"""Decode extracted tensors into output-friendly detection objects."""
from typing import List
import numpy as np
from soundevent import data
from batdetect2.postprocess.types import (
ClipDetectionsArray,
Detection,
)
from batdetect2.postprocess.types import ClipDetectionsArray, Detection
from batdetect2.targets.types import TargetProtocol
__all__ = [
"to_raw_predictions",
"convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD",
"convert_raw_prediction_to_sound_event_prediction",
"convert_raw_predictions_to_clip_prediction",
"get_class_tags",
"get_generic_tags",
"get_prediction_features",
"to_detections",
]
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
"""Default threshold applied to classification scores.
Class predictions with scores below this value are typically ignored during
decoding.
"""
def to_raw_predictions(
def to_detections(
detections: ClipDetectionsArray,
targets: TargetProtocol,
) -> List[Detection]:
@ -69,7 +64,6 @@ def convert_raw_predictions_to_clip_prediction(
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
) -> data.ClipPrediction:
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
return data.ClipPrediction(
clip=clip,
sound_events=[
@ -92,7 +86,6 @@ def convert_raw_prediction_to_sound_event_prediction(
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False,
):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
sound_event = data.SoundEvent(
recording=recording,
geometry=raw_prediction.geometry,
@ -123,7 +116,6 @@ def get_generic_tags(
detection_score: float,
generic_class_tags: List[data.Tag],
) -> List[data.PredictedTag]:
"""Create PredictedTag objects for the generic category."""
return [
data.PredictedTag(tag=tag, score=detection_score)
for tag in generic_class_tags
@ -131,7 +123,6 @@ def get_generic_tags(
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features."""
return [
data.Feature(
term=data.Term(
@ -151,39 +142,11 @@ def get_class_tags(
top_class_only: bool = False,
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]:
"""Generate specific PredictedTags based on class scores and decoder.
Filters class scores by the threshold, sorts remaining scores descending,
decodes the class name(s) into base tags using the `sound_event_decoder`,
and creates `PredictedTag` objects associating the class score. Stops after
the first (top) class if `top_class_only` is True.
Parameters
----------
class_scores : xr.DataArray
A 1D xarray DataArray containing class probabilities/scores, indexed
by a 'category' coordinate holding the class names.
sound_event_decoder : SoundEventDecoder
Function to map a class name string to a list of base `data.Tag`
objects.
top_class_only : bool, default=False
If True, only generate tags for the single highest-scoring class above
the threshold.
threshold : float, optional
Minimum score for a class to be considered. If None, all classes are
processed (or top-1 if `top_class_only` is True). Defaults to
`DEFAULT_CLASSIFICATION_THRESHOLD`.
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the class(es) that passed the
threshold, ordered by score if `top_class_only` is False.
"""
tags = []
for class_name, score in _iterate_sorted(
class_scores, targets.class_names
class_scores,
targets.class_names,
):
if threshold is not None and score < threshold:
continue

View File

@ -0,0 +1,55 @@
from collections.abc import Sequence
from dataclasses import replace
from typing import Annotated, Literal
from pydantic import Field
from soundevent.geometry import shift_geometry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.outputs.types import DetectionTransform
from batdetect2.postprocess.types import Detection
__all__ = [
"DetectionTransformConfig",
"detection_transforms",
"shift_detection_time",
"shift_detections_to_start_time",
]
detection_transforms: Registry[DetectionTransform, []] = Registry(
"detection_transform"
)
@add_import_config(detection_transforms)
class DetectionTransformImportConfig(ImportConfig):
name: Literal["import"] = "import"
DetectionTransformConfig = Annotated[
DetectionTransformImportConfig,
Field(discriminator="name"),
]
def shift_detection_time(detection: Detection, time: float) -> Detection:
geometry = shift_geometry(detection.geometry, time=time)
return replace(detection, geometry=geometry)
def shift_detections_to_start_time(
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]:
if start_time == 0:
return list(detections)
return [
shift_detection_time(detection, time=start_time)
for detection in detections
]

View File

@ -1,12 +1,20 @@
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import Generic, Protocol, TypeVar
from soundevent import data
from soundevent.data import PathLike
from batdetect2.postprocess.types import ClipDetections
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
__all__ = [
"ClipDetectionsTransform",
"DetectionTransform",
"OutputFormatterProtocol",
"OutputTransformProtocol",
]
T = TypeVar("T")
@ -23,3 +31,31 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
) -> None: ...
def load(self, path: PathLike) -> list[T]: ...
DetectionTransform = Callable[[Detection], Detection | None]
ClipDetectionsTransform = Callable[[ClipDetections], ClipDetections]
class OutputTransformProtocol(Protocol):
def to_detections(
self,
detections: ClipDetectionsTensor,
start_time: float = 0,
) -> list[Detection]: ...
def to_clip_detections(
self,
detections: ClipDetectionsTensor,
clip: data.Clip,
) -> ClipDetections: ...
def transform_detections(
self,
detections: Sequence[Detection],
) -> list[Detection]: ...
def transform_clip_detections(
self,
prediction: ClipDetections,
) -> ClipDetections: ...

View File

@ -4,10 +4,6 @@ from batdetect2.postprocess.config import (
PostprocessConfig,
load_postprocess_config,
)
from batdetect2.postprocess.decoding import (
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.nms import non_max_suppression
from batdetect2.postprocess.postprocessor import (
Postprocessor,
@ -18,8 +14,6 @@ __all__ = [
"PostprocessConfig",
"Postprocessor",
"build_postprocessor",
"convert_raw_predictions_to_clip_prediction",
"to_raw_predictions",
"load_postprocess_config",
"non_max_suppression",
]

View File

@ -2,7 +2,6 @@ from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
__all__ = [
@ -11,6 +10,7 @@ __all__ = [
]
DEFAULT_DETECTION_THRESHOLD = 0.01
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
TOP_K_PER_SEC = 100

View File

@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models.types import ModelOutput
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.postprocess.types import ClipDetections
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
@ -25,7 +24,7 @@ class ValidationMetrics(Callback):
super().__init__()
self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform()
self.output_transform = output_transform
self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[ClipDetections] = []
@ -92,6 +91,14 @@ class ValidationMetrics(Callback):
dataloader_idx: int = 0,
) -> None:
model = pl_module.model
if self.output_transform is None:
self.output_transform = build_output_transform(
targets=model.targets
)
output_transform = self.output_transform
assert output_transform is not None
dataset = self.get_dataset(trainer)
clip_annotations = [
@ -101,17 +108,14 @@ class ValidationMetrics(Callback):
clip_detections = model.postprocessor(outputs)
predictions = [
ClipDetections(
output_transform.to_clip_detections(
detections=clip_dets,
clip=clip_annotation.clip,
detections=to_raw_predictions(
clip_dets.numpy(), targets=model.targets
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections, strict=False
)
]
predictions = self.output_transform(predictions)
self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions)

View File

@ -1,12 +1,35 @@
from dataclasses import replace
import numpy as np
import torch
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.outputs import build_output_transform
from batdetect2.postprocess.types import ClipDetections, Detection
from batdetect2.outputs.transforms import OutputTransform
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
def test_shift_time_to_clip_start(clip: data.Clip):
def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
return ClipDetectionsTensor(
scores=torch.tensor([0.9], dtype=torch.float32),
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32),
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
times=torch.tensor([0.2], dtype=torch.float32),
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
features=torch.tensor([[1.0, 2.0]], dtype=torch.float32),
)
def test_shift_time_to_clip_start(
clip: data.Clip,
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection(
@ -16,7 +39,7 @@ def test_shift_time_to_clip_start(clip: data.Clip):
features=np.array([1.0, 2.0]),
)
transformed = build_output_transform()(
transformed = OutputTransform(targets=sample_targets)(
[ClipDetections(clip=clip, detections=[detection])]
)[0]
@ -28,21 +51,71 @@ def test_shift_time_to_clip_start(clip: data.Clip):
assert np.isclose(end_time, 2.7)
def test_transform_identity_when_disabled(clip: data.Clip):
def test_to_clip_detections_shifts_by_clip_start(
clip: data.Clip,
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
transform = build_output_transform(targets=sample_targets)
raw = _mock_clip_detections_tensor()
shifted = transform.to_clip_detections(detections=raw, clip=clip)
unshifted = transform.to_detections(detections=raw, start_time=0)
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry)
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
def test_detection_and_clip_transforms_applied_in_order(
clip: data.Clip,
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection(
detection_1 = Detection(
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
detection_score=0.9,
detection_score=0.5,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
detection_2 = Detection(
geometry=data.BoundingBox(coordinates=[0.2, 10_000, 0.3, 12_000]),
detection_score=0.7,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
transform = build_output_transform(
config={"shift_time_to_clip_start": False}
def boost_score(detection: Detection) -> Detection:
return replace(
detection,
detection_score=detection.detection_score + 0.2,
)
def keep_high_score(detection: Detection) -> Detection | None:
if detection.detection_score < 0.8:
return None
return detection
def tag_clip_transform(prediction: ClipDetections) -> ClipDetections:
detections = [
replace(detection, detection_score=1.0)
for detection in prediction.detections
]
return replace(prediction, detections=detections)
transform = OutputTransform(
targets=sample_targets,
detection_transform_steps=[boost_score, keep_high_score],
clip_transform_steps=[tag_clip_transform],
)
transformed = transform(
[ClipDetections(clip=clip, detections=[detection])]
[ClipDetections(clip=clip, detections=[detection_1, detection_2])]
)[0]
assert transformed.detections[0].geometry == detection.geometry
assert len(transformed.detections) == 1
assert transformed.detections[0].detection_score == 1.0
start_time, _, _, _ = compute_bounds(transformed.detections[0].geometry)
assert np.isclose(start_time, 2.7)

View File

@ -6,7 +6,7 @@ import pytest
import xarray as xr
from soundevent import data
from batdetect2.postprocess.decoding import (
from batdetect2.outputs.transforms.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,