Add outputs module

This commit is contained in:
mbsantiago 2026-03-17 23:00:26 +00:00
parent 3b47c688dd
commit c226dc3f2b
23 changed files with 322 additions and 149 deletions

View File

@ -102,6 +102,7 @@ exclude = [
"src/batdetect2/plotting/legacy", "src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy", "src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune", "src/batdetect2/finetune",
"src/batdetect2/utils",
] ]
[tool.ruff.format] [tool.ruff.format]
@ -121,4 +122,5 @@ exclude = [
"src/batdetect2/plotting/legacy", "src/batdetect2/plotting/legacy",
"src/batdetect2/evaluate/legacy", "src/batdetect2/evaluate/legacy",
"src/batdetect2/finetune", "src/batdetect2/finetune",
"src/batdetect2/utils",
] ]

View File

@ -11,17 +11,20 @@ from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs from batdetect2.core import merge_configs
from batdetect2.data import ( from batdetect2.data import (
OutputFormatConfig,
build_output_formatter,
get_output_formatter,
load_dataset_from_config, load_dataset_from_config,
) )
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
from batdetect2.data.predictions.base import OutputFormatterProtocol
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.outputs import (
OutputFormatConfig,
OutputTransformProtocol,
build_output_formatter,
build_output_transform,
get_output_formatter,
)
from batdetect2.postprocess import build_postprocessor, to_raw_predictions from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
@ -35,6 +38,7 @@ from batdetect2.typing import (
ClipDetections, ClipDetections,
Detection, Detection,
EvaluatorProtocol, EvaluatorProtocol,
OutputFormatterProtocol,
PostprocessorProtocol, PostprocessorProtocol,
PreprocessorProtocol, PreprocessorProtocol,
TargetProtocol, TargetProtocol,
@ -51,6 +55,7 @@ class BatDetect2API:
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
evaluator: EvaluatorProtocol, evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol, formatter: OutputFormatterProtocol,
output_transform: OutputTransformProtocol,
model: Model, model: Model,
): ):
self.config = config self.config = config
@ -61,6 +66,7 @@ class BatDetect2API:
self.evaluator = evaluator self.evaluator = evaluator
self.model = model self.model = model
self.formatter = formatter self.formatter = formatter
self.output_transform = output_transform
self.model.eval() self.model.eval()
@ -208,10 +214,16 @@ class BatDetect2API:
detections = self.model.postprocessor( detections = self.model.postprocessor(
outputs, outputs,
start_times=[start_time],
)[0] )[0]
raw_predictions = to_raw_predictions(
detections.numpy(),
targets=self.targets,
)
return to_raw_predictions(detections.numpy(), targets=self.targets) return self.output_transform.transform_detections(
raw_predictions,
start_time=start_time,
)
def process_directory( def process_directory(
self, self,
@ -304,7 +316,13 @@ class BatDetect2API:
# postprocessor as these may be moved to another device. # postprocessor as these may be moved to another device.
model = build_model(config=config.model) model = build_model(config=config.model)
formatter = build_output_formatter(targets, config=config.output) formatter = build_output_formatter(
targets,
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform
)
return cls( return cls(
config=config, config=config,
@ -315,6 +333,7 @@ class BatDetect2API:
evaluator=evaluator, evaluator=evaluator,
model=model, model=model,
formatter=formatter, formatter=formatter,
output_transform=output_transform,
) )
@classmethod @classmethod
@ -351,7 +370,13 @@ class BatDetect2API:
evaluator = build_evaluator(config=config.evaluation, targets=targets) evaluator = build_evaluator(config=config.evaluation, targets=targets)
formatter = build_output_formatter(targets, config=config.output) formatter = build_output_formatter(
targets,
config=config.outputs.format,
)
output_transform = build_output_transform(
config=config.outputs.transform
)
return cls( return cls(
config=config, config=config,
@ -362,4 +387,5 @@ class BatDetect2API:
evaluator=evaluator, evaluator=evaluator,
model=model, model=model,
formatter=formatter, formatter=formatter,
output_transform=output_transform,
) )

View File

@ -5,14 +5,13 @@ from soundevent.data import PathLike
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.predictions import OutputFormatConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.evaluate.config import ( from batdetect2.evaluate.config import (
EvaluationConfig, EvaluationConfig,
get_default_eval_config, get_default_eval_config,
) )
from batdetect2.inference.config import InferenceConfig from batdetect2.inference.config import InferenceConfig
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
__all__ = [ __all__ = [
@ -32,7 +31,7 @@ class BatDetect2Config(BaseConfig):
model: ModelConfig = Field(default_factory=ModelConfig) model: ModelConfig = Field(default_factory=ModelConfig)
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig)
output: OutputFormatConfig = Field(default_factory=RawOutputConfig) outputs: OutputsConfig = Field(default_factory=OutputsConfig)
def validate_config(config: dict | None) -> BatDetect2Config: def validate_config(config: dict | None) -> BatDetect2Config:

View File

@ -12,15 +12,6 @@ from batdetect2.data.datasets import (
load_dataset_config, load_dataset_config,
load_dataset_from_config, load_dataset_from_config,
) )
from batdetect2.data.predictions import (
BatDetect2OutputConfig,
OutputFormatConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.data.summary import ( from batdetect2.data.summary import (
compute_class_summary, compute_class_summary,
extract_recordings_df, extract_recordings_df,
@ -36,6 +27,7 @@ __all__ = [
"BatDetect2OutputConfig", "BatDetect2OutputConfig",
"DatasetConfig", "DatasetConfig",
"OutputFormatConfig", "OutputFormatConfig",
"ParquetOutputConfig",
"RawOutputConfig", "RawOutputConfig",
"SoundEventOutputConfig", "SoundEventOutputConfig",
"build_output_formatter", "build_output_formatter",

View File

@ -10,6 +10,7 @@ from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger from batdetect2.logging import build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import build_output_transform
from batdetect2.typing import Detection from batdetect2.typing import Detection
if TYPE_CHECKING: if TYPE_CHECKING:
@ -61,7 +62,12 @@ def evaluate(
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
) )
module = EvaluationModule(model, evaluator) output_transform = build_output_transform(config=config.outputs.transform)
module = EvaluationModule(
model,
evaluator,
output_transform=output_transform,
)
trainer = Trainer(logger=logger, enable_checkpointing=False) trainer = Trainer(logger=logger, enable_checkpointing=False)
metrics = trainer.test(module, loader) metrics = trainer.test(module, loader)

View File

@ -7,6 +7,7 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import ClipDetections from batdetect2.typing.postprocess import ClipDetections
@ -17,11 +18,13 @@ class EvaluationModule(LightningModule):
self, self,
model: Model, model: Model,
evaluator: EvaluatorProtocol, evaluator: EvaluatorProtocol,
output_transform: OutputTransformProtocol | None = None,
): ):
super().__init__() super().__init__()
self.model = model self.model = model
self.evaluator = evaluator self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform()
self.clip_annotations: List[data.ClipAnnotation] = [] self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[ClipDetections] = [] self.predictions: List[ClipDetections] = []
@ -34,10 +37,7 @@ class EvaluationModule(LightningModule):
] ]
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor( clip_detections = self.model.postprocessor(outputs)
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [ predictions = [
ClipDetections( ClipDetections(
clip=clip_annotation.clip, clip=clip_annotation.clip,
@ -50,6 +50,7 @@ class EvaluationModule(LightningModule):
clip_annotations, clip_detections, strict=False clip_annotations, clip_detections, strict=False
) )
] ]
predictions = self.output_transform(predictions)
self.clip_annotations.extend(clip_annotations) self.clip_annotations.extend(clip_annotations)
self.predictions.extend(predictions) self.predictions.extend(predictions)

View File

@ -8,6 +8,7 @@ from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.preprocess.preprocessor import build_preprocessor from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import ClipDetections from batdetect2.typing.postprocess import ClipDetections
@ -28,6 +29,7 @@ def run_batch_inference(
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None, config: Optional["BatDetect2Config"] = None,
output_transform: Optional[OutputTransformProtocol] = None,
num_workers: int | None = None, num_workers: int | None = None,
batch_size: int | None = None, batch_size: int | None = None,
) -> List[ClipDetections]: ) -> List[ClipDetections]:
@ -42,6 +44,9 @@ def run_batch_inference(
) )
targets = targets or build_targets() targets = targets or build_targets()
output_transform = output_transform or build_output_transform(
config=config.outputs.transform,
)
loader = build_inference_loader( loader = build_inference_loader(
clips, clips,
@ -52,7 +57,10 @@ def run_batch_inference(
batch_size=batch_size, batch_size=batch_size,
) )
module = InferenceModule(model) module = InferenceModule(
model,
output_transform=output_transform,
)
trainer = Trainer(enable_checkpointing=False, logger=False) trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader) outputs = trainer.predict(module, loader)
return [ return [

View File

@ -5,14 +5,20 @@ from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import ClipDetections from batdetect2.typing.postprocess import ClipDetections
class InferenceModule(LightningModule): class InferenceModule(LightningModule):
def __init__(self, model: Model): def __init__(
self,
model: Model,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__() super().__init__()
self.model = model self.model = model
self.output_transform = output_transform or build_output_transform()
def predict_step( def predict_step(
self, self,
@ -26,10 +32,7 @@ class InferenceModule(LightningModule):
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor( clip_detections = self.model.postprocessor(outputs)
outputs,
start_times=[clip.start_time for clip in clips],
)
predictions = [ predictions = [
ClipDetections( ClipDetections(
@ -42,7 +45,7 @@ class InferenceModule(LightningModule):
for clip, clip_dets in zip(clips, clip_detections, strict=False) for clip, clip_dets in zip(clips, clip_detections, strict=False)
] ]
return predictions return self.output_transform(predictions)
def get_dataset(self) -> InferenceDataset: def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders dataloaders = self.trainer.predict_dataloaders

View File

@ -0,0 +1,31 @@
from batdetect2.outputs.config import OutputsConfig
from batdetect2.outputs.formats import (
BatDetect2OutputConfig,
OutputFormatConfig,
ParquetOutputConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
load_predictions,
)
from batdetect2.outputs.transforms import (
OutputTransformConfig,
OutputTransformProtocol,
build_output_transform,
)
__all__ = [
"BatDetect2OutputConfig",
"OutputFormatConfig",
"OutputTransformConfig",
"OutputTransformProtocol",
"OutputsConfig",
"ParquetOutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",
"build_output_transform",
"get_output_formatter",
"load_predictions",
]

View File

@ -0,0 +1,15 @@
from pydantic import Field
from batdetect2.core.configs import BaseConfig
from batdetect2.outputs.formats import OutputFormatConfig
from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.outputs.transforms import OutputTransformConfig
__all__ = ["OutputsConfig"]
class OutputsConfig(BaseConfig):
format: OutputFormatConfig = Field(default_factory=RawOutputConfig)
transform: OutputTransformConfig = Field(
default_factory=OutputTransformConfig
)

View File

@ -3,23 +3,25 @@ from typing import Annotated
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.data.predictions.base import ( from batdetect2.outputs.formats.base import (
OutputFormatterProtocol, OutputFormatterProtocol,
prediction_formatters, output_formatters,
) )
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
from batdetect2.data.predictions.parquet import ParquetOutputConfig from batdetect2.outputs.formats.parquet import ParquetOutputConfig
from batdetect2.data.predictions.raw import RawOutputConfig from batdetect2.outputs.formats.raw import RawOutputConfig
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
from batdetect2.typing import TargetProtocol from batdetect2.typing import TargetProtocol
__all__ = [ __all__ = [
"build_output_formatter",
"get_output_formatter",
"BatDetect2OutputConfig", "BatDetect2OutputConfig",
"OutputFormatConfig",
"ParquetOutputConfig", "ParquetOutputConfig",
"RawOutputConfig", "RawOutputConfig",
"SoundEventOutputConfig", "SoundEventOutputConfig",
"build_output_formatter",
"get_output_formatter",
"load_predictions",
] ]
@ -42,7 +44,7 @@ def build_output_formatter(
config = config or RawOutputConfig() config = config or RawOutputConfig()
targets = targets or build_targets() targets = targets or build_targets()
return prediction_formatters.build(config, targets) return output_formatters.build(config, targets)
def get_output_formatter( def get_output_formatter(
@ -56,7 +58,7 @@ def get_output_formatter(
if name is None: if name is None:
raise ValueError("Either config or name must be provided.") raise ValueError("Either config or name must be provided.")
config_class = prediction_formatters.get_config_type(name) config_class = output_formatters.get_config_type(name)
config = config_class() # type: ignore config = config_class() # type: ignore
if config.name != name: # type: ignore if config.name != name: # type: ignore

View File

@ -9,6 +9,13 @@ from batdetect2.typing import (
TargetProtocol, TargetProtocol,
) )
__all__ = [
"OutputFormatterProtocol",
"PredictionFormatterImportConfig",
"make_path_relative",
"output_formatters",
]
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path: def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
path = Path(path) path = Path(path)
@ -25,12 +32,12 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
return path return path
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = ( output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter") Registry(name="output_formatter")
) )
@add_import_config(prediction_formatters) @add_import_config(output_formatters)
class PredictionFormatterImportConfig(ImportConfig): class PredictionFormatterImportConfig(ImportConfig):
"""Use any callable as a prediction formatter. """Use any callable as a prediction formatter.

View File

@ -7,9 +7,9 @@ from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import ( from batdetect2.outputs.formats.base import (
make_path_relative, make_path_relative,
prediction_formatters, output_formatters,
) )
from batdetect2.targets import terms from batdetect2.targets import terms
from batdetect2.typing import ( from batdetect2.typing import (
@ -28,76 +28,32 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass): class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the
annotation tool.
"""
start_time: float start_time: float
"""Start time in seconds."""
end_time: float end_time: float
"""End time in seconds."""
low_freq: float low_freq: float
"""Low frequency in Hz."""
high_freq: float high_freq: float
"""High frequency in Hz."""
class_prob: float class_prob: float
"""Probability of class assignment."""
det_prob: float det_prob: float
"""Probability of detection."""
individual: str individual: str
"""Individual ID."""
event: str event: str
"""Type of detected event."""
class FileAnnotation(TypedDict): class FileAnnotation(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
id: str id: str
"""File ID."""
annotated: bool annotated: bool
"""Whether file has been annotated."""
duration: float duration: float
"""Duration of audio file."""
issues: bool issues: bool
"""Whether file has issues."""
time_exp: float time_exp: float
"""Time expansion factor."""
class_name: str class_name: str
"""Class predicted at file level."""
notes: str notes: str
"""Notes of file."""
annotation: List[Annotation] annotation: List[Annotation]
"""List of annotations."""
file_path: NotRequired[str] # ty: ignore[invalid-type-form] file_path: NotRequired[str] # ty: ignore[invalid-type-form]
"""Path to file."""
class BatDetect2OutputConfig(BaseConfig): class BatDetect2OutputConfig(BaseConfig):
name: Literal["batdetect2"] = "batdetect2" name: Literal["batdetect2"] = "batdetect2"
event_name: str = "Echolocation" event_name: str = "Echolocation"
annotation_note: str = "Automatically generated." annotation_note: str = "Automatically generated."
@ -156,8 +112,6 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
] ]
def get_recording_class(self, annotations: List[Annotation]) -> str: def get_recording_class(self, annotations: List[Annotation]) -> str:
"""Get class of recording from annotations."""
if not annotations: if not annotations:
return "" return ""
@ -215,7 +169,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
**{"class": top_class}, **{"class": top_class},
) )
@prediction_formatters.register(BatDetect2OutputConfig) @output_formatters.register(BatDetect2OutputConfig)
@staticmethod @staticmethod
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol): def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
return BatDetect2Formatter( return BatDetect2Formatter(

View File

@ -9,9 +9,9 @@ from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import ( from batdetect2.outputs.formats.base import (
make_path_relative, make_path_relative,
prediction_formatters, output_formatters,
) )
from batdetect2.typing import ( from batdetect2.typing import (
ClipDetections, ClipDetections,
@ -59,10 +59,7 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
if not path.parent.exists(): if not path.parent.exists():
path.parent.mkdir(parents=True) path.parent.mkdir(parents=True)
# Ensure the file has .parquet extension if it's a file path
if path.suffix != ".parquet": if path.suffix != ".parquet":
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
if path.is_dir() or not path.suffix: if path.is_dir() or not path.suffix:
path = path / "predictions.parquet" path = path / "predictions.parquet"
@ -90,7 +87,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
} }
if self.include_geometry: if self.include_geometry:
# Store geometry as [start_time, low_freq, end_time, high_freq]
start_time, low_freq, end_time, high_freq = compute_bounds( start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry pred.geometry
) )
@ -98,8 +94,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
row["low_freq"] = low_freq row["low_freq"] = low_freq
row["end_time"] = end_time row["end_time"] = end_time
row["high_freq"] = high_freq row["high_freq"] = high_freq
# Store full geometry as JSON
row["geometry"] = pred.geometry.model_dump_json() row["geometry"] = pred.geometry.model_dump_json()
if self.include_class_scores: if self.include_class_scores:
@ -121,11 +115,9 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
def load(self, path: data.PathLike) -> List[ClipDetections]: def load(self, path: data.PathLike) -> List[ClipDetections]:
path = Path(path) path = Path(path)
if path.is_dir(): if path.is_dir():
# Try to find parquet files
files = list(path.glob("*.parquet")) files = list(path.glob("*.parquet"))
if not files: if not files:
return [] return []
# Read all and concatenate
dfs = [pd.read_parquet(f) for f in files] dfs = [pd.read_parquet(f) for f in files]
df = pd.concat(dfs, ignore_index=True) df = pd.concat(dfs, ignore_index=True)
else: else:
@ -148,7 +140,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
) )
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []} predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
# Reconstruct geometry
if "geometry" in row and row["geometry"]: if "geometry" in row and row["geometry"]:
geometry = data.geometry_validate(row["geometry"]) geometry = data.geometry_validate(row["geometry"])
else: else:
@ -182,13 +173,14 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
for clip_data in predictions_by_clip.values(): for clip_data in predictions_by_clip.values():
results.append( results.append(
ClipDetections( ClipDetections(
clip=clip_data["clip"], detections=clip_data["preds"] clip=clip_data["clip"],
detections=clip_data["preds"],
) )
) )
return results return results
@prediction_formatters.register(ParquetOutputConfig) @output_formatters.register(ParquetOutputConfig)
@staticmethod @staticmethod
def from_config(config: ParquetOutputConfig, targets: TargetProtocol): def from_config(config: ParquetOutputConfig, targets: TargetProtocol):
return ParquetFormatter( return ParquetFormatter(

View File

@ -10,9 +10,9 @@ from soundevent import data
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import ( from batdetect2.outputs.formats.base import (
make_path_relative, make_path_relative,
prediction_formatters, output_formatters,
) )
from batdetect2.typing import ( from batdetect2.typing import (
ClipDetections, ClipDetections,
@ -95,56 +95,56 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
update=dict(path=make_path_relative(recording.path, audio_dir)) update=dict(path=make_path_relative(recording.path, audio_dir))
) )
data = defaultdict(list) values = defaultdict(list)
for pred in prediction.detections: for pred in prediction.detections:
detection_id = str(uuid4()) detection_id = str(uuid4())
data["detection_id"].append(detection_id) values["detection_id"].append(detection_id)
data["detection_score"].append(pred.detection_score) values["detection_score"].append(pred.detection_score)
start_time, low_freq, end_time, high_freq = compute_bounds( start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry pred.geometry
) )
data["start_time"].append(start_time) values["start_time"].append(start_time)
data["end_time"].append(end_time) values["end_time"].append(end_time)
data["low_freq"].append(low_freq) values["low_freq"].append(low_freq)
data["high_freq"].append(high_freq) values["high_freq"].append(high_freq)
data["geometry"].append(pred.geometry.model_dump_json()) values["geometry"].append(pred.geometry.model_dump_json())
top_class_index = int(np.argmax(pred.class_scores)) top_class_index = int(np.argmax(pred.class_scores))
top_class_score = float(pred.class_scores[top_class_index]) top_class_score = float(pred.class_scores[top_class_index])
top_class = self.targets.class_names[top_class_index] top_class = self.targets.class_names[top_class_index]
data["top_class"].append(top_class) values["top_class"].append(top_class)
data["top_class_score"].append(top_class_score) values["top_class_score"].append(top_class_score)
data["class_scores"].append(pred.class_scores) values["class_scores"].append(pred.class_scores)
data["features"].append(pred.features) values["features"].append(pred.features)
num_features = len(pred.features) num_features = len(pred.features)
data_vars = { data_vars = {
"score": (["detection"], data["detection_score"]), "score": (["detection"], values["detection_score"]),
"start_time": (["detection"], data["start_time"]), "start_time": (["detection"], values["start_time"]),
"end_time": (["detection"], data["end_time"]), "end_time": (["detection"], values["end_time"]),
"low_freq": (["detection"], data["low_freq"]), "low_freq": (["detection"], values["low_freq"]),
"high_freq": (["detection"], data["high_freq"]), "high_freq": (["detection"], values["high_freq"]),
"top_class": (["detection"], data["top_class"]), "top_class": (["detection"], values["top_class"]),
"top_class_score": (["detection"], data["top_class_score"]), "top_class_score": (["detection"], values["top_class_score"]),
} }
coords = { coords = {
"detection": ("detection", data["detection_id"]), "detection": ("detection", values["detection_id"]),
"clip_start": clip.start_time, "clip_start": clip.start_time,
"clip_end": clip.end_time, "clip_end": clip.end_time,
"clip_id": str(clip.uuid), "clip_id": str(clip.uuid),
} }
if self.include_class_scores: if self.include_class_scores:
class_scores = np.stack(data["class_scores"], axis=0) class_scores = np.stack(values["class_scores"], axis=0)
data_vars["class_scores"] = ( data_vars["class_scores"] = (
["detection", "classes"], ["detection", "classes"],
class_scores, class_scores,
@ -152,12 +152,12 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
coords["classes"] = ("classes", self.targets.class_names) coords["classes"] = ("classes", self.targets.class_names)
if self.include_features: if self.include_features:
features = np.stack(data["features"], axis=0) features = np.stack(values["features"], axis=0)
data_vars["features"] = (["detection", "feature"], features) data_vars["features"] = (["detection", "feature"], features)
coords["feature"] = ("feature", np.arange(num_features)) coords["feature"] = ("feature", np.arange(num_features))
if self.include_geometry: if self.include_geometry:
data_vars["geometry"] = (["detection"], data["geometry"]) data_vars["geometry"] = (["detection"], values["geometry"])
return xr.Dataset( return xr.Dataset(
data_vars=data_vars, data_vars=data_vars,
@ -169,7 +169,6 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections: def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
clip_data = dataset clip_data = dataset
clip_id = clip_data.clip_id.item()
recording = data.Recording.model_validate_json( recording = data.Recording.model_validate_json(
clip_data.attrs["recording"] clip_data.attrs["recording"]
@ -232,7 +231,7 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
detections=sound_events, detections=sound_events,
) )
@prediction_formatters.register(RawOutputConfig) @output_formatters.register(RawOutputConfig)
@staticmethod @staticmethod
def from_config(config: RawOutputConfig, targets: TargetProtocol): def from_config(config: RawOutputConfig, targets: TargetProtocol):
return RawFormatter( return RawFormatter(

View File

@ -5,8 +5,8 @@ import numpy as np
from soundevent import data, io from soundevent import data, io
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import ( from batdetect2.outputs.formats.base import (
prediction_formatters, output_formatters,
) )
from batdetect2.typing import ( from batdetect2.typing import (
ClipDetections, ClipDetections,
@ -121,7 +121,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
return tags return tags
@prediction_formatters.register(SoundEventOutputConfig) @output_formatters.register(SoundEventOutputConfig)
@staticmethod @staticmethod
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol): def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
return SoundEventOutputFormatter( return SoundEventOutputFormatter(

View File

@ -0,0 +1,89 @@
from collections.abc import Sequence
from dataclasses import replace
from typing import Protocol
from soundevent.geometry import shift_geometry
from batdetect2.core.configs import BaseConfig
from batdetect2.typing import ClipDetections, Detection
__all__ = [
"OutputTransform",
"OutputTransformConfig",
"OutputTransformProtocol",
"build_output_transform",
]
class OutputTransformConfig(BaseConfig):
shift_time_to_clip_start: bool = True
class OutputTransformProtocol(Protocol):
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]: ...
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]: ...
def shift_detection_time(detection: Detection, time: float) -> Detection:
geometry = shift_geometry(detection.geometry, time=time)
return replace(detection, geometry=geometry)
class OutputTransform(OutputTransformProtocol):
def __init__(self, shift_time_to_clip_start: bool = True):
self.shift_time_to_clip_start = shift_time_to_clip_start
def __call__(
self,
predictions: Sequence[ClipDetections],
) -> list[ClipDetections]:
return [
self.transform_prediction(prediction) for prediction in predictions
]
def transform_prediction(
self, prediction: ClipDetections
) -> ClipDetections:
if not self.shift_time_to_clip_start:
return prediction
detections = self.transform_detections(
prediction.detections,
start_time=prediction.clip.start_time,
)
return ClipDetections(clip=prediction.clip, detections=detections)
def transform_detections(
self,
detections: Sequence[Detection],
start_time: float = 0,
) -> list[Detection]:
if not self.shift_time_to_clip_start or start_time == 0:
return list(detections)
return [
shift_detection_time(detection, time=start_time)
for detection in detections
]
def build_output_transform(
config: OutputTransformConfig | dict | None = None,
) -> OutputTransformProtocol:
if config is None:
config = OutputTransformConfig()
if not isinstance(config, OutputTransformConfig):
config = OutputTransformConfig.model_validate(config)
return OutputTransform(
shift_time_to_clip_start=config.shift_time_to_clip_start,
)

View File

@ -64,7 +64,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
def forward( def forward(
self, self,
output: ModelOutput, output: ModelOutput,
start_times: list[float] | None = None,
) -> list[ClipDetectionsTensor]: ) -> list[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression( detection_heatmap = non_max_suppression(
output.detection_probs.detach(), output.detection_probs.detach(),
@ -83,9 +82,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
threshold=self.detection_threshold, threshold=self.detection_threshold,
) )
if start_times is None:
start_times = [0 for _ in range(len(detections))]
return [ return [
map_detection_to_clip( map_detection_to_clip(
detection, detection,

View File

@ -6,6 +6,7 @@ from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
@ -18,10 +19,15 @@ from batdetect2.typing import (
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self, evaluator: EvaluatorProtocol): def __init__(
self,
evaluator: EvaluatorProtocol,
output_transform: OutputTransformProtocol | None = None,
):
super().__init__() super().__init__()
self.evaluator = evaluator self.evaluator = evaluator
self.output_transform = output_transform or build_output_transform()
self._clip_annotations: List[data.ClipAnnotation] = [] self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[ClipDetections] = [] self._predictions: List[ClipDetections] = []
@ -95,10 +101,7 @@ class ValidationMetrics(Callback):
for example_idx in batch.idx for example_idx in batch.idx
] ]
clip_detections = model.postprocessor( clip_detections = model.postprocessor(outputs)
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [ predictions = [
ClipDetections( ClipDetections(
clip=clip_annotation.clip, clip=clip_annotation.clip,
@ -110,6 +113,7 @@ class ValidationMetrics(Callback):
clip_annotations, clip_detections, strict=False clip_annotations, clip_detections, strict=False
) )
] ]
predictions = self.output_transform(predictions)
self._clip_annotations.extend(clip_annotations) self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions) self._predictions.extend(predictions)

View File

@ -12,7 +12,7 @@ system that deal with model predictions.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Protocol, Sequence from typing import List, NamedTuple, Protocol
import numpy as np import numpy as np
import torch import torch
@ -101,5 +101,4 @@ class PostprocessorProtocol(Protocol):
def __call__( def __call__(
self, self,
output: ModelOutput, output: ModelOutput,
start_times: Sequence[float] | None = None,
) -> List[ClipDetectionsTensor]: ... ) -> List[ClipDetectionsTensor]: ...

View File

@ -5,7 +5,7 @@ import numpy as np
import pytest import pytest
from soundevent import data from soundevent import data
from batdetect2.data.predictions import ( from batdetect2.outputs.formats import (
ParquetOutputConfig, ParquetOutputConfig,
build_output_formatter, build_output_formatter,
) )

View File

@ -4,7 +4,7 @@ import numpy as np
import pytest import pytest
from soundevent import data from soundevent import data
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
from batdetect2.typing import ( from batdetect2.typing import (
ClipDetections, ClipDetections,
Detection, Detection,

View File

@ -0,0 +1,48 @@
import numpy as np
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.outputs import build_output_transform
from batdetect2.typing import ClipDetections, Detection
def test_shift_time_to_clip_start(clip: data.Clip):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection(
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
transformed = build_output_transform()(
[ClipDetections(clip=clip, detections=[detection])]
)[0]
start_time, _, end_time, _ = compute_bounds(
transformed.detections[0].geometry
)
assert np.isclose(start_time, 2.6)
assert np.isclose(end_time, 2.7)
def test_transform_identity_when_disabled(clip: data.Clip):
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
detection = Detection(
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
detection_score=0.9,
class_scores=np.array([0.9]),
features=np.array([1.0, 2.0]),
)
transform = build_output_transform(
config={"shift_time_to_clip_start": False}
)
transformed = transform(
[ClipDetections(clip=clip, detections=[detection])]
)[0]
assert transformed.detections[0].geometry == detection.geometry