mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Add outputs module
This commit is contained in:
parent
3b47c688dd
commit
c226dc3f2b
@ -102,6 +102,7 @@ exclude = [
|
|||||||
"src/batdetect2/plotting/legacy",
|
"src/batdetect2/plotting/legacy",
|
||||||
"src/batdetect2/evaluate/legacy",
|
"src/batdetect2/evaluate/legacy",
|
||||||
"src/batdetect2/finetune",
|
"src/batdetect2/finetune",
|
||||||
|
"src/batdetect2/utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.ruff.format]
|
[tool.ruff.format]
|
||||||
@ -121,4 +122,5 @@ exclude = [
|
|||||||
"src/batdetect2/plotting/legacy",
|
"src/batdetect2/plotting/legacy",
|
||||||
"src/batdetect2/evaluate/legacy",
|
"src/batdetect2/evaluate/legacy",
|
||||||
"src/batdetect2/finetune",
|
"src/batdetect2/finetune",
|
||||||
|
"src/batdetect2/utils",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -11,17 +11,20 @@ from batdetect2.audio import build_audio_loader
|
|||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.core import merge_configs
|
from batdetect2.core import merge_configs
|
||||||
from batdetect2.data import (
|
from batdetect2.data import (
|
||||||
OutputFormatConfig,
|
|
||||||
build_output_formatter,
|
|
||||||
get_output_formatter,
|
|
||||||
load_dataset_from_config,
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.data.predictions.base import OutputFormatterProtocol
|
|
||||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
|
from batdetect2.outputs import (
|
||||||
|
OutputFormatConfig,
|
||||||
|
OutputTransformProtocol,
|
||||||
|
build_output_formatter,
|
||||||
|
build_output_transform,
|
||||||
|
get_output_formatter,
|
||||||
|
)
|
||||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
@ -35,6 +38,7 @@ from batdetect2.typing import (
|
|||||||
ClipDetections,
|
ClipDetections,
|
||||||
Detection,
|
Detection,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
|
OutputFormatterProtocol,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
@ -51,6 +55,7 @@ class BatDetect2API:
|
|||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
formatter: OutputFormatterProtocol,
|
formatter: OutputFormatterProtocol,
|
||||||
|
output_transform: OutputTransformProtocol,
|
||||||
model: Model,
|
model: Model,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -61,6 +66,7 @@ class BatDetect2API:
|
|||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.model = model
|
self.model = model
|
||||||
self.formatter = formatter
|
self.formatter = formatter
|
||||||
|
self.output_transform = output_transform
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@ -208,10 +214,16 @@ class BatDetect2API:
|
|||||||
|
|
||||||
detections = self.model.postprocessor(
|
detections = self.model.postprocessor(
|
||||||
outputs,
|
outputs,
|
||||||
start_times=[start_time],
|
|
||||||
)[0]
|
)[0]
|
||||||
|
raw_predictions = to_raw_predictions(
|
||||||
|
detections.numpy(),
|
||||||
|
targets=self.targets,
|
||||||
|
)
|
||||||
|
|
||||||
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
return self.output_transform.transform_detections(
|
||||||
|
raw_predictions,
|
||||||
|
start_time=start_time,
|
||||||
|
)
|
||||||
|
|
||||||
def process_directory(
|
def process_directory(
|
||||||
self,
|
self,
|
||||||
@ -304,7 +316,13 @@ class BatDetect2API:
|
|||||||
# postprocessor as these may be moved to another device.
|
# postprocessor as these may be moved to another device.
|
||||||
model = build_model(config=config.model)
|
model = build_model(config=config.model)
|
||||||
|
|
||||||
formatter = build_output_formatter(targets, config=config.output)
|
formatter = build_output_formatter(
|
||||||
|
targets,
|
||||||
|
config=config.outputs.format,
|
||||||
|
)
|
||||||
|
output_transform = build_output_transform(
|
||||||
|
config=config.outputs.transform
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
@ -315,6 +333,7 @@ class BatDetect2API:
|
|||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
model=model,
|
model=model,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
|
output_transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -351,7 +370,13 @@ class BatDetect2API:
|
|||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
|
|
||||||
formatter = build_output_formatter(targets, config=config.output)
|
formatter = build_output_formatter(
|
||||||
|
targets,
|
||||||
|
config=config.outputs.format,
|
||||||
|
)
|
||||||
|
output_transform = build_output_transform(
|
||||||
|
config=config.outputs.transform
|
||||||
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
@ -362,4 +387,5 @@ class BatDetect2API:
|
|||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
model=model,
|
model=model,
|
||||||
formatter=formatter,
|
formatter=formatter,
|
||||||
|
output_transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -5,14 +5,13 @@ from soundevent.data import PathLike
|
|||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.data.predictions import OutputFormatConfig
|
|
||||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
|
||||||
from batdetect2.evaluate.config import (
|
from batdetect2.evaluate.config import (
|
||||||
EvaluationConfig,
|
EvaluationConfig,
|
||||||
get_default_eval_config,
|
get_default_eval_config,
|
||||||
)
|
)
|
||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.outputs import OutputsConfig
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -32,7 +31,7 @@ class BatDetect2Config(BaseConfig):
|
|||||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||||
|
|
||||||
|
|
||||||
def validate_config(config: dict | None) -> BatDetect2Config:
|
def validate_config(config: dict | None) -> BatDetect2Config:
|
||||||
|
|||||||
@ -12,15 +12,6 @@ from batdetect2.data.datasets import (
|
|||||||
load_dataset_config,
|
load_dataset_config,
|
||||||
load_dataset_from_config,
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.data.predictions import (
|
|
||||||
BatDetect2OutputConfig,
|
|
||||||
OutputFormatConfig,
|
|
||||||
RawOutputConfig,
|
|
||||||
SoundEventOutputConfig,
|
|
||||||
build_output_formatter,
|
|
||||||
get_output_formatter,
|
|
||||||
load_predictions,
|
|
||||||
)
|
|
||||||
from batdetect2.data.summary import (
|
from batdetect2.data.summary import (
|
||||||
compute_class_summary,
|
compute_class_summary,
|
||||||
extract_recordings_df,
|
extract_recordings_df,
|
||||||
@ -36,6 +27,7 @@ __all__ = [
|
|||||||
"BatDetect2OutputConfig",
|
"BatDetect2OutputConfig",
|
||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
"OutputFormatConfig",
|
"OutputFormatConfig",
|
||||||
|
"ParquetOutputConfig",
|
||||||
"RawOutputConfig",
|
"RawOutputConfig",
|
||||||
"SoundEventOutputConfig",
|
"SoundEventOutputConfig",
|
||||||
"build_output_formatter",
|
"build_output_formatter",
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from batdetect2.evaluate.evaluator import build_evaluator
|
|||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.outputs import build_output_transform
|
||||||
from batdetect2.typing import Detection
|
from batdetect2.typing import Detection
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -61,7 +62,12 @@ def evaluate(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
module = EvaluationModule(model, evaluator)
|
output_transform = build_output_transform(config=config.outputs.transform)
|
||||||
|
module = EvaluationModule(
|
||||||
|
model,
|
||||||
|
evaluator,
|
||||||
|
output_transform=output_transform,
|
||||||
|
)
|
||||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||||
metrics = trainer.test(module, loader)
|
metrics = trainer.test(module, loader)
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch.utils.data import DataLoader
|
|||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import EvaluatorProtocol
|
from batdetect2.typing import EvaluatorProtocol
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
@ -17,11 +18,13 @@ class EvaluationModule(LightningModule):
|
|||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
self.output_transform = output_transform or build_output_transform()
|
||||||
|
|
||||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self.predictions: List[ClipDetections] = []
|
self.predictions: List[ClipDetections] = []
|
||||||
@ -34,10 +37,7 @@ class EvaluationModule(LightningModule):
|
|||||||
]
|
]
|
||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
clip_detections = self.model.postprocessor(
|
clip_detections = self.model.postprocessor(outputs)
|
||||||
outputs,
|
|
||||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
|
||||||
)
|
|
||||||
predictions = [
|
predictions = [
|
||||||
ClipDetections(
|
ClipDetections(
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
@ -50,6 +50,7 @@ class EvaluationModule(LightningModule):
|
|||||||
clip_annotations, clip_detections, strict=False
|
clip_annotations, clip_detections, strict=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
predictions = self.output_transform(predictions)
|
||||||
|
|
||||||
self.clip_annotations.extend(clip_annotations)
|
self.clip_annotations.extend(clip_annotations)
|
||||||
self.predictions.extend(predictions)
|
self.predictions.extend(predictions)
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from batdetect2.inference.clips import get_clips_from_files
|
|||||||
from batdetect2.inference.dataset import build_inference_loader
|
from batdetect2.inference.dataset import build_inference_loader
|
||||||
from batdetect2.inference.lightning import InferenceModule
|
from batdetect2.inference.lightning import InferenceModule
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||||
from batdetect2.targets.targets import build_targets
|
from batdetect2.targets.targets import build_targets
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
@ -28,6 +29,7 @@ def run_batch_inference(
|
|||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
|
output_transform: Optional[OutputTransformProtocol] = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int | None = None,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
) -> List[ClipDetections]:
|
) -> List[ClipDetections]:
|
||||||
@ -42,6 +44,9 @@ def run_batch_inference(
|
|||||||
)
|
)
|
||||||
|
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
|
output_transform = output_transform or build_output_transform(
|
||||||
|
config=config.outputs.transform,
|
||||||
|
)
|
||||||
|
|
||||||
loader = build_inference_loader(
|
loader = build_inference_loader(
|
||||||
clips,
|
clips,
|
||||||
@ -52,7 +57,10 @@ def run_batch_inference(
|
|||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
module = InferenceModule(model)
|
module = InferenceModule(
|
||||||
|
model,
|
||||||
|
output_transform=output_transform,
|
||||||
|
)
|
||||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||||
outputs = trainer.predict(module, loader)
|
outputs = trainer.predict(module, loader)
|
||||||
return [
|
return [
|
||||||
|
|||||||
@ -5,14 +5,20 @@ from torch.utils.data import DataLoader
|
|||||||
|
|
||||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing.postprocess import ClipDetections
|
from batdetect2.typing.postprocess import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
def __init__(self, model: Model):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Model,
|
||||||
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.output_transform = output_transform or build_output_transform()
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
self,
|
self,
|
||||||
@ -26,10 +32,7 @@ class InferenceModule(LightningModule):
|
|||||||
|
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
|
|
||||||
clip_detections = self.model.postprocessor(
|
clip_detections = self.model.postprocessor(outputs)
|
||||||
outputs,
|
|
||||||
start_times=[clip.start_time for clip in clips],
|
|
||||||
)
|
|
||||||
|
|
||||||
predictions = [
|
predictions = [
|
||||||
ClipDetections(
|
ClipDetections(
|
||||||
@ -42,7 +45,7 @@ class InferenceModule(LightningModule):
|
|||||||
for clip, clip_dets in zip(clips, clip_detections, strict=False)
|
for clip, clip_dets in zip(clips, clip_detections, strict=False)
|
||||||
]
|
]
|
||||||
|
|
||||||
return predictions
|
return self.output_transform(predictions)
|
||||||
|
|
||||||
def get_dataset(self) -> InferenceDataset:
|
def get_dataset(self) -> InferenceDataset:
|
||||||
dataloaders = self.trainer.predict_dataloaders
|
dataloaders = self.trainer.predict_dataloaders
|
||||||
|
|||||||
31
src/batdetect2/outputs/__init__.py
Normal file
31
src/batdetect2/outputs/__init__.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
from batdetect2.outputs.config import OutputsConfig
|
||||||
|
from batdetect2.outputs.formats import (
|
||||||
|
BatDetect2OutputConfig,
|
||||||
|
OutputFormatConfig,
|
||||||
|
ParquetOutputConfig,
|
||||||
|
RawOutputConfig,
|
||||||
|
SoundEventOutputConfig,
|
||||||
|
build_output_formatter,
|
||||||
|
get_output_formatter,
|
||||||
|
load_predictions,
|
||||||
|
)
|
||||||
|
from batdetect2.outputs.transforms import (
|
||||||
|
OutputTransformConfig,
|
||||||
|
OutputTransformProtocol,
|
||||||
|
build_output_transform,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BatDetect2OutputConfig",
|
||||||
|
"OutputFormatConfig",
|
||||||
|
"OutputTransformConfig",
|
||||||
|
"OutputTransformProtocol",
|
||||||
|
"OutputsConfig",
|
||||||
|
"ParquetOutputConfig",
|
||||||
|
"RawOutputConfig",
|
||||||
|
"SoundEventOutputConfig",
|
||||||
|
"build_output_formatter",
|
||||||
|
"build_output_transform",
|
||||||
|
"get_output_formatter",
|
||||||
|
"load_predictions",
|
||||||
|
]
|
||||||
15
src/batdetect2/outputs/config.py
Normal file
15
src/batdetect2/outputs/config.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.outputs.formats import OutputFormatConfig
|
||||||
|
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||||
|
from batdetect2.outputs.transforms import OutputTransformConfig
|
||||||
|
|
||||||
|
__all__ = ["OutputsConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
class OutputsConfig(BaseConfig):
|
||||||
|
format: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||||
|
transform: OutputTransformConfig = Field(
|
||||||
|
default_factory=OutputTransformConfig
|
||||||
|
)
|
||||||
@ -3,23 +3,25 @@ from typing import Annotated
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.data.predictions.base import (
|
from batdetect2.outputs.formats.base import (
|
||||||
OutputFormatterProtocol,
|
OutputFormatterProtocol,
|
||||||
prediction_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
from batdetect2.outputs.formats.batdetect2 import BatDetect2OutputConfig
|
||||||
from batdetect2.data.predictions.parquet import ParquetOutputConfig
|
from batdetect2.outputs.formats.parquet import ParquetOutputConfig
|
||||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
from batdetect2.outputs.formats.raw import RawOutputConfig
|
||||||
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
from batdetect2.outputs.formats.soundevent import SoundEventOutputConfig
|
||||||
from batdetect2.typing import TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_output_formatter",
|
|
||||||
"get_output_formatter",
|
|
||||||
"BatDetect2OutputConfig",
|
"BatDetect2OutputConfig",
|
||||||
|
"OutputFormatConfig",
|
||||||
"ParquetOutputConfig",
|
"ParquetOutputConfig",
|
||||||
"RawOutputConfig",
|
"RawOutputConfig",
|
||||||
"SoundEventOutputConfig",
|
"SoundEventOutputConfig",
|
||||||
|
"build_output_formatter",
|
||||||
|
"get_output_formatter",
|
||||||
|
"load_predictions",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -42,7 +44,7 @@ def build_output_formatter(
|
|||||||
config = config or RawOutputConfig()
|
config = config or RawOutputConfig()
|
||||||
|
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
return prediction_formatters.build(config, targets)
|
return output_formatters.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
def get_output_formatter(
|
def get_output_formatter(
|
||||||
@ -56,7 +58,7 @@ def get_output_formatter(
|
|||||||
if name is None:
|
if name is None:
|
||||||
raise ValueError("Either config or name must be provided.")
|
raise ValueError("Either config or name must be provided.")
|
||||||
|
|
||||||
config_class = prediction_formatters.get_config_type(name)
|
config_class = output_formatters.get_config_type(name)
|
||||||
config = config_class() # type: ignore
|
config = config_class() # type: ignore
|
||||||
|
|
||||||
if config.name != name: # type: ignore
|
if config.name != name: # type: ignore
|
||||||
@ -9,6 +9,13 @@ from batdetect2.typing import (
|
|||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OutputFormatterProtocol",
|
||||||
|
"PredictionFormatterImportConfig",
|
||||||
|
"make_path_relative",
|
||||||
|
"output_formatters",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
@ -25,12 +32,12 @@ def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
output_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||||
Registry(name="output_formatter")
|
Registry(name="output_formatter")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(prediction_formatters)
|
@add_import_config(output_formatters)
|
||||||
class PredictionFormatterImportConfig(ImportConfig):
|
class PredictionFormatterImportConfig(ImportConfig):
|
||||||
"""Use any callable as a prediction formatter.
|
"""Use any callable as a prediction formatter.
|
||||||
|
|
||||||
@ -7,9 +7,9 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.data.predictions.base import (
|
from batdetect2.outputs.formats.base import (
|
||||||
make_path_relative,
|
make_path_relative,
|
||||||
prediction_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.targets import terms
|
from batdetect2.targets import terms
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
@ -28,76 +28,32 @@ DictWithClass = TypedDict("DictWithClass", {"class": str})
|
|||||||
|
|
||||||
|
|
||||||
class Annotation(DictWithClass):
|
class Annotation(DictWithClass):
|
||||||
"""Format of annotations.
|
|
||||||
|
|
||||||
This is the format of a single annotation as expected by the
|
|
||||||
annotation tool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
start_time: float
|
start_time: float
|
||||||
"""Start time in seconds."""
|
|
||||||
|
|
||||||
end_time: float
|
end_time: float
|
||||||
"""End time in seconds."""
|
|
||||||
|
|
||||||
low_freq: float
|
low_freq: float
|
||||||
"""Low frequency in Hz."""
|
|
||||||
|
|
||||||
high_freq: float
|
high_freq: float
|
||||||
"""High frequency in Hz."""
|
|
||||||
|
|
||||||
class_prob: float
|
class_prob: float
|
||||||
"""Probability of class assignment."""
|
|
||||||
|
|
||||||
det_prob: float
|
det_prob: float
|
||||||
"""Probability of detection."""
|
|
||||||
|
|
||||||
individual: str
|
individual: str
|
||||||
"""Individual ID."""
|
|
||||||
|
|
||||||
event: str
|
event: str
|
||||||
"""Type of detected event."""
|
|
||||||
|
|
||||||
|
|
||||||
class FileAnnotation(TypedDict):
|
class FileAnnotation(TypedDict):
|
||||||
"""Format of results.
|
|
||||||
|
|
||||||
This is the format of the results expected by the annotation tool.
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
"""File ID."""
|
|
||||||
|
|
||||||
annotated: bool
|
annotated: bool
|
||||||
"""Whether file has been annotated."""
|
|
||||||
|
|
||||||
duration: float
|
duration: float
|
||||||
"""Duration of audio file."""
|
|
||||||
|
|
||||||
issues: bool
|
issues: bool
|
||||||
"""Whether file has issues."""
|
|
||||||
|
|
||||||
time_exp: float
|
time_exp: float
|
||||||
"""Time expansion factor."""
|
|
||||||
|
|
||||||
class_name: str
|
class_name: str
|
||||||
"""Class predicted at file level."""
|
|
||||||
|
|
||||||
notes: str
|
notes: str
|
||||||
"""Notes of file."""
|
|
||||||
|
|
||||||
annotation: List[Annotation]
|
annotation: List[Annotation]
|
||||||
"""List of annotations."""
|
|
||||||
|
|
||||||
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
|
file_path: NotRequired[str] # ty: ignore[invalid-type-form]
|
||||||
"""Path to file."""
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2OutputConfig(BaseConfig):
|
class BatDetect2OutputConfig(BaseConfig):
|
||||||
name: Literal["batdetect2"] = "batdetect2"
|
name: Literal["batdetect2"] = "batdetect2"
|
||||||
|
|
||||||
event_name: str = "Echolocation"
|
event_name: str = "Echolocation"
|
||||||
|
|
||||||
annotation_note: str = "Automatically generated."
|
annotation_note: str = "Automatically generated."
|
||||||
|
|
||||||
|
|
||||||
@ -156,8 +112,6 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
]
|
]
|
||||||
|
|
||||||
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
||||||
"""Get class of recording from annotations."""
|
|
||||||
|
|
||||||
if not annotations:
|
if not annotations:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
@ -215,7 +169,7 @@ class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
|||||||
**{"class": top_class},
|
**{"class": top_class},
|
||||||
)
|
)
|
||||||
|
|
||||||
@prediction_formatters.register(BatDetect2OutputConfig)
|
@output_formatters.register(BatDetect2OutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
|
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
|
||||||
return BatDetect2Formatter(
|
return BatDetect2Formatter(
|
||||||
@ -9,9 +9,9 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.data.predictions.base import (
|
from batdetect2.outputs.formats.base import (
|
||||||
make_path_relative,
|
make_path_relative,
|
||||||
prediction_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipDetections,
|
ClipDetections,
|
||||||
@ -59,10 +59,7 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
if not path.parent.exists():
|
if not path.parent.exists():
|
||||||
path.parent.mkdir(parents=True)
|
path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
# Ensure the file has .parquet extension if it's a file path
|
|
||||||
if path.suffix != ".parquet":
|
if path.suffix != ".parquet":
|
||||||
# If it's a directory, we might want to save as a partitioned dataset or a single file inside
|
|
||||||
# For now, let's assume the user provides a full file path or a directory where we save 'predictions.parquet'
|
|
||||||
if path.is_dir() or not path.suffix:
|
if path.is_dir() or not path.suffix:
|
||||||
path = path / "predictions.parquet"
|
path = path / "predictions.parquet"
|
||||||
|
|
||||||
@ -90,7 +87,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if self.include_geometry:
|
if self.include_geometry:
|
||||||
# Store geometry as [start_time, low_freq, end_time, high_freq]
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
pred.geometry
|
pred.geometry
|
||||||
)
|
)
|
||||||
@ -98,8 +94,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
row["low_freq"] = low_freq
|
row["low_freq"] = low_freq
|
||||||
row["end_time"] = end_time
|
row["end_time"] = end_time
|
||||||
row["high_freq"] = high_freq
|
row["high_freq"] = high_freq
|
||||||
|
|
||||||
# Store full geometry as JSON
|
|
||||||
row["geometry"] = pred.geometry.model_dump_json()
|
row["geometry"] = pred.geometry.model_dump_json()
|
||||||
|
|
||||||
if self.include_class_scores:
|
if self.include_class_scores:
|
||||||
@ -121,11 +115,9 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
def load(self, path: data.PathLike) -> List[ClipDetections]:
|
||||||
path = Path(path)
|
path = Path(path)
|
||||||
if path.is_dir():
|
if path.is_dir():
|
||||||
# Try to find parquet files
|
|
||||||
files = list(path.glob("*.parquet"))
|
files = list(path.glob("*.parquet"))
|
||||||
if not files:
|
if not files:
|
||||||
return []
|
return []
|
||||||
# Read all and concatenate
|
|
||||||
dfs = [pd.read_parquet(f) for f in files]
|
dfs = [pd.read_parquet(f) for f in files]
|
||||||
df = pd.concat(dfs, ignore_index=True)
|
df = pd.concat(dfs, ignore_index=True)
|
||||||
else:
|
else:
|
||||||
@ -148,7 +140,6 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
)
|
)
|
||||||
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
|
predictions_by_clip[clip_uuid] = {"clip": clip, "preds": []}
|
||||||
|
|
||||||
# Reconstruct geometry
|
|
||||||
if "geometry" in row and row["geometry"]:
|
if "geometry" in row and row["geometry"]:
|
||||||
geometry = data.geometry_validate(row["geometry"])
|
geometry = data.geometry_validate(row["geometry"])
|
||||||
else:
|
else:
|
||||||
@ -182,13 +173,14 @@ class ParquetFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
for clip_data in predictions_by_clip.values():
|
for clip_data in predictions_by_clip.values():
|
||||||
results.append(
|
results.append(
|
||||||
ClipDetections(
|
ClipDetections(
|
||||||
clip=clip_data["clip"], detections=clip_data["preds"]
|
clip=clip_data["clip"],
|
||||||
|
detections=clip_data["preds"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
@prediction_formatters.register(ParquetOutputConfig)
|
@output_formatters.register(ParquetOutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: ParquetOutputConfig, targets: TargetProtocol):
|
def from_config(config: ParquetOutputConfig, targets: TargetProtocol):
|
||||||
return ParquetFormatter(
|
return ParquetFormatter(
|
||||||
@ -10,9 +10,9 @@ from soundevent import data
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.data.predictions.base import (
|
from batdetect2.outputs.formats.base import (
|
||||||
make_path_relative,
|
make_path_relative,
|
||||||
prediction_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipDetections,
|
ClipDetections,
|
||||||
@ -95,56 +95,56 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
update=dict(path=make_path_relative(recording.path, audio_dir))
|
update=dict(path=make_path_relative(recording.path, audio_dir))
|
||||||
)
|
)
|
||||||
|
|
||||||
data = defaultdict(list)
|
values = defaultdict(list)
|
||||||
|
|
||||||
for pred in prediction.detections:
|
for pred in prediction.detections:
|
||||||
detection_id = str(uuid4())
|
detection_id = str(uuid4())
|
||||||
|
|
||||||
data["detection_id"].append(detection_id)
|
values["detection_id"].append(detection_id)
|
||||||
data["detection_score"].append(pred.detection_score)
|
values["detection_score"].append(pred.detection_score)
|
||||||
|
|
||||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
pred.geometry
|
pred.geometry
|
||||||
)
|
)
|
||||||
|
|
||||||
data["start_time"].append(start_time)
|
values["start_time"].append(start_time)
|
||||||
data["end_time"].append(end_time)
|
values["end_time"].append(end_time)
|
||||||
data["low_freq"].append(low_freq)
|
values["low_freq"].append(low_freq)
|
||||||
data["high_freq"].append(high_freq)
|
values["high_freq"].append(high_freq)
|
||||||
|
|
||||||
data["geometry"].append(pred.geometry.model_dump_json())
|
values["geometry"].append(pred.geometry.model_dump_json())
|
||||||
|
|
||||||
top_class_index = int(np.argmax(pred.class_scores))
|
top_class_index = int(np.argmax(pred.class_scores))
|
||||||
top_class_score = float(pred.class_scores[top_class_index])
|
top_class_score = float(pred.class_scores[top_class_index])
|
||||||
top_class = self.targets.class_names[top_class_index]
|
top_class = self.targets.class_names[top_class_index]
|
||||||
|
|
||||||
data["top_class"].append(top_class)
|
values["top_class"].append(top_class)
|
||||||
data["top_class_score"].append(top_class_score)
|
values["top_class_score"].append(top_class_score)
|
||||||
|
|
||||||
data["class_scores"].append(pred.class_scores)
|
values["class_scores"].append(pred.class_scores)
|
||||||
data["features"].append(pred.features)
|
values["features"].append(pred.features)
|
||||||
|
|
||||||
num_features = len(pred.features)
|
num_features = len(pred.features)
|
||||||
|
|
||||||
data_vars = {
|
data_vars = {
|
||||||
"score": (["detection"], data["detection_score"]),
|
"score": (["detection"], values["detection_score"]),
|
||||||
"start_time": (["detection"], data["start_time"]),
|
"start_time": (["detection"], values["start_time"]),
|
||||||
"end_time": (["detection"], data["end_time"]),
|
"end_time": (["detection"], values["end_time"]),
|
||||||
"low_freq": (["detection"], data["low_freq"]),
|
"low_freq": (["detection"], values["low_freq"]),
|
||||||
"high_freq": (["detection"], data["high_freq"]),
|
"high_freq": (["detection"], values["high_freq"]),
|
||||||
"top_class": (["detection"], data["top_class"]),
|
"top_class": (["detection"], values["top_class"]),
|
||||||
"top_class_score": (["detection"], data["top_class_score"]),
|
"top_class_score": (["detection"], values["top_class_score"]),
|
||||||
}
|
}
|
||||||
|
|
||||||
coords = {
|
coords = {
|
||||||
"detection": ("detection", data["detection_id"]),
|
"detection": ("detection", values["detection_id"]),
|
||||||
"clip_start": clip.start_time,
|
"clip_start": clip.start_time,
|
||||||
"clip_end": clip.end_time,
|
"clip_end": clip.end_time,
|
||||||
"clip_id": str(clip.uuid),
|
"clip_id": str(clip.uuid),
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.include_class_scores:
|
if self.include_class_scores:
|
||||||
class_scores = np.stack(data["class_scores"], axis=0)
|
class_scores = np.stack(values["class_scores"], axis=0)
|
||||||
data_vars["class_scores"] = (
|
data_vars["class_scores"] = (
|
||||||
["detection", "classes"],
|
["detection", "classes"],
|
||||||
class_scores,
|
class_scores,
|
||||||
@ -152,12 +152,12 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
coords["classes"] = ("classes", self.targets.class_names)
|
coords["classes"] = ("classes", self.targets.class_names)
|
||||||
|
|
||||||
if self.include_features:
|
if self.include_features:
|
||||||
features = np.stack(data["features"], axis=0)
|
features = np.stack(values["features"], axis=0)
|
||||||
data_vars["features"] = (["detection", "feature"], features)
|
data_vars["features"] = (["detection", "feature"], features)
|
||||||
coords["feature"] = ("feature", np.arange(num_features))
|
coords["feature"] = ("feature", np.arange(num_features))
|
||||||
|
|
||||||
if self.include_geometry:
|
if self.include_geometry:
|
||||||
data_vars["geometry"] = (["detection"], data["geometry"])
|
data_vars["geometry"] = (["detection"], values["geometry"])
|
||||||
|
|
||||||
return xr.Dataset(
|
return xr.Dataset(
|
||||||
data_vars=data_vars,
|
data_vars=data_vars,
|
||||||
@ -169,7 +169,6 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
|
|
||||||
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
|
def pred_from_xr(self, dataset: xr.Dataset) -> ClipDetections:
|
||||||
clip_data = dataset
|
clip_data = dataset
|
||||||
clip_id = clip_data.clip_id.item()
|
|
||||||
|
|
||||||
recording = data.Recording.model_validate_json(
|
recording = data.Recording.model_validate_json(
|
||||||
clip_data.attrs["recording"]
|
clip_data.attrs["recording"]
|
||||||
@ -232,7 +231,7 @@ class RawFormatter(OutputFormatterProtocol[ClipDetections]):
|
|||||||
detections=sound_events,
|
detections=sound_events,
|
||||||
)
|
)
|
||||||
|
|
||||||
@prediction_formatters.register(RawOutputConfig)
|
@output_formatters.register(RawOutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
||||||
return RawFormatter(
|
return RawFormatter(
|
||||||
@ -5,8 +5,8 @@ import numpy as np
|
|||||||
from soundevent import data, io
|
from soundevent import data, io
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.data.predictions.base import (
|
from batdetect2.outputs.formats.base import (
|
||||||
prediction_formatters,
|
output_formatters,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipDetections,
|
ClipDetections,
|
||||||
@ -121,7 +121,7 @@ class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
|||||||
|
|
||||||
return tags
|
return tags
|
||||||
|
|
||||||
@prediction_formatters.register(SoundEventOutputConfig)
|
@output_formatters.register(SoundEventOutputConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
|
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
|
||||||
return SoundEventOutputFormatter(
|
return SoundEventOutputFormatter(
|
||||||
89
src/batdetect2/outputs/transforms.py
Normal file
89
src/batdetect2/outputs/transforms.py
Normal file
@ -0,0 +1,89 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import replace
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
from soundevent.geometry import shift_geometry
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.typing import ClipDetections, Detection
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OutputTransform",
|
||||||
|
"OutputTransformConfig",
|
||||||
|
"OutputTransformProtocol",
|
||||||
|
"build_output_transform",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTransformConfig(BaseConfig):
|
||||||
|
shift_time_to_clip_start: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTransformProtocol(Protocol):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[ClipDetections],
|
||||||
|
) -> list[ClipDetections]: ...
|
||||||
|
|
||||||
|
def transform_detections(
|
||||||
|
self,
|
||||||
|
detections: Sequence[Detection],
|
||||||
|
start_time: float = 0,
|
||||||
|
) -> list[Detection]: ...
|
||||||
|
|
||||||
|
|
||||||
|
def shift_detection_time(detection: Detection, time: float) -> Detection:
|
||||||
|
geometry = shift_geometry(detection.geometry, time=time)
|
||||||
|
return replace(detection, geometry=geometry)
|
||||||
|
|
||||||
|
|
||||||
|
class OutputTransform(OutputTransformProtocol):
|
||||||
|
def __init__(self, shift_time_to_clip_start: bool = True):
|
||||||
|
self.shift_time_to_clip_start = shift_time_to_clip_start
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[ClipDetections],
|
||||||
|
) -> list[ClipDetections]:
|
||||||
|
return [
|
||||||
|
self.transform_prediction(prediction) for prediction in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
def transform_prediction(
|
||||||
|
self, prediction: ClipDetections
|
||||||
|
) -> ClipDetections:
|
||||||
|
if not self.shift_time_to_clip_start:
|
||||||
|
return prediction
|
||||||
|
|
||||||
|
detections = self.transform_detections(
|
||||||
|
prediction.detections,
|
||||||
|
start_time=prediction.clip.start_time,
|
||||||
|
)
|
||||||
|
return ClipDetections(clip=prediction.clip, detections=detections)
|
||||||
|
|
||||||
|
def transform_detections(
|
||||||
|
self,
|
||||||
|
detections: Sequence[Detection],
|
||||||
|
start_time: float = 0,
|
||||||
|
) -> list[Detection]:
|
||||||
|
if not self.shift_time_to_clip_start or start_time == 0:
|
||||||
|
return list(detections)
|
||||||
|
|
||||||
|
return [
|
||||||
|
shift_detection_time(detection, time=start_time)
|
||||||
|
for detection in detections
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_output_transform(
|
||||||
|
config: OutputTransformConfig | dict | None = None,
|
||||||
|
) -> OutputTransformProtocol:
|
||||||
|
if config is None:
|
||||||
|
config = OutputTransformConfig()
|
||||||
|
|
||||||
|
if not isinstance(config, OutputTransformConfig):
|
||||||
|
config = OutputTransformConfig.model_validate(config)
|
||||||
|
|
||||||
|
return OutputTransform(
|
||||||
|
shift_time_to_clip_start=config.shift_time_to_clip_start,
|
||||||
|
)
|
||||||
@ -64,7 +64,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
start_times: list[float] | None = None,
|
|
||||||
) -> list[ClipDetectionsTensor]:
|
) -> list[ClipDetectionsTensor]:
|
||||||
detection_heatmap = non_max_suppression(
|
detection_heatmap = non_max_suppression(
|
||||||
output.detection_probs.detach(),
|
output.detection_probs.detach(),
|
||||||
@ -83,9 +82,6 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
threshold=self.detection_threshold,
|
threshold=self.detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
if start_times is None:
|
|
||||||
start_times = [0 for _ in range(len(detections))]
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
map_detection_to_clip(
|
map_detection_to_clip(
|
||||||
detection,
|
detection,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from soundevent import data
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
@ -18,10 +19,15 @@ from batdetect2.typing import (
|
|||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self, evaluator: EvaluatorProtocol):
|
def __init__(
|
||||||
|
self,
|
||||||
|
evaluator: EvaluatorProtocol,
|
||||||
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
self.output_transform = output_transform or build_output_transform()
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self._predictions: List[ClipDetections] = []
|
self._predictions: List[ClipDetections] = []
|
||||||
@ -95,10 +101,7 @@ class ValidationMetrics(Callback):
|
|||||||
for example_idx in batch.idx
|
for example_idx in batch.idx
|
||||||
]
|
]
|
||||||
|
|
||||||
clip_detections = model.postprocessor(
|
clip_detections = model.postprocessor(outputs)
|
||||||
outputs,
|
|
||||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
|
||||||
)
|
|
||||||
predictions = [
|
predictions = [
|
||||||
ClipDetections(
|
ClipDetections(
|
||||||
clip=clip_annotation.clip,
|
clip=clip_annotation.clip,
|
||||||
@ -110,6 +113,7 @@ class ValidationMetrics(Callback):
|
|||||||
clip_annotations, clip_detections, strict=False
|
clip_annotations, clip_detections, strict=False
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
predictions = self.output_transform(predictions)
|
||||||
|
|
||||||
self._clip_annotations.extend(clip_annotations)
|
self._clip_annotations.extend(clip_annotations)
|
||||||
self._predictions.extend(predictions)
|
self._predictions.extend(predictions)
|
||||||
|
|||||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Protocol, Sequence
|
from typing import List, NamedTuple, Protocol
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -101,5 +101,4 @@ class PostprocessorProtocol(Protocol):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
start_times: Sequence[float] | None = None,
|
|
||||||
) -> List[ClipDetectionsTensor]: ...
|
) -> List[ClipDetectionsTensor]: ...
|
||||||
|
|||||||
@ -5,7 +5,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.predictions import (
|
from batdetect2.outputs.formats import (
|
||||||
ParquetOutputConfig,
|
ParquetOutputConfig,
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data.predictions import RawOutputConfig, build_output_formatter
|
from batdetect2.outputs.formats import RawOutputConfig, build_output_formatter
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipDetections,
|
ClipDetections,
|
||||||
Detection,
|
Detection,
|
||||||
|
|||||||
48
tests/test_outputs/test_transform/test_transform.py
Normal file
48
tests/test_outputs/test_transform/test_transform.py
Normal file
@ -0,0 +1,48 @@
|
|||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.outputs import build_output_transform
|
||||||
|
from batdetect2.typing import ClipDetections, Detection
|
||||||
|
|
||||||
|
|
||||||
|
def test_shift_time_to_clip_start(clip: data.Clip):
|
||||||
|
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||||
|
|
||||||
|
detection = Detection(
|
||||||
|
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
||||||
|
detection_score=0.9,
|
||||||
|
class_scores=np.array([0.9]),
|
||||||
|
features=np.array([1.0, 2.0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
transformed = build_output_transform()(
|
||||||
|
[ClipDetections(clip=clip, detections=[detection])]
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
start_time, _, end_time, _ = compute_bounds(
|
||||||
|
transformed.detections[0].geometry
|
||||||
|
)
|
||||||
|
|
||||||
|
assert np.isclose(start_time, 2.6)
|
||||||
|
assert np.isclose(end_time, 2.7)
|
||||||
|
|
||||||
|
|
||||||
|
def test_transform_identity_when_disabled(clip: data.Clip):
|
||||||
|
clip = clip.model_copy(update={"start_time": 2.5, "end_time": 3.0})
|
||||||
|
|
||||||
|
detection = Detection(
|
||||||
|
geometry=data.BoundingBox(coordinates=[0.1, 10_000, 0.2, 12_000]),
|
||||||
|
detection_score=0.9,
|
||||||
|
class_scores=np.array([0.9]),
|
||||||
|
features=np.array([1.0, 2.0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
transform = build_output_transform(
|
||||||
|
config={"shift_time_to_clip_start": False}
|
||||||
|
)
|
||||||
|
transformed = transform(
|
||||||
|
[ClipDetections(clip=clip, detections=[detection])]
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
assert transformed.detections[0].geometry == detection.geometry
|
||||||
Loading…
Reference in New Issue
Block a user