mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add outputs module
This commit is contained in:
parent
3b47c688dd
commit
c226dc3f2b
@ -102,6 +102,7 @@ exclude = [
|
||||
"src/batdetect2/plotting/legacy",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
]
|
||||
|
||||
[tool.ruff.format]
|
||||
@ -121,4 +122,5 @@ exclude = [
|
||||
"src/batdetect2/plotting/legacy",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
]
|
||||
|
||||
@ -11,17 +11,20 @@ from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.data import (
|
||||
OutputFormatConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_dataset_from_config,
|
||||
)
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.data.predictions.base import OutputFormatterProtocol
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
@ -35,6 +38,7 @@ from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
EvaluatorProtocol,
|
||||
OutputFormatterProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
@ -51,6 +55,7 @@ class BatDetect2API:
|
||||
postprocessor: PostprocessorProtocol,
|
||||
evaluator: EvaluatorProtocol,
|
||||
formatter: OutputFormatterProtocol,
|
||||
output_transform: OutputTransformProtocol,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
@ -61,6 +66,7 @@ class BatDetect2API:
|
||||
self.evaluator = evaluator
|
||||
self.model = model
|
||||
self.formatter = formatter
|
||||
self.output_transform = output_transform
|
||||
|
||||
self.model.eval()
|
||||
|
||||
@ -208,10 +214,16 @@ class BatDetect2API:
|
||||
|
||||
detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[start_time],
|
||||
)[0]
|
||||
raw_predictions = to_raw_predictions(
|
||||
detections.numpy(),
|
||||
targets=self.targets,
|
||||
)
|
||||
|
||||
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
||||
return self.output_transform.transform_detections(
|
||||
raw_predictions,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
def process_directory(
|
||||
self,
|
||||
@ -304,7 +316,13 @@ class BatDetect2API:
|
||||
# postprocessor as these may be moved to another device.
|
||||
model = build_model(config=config.model)
|
||||
|
||||
formatter = build_output_formatter(targets, config=config.output)
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=config.outputs.format,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
@ -315,6 +333,7 @@ class BatDetect2API:
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -351,7 +370,13 @@ class BatDetect2API:
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
formatter = build_output_formatter(targets, config=config.output)
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=config.outputs.format,
|
||||
)
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
@ -362,4 +387,5 @@ class BatDetect2API:
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
|
||||
@ -5,14 +5,13 @@ from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.predictions import OutputFormatConfig
|
||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||
from batdetect2.evaluate.config import (
|
||||
EvaluationConfig,
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
@ -32,7 +31,7 @@ class BatDetect2Config(BaseConfig):
|
||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||
|
||||
|
||||
def validate_config(config: dict | None) -> BatDetect2Config:
|
||||
|
||||
@ -12,15 +12,6 @@ from batdetect2.data.datasets import (
|
||||
load_dataset_config,
|
||||
load_dataset_from_config,
|
||||
)
|
||||
from batdetect2.data.predictions import (
|
||||
BatDetect2OutputConfig,
|
||||
OutputFormatConfig,
|
||||
RawOutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_predictions,
|
||||
)
|
||||
from batdetect2.data.summary import (
|
||||
compute_class_summary,
|
||||
extract_recordings_df,
|
||||
@ -36,6 +27,7 @@ __all__ = [
|
||||
"BatDetect2OutputConfig",
|
||||
"DatasetConfig",
|
||||
"OutputFormatConfig",
|
||||
"ParquetOutputConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
"build_output_formatter",
|
||||
|
||||
@ -10,6 +10,7 @@ from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.typing import Detection
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -61,7 +62,12 @@ def evaluate(
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
module = EvaluationModule(model, evaluator)
|
||||
output_transform = build_output_transform(config=config.outputs.transform)
|
||||
module = EvaluationModule(
|
||||
model,
|
||||
evaluator,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||
metrics = trainer.test(module, loader)
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from torch.utils.data import DataLoader
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
@ -17,11 +18,13 @@ class EvaluationModule(LightningModule):
|
||||
self,
|
||||
model: Model,
|
||||
evaluator: EvaluatorProtocol,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.evaluator = evaluator
|
||||
self.output_transform = output_transform or build_output_transform()
|
||||
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[ClipDetections] = []
|
||||
@ -34,10 +37,7 @@ class EvaluationModule(LightningModule):
|
||||
]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
predictions = [
|
||||
ClipDetections(
|
||||
clip=clip_annotation.clip,
|
||||
@ -50,6 +50,7 @@ class EvaluationModule(LightningModule):
|
||||
clip_annotations, clip_detections, strict=False
|
||||
)
|
||||
]
|
||||
predictions = self.output_transform(predictions)
|
||||
|
||||
self.clip_annotations.extend(clip_annotations)
|
||||
self.predictions.extend(predictions)
|
||||
|
||||
@ -8,6 +8,7 @@ from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
@ -28,6 +29,7 @@ def run_batch_inference(
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
output_transform: Optional[OutputTransformProtocol] = None,
|
||||
num_workers: int | None = None,
|
||||
batch_size: int | None = None,
|
||||
) -> List[ClipDetections]:
|
||||
@ -42,6 +44,9 @@ def run_batch_inference(
|
||||
)
|
||||
|
||||
targets = targets or build_targets()
|
||||
output_transform = output_transform or build_output_transform(
|
||||
config=config.outputs.transform,
|
||||
)
|
||||
|
||||
loader = build_inference_loader(
|
||||
clips,
|
||||
@ -52,7 +57,10 @@ def run_batch_inference(
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
module = InferenceModule(model)
|
||||
module = InferenceModule(
|
||||
model,
|
||||
output_transform=output_transform,
|
||||
)
|
||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||
outputs = trainer.predict(module, loader)
|
||||
return [
|
||||
|
||||
@ -5,14 +5,20 @@ from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing.postprocess import ClipDetections
|
||||
|
||||
|
||||
class InferenceModule(LightningModule):
|
||||
def __init__(self, model: Model):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.output_transform = output_transform or build_output_transform()
|
||||
|
||||
def predict_step(
|
||||
self,
|
||||
@ -26,10 +32,7 @@ class InferenceModule(LightningModule):
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
clip_detections = self.model.postprocessor(outputs)
|
||||
|
||||
predictions = [
|
||||
ClipDetections(
|
||||
@ -42,7 +45,7 @@ class InferenceModule(LightningModule):
|
||||
for clip, clip_dets in zip(clips, clip_detections, strict=False)
|
||||
]
|
||||
|
||||
return predictions
|
||||
return self.output_transform(predictions)
|
||||
|
||||
def get_dataset(self) -> InferenceDataset:
|
||||
dataloaders = self.trainer.predict_dataloaders
|
||||
|
||||
31
src/batdetect2/outputs/__init__.py
Normal file
31
src/batdetect2/outputs/__init__.py
Normal file
@ -0,0 +1,31 @@
|
||||
from batdetect2.outputs.config import OutputsConfig
|
||||
from batdetect2.outputs.formats import (
|
||||
BatDetect2OutputConfig,
|
||||
OutputFormatConfig,
|
||||
ParquetOutputConfig,
|
||||
RawOutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_predictions,
|
||||
)
|
||||
from batdetect2.outputs.transforms import (
|
||||
OutputTransformConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_transform,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"BatDetect2OutputConfig",
|
||||
"OutputFormatConfig",
|
||||
"OutputTransformConfig",
|
||||
"OutputTransformProtocol",
|
||||
"OutputsConfig",
|
||||
"ParquetOutputConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
"build_output_formatter",
|
||||
"build_output_transform",
|
||||
"get_output_formatter",
|
||||
"load_predictions",
|
||||
]
|
||||
15
src/batdetect2/outputs/config.py
Normal file
15
src/batdetect2/outputs/config.py
Normal file
@ -0,0 +1,15 @@
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.outputs.formats import OutputFormatConfig
|
||||
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||
from batdetect2.outputs.transforms import OutputTransformConfig
|
||||
|
||||
__all__ = ["OutputsConfig"]
|
||||
|
||||
|
||||
class OutputsConfig(BaseConfig):
|
||||
format: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||
transform: OutputTransformConfig = Field(
|
||||
default_factory=OutputTransformConfig
|
||||
)
|
||||
@ -3,23 +3,25 @@ from typing import Annotated
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.data.predictions.base import (
|
||||
from batdetect2.outputs.formats.base import (
|
||||
OutputFormatterProtocol,
|
||||
prediction_formatters,
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
||||
from batdetect2.data.predictions.parquet import ParquetOutputConfig
|
||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
||||
from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
|
||||
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
|
||||
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_output_formatter",
|
||||
"get_output_formatter",
|
||||
"BatDetect2OutputConfig",
|
||||
"OutputFormatConfig",
|
||||
"ParquetOutputConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
"build_output_formatter",
|
||||
"get_output_formatter",
|
||||
"load_predictions",
|
||||
]
|
||||
|
||||
|
||||
@ -42,7 +44,7 @@ def build_output_formatter(
|
||||
config = config or RawOutputConfig()
|
||||
|
||||
targets = targets or build_targets()
|
||||
return prediction_formatters.build(config, targets)
|
||||
return output_formatters.build(config, targets)
|
||||
|
||||
|
||||
def get_output_formatter(
|
||||
@ -56,7 +58,7 @@ def get_output_formatter(
|
||||
if name is None:
|
||||
raise ValueError("Either config or name must be provided.")
|
||||
|
||||
config_class = prediction_formatters.get_config_type(name)
|
||||
config_class = output_formatters.get_config_type(name)
|
||||
config = config_class() # type: ignore
|
||||
|
||||
if config.name != name: # type: ignore
|
||||
@ -9,6 +9,13 @@ from batdetect2.typing import (
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"OutputFormatterProtocol",
|
||||
"PredictionFormatterImportConfig",
|
||||
"make_path_relative",
|
||||
"output_formatters",
|
||||
]
|
||||
|
||||
|
||||
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||
path = Path(path)
|
||||
@ -25,12 +32,12 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||
return path
|
||||
|
||||
|
||||
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||
output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||
Registry(name="output_formatter")
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(prediction_formatters)
|
||||
@add_import_config(output_formatters)
|
||||
class PredictionFormatterImportConfig(ImportConfig):
|
||||
"""Use any callable as a prediction formatter.
|
||||
|
||||
@ -7,9 +7,9 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.data.predictions.base import (
|
||||
from batdetect2.outputs.formats.base import (
|
||||
make_path_relative,
|
||||
prediction_formatters,
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.typing import (
|
||||
@ -28,76 +28,32 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
|
||||
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the
|
||||
annotation tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
"""Start time in seconds."""
|
||||
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: float
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: float
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
"""Probability of class assignment."""
|
||||
|
||||
det_prob: float
|
||||
"""Probability of detection."""
|
||||
|
||||
individual: str
|
||||
"""Individual ID."""
|
||||
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
|
||||
class FileAnnotation(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""File ID."""
|
||||
|
||||
annotated: bool
|
||||
"""Whether file has been annotated."""
|
||||
|
||||
duration: float
|
||||
"""Duration of audio file."""
|
||||
|
||||
issues: bool
|
||||
"""Whether file has issues."""
|
||||
|
||||
time_exp: float
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level."""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
|
||||
"""Path to file."""
|
||||
|
||||
|
||||
class BatDetect2OutputConfig(BaseConfig):
|
||||
name: Literal["batdetect2"] = "batdetect2"
|
||||
|
||||
event_name: str = "Echolocation"
|
||||
|
||||
annotation_note: str = "Automatically generated."
|
||||
|
||||
|
||||
@ -156,8 +112,6 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
]
|
||||
|
||||
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
||||
"""Get class of recording from annotations."""
|
||||
|
||||
if not annotations:
|
||||
return ""
|
||||
|
||||
@ -215,7 +169,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
**{"class": top_class},
|
||||
)
|
||||
|
||||
@prediction_formatters.register(BatDetect2OutputConfig)
|
||||
@output_formatters.register(BatDetect2OutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
|
||||
return BatDetect2Formatter(
|
||||
@ -9,9 +9,9 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.data.predictions.base import (
|
||||
from batdetect2.outputs.formats.base import (
|
||||
make_path_relative,
|
||||
prediction_formatters,
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
@ -59,10 +59,7 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
# Ensure the file has .parquet extension if it's a file path
|
||||
if path.suffix != ".parquet":
|
||||
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
|
||||
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
|
||||
if path.is_dir() or not path.suffix:
|
||||
path = path / "predictions.parquet"
|
||||
|
||||
@ -90,7 +87,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
}
|
||||
|
||||
if self.include_geometry:
|
||||
# Store geometry as [start_time, low_freq, end_time, high_freq]
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
pred.geometry
|
||||
)
|
||||
@ -98,8 +94,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
row["low_freq"] = low_freq
|
||||
row["end_time"] = end_time
|
||||
row["high_freq"] = high_freq
|
||||
|
||||
# Store full geometry as JSON
|
||||
row["geometry"] = pred.geometry.model_dump_json()
|
||||
|
||||
if self.include_class_scores:
|
||||
@ -121,11 +115,9 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
||||
path = Path(path)
|
||||
if path.is_dir():
|
||||
# Try to find parquet files
|
||||
files = list(path.glob("*.parquet"))
|
||||
if not files:
|
||||
return []
|
||||
# Read all and concatenate
|
||||
dfs = [pd.read_parquet(f) for f in files]
|
||||
df = pd.concat(dfs, ignore_index=True)
|
||||
else:
|
||||
@ -148,7 +140,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
)
|
||||
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
|
||||
|
||||
# Reconstruct geometry
|
||||
if "geometry" in row and row["geometry"]:
|
||||
geometry = data.geometry_validate(row["geometry"])
|
||||
else:
|
||||
@ -182,13 +173,14 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
for clip_data in predictions_by_clip.values():
|
||||
results.append(
|
||||
ClipDetections(
|
||||
clip=clip_data["clip"], detections=clip_data["preds"]
|
||||
clip=clip_data["clip"],
|
||||
detections=clip_data["preds"],
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
@prediction_formatters.register(ParquetOutputConfig)
|
||||
@output_formatters.register(ParquetOutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: ParquetOutputConfig, targets: TargetProtocol):
|
||||
return ParquetFormatter(
|
||||
@ -10,9 +10,9 @@ from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.data.predictions.base import (
|
||||
from batdetect2.outputs.formats.base import (
|
||||
make_path_relative,
|
||||
prediction_formatters,
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
@ -95,56 +95,56 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
update=dict(path=make_path_relative(recording.path, audio_dir))
|
||||
)
|
||||
|
||||
data = defaultdict(list)
|
||||
values = defaultdict(list)
|
||||
|
||||
for pred in prediction.detections:
|
||||
detection_id = str(uuid4())
|
||||
|
||||
data["detection_id"].append(detection_id)
|
||||
data["detection_score"].append(pred.detection_score)
|
||||
values["detection_id"].append(detection_id)
|
||||
values["detection_score"].append(pred.detection_score)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
pred.geometry
|
||||
)
|
||||
|
||||
data["start_time"].append(start_time)
|
||||
data["end_time"].append(end_time)
|
||||
data["low_freq"].append(low_freq)
|
||||
data["high_freq"].append(high_freq)
|
||||
values["start_time"].append(start_time)
|
||||
values["end_time"].append(end_time)
|
||||
values["low_freq"].append(low_freq)
|
||||
values["high_freq"].append(high_freq)
|
||||
|
||||
data["geometry"].append(pred.geometry.model_dump_json())
|
||||
values["geometry"].append(pred.geometry.model_dump_json())
|
||||
|
||||
top_class_index = int(np.argmax(pred.class_scores))
|
||||
top_class_score = float(pred.class_scores[top_class_index])
|
||||
top_class = self.targets.class_names[top_class_index]
|
||||
|
||||
data["top_class"].append(top_class)
|
||||
data["top_class_score"].append(top_class_score)
|
||||
values["top_class"].append(top_class)
|
||||
values["top_class_score"].append(top_class_score)
|
||||
|
||||
data["class_scores"].append(pred.class_scores)
|
||||
data["features"].append(pred.features)
|
||||
values["class_scores"].append(pred.class_scores)
|
||||
values["features"].append(pred.features)
|
||||
|
||||
num_features = len(pred.features)
|
||||
|
||||
data_vars = {
|
||||
"score": (["detection"], data["detection_score"]),
|
||||
"start_time": (["detection"], data["start_time"]),
|
||||
"end_time": (["detection"], data["end_time"]),
|
||||
"low_freq": (["detection"], data["low_freq"]),
|
||||
"high_freq": (["detection"], data["high_freq"]),
|
||||
"top_class": (["detection"], data["top_class"]),
|
||||
"top_class_score": (["detection"], data["top_class_score"]),
|
||||
"score": (["detection"], values["detection_score"]),
|
||||
"start_time": (["detection"], values["start_time"]),
|
||||
"end_time": (["detection"], values["end_time"]),
|
||||
"low_freq": (["detection"], values["low_freq"]),
|
||||
"high_freq": (["detection"], values["high_freq"]),
|
||||
"top_class": (["detection"], values["top_class"]),
|
||||
"top_class_score": (["detection"], values["top_class_score"]),
|
||||
}
|
||||
|
||||
coords = {
|
||||
"detection": ("detection", data["detection_id"]),
|
||||
"detection": ("detection", values["detection_id"]),
|
||||
"clip_start": clip.start_time,
|
||||
"clip_end": clip.end_time,
|
||||
"clip_id": str(clip.uuid),
|
||||
}
|
||||
|
||||
if self.include_class_scores:
|
||||
class_scores = np.stack(data["class_scores"], axis=0)
|
||||
class_scores = np.stack(values["class_scores"], axis=0)
|
||||
data_vars["class_scores"] = (
|
||||
["detection", "classes"],
|
||||
class_scores,
|
||||
@ -152,12 +152,12 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
coords["classes"] = ("classes", self.targets.class_names)
|
||||
|
||||
if self.include_features:
|
||||
features = np.stack(data["features"], axis=0)
|
||||
features = np.stack(values["features"], axis=0)
|
||||
data_vars["features"] = (["detection", "feature"], features)
|
||||
coords["feature"] = ("feature", np.arange(num_features))
|
||||
|
||||
if self.include_geometry:
|
||||
data_vars["geometry"] = (["detection"], data["geometry"])
|
||||
data_vars["geometry"] = (["detection"], values["geometry"])
|
||||
|
||||
return xr.Dataset(
|
||||
data_vars=data_vars,
|
||||
@ -169,7 +169,6 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
|
||||
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
|
||||
clip_data = dataset
|
||||
clip_id = clip_data.clip_id.item()
|
||||
|
||||
recording = data.Recording.model_validate_json(
|
||||
clip_data.attrs["recording"]
|
||||
@ -232,7 +231,7 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
||||
detections=sound_events,
|
||||
)
|
||||
|
||||
@prediction_formatters.register(RawOutputConfig)
|
||||
@output_formatters.register(RawOutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
||||
return RawFormatter(
|
||||
@ -5,8 +5,8 @@ import numpy as np
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.data.predictions.base import (
|
||||
prediction_formatters,
|
||||
from batdetect2.outputs.formats.base import (
|
||||
output_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
@ -121,7 +121,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||
|
||||
return tags
|
||||
|
||||
@prediction_formatters.register(SoundEventOutputConfig)
|
||||
@output_formatters.register(SoundEventOutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
|
||||
return SoundEventOutputFormatter(
|
||||
89
src/batdetect2/outputs/transforms.py
Normal file
89
src/batdetect2/outputs/transforms.py
Normal file
@ -0,0 +1,89 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import replace
|
||||
from typing import Protocol
|
||||
|
||||
from soundevent.geometry import shift_geometry
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.typing import ClipDetections, Detection
|
||||
|
||||
__all__ = [
|
||||
"OutputTransform",
|
||||
"OutputTransformConfig",
|
||||
"OutputTransformProtocol",
|
||||
"build_output_transform",
|
||||
]
|
||||
|
||||
|
||||
class OutputTransformConfig(BaseConfig):
|
||||
shift_time_to_clip_start: bool = True
|
||||
|
||||
|
||||
class OutputTransformProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> list[ClipDetections]: ...
|
||||
|
||||
def transform_detections(
|
||||
self,
|
||||
detections: Sequence[Detection],
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]: ...
|
||||
|
||||
|
||||
def shift_detection_time(detection: Detection, time: float) -> Detection:
|
||||
geometry = shift_geometry(detection.geometry, time=time)
|
||||
return replace(detection, geometry=geometry)
|
||||
|
||||
|
||||
class OutputTransform(OutputTransformProtocol):
|
||||
def __init__(self, shift_time_to_clip_start: bool = True):
|
||||
self.shift_time_to_clip_start = shift_time_to_clip_start
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
predictions: Sequence[ClipDetections],
|
||||
) -> list[ClipDetections]:
|
||||
return [
|
||||
self.transform_prediction(prediction) for prediction in predictions
|
||||
]
|
||||
|
||||
def transform_prediction(
|
||||
self, prediction: ClipDetections
|
||||
) -> ClipDetections:
|
||||
if not self.shift_time_to_clip_start:
|
||||
return prediction
|
||||
|
||||
detections = self.transform_detections(
|
||||
prediction.detections,
|
||||
start_time=prediction.clip.start_time,
|
||||
)
|
||||
return ClipDetections(clip=prediction.clip, detections=detections)
|
||||
|
||||
def transform_detections(
|
||||
self,
|
||||
detections: Sequence[Detection],
|
||||
start_time: float = 0,
|
||||
) -> list[Detection]:
|
||||
if not self.shift_time_to_clip_start or start_time == 0:
|
||||
return list(detections)
|
||||
|
||||
return [
|
||||
shift_detection_time(detection, time=start_time)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
|
||||
def build_output_transform(
|
||||
config: OutputTransformConfig | dict | None = None,
|
||||
) -> OutputTransformProtocol:
|
||||
if config is None:
|
||||
config = OutputTransformConfig()
|
||||
|
||||
if not isinstance(config, OutputTransformConfig):
|
||||
config = OutputTransformConfig.model_validate(config)
|
||||
|
||||
return OutputTransform(
|
||||
shift_time_to_clip_start=config.shift_time_to_clip_start,
|
||||
)
|
||||
@ -64,7 +64,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
def forward(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: list[float] | None = None,
|
||||
) -> list[ClipDetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
@ -83,9 +82,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
start_times = [0 for _ in range(len(detections))]
|
||||
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
|
||||
@ -6,6 +6,7 @@ from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -18,10 +19,15 @@ from batdetect2.typing import (
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self, evaluator: EvaluatorProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
evaluator: EvaluatorProtocol,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.evaluator = evaluator
|
||||
self.output_transform = output_transform or build_output_transform()
|
||||
|
||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||
self._predictions: List[ClipDetections] = []
|
||||
@ -95,10 +101,7 @@ class ValidationMetrics(Callback):
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
clip_detections = model.postprocessor(
|
||||
outputs,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
clip_detections = model.postprocessor(outputs)
|
||||
predictions = [
|
||||
ClipDetections(
|
||||
clip=clip_annotation.clip,
|
||||
@ -110,6 +113,7 @@ class ValidationMetrics(Callback):
|
||||
clip_annotations, clip_detections, strict=False
|
||||
)
|
||||
]
|
||||
predictions = self.output_transform(predictions)
|
||||
|
||||
self._clip_annotations.extend(clip_annotations)
|
||||
self._predictions.extend(predictions)
|
||||
|
||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Protocol, Sequence
|
||||
from typing import List, NamedTuple, Protocol
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -101,5 +101,4 @@ class PostprocessorProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Sequence[float] | None = None,
|
||||
) -> List[ClipDetectionsTensor]: ...
|
||||
|
||||
@ -5,7 +5,7 @@ import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.predictions import (
|
||||
from batdetect2.outputs.formats import (
|
||||
ParquetOutputConfig,
|
||||
build_output_formatter,
|
||||
)
|
||||
|
||||
@ -4,7 +4,7 @@ import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
||||
from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
|
||||
from batdetect2.typing import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
|
||||
48
tests/test_outputs/test_transform/test_transform.py
Normal file
48
tests/test_outputs/test_transform/test_transform.py
Normal file
@ -0,0 +1,48 @@
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.typing import ClipDetections, Detection
|
||||
|
||||
|
||||
def test_shift_time_to_clip_start(clip: data.Clip):
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
|
||||
detection = Detection(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
||||
detection_score=0.9,
|
||||
class_scores=np.array([0.9]),
|
||||
features=np.array([1.0, 2.0]),
|
||||
)
|
||||
|
||||
transformed = build_output_transform()(
|
||||
[ClipDetections(clip=clip, detections=[detection])]
|
||||
)[0]
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(
|
||||
transformed.detections[0].geometry
|
||||
)
|
||||
|
||||
assert np.isclose(start_time, 2.6)
|
||||
assert np.isclose(end_time, 2.7)
|
||||
|
||||
|
||||
def test_transform_identity_when_disabled(clip: data.Clip):
|
||||
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||
|
||||
detection = Detection(
|
||||
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
||||
detection_score=0.9,
|
||||
class_scores=np.array([0.9]),
|
||||
features=np.array([1.0, 2.0]),
|
||||
)
|
||||
|
||||
transform = build_output_transform(
|
||||
config={"shift_time_to_clip_start": False}
|
||||
)
|
||||
transformed = transform(
|
||||
[ClipDetections(clip=clip, detections=[detection])]
|
||||
)[0]
|
||||
|
||||
assert transformed.detections[0].geometry == detection.geometry
|
||||
Loading…
Reference in New Issue
Block a user