mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Moved decoding to outputs
This commit is contained in:
parent
31d4f92359
commit
daff74fdde
@ -26,7 +26,7 @@ from batdetect2.outputs import (
|
|||||||
get_output_formatter,
|
get_output_formatter,
|
||||||
)
|
)
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.postprocess.types import (
|
from batdetect2.postprocess.types import (
|
||||||
ClipDetections,
|
ClipDetections,
|
||||||
Detection,
|
Detection,
|
||||||
@ -215,13 +215,8 @@ class BatDetect2API:
|
|||||||
detections = self.model.postprocessor(
|
detections = self.model.postprocessor(
|
||||||
outputs,
|
outputs,
|
||||||
)[0]
|
)[0]
|
||||||
raw_predictions = to_raw_predictions(
|
return self.output_transform.to_detections(
|
||||||
detections.numpy(),
|
detections=detections,
|
||||||
targets=self.targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.output_transform.transform_detections(
|
|
||||||
raw_predictions,
|
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -321,7 +316,8 @@ class BatDetect2API:
|
|||||||
config=config.outputs.format,
|
config=config.outputs.format,
|
||||||
)
|
)
|
||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=config.outputs.transform
|
config=config.outputs.transform,
|
||||||
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
@ -375,7 +371,8 @@ class BatDetect2API:
|
|||||||
config=config.outputs.format,
|
config=config.outputs.format,
|
||||||
)
|
)
|
||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=config.outputs.transform
|
config=config.outputs.transform,
|
||||||
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
|
|||||||
@ -61,7 +61,10 @@ def run_evaluate(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
output_transform = build_output_transform(config=output_config.transform)
|
output_transform = build_output_transform(
|
||||||
|
config=output_config.transform,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
module = EvaluationModule(
|
module = EvaluationModule(
|
||||||
model,
|
model,
|
||||||
evaluator,
|
evaluator,
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
|
|||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
@ -24,7 +23,9 @@ class EvaluationModule(LightningModule):
|
|||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.output_transform = output_transform or build_output_transform()
|
self.output_transform = output_transform or build_output_transform(
|
||||||
|
targets=evaluator.targets
|
||||||
|
)
|
||||||
|
|
||||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self.predictions: List[ClipDetections] = []
|
self.predictions: List[ClipDetections] = []
|
||||||
@ -39,18 +40,14 @@ class EvaluationModule(LightningModule):
|
|||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
clip_detections = self.model.postprocessor(outputs)
|
clip_detections = self.model.postprocessor(outputs)
|
||||||
predictions = [
|
predictions = [
|
||||||
ClipDetections(
|
self.output_transform.to_clip_detections(
|
||||||
|
detections=clip_dets,
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
detections=to_raw_predictions(
|
|
||||||
clip_dets.numpy(),
|
|
||||||
targets=self.evaluator.targets,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for clip_annotation, clip_dets in zip(
|
for clip_annotation, clip_dets in zip(
|
||||||
clip_annotations, clip_detections, strict=False
|
clip_annotations, clip_detections, strict=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
predictions = self.output_transform(predictions)
|
|
||||||
|
|
||||||
self.clip_annotations.extend(clip_annotations)
|
self.clip_annotations.extend(clip_annotations)
|
||||||
self.predictions.extend(predictions)
|
self.predictions.extend(predictions)
|
||||||
|
|||||||
@ -44,6 +44,7 @@ def run_batch_inference(
|
|||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
output_transform = output_transform or build_output_transform(
|
output_transform = output_transform or build_output_transform(
|
||||||
config=config.outputs.transform,
|
config=config.outputs.transform,
|
||||||
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
loader = build_inference_loader(
|
loader = build_inference_loader(
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from torch.utils.data import DataLoader
|
|||||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
@ -18,7 +17,9 @@ class InferenceModule(LightningModule):
|
|||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.output_transform = output_transform or build_output_transform()
|
self.output_transform = output_transform or build_output_transform(
|
||||||
|
targets=model.targets
|
||||||
|
)
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
self,
|
self,
|
||||||
@ -34,19 +35,14 @@ class InferenceModule(LightningModule):
|
|||||||
|
|
||||||
clip_detections = self.model.postprocessor(outputs)
|
clip_detections = self.model.postprocessor(outputs)
|
||||||
|
|
||||||
predictions = [
|
return [
|
||||||
ClipDetections(
|
self.output_transform.to_clip_detections(
|
||||||
|
detections=clip_dets,
|
||||||
clip=clip,
|
clip=clip,
|
||||||
detections=to_raw_predictions(
|
|
||||||
clip_dets.numpy(),
|
|
||||||
targets=self.model.targets,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for clip, clip_dets in zip(clips, clip_detections, strict=False)
|
for clip, clip_dets in zip(clips, clip_detections, strict=True)
|
||||||
]
|
]
|
||||||
|
|
||||||
return self.output_transform(predictions)
|
|
||||||
|
|
||||||
def get_dataset(self) -> InferenceDataset:
|
def get_dataset(self) -> InferenceDataset:
|
||||||
dataloaders = self.trainer.predict_dataloaders
|
dataloaders = self.trainer.predict_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
|||||||
@ -11,9 +11,9 @@ from batdetect2.outputs.formats import (
|
|||||||
)
|
)
|
||||||
from batdetect2.outputs.transforms import (
|
from batdetect2.outputs.transforms import (
|
||||||
OutputTransformConfig,
|
OutputTransformConfig,
|
||||||
OutputTransformProtocol,
|
|
||||||
build_output_transform,
|
build_output_transform,
|
||||||
)
|
)
|
||||||
|
from batdetect2.outputs.types import OutputTransformProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BatDetect2OutputConfig",
|
"BatDetect2OutputConfig",
|
||||||
|
|||||||
@ -1,89 +0,0 @@
|
|||||||
from collections.abc import Sequence
|
|
||||||
from dataclasses import replace
|
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
from soundevent.geometry import shift_geometry
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"OutputTransform",
|
|
||||||
"OutputTransformConfig",
|
|
||||||
"OutputTransformProtocol",
|
|
||||||
"build_output_transform",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class OutputTransformConfig(BaseConfig):
|
|
||||||
shift_time_to_clip_start: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
class OutputTransformProtocol(Protocol):
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
predictions: Sequence[ClipDetections],
|
|
||||||
) -> list[ClipDetections]: ...
|
|
||||||
|
|
||||||
def transform_detections(
|
|
||||||
self,
|
|
||||||
detections: Sequence[Detection],
|
|
||||||
start_time: float = 0,
|
|
||||||
) -> list[Detection]: ...
|
|
||||||
|
|
||||||
|
|
||||||
def shift_detection_time(detection: Detection, time: float) -> Detection:
|
|
||||||
geometry = shift_geometry(detection.geometry, time=time)
|
|
||||||
return replace(detection, geometry=geometry)
|
|
||||||
|
|
||||||
|
|
||||||
class OutputTransform(OutputTransformProtocol):
|
|
||||||
def __init__(self, shift_time_to_clip_start: bool = True):
|
|
||||||
self.shift_time_to_clip_start = shift_time_to_clip_start
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
predictions: Sequence[ClipDetections],
|
|
||||||
) -> list[ClipDetections]:
|
|
||||||
return [
|
|
||||||
self.transform_prediction(prediction) for prediction in predictions
|
|
||||||
]
|
|
||||||
|
|
||||||
def transform_prediction(
|
|
||||||
self, prediction: ClipDetections
|
|
||||||
) -> ClipDetections:
|
|
||||||
if not self.shift_time_to_clip_start:
|
|
||||||
return prediction
|
|
||||||
|
|
||||||
detections = self.transform_detections(
|
|
||||||
prediction.detections,
|
|
||||||
start_time=prediction.clip.start_time,
|
|
||||||
)
|
|
||||||
return ClipDetections(clip=prediction.clip, detections=detections)
|
|
||||||
|
|
||||||
def transform_detections(
|
|
||||||
self,
|
|
||||||
detections: Sequence[Detection],
|
|
||||||
start_time: float = 0,
|
|
||||||
) -> list[Detection]:
|
|
||||||
if not self.shift_time_to_clip_start or start_time == 0:
|
|
||||||
return list(detections)
|
|
||||||
|
|
||||||
return [
|
|
||||||
shift_detection_time(detection, time=start_time)
|
|
||||||
for detection in detections
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_output_transform(
|
|
||||||
config: OutputTransformConfig | dict | None = None,
|
|
||||||
) -> OutputTransformProtocol:
|
|
||||||
if config is None:
|
|
||||||
config = OutputTransformConfig()
|
|
||||||
|
|
||||||
if not isinstance(config, OutputTransformConfig):
|
|
||||||
config = OutputTransformConfig.model_validate(config)
|
|
||||||
|
|
||||||
return OutputTransform(
|
|
||||||
shift_time_to_clip_start=config.shift_time_to_clip_start,
|
|
||||||
)
|
|
||||||
173
src/batdetect2/outputs/transforms/__init__.py
Normal file
173
src/batdetect2/outputs/transforms/__init__.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.outputs.transforms.clip_transforms import (
|
||||||
|
ClipDetectionsTransformConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.outputs.transforms.clip_transforms import (
|
||||||
|
clip_transforms as clip_transform_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.outputs.transforms.decoding import to_detections
|
||||||
|
from batdetect2.outputs.transforms.detection_transforms import (
|
||||||
|
DetectionTransformConfig,
|
||||||
|
shift_detections_to_start_time,
|
||||||
|
)
|
||||||
|
from batdetect2.outputs.transforms.detection_transforms import (
|
||||||
|
detection_transforms as detection_transform_registry,
|
||||||
|
)
|
||||||
|
from batdetect2.outputs.types import (
|
||||||
|
ClipDetectionsTransform,
|
||||||
|
DetectionTransform,
|
||||||
|
OutputTransformProtocol,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess.types import (
|
||||||
|
ClipDetections,
|
||||||
|
ClipDetectionsTensor,
|
||||||
|
Detection,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipDetectionsTransformConfig",
|
||||||
|
"DetectionTransformConfig",
|
||||||
|
"OutputTransform",
|
||||||
|
"OutputTransformConfig",
|
||||||
|
"build_output_transform",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTransformConfig(BaseConfig):
|
||||||
|
detection_transforms: list[DetectionTransformConfig] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
clip_transforms: list[ClipDetectionsTransformConfig] = Field(
|
||||||
|
default_factory=list
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTransform(OutputTransformProtocol):
|
||||||
|
detection_transform_steps: list[DetectionTransform]
|
||||||
|
clip_transform_steps: list[ClipDetectionsTransform]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
detection_transform_steps: Sequence[DetectionTransform] = (),
|
||||||
|
clip_transform_steps: Sequence[ClipDetectionsTransform] = (),
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.detection_transform_steps = list(detection_transform_steps)
|
||||||
|
self.clip_transform_steps = list(clip_transform_steps)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[ClipDetections],
|
||||||
|
) -> list[ClipDetections]:
|
||||||
|
return [
|
||||||
|
self._transform_prediction(prediction)
|
||||||
|
for prediction in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
def _transform_prediction(
|
||||||
|
self,
|
||||||
|
prediction: ClipDetections,
|
||||||
|
) -> ClipDetections:
|
||||||
|
detections = shift_detections_to_start_time(
|
||||||
|
prediction.detections,
|
||||||
|
start_time=prediction.clip.start_time,
|
||||||
|
)
|
||||||
|
detections = self.transform_detections(detections)
|
||||||
|
return self.transform_clip_detections(
|
||||||
|
ClipDetections(clip=prediction.clip, detections=detections)
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_detections(
|
||||||
|
self,
|
||||||
|
detections: ClipDetectionsTensor,
|
||||||
|
start_time: float = 0,
|
||||||
|
) -> list[Detection]:
|
||||||
|
decoded = to_detections(detections.numpy(), targets=self.targets)
|
||||||
|
shifted = shift_detections_to_start_time(
|
||||||
|
decoded,
|
||||||
|
start_time=start_time,
|
||||||
|
)
|
||||||
|
return self.transform_detections(shifted)
|
||||||
|
|
||||||
|
def to_clip_detections(
|
||||||
|
self,
|
||||||
|
detections: ClipDetectionsTensor,
|
||||||
|
clip: data.Clip,
|
||||||
|
) -> ClipDetections:
|
||||||
|
prediction = ClipDetections(
|
||||||
|
clip=clip,
|
||||||
|
detections=self.to_detections(
|
||||||
|
detections,
|
||||||
|
start_time=clip.start_time,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
return self.transform_clip_detections(prediction)
|
||||||
|
|
||||||
|
def transform_detections(
|
||||||
|
self,
|
||||||
|
detections: Sequence[Detection],
|
||||||
|
) -> list[Detection]:
|
||||||
|
out: list[Detection] = []
|
||||||
|
for detection in detections:
|
||||||
|
transformed = self.transform_detection(detection)
|
||||||
|
|
||||||
|
if transformed is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
out.append(transformed)
|
||||||
|
|
||||||
|
return []
|
||||||
|
|
||||||
|
def transform_detection(
|
||||||
|
self,
|
||||||
|
detection: Detection,
|
||||||
|
) -> Detection | None:
|
||||||
|
for transform in self.detection_transform_steps:
|
||||||
|
detection = transform(detection) # type: ignore
|
||||||
|
|
||||||
|
if detection is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return detection
|
||||||
|
|
||||||
|
def transform_clip_detections(
|
||||||
|
self,
|
||||||
|
prediction: ClipDetections,
|
||||||
|
) -> ClipDetections:
|
||||||
|
for transform in self.clip_transform_steps:
|
||||||
|
prediction = transform(prediction)
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
|
||||||
|
def build_output_transform(
|
||||||
|
config: OutputTransformConfig | dict | None = None,
|
||||||
|
targets: TargetProtocol | None = None,
|
||||||
|
) -> OutputTransformProtocol:
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config = OutputTransformConfig()
|
||||||
|
|
||||||
|
if not isinstance(config, OutputTransformConfig):
|
||||||
|
config = OutputTransformConfig.model_validate(config)
|
||||||
|
|
||||||
|
targets = targets or build_targets()
|
||||||
|
|
||||||
|
return OutputTransform(
|
||||||
|
targets=targets,
|
||||||
|
detection_transform_steps=[
|
||||||
|
detection_transform_registry.build(transform_config)
|
||||||
|
for transform_config in config.detection_transforms
|
||||||
|
],
|
||||||
|
clip_transform_steps=[
|
||||||
|
clip_transform_registry.build(transform_config)
|
||||||
|
for transform_config in config.clip_transforms
|
||||||
|
],
|
||||||
|
)
|
||||||
31
src/batdetect2/outputs/transforms/clip_transforms.py
Normal file
31
src/batdetect2/outputs/transforms/clip_transforms.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.core.registries import (
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
from batdetect2.outputs.types import ClipDetectionsTransform
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipDetectionsTransformConfig",
|
||||||
|
"clip_transforms",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
clip_transforms: Registry[ClipDetectionsTransform, []] = Registry(
|
||||||
|
"clip_detection_transform"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(clip_transforms)
|
||||||
|
class ClipDetectionsTransformImportConfig(ImportConfig):
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
ClipDetectionsTransformConfig = Annotated[
|
||||||
|
ClipDetectionsTransformImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
@ -1,33 +1,28 @@
|
|||||||
"""Decodes extracted detection data into standard soundevent predictions."""
|
"""Decode extracted tensors into output-friendly detection objects."""
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.types import (
|
from batdetect2.postprocess.types import ClipDetectionsArray, Detection
|
||||||
ClipDetectionsArray,
|
|
||||||
Detection,
|
|
||||||
)
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"to_raw_predictions",
|
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
|
||||||
"convert_raw_prediction_to_sound_event_prediction",
|
|
||||||
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
"DEFAULT_CLASSIFICATION_THRESHOLD",
|
||||||
|
"convert_raw_prediction_to_sound_event_prediction",
|
||||||
|
"convert_raw_predictions_to_clip_prediction",
|
||||||
|
"get_class_tags",
|
||||||
|
"get_generic_tags",
|
||||||
|
"get_prediction_features",
|
||||||
|
"to_detections",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||||
"""Default threshold applied to classification scores.
|
|
||||||
|
|
||||||
Class predictions with scores below this value are typically ignored during
|
|
||||||
decoding.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def to_raw_predictions(
|
def to_detections(
|
||||||
detections: ClipDetectionsArray,
|
detections: ClipDetectionsArray,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> List[Detection]:
|
) -> List[Detection]:
|
||||||
@ -69,7 +64,6 @@ def convert_raw_predictions_to_clip_prediction(
|
|||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
) -> data.ClipPrediction:
|
) -> data.ClipPrediction:
|
||||||
"""Convert a list of RawPredictions into a soundevent ClipPrediction."""
|
|
||||||
return data.ClipPrediction(
|
return data.ClipPrediction(
|
||||||
clip=clip,
|
clip=clip,
|
||||||
sound_events=[
|
sound_events=[
|
||||||
@ -92,7 +86,6 @@ def convert_raw_prediction_to_sound_event_prediction(
|
|||||||
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
classification_threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
):
|
):
|
||||||
"""Convert a single RawPrediction into a soundevent SoundEventPrediction."""
|
|
||||||
sound_event = data.SoundEvent(
|
sound_event = data.SoundEvent(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
geometry=raw_prediction.geometry,
|
geometry=raw_prediction.geometry,
|
||||||
@ -123,7 +116,6 @@ def get_generic_tags(
|
|||||||
detection_score: float,
|
detection_score: float,
|
||||||
generic_class_tags: List[data.Tag],
|
generic_class_tags: List[data.Tag],
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
"""Create PredictedTag objects for the generic category."""
|
|
||||||
return [
|
return [
|
||||||
data.PredictedTag(tag=tag, score=detection_score)
|
data.PredictedTag(tag=tag, score=detection_score)
|
||||||
for tag in generic_class_tags
|
for tag in generic_class_tags
|
||||||
@ -131,7 +123,6 @@ def get_generic_tags(
|
|||||||
|
|
||||||
|
|
||||||
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
def get_prediction_features(features: np.ndarray) -> List[data.Feature]:
|
||||||
"""Convert an extracted feature vector DataArray into soundevent Features."""
|
|
||||||
return [
|
return [
|
||||||
data.Feature(
|
data.Feature(
|
||||||
term=data.Term(
|
term=data.Term(
|
||||||
@ -151,39 +142,11 @@ def get_class_tags(
|
|||||||
top_class_only: bool = False,
|
top_class_only: bool = False,
|
||||||
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
threshold: float | None = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
) -> List[data.PredictedTag]:
|
) -> List[data.PredictedTag]:
|
||||||
"""Generate specific PredictedTags based on class scores and decoder.
|
|
||||||
|
|
||||||
Filters class scores by the threshold, sorts remaining scores descending,
|
|
||||||
decodes the class name(s) into base tags using the `sound_event_decoder`,
|
|
||||||
and creates `PredictedTag` objects associating the class score. Stops after
|
|
||||||
the first (top) class if `top_class_only` is True.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
class_scores : xr.DataArray
|
|
||||||
A 1D xarray DataArray containing class probabilities/scores, indexed
|
|
||||||
by a 'category' coordinate holding the class names.
|
|
||||||
sound_event_decoder : SoundEventDecoder
|
|
||||||
Function to map a class name string to a list of base `data.Tag`
|
|
||||||
objects.
|
|
||||||
top_class_only : bool, default=False
|
|
||||||
If True, only generate tags for the single highest-scoring class above
|
|
||||||
the threshold.
|
|
||||||
threshold : float, optional
|
|
||||||
Minimum score for a class to be considered. If None, all classes are
|
|
||||||
processed (or top-1 if `top_class_only` is True). Defaults to
|
|
||||||
`DEFAULT_CLASSIFICATION_THRESHOLD`.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.PredictedTag]
|
|
||||||
A list of `PredictedTag` objects for the class(es) that passed the
|
|
||||||
threshold, ordered by score if `top_class_only` is False.
|
|
||||||
"""
|
|
||||||
tags = []
|
tags = []
|
||||||
|
|
||||||
for class_name, score in _iterate_sorted(
|
for class_name, score in _iterate_sorted(
|
||||||
class_scores, targets.class_names
|
class_scores,
|
||||||
|
targets.class_names,
|
||||||
):
|
):
|
||||||
if threshold is not None and score < threshold:
|
if threshold is not None and score < threshold:
|
||||||
continue
|
continue
|
||||||
55
src/batdetect2/outputs/transforms/detection_transforms.py
Normal file
55
src/batdetect2/outputs/transforms/detection_transforms.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import replace
|
||||||
|
from typing import Annotated, Literal
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.geometry import shift_geometry
|
||||||
|
|
||||||
|
from batdetect2.core.registries import (
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
from batdetect2.outputs.types import DetectionTransform
|
||||||
|
from batdetect2.postprocess.types import Detection
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DetectionTransformConfig",
|
||||||
|
"detection_transforms",
|
||||||
|
"shift_detection_time",
|
||||||
|
"shift_detections_to_start_time",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
detection_transforms: Registry[DetectionTransform, []] = Registry(
|
||||||
|
"detection_transform"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(detection_transforms)
|
||||||
|
class DetectionTransformImportConfig(ImportConfig):
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
DetectionTransformConfig = Annotated[
|
||||||
|
DetectionTransformImportConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def shift_detection_time(detection: Detection, time: float) -> Detection:
|
||||||
|
geometry = shift_geometry(detection.geometry, time=time)
|
||||||
|
return replace(detection, geometry=geometry)
|
||||||
|
|
||||||
|
|
||||||
|
def shift_detections_to_start_time(
|
||||||
|
detections: Sequence[Detection],
|
||||||
|
start_time: float = 0,
|
||||||
|
) -> list[Detection]:
|
||||||
|
if start_time == 0:
|
||||||
|
return list(detections)
|
||||||
|
|
||||||
|
return [
|
||||||
|
shift_detection_time(detection, time=start_time)
|
||||||
|
for detection in detections
|
||||||
|
]
|
||||||
@ -1,12 +1,20 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Callable, Sequence
|
||||||
from typing import Generic, Protocol, TypeVar
|
from typing import Generic, Protocol, TypeVar
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import (
|
||||||
|
ClipDetections,
|
||||||
|
ClipDetectionsTensor,
|
||||||
|
Detection,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ClipDetectionsTransform",
|
||||||
|
"DetectionTransform",
|
||||||
"OutputFormatterProtocol",
|
"OutputFormatterProtocol",
|
||||||
|
"OutputTransformProtocol",
|
||||||
]
|
]
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
@ -23,3 +31,31 @@ class OutputFormatterProtocol(Protocol, Generic[T]):
|
|||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
def load(self, path: PathLike) -> list[T]: ...
|
def load(self, path: PathLike) -> list[T]: ...
|
||||||
|
|
||||||
|
|
||||||
|
DetectionTransform = Callable[[Detection], Detection | None]
|
||||||
|
ClipDetectionsTransform = Callable[[ClipDetections], ClipDetections]
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTransformProtocol(Protocol):
|
||||||
|
def to_detections(
|
||||||
|
self,
|
||||||
|
detections: ClipDetectionsTensor,
|
||||||
|
start_time: float = 0,
|
||||||
|
) -> list[Detection]: ...
|
||||||
|
|
||||||
|
def to_clip_detections(
|
||||||
|
self,
|
||||||
|
detections: ClipDetectionsTensor,
|
||||||
|
clip: data.Clip,
|
||||||
|
) -> ClipDetections: ...
|
||||||
|
|
||||||
|
def transform_detections(
|
||||||
|
self,
|
||||||
|
detections: Sequence[Detection],
|
||||||
|
) -> list[Detection]: ...
|
||||||
|
|
||||||
|
def transform_clip_detections(
|
||||||
|
self,
|
||||||
|
prediction: ClipDetections,
|
||||||
|
) -> ClipDetections: ...
|
||||||
|
|||||||
@ -4,10 +4,6 @@ from batdetect2.postprocess.config import (
|
|||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
load_postprocess_config,
|
load_postprocess_config,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.decoding import (
|
|
||||||
convert_raw_predictions_to_clip_prediction,
|
|
||||||
to_raw_predictions,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess.nms import non_max_suppression
|
from batdetect2.postprocess.nms import non_max_suppression
|
||||||
from batdetect2.postprocess.postprocessor import (
|
from batdetect2.postprocess.postprocessor import (
|
||||||
Postprocessor,
|
Postprocessor,
|
||||||
@ -18,8 +14,6 @@ __all__ = [
|
|||||||
"PostprocessConfig",
|
"PostprocessConfig",
|
||||||
"Postprocessor",
|
"Postprocessor",
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
"convert_raw_predictions_to_clip_prediction",
|
|
||||||
"to_raw_predictions",
|
|
||||||
"load_postprocess_config",
|
"load_postprocess_config",
|
||||||
"non_max_suppression",
|
"non_max_suppression",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -2,7 +2,6 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.postprocess.decoding import DEFAULT_CLASSIFICATION_THRESHOLD
|
|
||||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -11,6 +10,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_DETECTION_THRESHOLD = 0.01
|
DEFAULT_DETECTION_THRESHOLD = 0.01
|
||||||
|
DEFAULT_CLASSIFICATION_THRESHOLD = 0.1
|
||||||
|
|
||||||
|
|
||||||
TOP_K_PER_SEC = 100
|
TOP_K_PER_SEC = 100
|
||||||
|
|||||||
@ -9,7 +9,6 @@ from batdetect2.evaluate.types import EvaluatorProtocol
|
|||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
@ -25,7 +24,7 @@ class ValidationMetrics(Callback):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.output_transform = output_transform or build_output_transform()
|
self.output_transform = output_transform
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self._predictions: List[ClipDetections] = []
|
self._predictions: List[ClipDetections] = []
|
||||||
@ -92,6 +91,14 @@ class ValidationMetrics(Callback):
|
|||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
model = pl_module.model
|
model = pl_module.model
|
||||||
|
if self.output_transform is None:
|
||||||
|
self.output_transform = build_output_transform(
|
||||||
|
targets=model.targets
|
||||||
|
)
|
||||||
|
|
||||||
|
output_transform = self.output_transform
|
||||||
|
assert output_transform is not None
|
||||||
|
|
||||||
dataset = self.get_dataset(trainer)
|
dataset = self.get_dataset(trainer)
|
||||||
|
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
@ -101,17 +108,14 @@ class ValidationMetrics(Callback):
|
|||||||
|
|
||||||
clip_detections = model.postprocessor(outputs)
|
clip_detections = model.postprocessor(outputs)
|
||||||
predictions = [
|
predictions = [
|
||||||
ClipDetections(
|
output_transform.to_clip_detections(
|
||||||
|
detections=clip_dets,
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
detections=to_raw_predictions(
|
|
||||||
clip_dets.numpy(), targets=model.targets
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
for clip_annotation, clip_dets in zip(
|
for clip_annotation, clip_dets in zip(
|
||||||
clip_annotations, clip_detections, strict=False
|
clip_annotations, clip_detections, strict=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
predictions = self.output_transform(predictions)
|
|
||||||
|
|
||||||
self._clip_annotations.extend(clip_annotations)
|
self._clip_annotations.extend(clip_annotations)
|
||||||
self._predictions.extend(predictions)
|
self._predictions.extend(predictions)
|
||||||
|
|||||||
@ -1,12 +1,35 @@
|
|||||||
|
from dataclasses import replace
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.outputs import build_output_transform
|
from batdetect2.outputs import build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections, Detection
|
from batdetect2.outputs.transforms import OutputTransform
|
||||||
|
from batdetect2.postprocess.types import (
|
||||||
|
ClipDetections,
|
||||||
|
ClipDetectionsTensor,
|
||||||
|
Detection,
|
||||||
|
)
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
def test_shift_time_to_clip_start(clip: data.Clip):
|
def _mock_clip_detections_tensor() -> ClipDetectionsTensor:
|
||||||
|
return ClipDetectionsTensor(
|
||||||
|
scores=torch.tensor([0.9], dtype=torch.float32),
|
||||||
|
sizes=torch.tensor([[0.1, 1_000.0]], dtype=torch.float32),
|
||||||
|
class_scores=torch.tensor([[0.8, 0.2]], dtype=torch.float32),
|
||||||
|
times=torch.tensor([0.2], dtype=torch.float32),
|
||||||
|
frequencies=torch.tensor([60_000.0], dtype=torch.float32),
|
||||||
|
features=torch.tensor([[1.0, 2.0]], dtype=torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_shift_time_to_clip_start(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
):
|
||||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||||
|
|
||||||
detection = Detection(
|
detection = Detection(
|
||||||
@ -16,7 +39,7 @@ def test_shift_time_to_clip_start(clip: data.Clip):
|
|||||||
features=np.array([1.0, 2.0]),
|
features=np.array([1.0, 2.0]),
|
||||||
)
|
)
|
||||||
|
|
||||||
transformed = build_output_transform()(
|
transformed = OutputTransform(targets=sample_targets)(
|
||||||
[ClipDetections(clip=clip, detections=[detection])]
|
[ClipDetections(clip=clip, detections=[detection])]
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
@ -28,21 +51,71 @@ def test_shift_time_to_clip_start(clip: data.Clip):
|
|||||||
assert np.isclose(end_time, 2.7)
|
assert np.isclose(end_time, 2.7)
|
||||||
|
|
||||||
|
|
||||||
def test_transform_identity_when_disabled(clip: data.Clip):
|
def test_to_clip_detections_shifts_by_clip_start(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
):
|
||||||
|
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||||
|
transform = build_output_transform(targets=sample_targets)
|
||||||
|
raw = _mock_clip_detections_tensor()
|
||||||
|
|
||||||
|
shifted = transform.to_clip_detections(detections=raw, clip=clip)
|
||||||
|
unshifted = transform.to_detections(detections=raw, start_time=0)
|
||||||
|
|
||||||
|
shifted_start, _, _, _ = compute_bounds(shifted.detections[0].geometry)
|
||||||
|
unshifted_start, _, _, _ = compute_bounds(unshifted[0].geometry)
|
||||||
|
|
||||||
|
assert np.isclose(shifted_start - unshifted_start, clip.start_time)
|
||||||
|
|
||||||
|
|
||||||
|
def test_detection_and_clip_transforms_applied_in_order(
|
||||||
|
clip: data.Clip,
|
||||||
|
sample_targets: TargetProtocol,
|
||||||
|
):
|
||||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||||
|
|
||||||
detection = Detection(
|
detection_1 = Detection(
|
||||||
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
||||||
detection_score=0.9,
|
detection_score=0.5,
|
||||||
|
class_scores=np.array([0.9]),
|
||||||
|
features=np.array([1.0, 2.0]),
|
||||||
|
)
|
||||||
|
detection_2 = Detection(
|
||||||
|
geometry=data.BoundingBox(coordinates=[0.2, 10_000, 0.3, 12_000]),
|
||||||
|
detection_score=0.7,
|
||||||
class_scores=np.array([0.9]),
|
class_scores=np.array([0.9]),
|
||||||
features=np.array([1.0, 2.0]),
|
features=np.array([1.0, 2.0]),
|
||||||
)
|
)
|
||||||
|
|
||||||
transform = build_output_transform(
|
def boost_score(detection: Detection) -> Detection:
|
||||||
config={"shift_time_to_clip_start": False}
|
return replace(
|
||||||
|
detection,
|
||||||
|
detection_score=detection.detection_score + 0.2,
|
||||||
|
)
|
||||||
|
|
||||||
|
def keep_high_score(detection: Detection) -> Detection | None:
|
||||||
|
if detection.detection_score < 0.8:
|
||||||
|
return None
|
||||||
|
return detection
|
||||||
|
|
||||||
|
def tag_clip_transform(prediction: ClipDetections) -> ClipDetections:
|
||||||
|
detections = [
|
||||||
|
replace(detection, detection_score=1.0)
|
||||||
|
for detection in prediction.detections
|
||||||
|
]
|
||||||
|
return replace(prediction, detections=detections)
|
||||||
|
|
||||||
|
transform = OutputTransform(
|
||||||
|
targets=sample_targets,
|
||||||
|
detection_transform_steps=[boost_score, keep_high_score],
|
||||||
|
clip_transform_steps=[tag_clip_transform],
|
||||||
)
|
)
|
||||||
transformed = transform(
|
transformed = transform(
|
||||||
[ClipDetections(clip=clip, detections=[detection])]
|
[ClipDetections(clip=clip, detections=[detection_1, detection_2])]
|
||||||
)[0]
|
)[0]
|
||||||
|
|
||||||
assert transformed.detections[0].geometry == detection.geometry
|
assert len(transformed.detections) == 1
|
||||||
|
assert transformed.detections[0].detection_score == 1.0
|
||||||
|
|
||||||
|
start_time, _, _, _ = compute_bounds(transformed.detections[0].geometry)
|
||||||
|
assert np.isclose(start_time, 2.7)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import pytest
|
|||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.outputs.transforms.decoding import (
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
convert_raw_prediction_to_sound_event_prediction,
|
||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user