Add other output formats

This commit is contained in:
mbsantiago 2025-10-27 19:39:20 +00:00
parent 6d0a73dda6
commit 401a3832ce
35 changed files with 1103 additions and 106 deletions

View File

@ -1,5 +1,6 @@
import json
from pathlib import Path
from typing import List, Optional, Sequence
from typing import Dict, List, Optional, Sequence, Tuple
import numpy as np
import torch
@ -9,6 +10,14 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.data import (
OutputFormatConfig,
build_output_formatter,
get_output_formatter,
load_dataset_from_config,
)
from batdetect2.data.datasets import Dataset
from batdetect2.data.predictions.base import OutputFormatterProtocol
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
@ -41,6 +50,7 @@ class BatDetect2API:
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
evaluator: EvaluatorProtocol,
formatter: OutputFormatterProtocol,
model: Model,
):
self.config = config
@ -50,9 +60,17 @@ class BatDetect2API:
self.postprocessor = postprocessor
self.evaluator = evaluator
self.model = model
self.formatter = formatter
self.model.eval()
def load_annotations(
self,
path: data.PathLike,
base_dir: Optional[data.PathLike] = None,
) -> Dataset:
return load_dataset_from_config(path, base_dir=base_dir)
def train(
self,
train_annotations: Sequence[data.ClipAnnotation],
@ -91,7 +109,8 @@ class BatDetect2API:
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
save_predictions: bool = True,
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
return evaluate(
self.model,
test_annotations,
@ -103,8 +122,41 @@ class BatDetect2API:
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
formatter=self.formatter if save_predictions else None,
)
def evaluate_predictions(
self,
annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[BatDetect2Prediction],
output_dir: Optional[data.PathLike] = None,
):
clip_evals = self.evaluator.evaluate(
annotations,
predictions,
)
metrics = self.evaluator.compute_metrics(clip_evals)
if output_dir is not None:
output_dir = Path(output_dir)
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
metrics_path = output_dir / "metrics.json"
metrics_path.write_text(json.dumps(metrics))
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
fig_path = output_dir / figure_name
if not fig_path.parent.is_dir():
fig_path.parent.mkdir(parents=True)
fig.savefig(fig_path)
return metrics
def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path)
@ -194,8 +246,38 @@ class BatDetect2API:
config=self.config,
)
def save_predictions(
self,
predictions: Sequence[BatDetect2Prediction],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
format: Optional[str] = None,
config: Optional[OutputFormatConfig] = None,
):
formatter = self.formatter
if format is not None or config is not None:
format = format or config.name # type: ignore
formatter = get_output_formatter(
name=format,
targets=self.targets,
config=config,
)
outs = formatter.format(predictions)
formatter.save(outs, audio_dir=audio_dir, path=path)
def load_predictions(
self,
path: data.PathLike,
) -> List[BatDetect2Prediction]:
return self.formatter.load(path)
@classmethod
def from_config(cls, config: BatDetect2Config):
def from_config(
cls,
config: BatDetect2Config,
):
targets = build_targets(config=config.targets)
audio_loader = build_audio_loader(config=config.audio)
@ -228,6 +310,8 @@ class BatDetect2API:
),
)
formatter = build_output_formatter(targets, config=config.output)
return cls(
config=config,
targets=targets,
@ -236,6 +320,7 @@ class BatDetect2API:
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
formatter=formatter,
)
@classmethod
@ -266,6 +351,8 @@ class BatDetect2API:
evaluator = build_evaluator(config=config.evaluation, targets=targets)
formatter = build_output_formatter(targets, config=config.output)
return cls(
config=config,
targets=targets,
@ -274,4 +361,5 @@ class BatDetect2API:
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
formatter=formatter,
)

View File

@ -1,9 +1,8 @@
"""BatDetect2 command line interface."""
import sys
import click
from loguru import logger
from batdetect2.logging import enable_logging
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
@ -27,22 +26,9 @@ BatDetect2 - Detection and Classification
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def cli(
verbose: int = 0,
):
def cli(verbose: int = 0):
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.enable("batdetect2")
enable_logging(verbose)
# click.echo(BATDETECT_ASCII_ART)

View File

@ -43,3 +43,55 @@ def summary(
)
print(f"Number of annotated clips: {len(dataset)}")
@data.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.option(
"--field",
type=str,
help="If the dataset info is in a nested field please specify here.",
)
@click.option(
"--output",
type=click.Path(exists=False),
default="annotations.json",
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help="The base directory to which all recording and annotations paths are relative to.",
)
def convert(
dataset_config: Path,
field: Optional[str] = None,
output: Path = Path("annotations.json"),
base_dir: Optional[Path] = None,
):
"""Convert a dataset config file to soundevent format."""
from soundevent import data, io
from batdetect2.data import load_dataset, load_dataset_config
base_dir = base_dir or Path.cwd()
config = load_dataset_config(
dataset_config,
field=field,
)
dataset = load_dataset(
config,
base_dir=base_dir,
)
annotation_set = data.AnnotationSet(
clip_annotations=list(dataset),
name=config.name,
description=config.description,
)
io.save(annotation_set, output)

View File

@ -4,9 +4,13 @@ from pydantic import Field
from soundevent.data import PathLike
from batdetect2.audio import AudioConfig
from batdetect2.core import BaseConfig
from batdetect2.core.configs import load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.data.predictions import OutputFormatConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.evaluate.config import (
EvaluationConfig,
get_default_eval_config,
)
from batdetect2.inference.config import InferenceConfig
from batdetect2.models.config import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
@ -17,6 +21,7 @@ from batdetect2.train.config import TrainingConfig
__all__ = [
"BatDetect2Config",
"load_full_config",
"validate_config",
]
@ -24,7 +29,9 @@ class BatDetect2Config(BaseConfig):
config_version: Literal["v1"] = "v1"
train: TrainingConfig = Field(default_factory=TrainingConfig)
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
evaluation: EvaluationConfig = Field(
default_factory=get_default_eval_config
)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
@ -33,6 +40,14 @@ class BatDetect2Config(BaseConfig):
audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
def validate_config(config: Optional[dict]) -> BatDetect2Config:
if config is None:
return BatDetect2Config()
return BatDetect2Config.model_validate(config)
def load_full_config(

View File

@ -29,7 +29,7 @@ class BaseConfig(BaseModel):
and serialization capabilities.
"""
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="ignore")
def to_yaml_string(
self,

View File

@ -77,6 +77,15 @@ class Registry(Generic[T_Type, P_Type]):
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
return tuple(self._config_types.values())
def get_config_type(self, name: str) -> Type[BaseModel]:
try:
return self._config_types[name]
except KeyError as err:
raise ValueError(
f"No config type with name '{name}' is registered. "
f"Existing config types: {list(self._config_types.keys())}"
) from err
def build(
self,
config: BaseModel,

View File

@ -1,5 +1,6 @@
from batdetect2.data.annotations import (
AnnotatedDataset,
AnnotationFormats,
AOEFAnnotations,
BatDetect2FilesAnnotations,
BatDetect2MergedAnnotations,
@ -11,6 +12,14 @@ from batdetect2.data.datasets import (
load_dataset_config,
load_dataset_from_config,
)
from batdetect2.data.predictions import (
BatDetect2OutputConfig,
OutputFormatConfig,
RawOutputConfig,
SoundEventOutputConfig,
build_output_formatter,
get_output_formatter,
)
from batdetect2.data.summary import (
compute_class_summary,
extract_recordings_df,
@ -20,12 +29,19 @@ from batdetect2.data.summary import (
__all__ = [
"AOEFAnnotations",
"AnnotatedDataset",
"AnnotationFormats",
"BatDetect2FilesAnnotations",
"BatDetect2MergedAnnotations",
"BatDetect2OutputConfig",
"DatasetConfig",
"OutputFormatConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
"build_output_formatter",
"compute_class_summary",
"extract_recordings_df",
"extract_sound_events_df",
"get_output_formatter",
"load_annotated_dataset",
"load_dataset",
"load_dataset_config",

View File

@ -13,7 +13,6 @@ format-specific loading function to retrieve the annotations as a standard
`soundevent.data.AnnotationSet`.
"""
from pathlib import Path
from typing import Annotated, Optional, Union
from pydantic import Field
@ -64,7 +63,7 @@ source configuration represents.
def load_annotated_dataset(
dataset: AnnotatedDataset,
base_dir: Optional[Path] = None,
base_dir: Optional[data.PathLike] = None,
) -> data.AnnotationSet:
"""Load annotations for a single data source based on its configuration.
@ -97,6 +96,7 @@ def load_annotated_dataset(
known format-specific loading functions implemented in the dispatch
logic.
"""
if isinstance(dataset, AOEFAnnotations):
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)

View File

@ -84,7 +84,7 @@ class AOEFAnnotations(AnnotatedDataset):
def load_aoef_annotated_dataset(
dataset: AOEFAnnotations,
base_dir: Optional[Path] = None,
base_dir: Optional[data.PathLike] = None,
) -> data.AnnotationSet:
"""Load annotations from an AnnotationSet or AnnotationProject file.

View File

@ -19,7 +19,7 @@ The core components are:
"""
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Sequence
from loguru import logger
from pydantic import Field
@ -53,7 +53,7 @@ __all__ = [
]
Dataset = List[data.ClipAnnotation]
Dataset = Sequence[data.ClipAnnotation]
"""Type alias for a loaded dataset representation.
Represents an entire dataset *after loading* as a flat Python list containing
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
def load_dataset(
config: DatasetConfig,
base_dir: Optional[Path] = None,
base_dir: Optional[data.PathLike] = None,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig."""
clip_annotations = []
@ -168,7 +168,7 @@ def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
def load_dataset_from_config(
path: data.PathLike,
field: Optional[str] = None,
base_dir: Optional[Path] = None,
base_dir: Optional[data.PathLike] = None,
) -> Dataset:
"""Load dataset annotation metadata from a configuration file.
@ -250,6 +250,6 @@ def save_dataset(
annotation_set = data.AnnotationSet(
name=name,
description=description,
clip_annotations=dataset,
clip_annotations=list(dataset),
)
io.save(annotation_set, path, audio_dir=audio_dir)

View File

@ -0,0 +1,58 @@
from typing import Annotated, Optional, Union
from pydantic import Field
from batdetect2.data.predictions.base import (
OutputFormatterProtocol,
prediction_formatters,
)
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
from batdetect2.data.predictions.raw import RawOutputConfig
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
from batdetect2.typing import TargetProtocol
__all__ = [
"build_output_formatter",
"get_output_formatter",
"BatDetect2OutputConfig",
"RawOutputConfig",
"SoundEventOutputConfig",
]
OutputFormatConfig = Annotated[
Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig],
Field(discriminator="name"),
]
def build_output_formatter(
targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None,
) -> OutputFormatterProtocol:
"""Construct the final output formatter."""
from batdetect2.targets import build_targets
config = config or RawOutputConfig()
targets = targets or build_targets()
return prediction_formatters.build(config, targets)
def get_output_formatter(
name: str,
targets: Optional[TargetProtocol] = None,
config: Optional[OutputFormatConfig] = None,
) -> OutputFormatterProtocol:
"""Get the output formatter by name."""
if config is None:
config_class = prediction_formatters.get_config_type(name)
config = config_class() # type: ignore
if config.name != name: # type: ignore
raise ValueError(
f"Config name {config.name} does not match formatter name {name}" # type: ignore
)
return build_output_formatter(targets, config)

View File

@ -0,0 +1,29 @@
from pathlib import Path
from soundevent.data import PathLike
from batdetect2.core import Registry
from batdetect2.typing import (
OutputFormatterProtocol,
TargetProtocol,
)
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
path = Path(path)
audio_dir = Path(audio_dir)
if path.is_absolute():
if not path.is_relative_to(audio_dir):
raise ValueError(
f"Audio file {path} is not in audio_dir {audio_dir}"
)
return path.relative_to(audio_dir)
return path
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
Registry(name="output_formatter")
)

View File

@ -0,0 +1,227 @@
import json
from pathlib import Path
from typing import List, Literal, Optional, Sequence, TypedDict
import numpy as np
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import (
make_path_relative,
prediction_formatters,
)
from batdetect2.targets import terms
from batdetect2.typing import (
BatDetect2Prediction,
OutputFormatterProtocol,
RawPrediction,
TargetProtocol,
)
try:
from typing import NotRequired # type: ignore
except ImportError:
from typing_extensions import NotRequired
DictWithClass = TypedDict("DictWithClass", {"class": str})
class Annotation(DictWithClass):
"""Format of annotations.
This is the format of a single annotation as expected by the
annotation tool.
"""
start_time: float
"""Start time in seconds."""
end_time: float
"""End time in seconds."""
low_freq: float
"""Low frequency in Hz."""
high_freq: float
"""High frequency in Hz."""
class_prob: float
"""Probability of class assignment."""
det_prob: float
"""Probability of detection."""
individual: str
"""Individual ID."""
event: str
"""Type of detected event."""
class FileAnnotation(TypedDict):
"""Format of results.
This is the format of the results expected by the annotation tool.
"""
id: str
"""File ID."""
annotated: bool
"""Whether file has been annotated."""
duration: float
"""Duration of audio file."""
issues: bool
"""Whether file has issues."""
time_exp: float
"""Time expansion factor."""
class_name: str
"""Class predicted at file level."""
notes: str
"""Notes of file."""
annotation: List[Annotation]
"""List of annotations."""
file_path: NotRequired[str]
"""Path to file."""
class BatDetect2OutputConfig(BaseConfig):
name: Literal["batdetect2"] = "batdetect2"
event_name: str = "Echolocation"
annotation_note: str = "Automatically generated."
class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
def __init__(
self,
targets: TargetProtocol,
event_name: str,
annotation_note: str,
):
self.targets = targets
self.event_name = event_name
self.annotation_note = annotation_note
def format(
self, predictions: Sequence[BatDetect2Prediction]
) -> List[FileAnnotation]:
return [
self.format_prediction(prediction) for prediction in predictions
]
def save(
self,
predictions: Sequence[FileAnnotation],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> None:
path = Path(path)
if not path.is_dir():
path.mkdir(parents=True)
for prediction in predictions:
pred_path = path / (prediction["id"] + ".json")
if audio_dir is not None and "file_path" in prediction:
prediction["file_path"] = str(
make_path_relative(
prediction["file_path"],
audio_dir,
)
)
pred_path.write_text(json.dumps(prediction))
def load(self, path: data.PathLike) -> List[FileAnnotation]:
path = Path(path)
files = list(path.glob("*.json"))
if not files:
return []
return [
json.loads(file.read_text()) for file in files if file.is_file()
]
def get_recording_class(self, annotations: List[Annotation]) -> str:
"""Get class of recording from annotations."""
if not annotations:
return ""
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
return highest_scoring["class"]
def format_prediction(
self, prediction: BatDetect2Prediction
) -> FileAnnotation:
recording = prediction.clip.recording
annotations = [
self.format_sound_event_prediction(pred)
for pred in prediction.predictions
]
return FileAnnotation(
id=recording.path.name,
file_path=str(recording.path),
annotated=False,
duration=recording.duration,
issues=False,
time_exp=recording.time_expansion,
class_name=self.get_recording_class(annotations),
notes=self.annotation_note,
annotation=annotations,
)
def get_class_name(self, class_index: int) -> str:
class_name = self.targets.class_names[class_index]
tags = self.targets.decode_class(class_name)
return data.find_tag_value(
tags,
term=terms.generic_class,
default=class_name,
) # type: ignore
def format_sound_event_prediction(
self, prediction: RawPrediction
) -> Annotation:
start_time, low_freq, end_time, high_freq = compute_bounds(
prediction.geometry
)
top_class_index = int(np.argmax(prediction.class_scores))
top_class_score = float(prediction.class_scores[top_class_index])
top_class = self.get_class_name(top_class_index)
return Annotation(
start_time=start_time,
end_time=end_time,
low_freq=low_freq,
high_freq=high_freq,
class_prob=top_class_score,
det_prob=float(prediction.detection_score),
individual="",
event=self.event_name,
**{"class": top_class},
)
@prediction_formatters.register(BatDetect2OutputConfig)
@staticmethod
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
return BatDetect2Formatter(
targets,
event_name=config.event_name,
annotation_note=config.annotation_note,
)

View File

@ -0,0 +1,246 @@
from collections import defaultdict
from pathlib import Path
from typing import List, Literal, Optional, Sequence
from uuid import UUID, uuid4
import numpy as np
import xarray as xr
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import (
make_path_relative,
prediction_formatters,
)
from batdetect2.typing import (
BatDetect2Prediction,
OutputFormatterProtocol,
RawPrediction,
TargetProtocol,
)
class RawOutputConfig(BaseConfig):
name: Literal["raw"] = "raw"
include_class_scores: bool = True
include_features: bool = True
include_geometry: bool = True
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
def __init__(
self,
targets: TargetProtocol,
include_class_scores: bool = True,
include_features: bool = True,
include_geometry: bool = True,
):
self.targets = targets
self.include_class_scores = include_class_scores
self.include_features = include_features
self.include_geometry = include_geometry
def format(
self,
predictions: Sequence[BatDetect2Prediction],
) -> List[BatDetect2Prediction]:
return list(predictions)
def save(
self,
predictions: Sequence[BatDetect2Prediction],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> None:
num_features = 0
tree = xr.DataTree()
for prediction in predictions:
clip = prediction.clip
recording = clip.recording
if audio_dir is not None:
recording = recording.model_copy(
update=dict(
path=make_path_relative(recording.path, audio_dir)
)
)
clip_data = defaultdict(list)
for pred in prediction.predictions:
detection_id = str(uuid4())
clip_data["detection_id"].append(detection_id)
clip_data["detection_score"].append(pred.detection_score)
start_time, low_freq, end_time, high_freq = compute_bounds(
pred.geometry
)
clip_data["start_time"].append(start_time)
clip_data["end_time"].append(end_time)
clip_data["low_freq"].append(low_freq)
clip_data["high_freq"].append(high_freq)
clip_data["geometry"].append(pred.geometry.model_dump_json())
top_class_index = int(np.argmax(pred.class_scores))
top_class_score = float(pred.class_scores[top_class_index])
top_class = self.targets.class_names[top_class_index]
clip_data["top_class"].append(top_class)
clip_data["top_class_score"].append(top_class_score)
clip_data["class_scores"].append(pred.class_scores)
clip_data["features"].append(pred.features)
num_features = len(pred.features)
data_vars = {
"score": (["detection"], clip_data["detection_score"]),
"start_time": (["detection"], clip_data["start_time"]),
"end_time": (["detection"], clip_data["end_time"]),
"low_freq": (["detection"], clip_data["low_freq"]),
"high_freq": (["detection"], clip_data["high_freq"]),
"top_class": (["detection"], clip_data["top_class"]),
"top_class_score": (
["detection"],
clip_data["top_class_score"],
),
}
coords = {
"detection": ("detection", clip_data["detection_id"]),
"clip_start": clip.start_time,
"clip_end": clip.end_time,
"clip_id": str(clip.uuid),
}
if self.include_class_scores:
data_vars["class_scores"] = (
["detection", "classes"],
clip_data["class_scores"],
)
coords["classes"] = ("classes", self.targets.class_names)
if self.include_features:
data_vars["features"] = (
["detection", "feature"],
clip_data["features"],
)
coords["feature"] = ("feature", np.arange(num_features))
if self.include_geometry:
data_vars["geometry"] = (["detection"], clip_data["geometry"])
dataset = xr.Dataset(
data_vars=data_vars,
coords=coords,
attrs={
"recording": recording.model_dump_json(exclude_none=True),
},
)
tree = tree.assign(
{
str(clip.uuid): xr.DataTree(
dataset=dataset,
name=str(clip.uuid),
)
}
)
path = Path(path)
if not path.suffix == ".nc":
path = Path(path).with_suffix(".nc")
tree.to_netcdf(path)
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
path = Path(path)
root = xr.load_datatree(path)
predictions: List[BatDetect2Prediction] = []
for _, clip_data in root.items():
recording = data.Recording.model_validate_json(
clip_data.attrs["recording"]
)
clip_id = clip_data.clip_id.item()
clip = data.Clip(
recording=recording,
uuid=UUID(clip_id),
start_time=clip_data.clip_start,
end_time=clip_data.clip_end,
)
sound_events = []
for detection in clip_data.detection:
score = clip_data.score.sel(detection=detection).item()
if "geometry" in clip_data:
geometry = data.geometry_validate(
clip_data.geometry.sel(detection=detection).item()
)
else:
start_time = clip_data.start_time.sel(detection=detection)
end_time = clip_data.end_time.sel(detection=detection)
low_freq = clip_data.low_freq.sel(detection=detection)
high_freq = clip_data.high_freq.sel(detection=detection)
geometry = data.BoundingBox(
coordinates=[start_time, low_freq, end_time, high_freq]
)
if "class_scores" in clip_data:
class_scores = clip_data.class_scores.sel(
detection=detection
).data
else:
class_scores = np.zeros(len(self.targets.class_names))
class_index = self.targets.class_names.index(
clip_data.top_class.sel(detection=detection).item()
)
class_scores[class_index] = clip_data.top_class_score.sel(
detection=detection
).item()
if "features" in clip_data:
features = clip_data.features.sel(detection=detection).data
else:
features = np.zeros(0)
sound_events.append(
RawPrediction(
geometry=geometry,
detection_score=score,
class_scores=class_scores,
features=features,
)
)
predictions.append(
BatDetect2Prediction(
clip=clip,
predictions=sound_events,
)
)
return predictions
@prediction_formatters.register(RawOutputConfig)
@staticmethod
def from_config(config: RawOutputConfig, targets: TargetProtocol):
return RawFormatter(
targets,
include_class_scores=config.include_class_scores,
include_features=config.include_features,
include_geometry=config.include_geometry,
)

View File

@ -0,0 +1,131 @@
from pathlib import Path
from typing import List, Literal, Optional, Sequence
import numpy as np
from soundevent import data, io
from batdetect2.core import BaseConfig
from batdetect2.data.predictions.base import (
prediction_formatters,
)
from batdetect2.typing import (
BatDetect2Prediction,
OutputFormatterProtocol,
RawPrediction,
TargetProtocol,
)
class SoundEventOutputConfig(BaseConfig):
name: Literal["soundevent"] = "soundevent"
top_k: Optional[int] = 1
min_score: Optional[float] = None
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
def __init__(
self,
targets: TargetProtocol,
top_k: Optional[int] = 1,
min_score: Optional[float] = 0,
):
self.targets = targets
self.top_k = top_k
self.min_score = min_score
def format(
self,
predictions: Sequence[BatDetect2Prediction],
) -> List[data.ClipPrediction]:
return [
self.format_prediction(prediction) for prediction in predictions
]
def save(
self,
predictions: Sequence[data.ClipPrediction],
path: data.PathLike,
audio_dir: Optional[data.PathLike] = None,
) -> None:
run = data.PredictionSet(clip_predictions=list(predictions))
path = Path(path)
if not path.suffix == ".json":
path = Path(path).with_suffix(".json")
io.save(run, path, audio_dir=audio_dir)
def load(self, path: data.PathLike) -> List[data.ClipPrediction]:
path = Path(path)
run = io.load(path, type="prediction_set")
return run.clip_predictions
def format_prediction(
self,
prediction: BatDetect2Prediction,
) -> data.ClipPrediction:
recording = prediction.clip.recording
return data.ClipPrediction(
clip=prediction.clip,
sound_events=[
self.format_sound_event_prediction(pred, recording)
for pred in prediction.predictions
],
)
def format_sound_event_prediction(
self,
prediction: RawPrediction,
recording: data.Recording,
) -> data.SoundEventPrediction:
return data.SoundEventPrediction(
sound_event=data.SoundEvent(
recording=recording,
geometry=prediction.geometry,
),
score=prediction.detection_score,
tags=self.get_sound_event_tags(prediction),
)
def get_sound_event_tags(
self, prediction: RawPrediction
) -> List[data.PredictedTag]:
sorted_indices = np.argsort(prediction.class_scores)[::-1]
tags = [
data.PredictedTag(
tag=tag,
score=prediction.detection_score,
)
for tag in self.targets.detection_class_tags
]
top_k = self.top_k or len(sorted_indices)
for ind in sorted_indices[:top_k]:
score = float(prediction.class_scores[ind])
if self.min_score is not None and score < self.min_score:
break
class_name = self.targets.class_names[ind]
class_tags = self.targets.decode_class(class_name)
tags.extend(
data.PredictedTag(
tag=tag,
score=score,
)
for tag in class_tags
)
return tags
@prediction_formatters.register(SoundEventOutputConfig)
@staticmethod
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
return SoundEventOutputFormatter(
targets,
top_k=config.top_k,
min_score=config.min_score,
)

View File

@ -159,6 +159,7 @@ def compute_class_summary(
exclude_generic=False,
exclude_non_target=True,
)
recordings = extract_recordings_df(dataset)
num_calls = (

View File

@ -27,6 +27,28 @@ class EvaluationConfig(BaseConfig):
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def get_default_eval_config() -> EvaluationConfig:
return EvaluationConfig.model_validate(
{
"tasks": [
{
"name": "sound_event_detection",
"plots": [
{"name": "pr_curve"},
{"name": "score_distribution"},
],
},
{
"name": "sound_event_classification",
"plots": [
{"name": "pr_curve"},
],
},
]
}
)
def load_evaluation_config(
path: data.PathLike,
field: Optional[str] = None,

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Sequence
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
from lightning import Trainer
from soundevent import data
@ -12,11 +12,13 @@ from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.postprocess import RawPrediction
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
OutputFormatterProtocol,
PreprocessorProtocol,
TargetProtocol,
)
@ -31,11 +33,12 @@ def evaluate(
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
formatter: Optional["OutputFormatterProtocol"] = None,
num_workers: Optional[int] = None,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
@ -66,4 +69,12 @@ def evaluate(
)
module = EvaluationModule(model, evaluator)
trainer = Trainer(logger=logger, enable_checkpointing=False)
return trainer.test(module, loader)
metrics = trainer.test(module, loader)
if formatter is not None and logger.log_dir is not None:
formatter.save(
module.predictions,
path=Path(logger.log_dir) / "predictions",
)
return metrics, module.predictions # type: ignore

View File

@ -6,7 +6,8 @@ from soundevent import data
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.tasks import build_task
from batdetect2.targets import build_targets
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
__all__ = [
"Evaluator",
@ -26,7 +27,7 @@ class Evaluator:
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
predictions: Sequence[BatDetect2Prediction],
) -> List[Any]:
return [
task.evaluate(clip_annotations, predictions) for task in self.tasks

View File

@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import EvaluatorProtocol
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.postprocess import BatDetect2Prediction
class EvaluationModule(LightningModule):
@ -24,7 +24,7 @@ class EvaluationModule(LightningModule):
self.evaluator = evaluator
self.clip_annotations: List[data.ClipAnnotation] = []
self.predictions: List[List[RawPrediction]] = []
self.predictions: List[BatDetect2Prediction] = []
def test_step(self, batch: TestExample, batch_idx: int):
dataset = self.get_dataset()
@ -39,11 +39,16 @@ class EvaluationModule(LightningModule):
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
BatDetect2Prediction(
clip=clip_annotation.clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections
)
for clip_dets in clip_detections
]
self.clip_annotations.extend(clip_annotations)

View File

@ -23,7 +23,7 @@ from batdetect2.evaluate.match import (
build_matcher,
)
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
@ -99,7 +99,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
predictions: Sequence[BatDetect2Prediction],
) -> List[T_Output]:
return [
self.evaluate_clip(clip_annotation, preds)
@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
prediction: BatDetect2Prediction,
) -> T_Output: ...
def include_sound_event_annotation(

View File

@ -1,7 +1,6 @@
from typing import (
List,
Literal,
Sequence,
)
from pydantic import Field
@ -23,7 +22,7 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
class ClassificationTaskConfig(BaseTaskConfig):
@ -49,12 +48,14 @@ class ClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
prediction: BatDetect2Prediction,
) -> ClipEval:
clip = clip_annotation.clip
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
pred
for pred in prediction.predictions
if self.include_prediction(pred, clip)
]
all_gts = [

View File

@ -1,5 +1,5 @@
from collections import defaultdict
from typing import List, Literal, Sequence
from typing import List, Literal
from pydantic import Field
from soundevent import data
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class ClipClassificationTaskConfig(BaseTaskConfig):
@ -37,7 +38,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
prediction: BatDetect2Prediction,
) -> ClipEval:
clip = clip_annotation.clip
@ -54,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
gt_classes.add(class_name)
pred_scores = defaultdict(float)
for pred in predictions:
for pred in prediction.predictions:
if not self.include_prediction(pred, clip):
continue

View File

@ -1,4 +1,4 @@
from typing import List, Literal, Sequence
from typing import List, Literal
from pydantic import Field
from soundevent import data
@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class ClipDetectionTaskConfig(BaseTaskConfig):
@ -36,7 +37,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
prediction: BatDetect2Prediction,
) -> ClipEval:
clip = clip_annotation.clip
@ -46,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
)
pred_score = 0
for pred in predictions:
for pred in prediction.predictions:
if not self.include_prediction(pred, clip):
continue

View File

@ -1,4 +1,4 @@
from typing import List, Literal, Sequence
from typing import List, Literal
from pydantic import Field
from soundevent import data
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class DetectionTaskConfig(BaseTaskConfig):
@ -35,7 +36,7 @@ class DetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
prediction: BatDetect2Prediction,
) -> ClipEval:
clip = clip_annotation.clip
@ -45,7 +46,9 @@ class DetectionTask(BaseTask[ClipEval]):
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
pred
for pred in prediction.predictions
if self.include_prediction(pred, clip)
]
scores = [pred.detection_score for pred in preds]

View File

@ -1,4 +1,4 @@
from typing import List, Literal, Sequence
from typing import List, Literal
from pydantic import Field
from soundevent import data
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
BaseTaskConfig,
tasks_registry,
)
from batdetect2.typing import RawPrediction, TargetProtocol
from batdetect2.typing import TargetProtocol
from batdetect2.typing.postprocess import BatDetect2Prediction
class TopClassDetectionTaskConfig(BaseTaskConfig):
@ -35,7 +36,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
def evaluate_clip(
self,
clip_annotation: data.ClipAnnotation,
predictions: Sequence[RawPrediction],
prediction: BatDetect2Prediction,
) -> ClipEval:
clip = clip_annotation.clip
@ -45,7 +46,9 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
if self.include_sound_event_annotation(sound_event, clip)
]
preds = [
pred for pred in predictions if self.include_prediction(pred, clip)
pred
for pred in prediction.predictions
if self.include_prediction(pred, clip)
]
# Take the highest score for each prediction
scores = [pred.class_scores.max() for pred in preds]

View File

@ -1,4 +1,5 @@
import io
import sys
from collections.abc import Callable
from functools import partial
from pathlib import Path
@ -32,6 +33,20 @@ from batdetect2.core.configs import BaseConfig
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
def enable_logging(level: int):
logger.remove()
if level == 0:
log_level = "WARNING"
elif level == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.enable("batdetect2")
class BaseLoggerConfig(BaseConfig):
log_dir: Path = DEFAULT_LOGS_DIR
experiment_name: Optional[str] = None

View File

@ -30,10 +30,7 @@ from typing import List, Optional
import torch
from batdetect2.models.backbones import (
Backbone,
build_backbone,
)
from batdetect2.models.backbones import Backbone, build_backbone
from batdetect2.models.blocks import (
ConvConfig,
FreqCoordConvDownConfig,
@ -62,16 +59,13 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import (
from batdetect2.typing import (
ClipDetectionsTensor,
DetectionModel,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"BBoxHead",
@ -133,6 +127,10 @@ def build_model(
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
):
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
config = config or BackboneConfig()
targets = targets or build_targets()
preprocessor = preprocessor or build_preprocessor()

View File

@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially
intermediate features.
"""
from typing import List
from typing import Dict, List, Optional, Union
import numpy as np
import torch
@ -30,6 +30,55 @@ __all__ = [
]
def to_xarray(
array: Union[torch.Tensor, np.ndarray],
start_time: float,
end_time: float,
min_freq: float = MIN_FREQ,
max_freq: float = MAX_FREQ,
name: str = "xarray",
extra_dims: Optional[List[str]] = None,
extra_coords: Optional[Dict[str, np.ndarray]] = None,
) -> xr.DataArray:
if isinstance(array, torch.Tensor):
array = array.detach().cpu().numpy()
extra_ndims = array.ndim - 2
if extra_ndims < 0:
raise ValueError(
"Input array must have at least 2 dimensions, "
f"got shape {array.shape}"
)
width = array.shape[-1]
height = array.shape[-2]
times = np.linspace(start_time, end_time, width, endpoint=False)
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
if extra_dims is None:
extra_dims = [f"dim_{i}" for i in range(extra_ndims)]
if extra_coords is None:
extra_coords = {}
return xr.DataArray(
data=array,
dims=[
*extra_dims,
Dimensions.frequency.value,
Dimensions.time.value,
],
coords={
**extra_coords,
Dimensions.frequency.value: freqs,
Dimensions.time.value: times,
},
name=name,
)
def map_detection_to_clip(
detections: ClipDetectionsTensor,
start_time: float,

View File

@ -10,9 +10,9 @@ from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.typing import (
BatDetect2Prediction,
EvaluatorProtocol,
ModelOutput,
RawPrediction,
TrainExample,
)
@ -24,7 +24,7 @@ class ValidationMetrics(Callback):
self.evaluator = evaluator
self._clip_annotations: List[data.ClipAnnotation] = []
self._predictions: List[List[RawPrediction]] = []
self._predictions: List[BatDetect2Prediction] = []
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
dataloaders = trainer.val_dataloaders
@ -100,8 +100,15 @@ class ValidationMetrics(Callback):
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
for clip_dets in clip_detections
BatDetect2Prediction(
clip=clip_annotation.clip,
predictions=to_raw_predictions(
clip_dets.numpy(), targets=model.targets
),
)
for clip_annotation, clip_dets in zip(
clip_annotations, clip_detections
)
]
self._clip_annotations.extend(clip_annotations)

View File

@ -6,11 +6,10 @@ from soundevent.data import PathLike
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.audio import TARGET_SAMPLERATE_HZ
from batdetect2.models import Model, build_model
from batdetect2.plotting.clips import build_preprocessor
from batdetect2.postprocess import build_postprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample
@ -27,20 +26,20 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
config: "BatDetect2Config",
input_samplerate: int = TARGET_SAMPLERATE_HZ,
learning_rate: float = 0.001,
config: Optional[dict] = None,
t_max: int = 100,
model: Optional[Model] = None,
loss: Optional[torch.nn.Module] = None,
):
from batdetect2.config import validate_config
super().__init__()
self.save_hyperparameters(logger=False)
self.input_samplerate = input_samplerate
self.config = config
self.learning_rate = learning_rate
self.config = validate_config(config)
self.input_samplerate = self.config.audio.samplerate
self.learning_rate = self.config.train.optimizer.learning_rate
self.t_max = t_max
if loss is None:
@ -104,14 +103,7 @@ def load_model_from_checkpoint(
def build_training_module(
config: Optional["BatDetect2Config"] = None,
config: Optional[dict] = None,
t_max: int = 200,
) -> TrainingModule:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
return TrainingModule(
config=config,
learning_rate=config.train.optimizer.learning_rate,
t_max=t_max,
)
return TrainingModule(config=config, t_max=t_max)

View File

@ -97,7 +97,7 @@ def train(
)
module = build_training_module(
config,
config.model_dump(mode="json"),
t_max=config.train.optimizer.t_max * len(train_dataloader),
)

View File

@ -1,3 +1,4 @@
from batdetect2.typing.data import OutputFormatterProtocol
from batdetect2.typing.evaluate import (
AffinityFunction,
ClipMatches,
@ -10,6 +11,7 @@ from batdetect2.typing.evaluate import (
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
ClipDetectionsTensor,
GeometryDecoder,
PostprocessorProtocol,
RawPrediction,
@ -43,8 +45,9 @@ __all__ = [
"Augmentation",
"BackboneModel",
"BatDetect2Prediction",
"ClipMatches",
"ClipDetectionsTensor",
"ClipLabeller",
"ClipMatches",
"ClipperProtocol",
"DetectionModel",
"EvaluatorProtocol",
@ -56,6 +59,7 @@ __all__ = [
"MatcherProtocol",
"MetricsProtocol",
"ModelOutput",
"OutputFormatterProtocol",
"PlotterProtocol",
"Position",
"PostprocessorProtocol",

View File

@ -0,0 +1,26 @@
from typing import Generic, List, Optional, Protocol, Sequence, TypeVar
from soundevent.data import PathLike
from batdetect2.typing.postprocess import BatDetect2Prediction
__all__ = [
"OutputFormatterProtocol",
]
T = TypeVar("T")
class OutputFormatterProtocol(Protocol, Generic[T]):
def format(
self, predictions: Sequence[BatDetect2Prediction]
) -> List[T]: ...
def save(
self,
predictions: Sequence[T],
path: PathLike,
audio_dir: Optional[PathLike] = None,
) -> None: ...
def load(self, path: PathLike) -> List[T]: ...

View File

@ -14,7 +14,7 @@ from typing import (
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
@ -115,7 +115,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
predictions: Sequence[BatDetect2Prediction],
) -> EvaluationOutput: ...
def compute_metrics(