mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Moved decoding to outputs
This commit is contained in:
parent
31d4f92359
commit
daff74fdde
@ -26,7 +26,7 @@ from batdetect2.outputs import (
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
@ -215,13 +215,8 @@ class BatDetect2API:
|
||||
detections = self.model.postprocessor(
|
||||
outputs,
|
||||
)[0]
|
||||
raw_predictions = to_raw_predictions(
|
||||
detections.numpy(),
|
||||
targets=self.targets,
|
||||
)
|
||||
|
||||
return self.output_transform.transform_detections(
|
||||
raw_predictions,
|
||||
return self.output_transform.to_detections(
|
||||
detections=detections,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
@ -321,7 +316,8 @@ class BatDetect2API:
|
||||
config=config.outputs.format,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform
|
||||
config=config.outputs.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
return cls(
|
||||
@ -375,7 +371,8 @@ class BatDetect2API:
|
||||
config=config.outputs.format,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform
|
||||
config=config.outputs.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
return cls(
|
||||
|
||||
@ -61,7 +61,10 @@ def run_evaluate(
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
output_transform = build_output_transform(config=output_config.transform)
|
||||
output_transform = build_output_transform(
|
||||
config=output_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
module = EvaluationModule(
|
||||
model,
|
||||
evaluator,
|
||||
|
||||
@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
@ -24,7 +23,9 @@ class EvaluationModule(LightningModule):
|
||||
|
||||
self.model = model
|
||||
self.evaluator = evaluator
|
||||
self.output_transform = output_transform or build_output_transform()
|
||||
self.output_transform = output_transform or build_output_transform(
|
||||
targets=evaluator.targets
|
||||
)
|
||||
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[ClipDetections] = []
|
||||
@ -39,18 +40,14 @@ class EvaluationModule(LightningModule):
|
||||
outputs = self.model.detector(batch.spec)
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
predictions = [
|
||||
ClipDetections(
|
||||
self.output_transform.to_clip_detections(
|
||||
detections=clip_dets,
|
||||
clip=clip_annotation.clip,
|
||||
detections=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
),
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections, strict=False
|
||||
)
|
||||
]
|
||||
predictions = self.output_transform(predictions)
|
||||
|
||||
self.clip_annotations.extend(clip_annotations)
|
||||
self.predictions.extend(predictions)
|
||||
|
||||
@ -44,6 +44,7 @@ def run_batch_inference(
|
||||
targets = targets or build_targets()
|
||||
output_transform = output_transform or build_output_transform(
|
||||
config=config.outputs.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
loader = build_inference_loader(
|
||||
|
||||
@ -6,7 +6,6 @@ from torch.utils.data import DataLoader
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
@ -18,7 +17,9 @@ class InferenceModule(LightningModule):
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.output_transform = output_transform or build_output_transform()
|
||||
self.output_transform = output_transform or build_output_transform(
|
||||
targets=model.targets
|
||||
)
|
||||
|
||||
def predict_step(
|
||||
self,
|
||||
@ -34,19 +35,14 @@ class InferenceModule(LightningModule):
|
||||
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
|
||||
predictions = [
|
||||
ClipDetections(
|
||||
return [
|
||||
self.output_transform.to_clip_detections(
|
||||
detections=clip_dets,
|
||||
clip=clip,
|
||||
detections=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.model.targets,
|
||||
),
|
||||
)
|
||||
for clip, clip_dets in zip(clips, clip_detections, strict=False)
|
||||
for clip, clip_dets in zip(clips, clip_detections, strict=True)
|
||||
]
|
||||
|
||||
return self.output_transform(predictions)
|
||||
|
||||
def get_dataset(self) -> InferenceDataset:
|
||||
dataloaders = self.trainer.predict_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
|
||||
@ -11,9 +11,9 @@ from batdetect2.outputs.formats import (
|
||||
)
|
||||
from batdetect2.outputs.transforms import (
|
||||
OutputTransformConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.outputs.types import OutputTransformProtocol
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2OutputConfig",
|
||||
|
||||
@ -1,89 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import replace
|
||||
from typing import Protocol
|
||||
|
||||
from soundevent.geometry import shift_geometry
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
|
||||
__all__ = [
|
||||
"OutputTransform",
|
||||
"OutputTransformConfig",
|
||||
"OutputTransformProtocol",
|
||||
"build_output_transform",
|
||||
]
|
||||
|
||||
|
||||
class OutputTransformConfig(BaseConfig):
|
||||
shift_time_to_clip_start: bool = True
|
||||
|
||||
|
||||
class OutputTransformProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> list[ClipDetections]: ...
|
||||
|
||||
def transform_detections(
|
||||
self,
|
||||
detections: Sequence[Detection],
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]: ...
|
||||
|
||||
|
||||
def shift_detection_time(detection: Detection, time: float) -> Detection:
|
||||
geometry = shift_geometry(detection.geometry, time=time)
|
||||
return replace(detection, geometry=geometry)
|
||||
|
||||
|
||||
class OutputTransform(OutputTransformProtocol):
|
||||
def __init__(self, shift_time_to_clip_start: bool = True):
|
||||
self.shift_time_to_clip_start = shift_time_to_clip_start
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> list[ClipDetections]:
|
||||
return [
|
||||
self.transform_prediction(prediction) for prediction in predictions
|
||||
]
|
||||
|
||||
def transform_prediction(
|
||||
self, prediction: ClipDetections
|
||||
) -> ClipDetections:
|
||||
if not self.shift_time_to_clip_start:
|
||||
return prediction
|
||||
|
||||
detections = self.transform_detections(
|
||||
prediction.detections,
|
||||
start_time=prediction.clip.start_time,
|
||||
)
|
||||
return ClipDetections(clip=prediction.clip, detections=detections)
|
||||
|
||||
def transform_detections(
|
||||
self,
|
||||
detections: Sequence[Detection],
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]:
|
||||
if not self.shift_time_to_clip_start or start_time == 0:
|
||||
return list(detections)
|
||||
|
||||
return [
|
||||
shift_detection_time(detection, time=start_time)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
|
||||
def build_output_transform(
|
||||
config: OutputTransformConfig | dict | None = None,
|
||||
) -> OutputTransformProtocol:
|
||||
if config is None:
|
||||
config = OutputTransformConfig()
|
||||
|
||||
if not isinstance(config, OutputTransformConfig):
|
||||
config = OutputTransformConfig.model_validate(config)
|
||||
|
||||
return OutputTransform(
|
||||
shift_time_to_clip_start=config.shift_time_to_clip_start,
|
||||
)
|
||||
173
src/batdetect2/outputs/transforms/__init__.py
Normal file
173
src/batdetect2/outputs/transforms/__init__.py
Normal file
@ -0,0 +1,173 @@
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.outputs.transforms.clip_transforms import (
|
||||
ClipDetectionsTransformConfig,
|
||||
)
|
||||
from batdetect2.outputs.transforms.clip_transforms import (
|
||||
clip_transforms as clip_transform_registry,
|
||||
)
|
||||
from batdetect2.outputs.transforms.decoding import to_detections
|
||||
from batdetect2.outputs.transforms.detection_transforms import (
|
||||
DetectionTransformConfig,
|
||||
shift_detections_to_start_time,
|
||||
)
|
||||
from batdetect2.outputs.transforms.detection_transforms import (
|
||||
detection_transforms as detection_transform_registry,
|
||||
)
|
||||
from batdetect2.outputs.types import (
|
||||
ClipDetectionsTransform,
|
||||
DetectionTransform,
|
||||
OutputTransformProtocol,
|
||||
)
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionsTransformConfig",
|
||||
"DetectionTransformConfig",
|
||||
"OutputTransform",
|
||||
"OutputTransformConfig",
|
||||
"build_output_transform",
|
||||
]
|
||||
|
||||
|
||||
class OutputTransformConfig(BaseConfig):
|
||||
detection_transforms: list[DetectionTransformConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
clip_transforms: list[ClipDetectionsTransformConfig] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
|
||||
|
||||
class OutputTransform(OutputTransformProtocol):
|
||||
detection_transform_steps: list[DetectionTransform]
|
||||
clip_transform_steps: list[ClipDetectionsTransform]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
detection_transform_steps: Sequence[DetectionTransform] = (),
|
||||
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
|
||||
):
|
||||
self.targets = targets
|
||||
self.detection_transform_steps = list(detection_transform_steps)
|
||||
self.clip_transform_steps = list(clip_transform_steps)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> list[ClipDetections]:
|
||||
return [
|
||||
self._transform_prediction(prediction)
|
||||
for prediction in predictions
|
||||
]
|
||||
|
||||
def _transform_prediction(
|
||||
self,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipDetections:
|
||||
detections = shift_detections_to_start_time(
|
||||
prediction.detections,
|
||||
start_time=prediction.clip.start_time,
|
||||
)
|
||||
detections = self.transform_detections(detections)
|
||||
return self.transform_clip_detections(
|
||||
ClipDetections(clip=prediction.clip, detections=detections)
|
||||
)
|
||||
|
||||
def to_detections(
|
||||
self,
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]:
|
||||
decoded = to_detections(detections.numpy(), targets=self.targets)
|
||||
shifted = shift_detections_to_start_time(
|
||||
decoded,
|
||||
start_time=start_time,
|
||||
)
|
||||
return self.transform_detections(shifted)
|
||||
|
||||
def to_clip_detections(
|
||||
self,
|
||||
detections: ClipDetectionsTensor,
|
||||
clip: data.Clip,
|
||||
) -> ClipDetections:
|
||||
prediction = ClipDetections(
|
||||
clip=clip,
|
||||
detections=self.to_detections(
|
||||
detections,
|
||||
start_time=clip.start_time,
|
||||
),
|
||||
)
|
||||
return self.transform_clip_detections(prediction)
|
||||
|
||||
def transform_detections(
|
||||
self,
|
||||
detections: Sequence[Detection],
|
||||
) -> list[Detection]:
|
||||
out: list[Detection] = []
|
||||
for detection in detections:
|
||||
transformed = self.transform_detection(detection)
|
||||
|
||||
if transformed is None:
|
||||
continue
|
||||
|
||||
out.append(transformed)
|
||||
|
||||
return []
|
||||
|
||||
def transform_detection(
|
||||
self,
|
||||
detection: Detection,
|
||||
) -> Detection | None:
|
||||
for transform in self.detection_transform_steps:
|
||||
detection = transform(detection) # type: ignore
|
||||
|
||||
if detection is None:
|
||||
return None
|
||||
|
||||
return detection
|
||||
|
||||
def transform_clip_detections(
|
||||
self,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipDetections:
|
||||
for transform in self.clip_transform_steps:
|
||||
prediction = transform(prediction)
|
||||
return prediction
|
||||
|
||||
|
||||
def build_output_transform(
|
||||
config: OutputTransformConfig | dict | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
) -> OutputTransformProtocol:
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
if config is None:
|
||||
config = OutputTransformConfig()
|
||||
|
||||
if not isinstance(config, OutputTransformConfig):
|
||||
config = OutputTransformConfig.model_validate(config)
|
||||
|
||||
targets = targets or build_targets()
|
||||
|
||||
return OutputTransform(
|
||||
targets=targets,
|
||||
detection_transform_steps=[
|
||||
detection_transform_registry.build(transform_config)
|
||||
for transform_config in config.detection_transforms
|
||||
],
|
||||
clip_transform_steps=[
|
||||
clip_transform_registry.build(transform_config)
|
||||
for transform_config in config.clip_transforms
|
||||
],
|
||||
)
|
||||
31
src/batdetect2/outputs/transforms/clip_transforms.py
Normal file
31
src/batdetect2/outputs/transforms/clip_transforms.py
Normal file
@ -0,0 +1,31 @@
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.outputs.types import ClipDetectionsTransform
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionsTransformConfig",
|
||||
"clip_transforms",
|
||||
]
|
||||
|
||||
|
||||
clip_transforms: Registry[ClipDetectionsTransform, []] = Registry(
|
||||
"clip_detection_transform"
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(clip_transforms)
|
||||
class ClipDetectionsTransformImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
ClipDetectionsTransformConfig = Annotated[
|
||||
ClipDetectionsTransformImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
@ -1,33 +1,28 @@
|
||||
"""Decodes extracted detection data into standard soundevent predictions."""
|
||||
"""Decode extracted tensors into output-friendly detection objects."""
|
||||
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetectionsArray,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.postprocess.types import ClipDetectionsArray, Detection
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"to_raw_predictions",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"convert_raw_prediction_to_sound_event_prediction",
|
||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||
"convert_raw_prediction_to_sound_event_prediction",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"get_class_tags",
|
||||
"get_generic_tags",
|
||||
"get_prediction_features",
|
||||
"to_detections",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||
"""Default threshold applied to classification scores.
|
||||
|
||||
Class predictions with scores below this value are typically ignored during
|
||||
decoding.
|
||||
"""
|
||||
|
||||
|
||||
def to_raw_predictions(
|
||||
def to_detections(
|
||||
detections: ClipDetectionsArray,
|
||||
targets: TargetProtocol,
|
||||
) -> List[Detection]:
|
||||
@ -69,7 +64,6 @@ def convert_raw_predictions_to_clip_prediction(
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
) -> data.ClipPrediction:
|
||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
|
||||
return data.ClipPrediction(
|
||||
clip=clip,
|
||||
sound_events=[
|
||||
@ -92,7 +86,6 @@ def convert_raw_prediction_to_sound_event_prediction(
|
||||
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
top_class_only: bool = False,
|
||||
):
|
||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
||||
sound_event = data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=raw_prediction.geometry,
|
||||
@ -123,7 +116,6 @@ def get_generic_tags(
|
||||
detection_score: float,
|
||||
generic_class_tags: List[data.Tag],
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Create PredictedTag objects for the generic category."""
|
||||
return [
|
||||
data.PredictedTag(tag=tag, score=detection_score)
|
||||
for tag in generic_class_tags
|
||||
@ -131,7 +123,6 @@ def get_generic_tags(
|
||||
|
||||
|
||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||
"""Convert an extracted feature vector DataArray into soundevent Features."""
|
||||
return [
|
||||
data.Feature(
|
||||
term=data.Term(
|
||||
@ -151,39 +142,11 @@ def get_class_tags(
|
||||
top_class_only: bool = False,
|
||||
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.PredictedTag]:
|
||||
"""Generate specific PredictedTags based on class scores and decoder.
|
||||
|
||||
Filters class scores by the threshold, sorts remaining scores descending,
|
||||
decodes the class name(s) into base tags using the `sound_event_decoder`,
|
||||
and creates `PredictedTag` objects associating the class score. Stops after
|
||||
the first (top) class if `top_class_only` is True.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
class_scores : xr.DataArray
|
||||
A 1D xarray DataArray containing class probabilities/scores, indexed
|
||||
by a 'category' coordinate holding the class names.
|
||||
sound_event_decoder : SoundEventDecoder
|
||||
Function to map a class name string to a list of base `data.Tag`
|
||||
objects.
|
||||
top_class_only : bool, default=False
|
||||
If True, only generate tags for the single highest-scoring class above
|
||||
the threshold.
|
||||
threshold : float, optional
|
||||
Minimum score for a class to be considered. If None, all classes are
|
||||
processed (or top-1 if `top_class_only` is True). Defaults to
|
||||
`DEFAULT_CLASSIFICATION_THRESHOLD`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.PredictedTag]
|
||||
A list of `PredictedTag` objects for the class(es) that passed the
|
||||
threshold, ordered by score if `top_class_only` is False.
|
||||
"""
|
||||
tags = []
|
||||
|
||||
for class_name, score in _iterate_sorted(
|
||||
class_scores, targets.class_names
|
||||
class_scores,
|
||||
targets.class_names,
|
||||
):
|
||||
if threshold is not None and score < threshold:
|
||||
continue
|
||||
55
src/batdetect2/outputs/transforms/detection_transforms.py
Normal file
55
src/batdetect2/outputs/transforms/detection_transforms.py
Normal file
@ -0,0 +1,55 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import replace
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.geometry import shift_geometry
|
||||
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.outputs.types import DetectionTransform
|
||||
from batdetect2.postprocess.types import Detection
|
||||
|
||||
__all__ = [
|
||||
"DetectionTransformConfig",
|
||||
"detection_transforms",
|
||||
"shift_detection_time",
|
||||
"shift_detections_to_start_time",
|
||||
]
|
||||
|
||||
|
||||
detection_transforms: Registry[DetectionTransform, []] = Registry(
|
||||
"detection_transform"
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(detection_transforms)
|
||||
class DetectionTransformImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
DetectionTransformConfig = Annotated[
|
||||
DetectionTransformImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def shift_detection_time(detection: Detection, time: float) -> Detection:
|
||||
geometry = shift_geometry(detection.geometry, time=time)
|
||||
return replace(detection, geometry=geometry)
|
||||
|
||||
|
||||
def shift_detections_to_start_time(
|
||||
detections: Sequence[Detection],
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]:
|
||||
if start_time == 0:
|
||||
return list(detections)
|
||||
|
||||
return [
|
||||
shift_detection_time(detection, time=start_time)
|
||||
for detection in detections
|
||||
]
|
||||
@ -1,12 +1,20 @@
|
||||
from collections.abc import Sequence
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Generic, Protocol, TypeVar
|
||||
|
||||
from soundevent import data
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ClipDetectionsTransform",
|
||||
"DetectionTransform",
|
||||
"OutputFormatterProtocol",
|
||||
"OutputTransformProtocol",
|
||||
]
|
||||
|
||||
T = TypeVar("T")
|
||||
@ -23,3 +31,31 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||
) -> None: ...
|
||||
|
||||
def load(self, path: PathLike) -> list[T]: ...
|
||||
|
||||
|
||||
DetectionTransform = Callable[[Detection], Detection | None]
|
||||
ClipDetectionsTransform = Callable[[ClipDetections], ClipDetections]
|
||||
|
||||
|
||||
class OutputTransformProtocol(Protocol):
|
||||
def to_detections(
|
||||
self,
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]: ...
|
||||
|
||||
def to_clip_detections(
|
||||
self,
|
||||
detections: ClipDetectionsTensor,
|
||||
clip: data.Clip,
|
||||
) -> ClipDetections: ...
|
||||
|
||||
def transform_detections(
|
||||
self,
|
||||
detections: Sequence[Detection],
|
||||
) -> list[Detection]: ...
|
||||
|
||||
def transform_clip_detections(
|
||||
self,
|
||||
prediction: ClipDetections,
|
||||
) -> ClipDetections: ...
|
||||
|
||||
@ -4,10 +4,6 @@ from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
load_postprocess_config,
|
||||
)
|
||||
from batdetect2.postprocess.decoding import (
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
to_raw_predictions,
|
||||
)
|
||||
from batdetect2.postprocess.nms import non_max_suppression
|
||||
from batdetect2.postprocess.postprocessor import (
|
||||
Postprocessor,
|
||||
@ -18,8 +14,6 @@ __all__ = [
|
||||
"PostprocessConfig",
|
||||
"Postprocessor",
|
||||
"build_postprocessor",
|
||||
"convert_raw_predictions_to_clip_prediction",
|
||||
"to_raw_predictions",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
]
|
||||
|
||||
@ -2,7 +2,6 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
||||
|
||||
__all__ = [
|
||||
@ -11,6 +10,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||
|
||||
|
||||
TOP_K_PER_SEC = 100
|
||||
|
||||
@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -25,7 +24,7 @@ class ValidationMetrics(Callback):
|
||||
super().__init__()
|
||||
|
||||
self.evaluator = evaluator
|
||||
self.output_transform = output_transform or build_output_transform()
|
||||
self.output_transform = output_transform
|
||||
|
||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||
self._predictions: List[ClipDetections] = []
|
||||
@ -92,6 +91,14 @@ class ValidationMetrics(Callback):
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
model = pl_module.model
|
||||
if self.output_transform is None:
|
||||
self.output_transform = build_output_transform(
|
||||
targets=model.targets
|
||||
)
|
||||
|
||||
output_transform = self.output_transform
|
||||
assert output_transform is not None
|
||||
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
@ -101,17 +108,14 @@ class ValidationMetrics(Callback):
|
||||
|
||||
clip_detections = model.postprocessor(outputs)
|
||||
predictions = [
|
||||
ClipDetections(
|
||||
output_transform.to_clip_detections(
|
||||
detections=clip_dets,
|
||||
clip=clip_annotation.clip,
|
||||
detections=to_raw_predictions(
|
||||
clip_dets.numpy(), targets=model.targets
|
||||
),
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections, strict=False
|
||||
)
|
||||
]
|
||||
predictions = self.output_transform(predictions)
|
||||
|
||||
self._clip_annotations.extend(clip_annotations)
|
||||
self._predictions.extend(predictions)
|
||||
|
||||
@ -1,12 +1,35 @@
|
||||
from dataclasses import replace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
||||
from batdetect2.outputs.transforms import OutputTransform
|
||||
from batdetect2.postprocess.types import (
|
||||
ClipDetections,
|
||||
ClipDetectionsTensor,
|
||||
Detection,
|
||||
)
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
|
||||
def test_shift_time_to_clip_start(clip: data.Clip):
|
||||
def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
||||
return ClipDetectionsTensor(
|
||||
scores=torch.tensor([0.9], dtype=torch.float32),
|
||||
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32),
|
||||
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
|
||||
times=torch.tensor([0.2], dtype=torch.float32),
|
||||
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
|
||||
features=torch.tensor([[1.0, 2.0]], dtype=torch.float32),
|
||||
)
|
||||
|
||||
|
||||
def test_shift_time_to_clip_start(
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
):
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
|
||||
detection = Detection(
|
||||
@ -16,7 +39,7 @@ def test_shift_time_to_clip_start(clip: data.Clip):
|
||||
features=np.array([1.0, 2.0]),
|
||||
)
|
||||
|
||||
transformed = build_output_transform()(
|
||||
transformed = OutputTransform(targets=sample_targets)(
|
||||
[ClipDetections(clip=clip, detections=[detection])]
|
||||
)[0]
|
||||
|
||||
@ -28,21 +51,71 @@ def test_shift_time_to_clip_start(clip: data.Clip):
|
||||
assert np.isclose(end_time, 2.7)
|
||||
|
||||
|
||||
def test_transform_identity_when_disabled(clip: data.Clip):
|
||||
def test_to_clip_detections_shifts_by_clip_start(
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
):
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
transform = build_output_transform(targets=sample_targets)
|
||||
raw = _mock_clip_detections_tensor()
|
||||
|
||||
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
||||
unshifted = transform.to_detections(detections=raw, start_time=0)
|
||||
|
||||
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry)
|
||||
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
|
||||
|
||||
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
|
||||
|
||||
|
||||
def test_detection_and_clip_transforms_applied_in_order(
|
||||
clip: data.Clip,
|
||||
sample_targets: TargetProtocol,
|
||||
):
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
|
||||
detection = Detection(
|
||||
detection_1 = Detection(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
||||
detection_score=0.9,
|
||||
detection_score=0.5,
|
||||
class_scores=np.array([0.9]),
|
||||
features=np.array([1.0, 2.0]),
|
||||
)
|
||||
detection_2 = Detection(
|
||||
geometry=data.BoundingBox(coordinates=[0.2, 10_000, 0.3, 12_000]),
|
||||
detection_score=0.7,
|
||||
class_scores=np.array([0.9]),
|
||||
features=np.array([1.0, 2.0]),
|
||||
)
|
||||
|
||||
transform = build_output_transform(
|
||||
config={"shift_time_to_clip_start": False}
|
||||
def boost_score(detection: Detection) -> Detection:
|
||||
return replace(
|
||||
detection,
|
||||
detection_score=detection.detection_score + 0.2,
|
||||
)
|
||||
|
||||
def keep_high_score(detection: Detection) -> Detection | None:
|
||||
if detection.detection_score < 0.8:
|
||||
return None
|
||||
return detection
|
||||
|
||||
def tag_clip_transform(prediction: ClipDetections) -> ClipDetections:
|
||||
detections = [
|
||||
replace(detection, detection_score=1.0)
|
||||
for detection in prediction.detections
|
||||
]
|
||||
return replace(prediction, detections=detections)
|
||||
|
||||
transform = OutputTransform(
|
||||
targets=sample_targets,
|
||||
detection_transform_steps=[boost_score, keep_high_score],
|
||||
clip_transform_steps=[tag_clip_transform],
|
||||
)
|
||||
transformed = transform(
|
||||
[ClipDetections(clip=clip, detections=[detection])]
|
||||
[ClipDetections(clip=clip, detections=[detection_1, detection_2])]
|
||||
)[0]
|
||||
|
||||
assert transformed.detections[0].geometry == detection.geometry
|
||||
assert len(transformed.detections) == 1
|
||||
assert transformed.detections[0].detection_score == 1.0
|
||||
|
||||
start_time, _, _, _ = compute_bounds(transformed.detections[0].geometry)
|
||||
assert np.isclose(start_time, 2.7)
|
||||
|
||||
@ -6,7 +6,7 @@ import pytest
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.decoding import (
|
||||
from batdetect2.outputs.transforms.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user