Moved decoding to outputs

This commit is contained in:
mbsantiago 2026-03-18 01:35:34 +00:00
parent 31d4f92359
commit daff74fdde
17 changed files with 429 additions and 195 deletions

View File

@ -26,7 +26,7 @@ from batdetect2.outputs import (
get_output_formatter, get_output_formatter,
) )
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess import build_postprocessor, to_raw_predictions from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.types import ( from batdetect2.postprocess.types import (
ClipDetections, ClipDetections,
Detection, Detection,
@ -215,13 +215,8 @@ class BatDetect2API:
detections = self.model.postprocessor( detections = self.model.postprocessor(
outputs, outputs,
)[0] )[0]
raw_predictions = to_raw_predictions( return self.output_transform.to_detections(
detections.numpy(), detections=detections,
targets=self.targets,
)
return self.output_transform.transform_detections(
raw_predictions,
start_time=start_time, start_time=start_time,
) )
@ -321,7 +316,8 @@ class BatDetect2API:
config=config.outputs.format, config=config.outputs.format,
) )
output_transform = build_output_transform( output_transform = build_output_transform(
config=config.outputs.transform config=config.outputs.transform,
targets=targets,
) )
return cls( return cls(
@ -375,7 +371,8 @@ class BatDetect2API:
config=config.outputs.format, config=config.outputs.format,
) )
output_transform = build_output_transform( output_transform = build_output_transform(
config=config.outputs.transform config=config.outputs.transform,
targets=targets,
) )
return cls( return cls(

View File

@ -61,7 +61,10 @@ def run_evaluate(
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
) )
output_transform = build_output_transform(config=output_config.transform) output_transform = build_output_transform(
config=output_config.transform,
targets=targets,
)
module = EvaluationModule( module = EvaluationModule(
model, model,
evaluator, evaluator,

View File

@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
@ -24,7 +23,9 @@ class EvaluationModule(LightningModule):
self.model = model self.model = model
self.evaluator = evaluator self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform() self.output_transform = output_transform or build_output_transform(
targets=evaluator.targets
)
self.clip_annotations: List[data.ClipAnnotation] = [] self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[ClipDetections] = [] self.predictions: List[ClipDetections] = []
@ -39,18 +40,14 @@ class EvaluationModule(LightningModule):
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(outputs) clip_detections = self.model.postprocessor(outputs)
predictions = [ predictions = [
ClipDetections( self.output_transform.to_clip_detections(
detections=clip_dets,
clip=clip_annotation.clip, clip=clip_annotation.clip,
detections=to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
),
) )
for clip_annotation, clip_dets in zip( for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections, strict=False clip_annotations, clip_detections, strict=False
) )
] ]
predictions = self.output_transform(predictions)
self.clip_annotations.extend(clip_annotations) self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions) self.predictions.extend(predictions)

View File

@ -44,6 +44,7 @@ def run_batch_inference(
targets = targets or build_targets() targets = targets or build_targets()
output_transform = output_transform or build_output_transform( output_transform = output_transform or build_output_transform(
config=config.outputs.transform, config=config.outputs.transform,
targets=targets,
) )
loader = build_inference_loader( loader = build_inference_loader(

View File

@ -6,7 +6,6 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
@ -18,7 +17,9 @@ class InferenceModule(LightningModule):
): ):
super().__init__() super().__init__()
self.model = model self.model = model
self.output_transform = output_transform or build_output_transform() self.output_transform = output_transform or build_output_transform(
targets=model.targets
)
def predict_step( def predict_step(
self, self,
@ -34,19 +35,14 @@ class InferenceModule(LightningModule):
clip_detections = self.model.postprocessor(outputs) clip_detections = self.model.postprocessor(outputs)
predictions = [ return [
ClipDetections( self.output_transform.to_clip_detections(
detections=clip_dets,
clip=clip, clip=clip,
detections=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),
) )
for clip, clip_dets in zip(clips, clip_detections, strict=False) for clip, clip_dets in zip(clips, clip_detections, strict=True)
] ]
return self.output_transform(predictions)
def get_dataset(self) -> InferenceDataset: def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders dataloaders = self.trainer.predict_dataloaders
assert isinstance(dataloaders, DataLoader) assert isinstance(dataloaders, DataLoader)

View File

@ -11,9 +11,9 @@ from batdetect2.outputs.formats import (
) )
from batdetect2.outputs.transforms import ( from batdetect2.outputs.transforms import (
OutputTransformConfig, OutputTransformConfig,
OutputTransformProtocol,
build_output_transform, build_output_transform,
) )
from batdetect2.outputs.types import OutputTransformProtocol
__all__ = [ __all__ = [
"BatDetect2OutputConfig", "BatDetect2OutputConfig",

View File

@ -1,89 +0,0 @@
from collections.abc import Sequence
from dataclasses import replace
from typing import Protocol
from soundevent.geometry import shift_geometry
from batdetect2.core.configs import BaseConfig
from batdetect2.postprocess.types import ClipDetections, Detection
__all__ = [
"OutputTransform",
"OutputTransformConfig",
"OutputTransformProtocol",
"build_output_transform",
]
class OutputTransformConfig(BaseConfig):
shift_time_to_clip_start: bool = True
class OutputTransformProtocol(Protocol):
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]: ...
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]: ...
def shift_detection_time(detection: Detection, time: float) -> Detection:
geometry = shift_geometry(detection.geometry, time=time)
return replace(detection, geometry=geometry)
class OutputTransform(OutputTransformProtocol):
def __init__(self, shift_time_to_clip_start: bool = True):
self.shift_time_to_clip_start = shift_time_to_clip_start
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]:
return [
self.transform_prediction(prediction) for prediction in predictions
]
def transform_prediction(
self, prediction: ClipDetections
) -> ClipDetections:
if not self.shift_time_to_clip_start:
return prediction
detections = self.transform_detections(
prediction.detections,
start_time=prediction.clip.start_time,
)
return ClipDetections(clip=prediction.clip, detections=detections)
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]:
if not self.shift_time_to_clip_start or start_time == 0:
return list(detections)
return [
shift_detection_time(detection, time=start_time)
for detection in detections
]
def build_output_transform(
config: OutputTransformConfig | dict | None = None,
) -> OutputTransformProtocol:
if config is None:
config = OutputTransformConfig()
if not isinstance(config, OutputTransformConfig):
config = OutputTransformConfig.model_validate(config)
return OutputTransform(
shift_time_to_clip_start=config.shift_time_to_clip_start,
)

View File

@ -0,0 +1,173 @@
from collections.abc import Sequence
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.outputs.transforms.clip_transforms import (
ClipDetectionsTransformConfig,
)
from batdetect2.outputs.transforms.clip_transforms import (
clip_transforms as clip_transform_registry,
)
from batdetect2.outputs.transforms.decoding import to_detections
from batdetect2.outputs.transforms.detection_transforms import (
DetectionTransformConfig,
shift_detections_to_start_time,
)
from batdetect2.outputs.transforms.detection_transforms import (
detection_transforms as detection_transform_registry,
)
from batdetect2.outputs.types import (
ClipDetectionsTransform,
DetectionTransform,
OutputTransformProtocol,
)
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
__all__ = [
"ClipDetectionsTransformConfig",
"DetectionTransformConfig",
"OutputTransform",
"OutputTransformConfig",
"build_output_transform",
]
class OutputTransformConfig(BaseConfig):
detection_transforms: list[DetectionTransformConfig] = Field(
default_factory=list
)
clip_transforms: list[ClipDetectionsTransformConfig] = Field(
default_factory=list
)
class OutputTransform(OutputTransformProtocol):
detection_transform_steps: list[DetectionTransform]
clip_transform_steps: list[ClipDetectionsTransform]
def __init__(
self,
targets: TargetProtocol,
detection_transform_steps: Sequence[DetectionTransform] = (),
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
):
self.targets = targets
self.detection_transform_steps = list(detection_transform_steps)
self.clip_transform_steps = list(clip_transform_steps)
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]:
return [
self._transform_prediction(prediction)
for prediction in predictions
]
def _transform_prediction(
self,
prediction: ClipDetections,
) -> ClipDetections:
detections = shift_detections_to_start_time(
prediction.detections,
start_time=prediction.clip.start_time,
)
detections = self.transform_detections(detections)
return self.transform_clip_detections(
ClipDetections(clip=prediction.clip, detections=detections)
)
def to_detections(
self,
detections: ClipDetectionsTensor,
start_time: float = 0,
) -> list[Detection]:
decoded = to_detections(detections.numpy(), targets=self.targets)
shifted = shift_detections_to_start_time(
decoded,
start_time=start_time,
)
return self.transform_detections(shifted)
def to_clip_detections(
self,
detections: ClipDetectionsTensor,
clip: data.Clip,
) -> ClipDetections:
prediction = ClipDetections(
clip=clip,
detections=self.to_detections(
detections,
start_time=clip.start_time,
),
)
return self.transform_clip_detections(prediction)
def transform_detections(
self,
detections: Sequence[Detection],
) -> list[Detection]:
out: list[Detection] = []
for detection in detections:
transformed = self.transform_detection(detection)
if transformed is None:
continue
out.append(transformed)
return []
def transform_detection(
self,
detection: Detection,
) -> Detection | None:
for transform in self.detection_transform_steps:
detection = transform(detection) # type: ignore
if detection is None:
return None
return detection
def transform_clip_detections(
self,
prediction: ClipDetections,
) -> ClipDetections:
for transform in self.clip_transform_steps:
prediction = transform(prediction)
return prediction
def build_output_transform(
config: OutputTransformConfig | dict | None = None,
targets: TargetProtocol | None = None,
) -> OutputTransformProtocol:
from batdetect2.targets import build_targets
if config is None:
config = OutputTransformConfig()
if not isinstance(config, OutputTransformConfig):
config = OutputTransformConfig.model_validate(config)
targets = targets or build_targets()
return OutputTransform(
targets=targets,
detection_transform_steps=[
detection_transform_registry.build(transform_config)
for transform_config in config.detection_transforms
],
clip_transform_steps=[
clip_transform_registry.build(transform_config)
for transform_config in config.clip_transforms
],
)

View File

@ -0,0 +1,31 @@
from typing import Annotated, Literal
from pydantic import Field
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.outputs.types import ClipDetectionsTransform
__all__ = [
"ClipDetectionsTransformConfig",
"clip_transforms",
]
clip_transforms: Registry[ClipDetectionsTransform, []] = Registry(
"clip_detection_transform"
)
@add_import_config(clip_transforms)
class ClipDetectionsTransformImportConfig(ImportConfig):
name: Literal["import"] = "import"
ClipDetectionsTransformConfig = Annotated[
ClipDetectionsTransformImportConfig,
Field(discriminator="name"),
]

View File

@ -1,33 +1,28 @@
"""Decodes extracted detection data into standard soundevent predictions.""" """Decode extracted tensors into output-friendly detection objects."""
from typing import List from typing import List
import numpy as np import numpy as np
from soundevent import data from soundevent import data
from batdetect2.postprocess.types import ( from batdetect2.postprocess.types import ClipDetectionsArray, Detection
ClipDetectionsArray,
Detection,
)
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
__all__ = [ __all__ = [
"to_raw_predictions",
"convert_raw_predictions_to_clip_prediction",
"convert_raw_prediction_to_sound_event_prediction",
"DEFAULT_CLASSIFICATION_THRESHOLD", "DEFAULT_CLASSIFICATION_THRESHOLD",
"convert_raw_prediction_to_sound_event_prediction",
"convert_raw_predictions_to_clip_prediction",
"get_class_tags",
"get_generic_tags",
"get_prediction_features",
"to_detections",
] ]
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1 DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
"""Default threshold applied to classification scores.
Class predictions with scores below this value are typically ignored during
decoding.
"""
def to_raw_predictions( def to_detections(
detections: ClipDetectionsArray, detections: ClipDetectionsArray,
targets: TargetProtocol, targets: TargetProtocol,
) -> List[Detection]: ) -> List[Detection]:
@ -69,7 +64,6 @@ def convert_raw_predictions_to_clip_prediction(
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False, top_class_only: bool = False,
) -> data.ClipPrediction: ) -> data.ClipPrediction:
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
return data.ClipPrediction( return data.ClipPrediction(
clip=clip, clip=clip,
sound_events=[ sound_events=[
@ -92,7 +86,6 @@ def convert_raw_prediction_to_sound_event_prediction(
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
top_class_only: bool = False, top_class_only: bool = False,
): ):
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
sound_event = data.SoundEvent( sound_event = data.SoundEvent(
recording=recording, recording=recording,
geometry=raw_prediction.geometry, geometry=raw_prediction.geometry,
@ -123,7 +116,6 @@ def get_generic_tags(
detection_score: float, detection_score: float,
generic_class_tags: List[data.Tag], generic_class_tags: List[data.Tag],
) -> List[data.PredictedTag]: ) -> List[data.PredictedTag]:
"""Create PredictedTag objects for the generic category."""
return [ return [
data.PredictedTag(tag=tag, score=detection_score) data.PredictedTag(tag=tag, score=detection_score)
for tag in generic_class_tags for tag in generic_class_tags
@ -131,7 +123,6 @@ def get_generic_tags(
def get_prediction_features(features: np.ndarray) -> List[data.Feature]: def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
"""Convert an extracted feature vector DataArray into soundevent Features."""
return [ return [
data.Feature( data.Feature(
term=data.Term( term=data.Term(
@ -151,39 +142,11 @@ def get_class_tags(
top_class_only: bool = False, top_class_only: bool = False,
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD, threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.PredictedTag]: ) -> List[data.PredictedTag]:
"""Generate specific PredictedTags based on class scores and decoder.
Filters class scores by the threshold, sorts remaining scores descending,
decodes the class name(s) into base tags using the `sound_event_decoder`,
and creates `PredictedTag` objects associating the class score. Stops after
the first (top) class if `top_class_only` is True.
Parameters
----------
class_scores : xr.DataArray
A 1D xarray DataArray containing class probabilities/scores, indexed
by a 'category' coordinate holding the class names.
sound_event_decoder : SoundEventDecoder
Function to map a class name string to a list of base `data.Tag`
objects.
top_class_only : bool, default=False
If True, only generate tags for the single highest-scoring class above
the threshold.
threshold : float, optional
Minimum score for a class to be considered. If None, all classes are
processed (or top-1 if `top_class_only` is True). Defaults to
`DEFAULT_CLASSIFICATION_THRESHOLD`.
Returns
-------
List[data.PredictedTag]
A list of `PredictedTag` objects for the class(es) that passed the
threshold, ordered by score if `top_class_only` is False.
"""
tags = [] tags = []
for class_name, score in _iterate_sorted( for class_name, score in _iterate_sorted(
class_scores, targets.class_names class_scores,
targets.class_names,
): ):
if threshold is not None and score < threshold: if threshold is not None and score < threshold:
continue continue

View File

@ -0,0 +1,55 @@
from collections.abc import Sequence
from dataclasses import replace
from typing import Annotated, Literal
from pydantic import Field
from soundevent.geometry import shift_geometry
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.outputs.types import DetectionTransform
from batdetect2.postprocess.types import Detection
__all__ = [
"DetectionTransformConfig",
"detection_transforms",
"shift_detection_time",
"shift_detections_to_start_time",
]
detection_transforms: Registry[DetectionTransform, []] = Registry(
"detection_transform"
)
@add_import_config(detection_transforms)
class DetectionTransformImportConfig(ImportConfig):
name: Literal["import"] = "import"
DetectionTransformConfig = Annotated[
DetectionTransformImportConfig,
Field(discriminator="name"),
]
def shift_detection_time(detection: Detection, time: float) -> Detection:
geometry = shift_geometry(detection.geometry, time=time)
return replace(detection, geometry=geometry)
def shift_detections_to_start_time(
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]:
if start_time == 0:
return list(detections)
return [
shift_detection_time(detection, time=start_time)
for detection in detections
]

View File

@ -1,12 +1,20 @@
from collections.abc import Sequence from collections.abc import Callable, Sequence
from typing import Generic, Protocol, TypeVar from typing import Generic, Protocol, TypeVar
from soundevent import data
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
__all__ = [ __all__ = [
"ClipDetectionsTransform",
"DetectionTransform",
"OutputFormatterProtocol", "OutputFormatterProtocol",
"OutputTransformProtocol",
] ]
T = TypeVar("T") T = TypeVar("T")
@ -23,3 +31,31 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
) -> None: ... ) -> None: ...
def load(self, path: PathLike) -> list[T]: ... def load(self, path: PathLike) -> list[T]: ...
DetectionTransform = Callable[[Detection], Detection | None]
ClipDetectionsTransform = Callable[[ClipDetections], ClipDetections]
class OutputTransformProtocol(Protocol):
def to_detections(
self,
detections: ClipDetectionsTensor,
start_time: float = 0,
) -> list[Detection]: ...
def to_clip_detections(
self,
detections: ClipDetectionsTensor,
clip: data.Clip,
) -> ClipDetections: ...
def transform_detections(
self,
detections: Sequence[Detection],
) -> list[Detection]: ...
def transform_clip_detections(
self,
prediction: ClipDetections,
) -> ClipDetections: ...

View File

@ -4,10 +4,6 @@ from batdetect2.postprocess.config import (
PostprocessConfig, PostprocessConfig,
load_postprocess_config, load_postprocess_config,
) )
from batdetect2.postprocess.decoding import (
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.nms import non_max_suppression from batdetect2.postprocess.nms import non_max_suppression
from batdetect2.postprocess.postprocessor import ( from batdetect2.postprocess.postprocessor import (
Postprocessor, Postprocessor,
@ -18,8 +14,6 @@ __all__ = [
"PostprocessConfig", "PostprocessConfig",
"Postprocessor", "Postprocessor",
"build_postprocessor", "build_postprocessor",
"convert_raw_predictions_to_clip_prediction",
"to_raw_predictions",
"load_postprocess_config", "load_postprocess_config",
"non_max_suppression", "non_max_suppression",
] ]

View File

@ -2,7 +2,6 @@ from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
__all__ = [ __all__ = [
@ -11,6 +10,7 @@ __all__ = [
] ]
DEFAULT_DETECTION_THRESHOLD = 0.01 DEFAULT_DETECTION_THRESHOLD = 0.01
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
TOP_K_PER_SEC = 100 TOP_K_PER_SEC = 100

View File

@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -25,7 +24,7 @@ class ValidationMetrics(Callback):
super().__init__() super().__init__()
self.evaluator = evaluator self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform() self.output_transform = output_transform
self._clip_annotations: List[data.ClipAnnotation] = [] self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[ClipDetections] = [] self._predictions: List[ClipDetections] = []
@ -92,6 +91,14 @@ class ValidationMetrics(Callback):
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> None: ) -> None:
model = pl_module.model model = pl_module.model
if self.output_transform is None:
self.output_transform = build_output_transform(
targets=model.targets
)
output_transform = self.output_transform
assert output_transform is not None
dataset = self.get_dataset(trainer) dataset = self.get_dataset(trainer)
clip_annotations = [ clip_annotations = [
@ -101,17 +108,14 @@ class ValidationMetrics(Callback):
clip_detections = model.postprocessor(outputs) clip_detections = model.postprocessor(outputs)
predictions = [ predictions = [
ClipDetections( output_transform.to_clip_detections(
detections=clip_dets,
clip=clip_annotation.clip, clip=clip_annotation.clip,
detections=to_raw_predictions(
clip_dets.numpy(), targets=model.targets
),
) )
for clip_annotation, clip_dets in zip( for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections, strict=False clip_annotations, clip_detections, strict=False
) )
] ]
predictions = self.output_transform(predictions)
self._clip_annotations.extend(clip_annotations) self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions) self._predictions.extend(predictions)

View File

@ -1,12 +1,35 @@
from dataclasses import replace
import numpy as np import numpy as np
import torch
from soundevent import data from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.outputs import build_output_transform from batdetect2.outputs import build_output_transform
from batdetect2.postprocess.types import ClipDetections, Detection from batdetect2.outputs.transforms import OutputTransform
from batdetect2.postprocess.types import (
ClipDetections,
ClipDetectionsTensor,
Detection,
)
from batdetect2.targets.types import TargetProtocol
def test_shift_time_to_clip_start(clip: data.Clip): def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
return ClipDetectionsTensor(
scores=torch.tensor([0.9], dtype=torch.float32),
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32),
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
times=torch.tensor([0.2], dtype=torch.float32),
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
features=torch.tensor([[1.0, 2.0]], dtype=torch.float32),
)
def test_shift_time_to_clip_start(
clip: data.Clip,
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection( detection = Detection(
@ -16,7 +39,7 @@ def test_shift_time_to_clip_start(clip: data.Clip):
features=np.array([1.0, 2.0]), features=np.array([1.0, 2.0]),
) )
transformed = build_output_transform()( transformed = OutputTransform(targets=sample_targets)(
[ClipDetections(clip=clip, detections=[detection])] [ClipDetections(clip=clip, detections=[detection])]
)[0] )[0]
@ -28,21 +51,71 @@ def test_shift_time_to_clip_start(clip: data.Clip):
assert np.isclose(end_time, 2.7) assert np.isclose(end_time, 2.7)
def test_transform_identity_when_disabled(clip: data.Clip): def test_to_clip_detections_shifts_by_clip_start(
clip: data.Clip,
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
transform = build_output_transform(targets=sample_targets)
raw = _mock_clip_detections_tensor()
shifted = transform.to_clip_detections(detections=raw, clip=clip)
unshifted = transform.to_detections(detections=raw, start_time=0)
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry)
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
def test_detection_and_clip_transforms_applied_in_order(
clip: data.Clip,
sample_targets: TargetProtocol,
):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0}) clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection( detection_1 = Detection(
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]), geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
detection_score=0.9, detection_score=0.5,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
detection_2 = Detection(
geometry=data.BoundingBox(coordinates=[0.2, 10_000, 0.3, 12_000]),
detection_score=0.7,
class_scores=np.array([0.9]), class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]), features=np.array([1.0, 2.0]),
) )
transform = build_output_transform( def boost_score(detection: Detection) -> Detection:
config={"shift_time_to_clip_start": False} return replace(
detection,
detection_score=detection.detection_score + 0.2,
)
def keep_high_score(detection: Detection) -> Detection | None:
if detection.detection_score < 0.8:
return None
return detection
def tag_clip_transform(prediction: ClipDetections) -> ClipDetections:
detections = [
replace(detection, detection_score=1.0)
for detection in prediction.detections
]
return replace(prediction, detections=detections)
transform = OutputTransform(
targets=sample_targets,
detection_transform_steps=[boost_score, keep_high_score],
clip_transform_steps=[tag_clip_transform],
) )
transformed = transform( transformed = transform(
[ClipDetections(clip=clip, detections=[detection])] [ClipDetections(clip=clip, detections=[detection_1, detection_2])]
)[0] )[0]
assert transformed.detections[0].geometry == detection.geometry assert len(transformed.detections) == 1
assert transformed.detections[0].detection_score == 1.0
start_time, _, _, _ = compute_bounds(transformed.detections[0].geometry)
assert np.isclose(start_time, 2.7)

View File

@ -6,7 +6,7 @@ import pytest
import xarray as xr import xarray as xr
from soundevent import data from soundevent import data
from batdetect2.postprocess.decoding import ( from batdetect2.outputs.transforms.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD, DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction, convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction, convert_raw_predictions_to_clip_prediction,