Add outputs module

This commit is contained in:
mbsantiago 2026-03-17 23:00:26 +00:00
parent 3b47c688dd
commit c226dc3f2b
23 changed files with 322 additions and 149 deletions

View File

@ -102,6 +102,7 @@ exclude = [
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
"src/batdetect2/utils",
]
[tool.ruff.format]
@ -121,4 +122,5 @@ exclude = [
"src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune",
"src/batdetect2/utils",
]

View File

@ -11,17 +11,20 @@ from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import (
OutputFormatConfig,
build_output_formatter,
get_output_formatter,
load_dataset_from_config,
)
from batdetect2.data.datasets import Dataset
from batdetect2.data.predictions.base import OutputFormatterProtocol
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model
from batdetect2.outputs import (
OutputFormatConfig,
OutputTransformProtocol,
build_output_formatter,
build_output_transform,
get_output_formatter,
)
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
@ -35,6 +38,7 @@ from batdetect2.typing import (
ClipDetections,
Detection,
EvaluatorProtocol,
OutputFormatterProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
@ -51,6 +55,7 @@ class BatDetect2API:
postprocessor: PostprocessorProtocol,
evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol,
output_transform: OutputTransformProtocol,
model: Model,
):
self.config = config
@ -61,6 +66,7 @@ class BatDetect2API:
self.evaluator = evaluator
self.model = model
self.formatter = formatter
self.output_transform = output_transform
self.model.eval()
@ -208,10 +214,16 @@ class BatDetect2API:
detections = self.model.postprocessor(
outputs,
start_times=[start_time],
)[0]
raw_predictions = to_raw_predictions(
detections.numpy(),
targets=self.targets,
)
return to_raw_predictions(detections.numpy(), targets=self.targets)
return self.output_transform.transform_detections(
raw_predictions,
start_time=start_time,
)
def process_directory(
self,
@ -304,7 +316,13 @@ class BatDetect2API:
# postprocessor as these may be moved to another device.
model = build_model(config=config.model)
formatter = build_output_formatter(targets, config=config.output)
formatter = build_output_formatter(
targets,
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform
)
return cls(
config=config,
@ -315,6 +333,7 @@ class BatDetect2API:
evaluator=evaluator,
model=model,
formatter=formatter,
output_transform=output_transform,
)
@classmethod
@ -351,7 +370,13 @@ class BatDetect2API:
evaluator = build_evaluator(config=config.evaluation, targets=targets)
formatter = build_output_formatter(targets, config=config.output)
formatter = build_output_formatter(
targets,
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform
)
return cls(
config=config,
@ -362,4 +387,5 @@ class BatDetect2API:
evaluator=evaluator,
model=model,
formatter=formatter,
output_transform=output_transform,
)

View File

@ -5,14 +5,13 @@ from soundevent.data import PathLike
from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.predictions import OutputFormatConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.evaluate.config import (
EvaluationConfig,
get_default_eval_config,
)
from batdetect2.inference.config import InferenceConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig
__all__ = [
@ -32,7 +31,7 @@ class BatDetect2Config(BaseConfig):
model: ModelConfig = Field(default_factory=ModelConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
def validate_config(config: dict | None) -> BatDetect2Config:

View File

@ -12,15 +12,6 @@ from batdetect2.data.datasets import (
load_dataset_config,
load_dataset_from_config,
)
from batdetect2.data.predictions import (
BatDetect2OutputConfig,
OutputFormatConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.data.summary import (
compute_class_summary,
extract_recordings_df,
@ -36,6 +27,7 @@ __all__ = [
"BatDetect2OutputConfig",
"DatasetConfig",
"OutputFormatConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",

View File

@ -10,6 +10,7 @@ from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.outputs import build_output_transform
from batdetect2.typing import Detection
if TYPE_CHECKING:
@ -61,7 +62,12 @@ def evaluate(
experiment_name=experiment_name,
run_name=run_name,
)
module = EvaluationModule(model, evaluator)
output_transform = build_output_transform(config=config.outputs.transform)
module = EvaluationModule(
model,
evaluator,
output_transform=output_transform,
)
trainer = Trainer(logger=logger, enable_checkpointing=False)
metrics = trainer.test(module, loader)

View File

@ -7,6 +7,7 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import ClipDetections
@ -17,11 +18,13 @@ class EvaluationModule(LightningModule):
self,
model: Model,
evaluator: EvaluatorProtocol,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__()
self.model = model
self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform()
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[ClipDetections] = []
@ -34,10 +37,7 @@ class EvaluationModule(LightningModule):
]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
clip_detections = self.model.postprocessor(outputs)
predictions = [
ClipDetections(
clip=clip_annotation.clip,
@ -50,6 +50,7 @@ class EvaluationModule(LightningModule):
clip_annotations, clip_detections, strict=False
)
]
predictions = self.output_transform(predictions)
self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions)

View File

@ -8,6 +8,7 @@ from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import ClipDetections
@ -28,6 +29,7 @@ def run_batch_inference(
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
output_transform: Optional[OutputTransformProtocol] = None,
num_workers: int | None = None,
batch_size: int | None = None,
) -> List[ClipDetections]:
@ -42,6 +44,9 @@ def run_batch_inference(
)
targets = targets or build_targets()
output_transform = output_transform or build_output_transform(
config=config.outputs.transform,
)
loader = build_inference_loader(
clips,
@ -52,7 +57,10 @@ def run_batch_inference(
batch_size=batch_size,
)
module = InferenceModule(model)
module = InferenceModule(
model,
output_transform=output_transform,
)
trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader)
return [

View File

@ -5,14 +5,20 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import ClipDetections
class InferenceModule(LightningModule):
def __init__(self, model: Model):
def __init__(
self,
model: Model,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__()
self.model = model
self.output_transform = output_transform or build_output_transform()
def predict_step(
self,
@ -26,10 +32,7 @@ class InferenceModule(LightningModule):
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[clip.start_time for clip in clips],
)
clip_detections = self.model.postprocessor(outputs)
predictions = [
ClipDetections(
@ -42,7 +45,7 @@ class InferenceModule(LightningModule):
for clip, clip_dets in zip(clips, clip_detections, strict=False)
]
return predictions
return self.output_transform(predictions)
def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders

View File

@ -0,0 +1,31 @@
from batdetect2.outputs.config import OutputsConfig
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
OutputFormatConfig,
ParquetOutputConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.outputs.transforms import (
OutputTransformConfig,
OutputTransformProtocol,
build_output_transform,
)
__all__ = [
"BatDetect2OutputConfig",
"OutputFormatConfig",
"OutputTransformConfig",
"OutputTransformProtocol",
"OutputsConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",
"build_output_transform",
"get_output_formatter",
"load_predictions",
]

View File

@ -0,0 +1,15 @@
from pydantic import Field
from batdetect2.core.configs import BaseConfig
from batdetect2.outputs.formats import OutputFormatConfig
from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.outputs.transforms import OutputTransformConfig
__all__ = ["OutputsConfig"]
class OutputsConfig(BaseConfig):
format: OutputFormatConfig = Field(default_factory=RawOutputConfig)
transform: OutputTransformConfig = Field(
default_factory=OutputTransformConfig
)

View File

@ -3,23 +3,25 @@ from typing import Annotated
from pydantic import Field
from soundevent.data import PathLike
from batdetect2.data.predictions.base import (
from batdetect2.outputs.formats.base import (
OutputFormatterProtocol,
prediction_formatters,
output_formatters,
)
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
from batdetect2.data.predictions.parquet import ParquetOutputConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
from batdetect2.typing import TargetProtocol
__all__ = [
"build_output_formatter",
"get_output_formatter",
"BatDetect2OutputConfig",
"OutputFormatConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",
"get_output_formatter",
"load_predictions",
]
@ -42,7 +44,7 @@ def build_output_formatter(
config = config or RawOutputConfig()
targets = targets or build_targets()
return prediction_formatters.build(config, targets)
return output_formatters.build(config, targets)
def get_output_formatter(
@ -56,7 +58,7 @@ def get_output_formatter(
if name is None:
raise ValueError("Either config or name must be provided.")
config_class = prediction_formatters.get_config_type(name)
config_class = output_formatters.get_config_type(name)
config = config_class() # type: ignore
if config.name != name: # type: ignore

View File

@ -9,6 +9,13 @@ from batdetect2.typing import (
TargetProtocol,
)
__all__ = [
"OutputFormatterProtocol",
"PredictionFormatterImportConfig",
"make_path_relative",
"output_formatters",
]
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
path = Path(path)
@ -25,12 +32,12 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
return path
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter")
)
@add_import_config(prediction_formatters)
@add_import_config(output_formatters)
class PredictionFormatterImportConfig(ImportConfig):
"""Use any callable as a prediction formatter.

View File

@ -7,9 +7,9 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import (
from batdetect2.outputs.formats.base import (
make_path_relative,
prediction_formatters,
output_formatters,
)
from batdetect2.targets import terms
from batdetect2.typing import (
@ -28,76 +28,32 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the
annotation tool.
"""
start_time: float
"""Start time in seconds."""
end_time: float
"""End time in seconds."""
low_freq: float
"""Low frequency in Hz."""
high_freq: float
"""High frequency in Hz."""
class_prob: float
"""Probability of class assignment."""
det_prob: float
"""Probability of detection."""
individual: str
"""Individual ID."""
event: str
"""Type of detected event."""
class FileAnnotation(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
id: str
"""File ID."""
annotated: bool
"""Whether file has been annotated."""
duration: float
"""Duration of audio file."""
issues: bool
"""Whether file has issues."""
time_exp: float
"""Time expansion factor."""
class_name: str
"""Class predicted at file level."""
notes: str
"""Notes of file."""
annotation: List[Annotation]
"""List of annotations."""
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
"""Path to file."""
class BatDetect2OutputConfig(BaseConfig):
name: Literal["batdetect2"] = "batdetect2"
event_name: str = "Echolocation"
annotation_note: str = "Automatically generated."
@ -156,8 +112,6 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
]
def get_recording_class(self, annotations: List[Annotation]) -> str:
"""Get class of recording from annotations."""
if not annotations:
return ""
@ -215,7 +169,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
**{"class": top_class},
)
@prediction_formatters.register(BatDetect2OutputConfig)
@output_formatters.register(BatDetect2OutputConfig)
@staticmethod
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
return BatDetect2Formatter(

View File

@ -9,9 +9,9 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import (
from batdetect2.outputs.formats.base import (
make_path_relative,
prediction_formatters,
output_formatters,
)
from batdetect2.typing import (
ClipDetections,
@ -59,10 +59,7 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
if not path.parent.exists():
path.parent.mkdir(parents=True)
# Ensure the file has .parquet extension if it's a file path
if path.suffix != ".parquet":
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
if path.is_dir() or not path.suffix:
path = path / "predictions.parquet"
@ -90,7 +87,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
}
if self.include_geometry:
# Store geometry as [start_time, low_freq, end_time, high_freq]
start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry
)
@ -98,8 +94,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
row["low_freq"] = low_freq
row["end_time"] = end_time
row["high_freq"] = high_freq
# Store full geometry as JSON
row["geometry"] = pred.geometry.model_dump_json()
if self.include_class_scores:
@ -121,11 +115,9 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
def load(self, path: data.PathLike) -> List[ClipDetections]:
path = Path(path)
if path.is_dir():
# Try to find parquet files
files = list(path.glob("*.parquet"))
if not files:
return []
# Read all and concatenate
dfs = [pd.read_parquet(f) for f in files]
df = pd.concat(dfs, ignore_index=True)
else:
@ -148,7 +140,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
)
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
# Reconstruct geometry
if "geometry" in row and row["geometry"]:
geometry = data.geometry_validate(row["geometry"])
else:
@ -182,13 +173,14 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
for clip_data in predictions_by_clip.values():
results.append(
ClipDetections(
clip=clip_data["clip"], detections=clip_data["preds"]
clip=clip_data["clip"],
detections=clip_data["preds"],
)
)
return results
@prediction_formatters.register(ParquetOutputConfig)
@output_formatters.register(ParquetOutputConfig)
@staticmethod
def from_config(config: ParquetOutputConfig, targets: TargetProtocol):
return ParquetFormatter(

View File

@ -10,9 +10,9 @@ from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import (
from batdetect2.outputs.formats.base import (
make_path_relative,
prediction_formatters,
output_formatters,
)
from batdetect2.typing import (
ClipDetections,
@ -95,56 +95,56 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
update=dict(path=make_path_relative(recording.path, audio_dir))
)
data = defaultdict(list)
values = defaultdict(list)
for pred in prediction.detections:
detection_id = str(uuid4())
data["detection_id"].append(detection_id)
data["detection_score"].append(pred.detection_score)
values["detection_id"].append(detection_id)
values["detection_score"].append(pred.detection_score)
start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry
)
data["start_time"].append(start_time)
data["end_time"].append(end_time)
data["low_freq"].append(low_freq)
data["high_freq"].append(high_freq)
values["start_time"].append(start_time)
values["end_time"].append(end_time)
values["low_freq"].append(low_freq)
values["high_freq"].append(high_freq)
data["geometry"].append(pred.geometry.model_dump_json())
values["geometry"].append(pred.geometry.model_dump_json())
top_class_index = int(np.argmax(pred.class_scores))
top_class_score = float(pred.class_scores[top_class_index])
top_class = self.targets.class_names[top_class_index]
data["top_class"].append(top_class)
data["top_class_score"].append(top_class_score)
values["top_class"].append(top_class)
values["top_class_score"].append(top_class_score)
data["class_scores"].append(pred.class_scores)
data["features"].append(pred.features)
values["class_scores"].append(pred.class_scores)
values["features"].append(pred.features)
num_features = len(pred.features)
data_vars = {
"score": (["detection"], data["detection_score"]),
"start_time": (["detection"], data["start_time"]),
"end_time": (["detection"], data["end_time"]),
"low_freq": (["detection"], data["low_freq"]),
"high_freq": (["detection"], data["high_freq"]),
"top_class": (["detection"], data["top_class"]),
"top_class_score": (["detection"], data["top_class_score"]),
"score": (["detection"], values["detection_score"]),
"start_time": (["detection"], values["start_time"]),
"end_time": (["detection"], values["end_time"]),
"low_freq": (["detection"], values["low_freq"]),
"high_freq": (["detection"], values["high_freq"]),
"top_class": (["detection"], values["top_class"]),
"top_class_score": (["detection"], values["top_class_score"]),
}
coords = {
"detection": ("detection", data["detection_id"]),
"detection": ("detection", values["detection_id"]),
"clip_start": clip.start_time,
"clip_end": clip.end_time,
"clip_id": str(clip.uuid),
}
if self.include_class_scores:
class_scores = np.stack(data["class_scores"], axis=0)
class_scores = np.stack(values["class_scores"], axis=0)
data_vars["class_scores"] = (
["detection", "classes"],
class_scores,
@ -152,12 +152,12 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
coords["classes"] = ("classes", self.targets.class_names)
if self.include_features:
features = np.stack(data["features"], axis=0)
features = np.stack(values["features"], axis=0)
data_vars["features"] = (["detection", "feature"], features)
coords["feature"] = ("feature", np.arange(num_features))
if self.include_geometry:
data_vars["geometry"] = (["detection"], data["geometry"])
data_vars["geometry"] = (["detection"], values["geometry"])
return xr.Dataset(
data_vars=data_vars,
@ -169,7 +169,6 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
clip_data = dataset
clip_id = clip_data.clip_id.item()
recording = data.Recording.model_validate_json(
clip_data.attrs["recording"]
@ -232,7 +231,7 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
detections=sound_events,
)
@prediction_formatters.register(RawOutputConfig)
@output_formatters.register(RawOutputConfig)
@staticmethod
def from_config(config: RawOutputConfig, targets: TargetProtocol):
return RawFormatter(

View File

@ -5,8 +5,8 @@ import numpy as np
from soundevent import data, io
from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import (
prediction_formatters,
from batdetect2.outputs.formats.base import (
output_formatters,
)
from batdetect2.typing import (
ClipDetections,
@ -121,7 +121,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
return tags
@prediction_formatters.register(SoundEventOutputConfig)
@output_formatters.register(SoundEventOutputConfig)
@staticmethod
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
return SoundEventOutputFormatter(

View File

@ -0,0 +1,89 @@
from collections.abc import Sequence
from dataclasses import replace
from typing import Protocol
from soundevent.geometry import shift_geometry
from batdetect2.core.configs import BaseConfig
from batdetect2.typing import ClipDetections, Detection
__all__ = [
"OutputTransform",
"OutputTransformConfig",
"OutputTransformProtocol",
"build_output_transform",
]
class OutputTransformConfig(BaseConfig):
shift_time_to_clip_start: bool = True
class OutputTransformProtocol(Protocol):
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]: ...
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]: ...
def shift_detection_time(detection: Detection, time: float) -> Detection:
geometry = shift_geometry(detection.geometry, time=time)
return replace(detection, geometry=geometry)
class OutputTransform(OutputTransformProtocol):
def __init__(self, shift_time_to_clip_start: bool = True):
self.shift_time_to_clip_start = shift_time_to_clip_start
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]:
return [
self.transform_prediction(prediction) for prediction in predictions
]
def transform_prediction(
self, prediction: ClipDetections
) -> ClipDetections:
if not self.shift_time_to_clip_start:
return prediction
detections = self.transform_detections(
prediction.detections,
start_time=prediction.clip.start_time,
)
return ClipDetections(clip=prediction.clip, detections=detections)
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]:
if not self.shift_time_to_clip_start or start_time == 0:
return list(detections)
return [
shift_detection_time(detection, time=start_time)
for detection in detections
]
def build_output_transform(
config: OutputTransformConfig | dict | None = None,
) -> OutputTransformProtocol:
if config is None:
config = OutputTransformConfig()
if not isinstance(config, OutputTransformConfig):
config = OutputTransformConfig.model_validate(config)
return OutputTransform(
shift_time_to_clip_start=config.shift_time_to_clip_start,
)

View File

@ -64,7 +64,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def forward(
self,
output: ModelOutput,
start_times: list[float] | None = None,
) -> list[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
@ -83,9 +82,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
threshold=self.detection_threshold,
)
if start_times is None:
start_times = [0 for _ in range(len(detections))]
return [
map_detection_to_clip(
detection,

View File

@ -6,6 +6,7 @@ from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.logging import get_image_logger
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
@ -18,10 +19,15 @@ from batdetect2.typing import (
class ValidationMetrics(Callback):
def __init__(self, evaluator: EvaluatorProtocol):
def __init__(
self,
evaluator: EvaluatorProtocol,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__()
self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform()
self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[ClipDetections] = []
@ -95,10 +101,7 @@ class ValidationMetrics(Callback):
for example_idx in batch.idx
]
clip_detections = model.postprocessor(
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
clip_detections = model.postprocessor(outputs)
predictions = [
ClipDetections(
clip=clip_annotation.clip,
@ -110,6 +113,7 @@ class ValidationMetrics(Callback):
clip_annotations, clip_detections, strict=False
)
]
predictions = self.output_transform(predictions)
self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions)

View File

@ -12,7 +12,7 @@ system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Protocol, Sequence
from typing import List, NamedTuple, Protocol
import numpy as np
import torch
@ -101,5 +101,4 @@ class PostprocessorProtocol(Protocol):
def __call__(
self,
output: ModelOutput,
start_times: Sequence[float] | None = None,
) -> List[ClipDetectionsTensor]: ...

View File

@ -5,7 +5,7 @@ import numpy as np
import pytest
from soundevent import data
from batdetect2.data.predictions import (
from batdetect2.outputs.formats import (
ParquetOutputConfig,
build_output_formatter,
)

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest
from soundevent import data
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
from batdetect2.typing import (
ClipDetections,
Detection,

View File

@ -0,0 +1,48 @@
import numpy as np
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.outputs import build_output_transform
from batdetect2.typing import ClipDetections, Detection
def test_shift_time_to_clip_start(clip: data.Clip):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection(
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
transformed = build_output_transform()(
[ClipDetections(clip=clip, detections=[detection])]
)[0]
start_time, _, end_time, _ = compute_bounds(
transformed.detections[0].geometry
)
assert np.isclose(start_time, 2.6)
assert np.isclose(end_time, 2.7)
def test_transform_identity_when_disabled(clip: data.Clip):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection(
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
transform = build_output_transform(
config={"shift_time_to_clip_start": False}
)
transformed = transform(
[ClipDetections(clip=clip, detections=[detection])]
)[0]
assert transformed.detections[0].geometry == detection.geometry