mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
3 Commits
6d0a73dda6
...
8366410332
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8366410332 | ||
|
|
aa4ad68958 | ||
|
|
401a3832ce |
@ -7,30 +7,30 @@ authors = [
|
|||||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"cf-xarray>=0.9.0",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
|
"deepmerge>=2.0",
|
||||||
|
"hydra-core>=1.3.2",
|
||||||
"librosa>=0.10.1",
|
"librosa>=0.10.1",
|
||||||
|
"lightning[extra]==2.5.0",
|
||||||
|
"loguru>=0.7.3",
|
||||||
"matplotlib>=3.7.1",
|
"matplotlib>=3.7.1",
|
||||||
|
"netcdf4>=1.6.5",
|
||||||
|
"numba>=0.60",
|
||||||
"numpy>=1.23.5",
|
"numpy>=1.23.5",
|
||||||
|
"omegaconf>=2.3.0",
|
||||||
|
"onnx>=1.16.0",
|
||||||
"pandas>=1.5.3",
|
"pandas>=1.5.3",
|
||||||
|
"pyyaml>=6.0.2",
|
||||||
"scikit-learn>=1.2.2",
|
"scikit-learn>=1.2.2",
|
||||||
"scipy>=1.10.1",
|
"scipy>=1.10.1",
|
||||||
|
"seaborn>=0.13.2",
|
||||||
|
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||||
|
"tensorboard>=2.16.2",
|
||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
|
||||||
"click>=8.1.7",
|
|
||||||
"netcdf4>=1.6.5",
|
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
"cf-xarray>=0.9.0",
|
|
||||||
"onnx>=1.16.0",
|
|
||||||
"lightning[extra]==2.5.0",
|
|
||||||
"tensorboard>=2.16.2",
|
|
||||||
"omegaconf>=2.3.0",
|
|
||||||
"pyyaml>=6.0.2",
|
|
||||||
"hydra-core>=1.3.2",
|
|
||||||
"numba>=0.60",
|
|
||||||
"loguru>=0.7.3",
|
|
||||||
"deepmerge>=2.0",
|
|
||||||
]
|
]
|
||||||
requires-python = ">=3.9,<3.13"
|
requires-python = ">=3.9,<3.13"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
@ -89,9 +89,7 @@ dev = [
|
|||||||
"python-lsp-server>=1.13.0",
|
"python-lsp-server>=1.13.0",
|
||||||
]
|
]
|
||||||
dvclive = ["dvclive>=3.48.2"]
|
dvclive = ["dvclive>=3.48.2"]
|
||||||
mlflow = [
|
mlflow = ["mlflow>=3.1.1"]
|
||||||
"mlflow>=3.1.1",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
@ -115,9 +113,7 @@ exclude = [
|
|||||||
"src/batdetect2/detector/",
|
"src/batdetect2/detector/",
|
||||||
"src/batdetect2/finetune",
|
"src/batdetect2/finetune",
|
||||||
"src/batdetect2/utils",
|
"src/batdetect2/utils",
|
||||||
"src/batdetect2/plotting",
|
|
||||||
"src/batdetect2/plot",
|
"src/batdetect2/plot",
|
||||||
"src/batdetect2/api",
|
|
||||||
"src/batdetect2/evaluate/legacy",
|
"src/batdetect2/evaluate/legacy",
|
||||||
"src/batdetect2/train/legacy",
|
"src/batdetect2/train/legacy",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Sequence
|
from typing import Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -9,6 +10,14 @@ from soundevent.audio.files import get_audio_files
|
|||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.core import merge_configs
|
from batdetect2.core import merge_configs
|
||||||
|
from batdetect2.data import (
|
||||||
|
OutputFormatConfig,
|
||||||
|
build_output_formatter,
|
||||||
|
get_output_formatter,
|
||||||
|
load_dataset_from_config,
|
||||||
|
)
|
||||||
|
from batdetect2.data.datasets import Dataset
|
||||||
|
from batdetect2.data.predictions.base import OutputFormatterProtocol
|
||||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
@ -41,6 +50,7 @@ class BatDetect2API:
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
|
formatter: OutputFormatterProtocol,
|
||||||
model: Model,
|
model: Model,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -50,9 +60,17 @@ class BatDetect2API:
|
|||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
self.model = model
|
self.model = model
|
||||||
|
self.formatter = formatter
|
||||||
|
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
def load_annotations(
|
||||||
|
self,
|
||||||
|
path: data.PathLike,
|
||||||
|
base_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Dataset:
|
||||||
|
return load_dataset_from_config(path, base_dir=base_dir)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
self,
|
self,
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
@ -91,7 +109,8 @@ class BatDetect2API:
|
|||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
):
|
save_predictions: bool = True,
|
||||||
|
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||||
return evaluate(
|
return evaluate(
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -103,8 +122,41 @@ class BatDetect2API:
|
|||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
|
formatter=self.formatter if save_predictions else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def evaluate_predictions(
|
||||||
|
self,
|
||||||
|
annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
output_dir: Optional[data.PathLike] = None,
|
||||||
|
):
|
||||||
|
clip_evals = self.evaluator.evaluate(
|
||||||
|
annotations,
|
||||||
|
predictions,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = self.evaluator.compute_metrics(clip_evals)
|
||||||
|
|
||||||
|
if output_dir is not None:
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
if not output_dir.is_dir():
|
||||||
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
metrics_path = output_dir / "metrics.json"
|
||||||
|
metrics_path.write_text(json.dumps(metrics))
|
||||||
|
|
||||||
|
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
|
||||||
|
fig_path = output_dir / figure_name
|
||||||
|
|
||||||
|
if not fig_path.parent.is_dir():
|
||||||
|
fig_path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
fig.savefig(fig_path)
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||||
return self.audio_loader.load_file(path)
|
return self.audio_loader.load_file(path)
|
||||||
|
|
||||||
@ -194,8 +246,38 @@ class BatDetect2API:
|
|||||||
config=self.config,
|
config=self.config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_predictions(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
path: data.PathLike,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
format: Optional[str] = None,
|
||||||
|
config: Optional[OutputFormatConfig] = None,
|
||||||
|
):
|
||||||
|
formatter = self.formatter
|
||||||
|
|
||||||
|
if format is not None or config is not None:
|
||||||
|
format = format or config.name # type: ignore
|
||||||
|
formatter = get_output_formatter(
|
||||||
|
name=format,
|
||||||
|
targets=self.targets,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
outs = formatter.format(predictions)
|
||||||
|
formatter.save(outs, audio_dir=audio_dir, path=path)
|
||||||
|
|
||||||
|
def load_predictions(
|
||||||
|
self,
|
||||||
|
path: data.PathLike,
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
return self.formatter.load(path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: BatDetect2Config):
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: BatDetect2Config,
|
||||||
|
):
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
@ -228,6 +310,8 @@ class BatDetect2API:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
formatter = build_output_formatter(targets, config=config.output)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -236,6 +320,7 @@ class BatDetect2API:
|
|||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
model=model,
|
model=model,
|
||||||
|
formatter=formatter,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -266,6 +351,8 @@ class BatDetect2API:
|
|||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
|
|
||||||
|
formatter = build_output_formatter(targets, config=config.output)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
config=config,
|
config=config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -274,4 +361,5 @@ class BatDetect2API:
|
|||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
model=model,
|
model=model,
|
||||||
|
formatter=formatter,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
"""BatDetect2 command line interface."""
|
"""BatDetect2 command line interface."""
|
||||||
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
|
||||||
|
from batdetect2.logging import enable_logging
|
||||||
|
|
||||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||||
|
|
||||||
@ -27,22 +26,9 @@ BatDetect2 - Detection and Classification
|
|||||||
count=True,
|
count=True,
|
||||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
)
|
)
|
||||||
def cli(
|
def cli(verbose: int = 0):
|
||||||
verbose: int = 0,
|
|
||||||
):
|
|
||||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||||
click.echo(INFO_STR)
|
click.echo(INFO_STR)
|
||||||
|
|
||||||
logger.remove()
|
enable_logging(verbose)
|
||||||
|
|
||||||
if verbose == 0:
|
|
||||||
log_level = "WARNING"
|
|
||||||
elif verbose == 1:
|
|
||||||
log_level = "INFO"
|
|
||||||
else:
|
|
||||||
log_level = "DEBUG"
|
|
||||||
|
|
||||||
logger.add(sys.stderr, level=log_level)
|
|
||||||
|
|
||||||
logger.enable("batdetect2")
|
|
||||||
# click.echo(BATDETECT_ASCII_ART)
|
# click.echo(BATDETECT_ASCII_ART)
|
||||||
|
|||||||
@ -43,3 +43,55 @@ def summary(
|
|||||||
)
|
)
|
||||||
|
|
||||||
print(f"Number of annotated clips: {len(dataset)}")
|
print(f"Number of annotated clips: {len(dataset)}")
|
||||||
|
|
||||||
|
|
||||||
|
@data.command()
|
||||||
|
@click.argument(
|
||||||
|
"dataset_config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--field",
|
||||||
|
type=str,
|
||||||
|
help="If the dataset info is in a nested field please specify here.",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--output",
|
||||||
|
type=click.Path(exists=False),
|
||||||
|
default="annotations.json",
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help="The base directory to which all recording and annotations paths are relative to.",
|
||||||
|
)
|
||||||
|
def convert(
|
||||||
|
dataset_config: Path,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
output: Path = Path("annotations.json"),
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
):
|
||||||
|
"""Convert a dataset config file to soundevent format."""
|
||||||
|
from soundevent import data, io
|
||||||
|
|
||||||
|
from batdetect2.data import load_dataset, load_dataset_config
|
||||||
|
|
||||||
|
base_dir = base_dir or Path.cwd()
|
||||||
|
|
||||||
|
config = load_dataset_config(
|
||||||
|
dataset_config,
|
||||||
|
field=field,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset(
|
||||||
|
config,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
annotation_set = data.AnnotationSet(
|
||||||
|
clip_annotations=list(dataset),
|
||||||
|
name=config.name,
|
||||||
|
description=config.description,
|
||||||
|
)
|
||||||
|
|
||||||
|
io.save(annotation_set, output)
|
||||||
|
|||||||
@ -4,9 +4,13 @@ from pydantic import Field
|
|||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.core.configs import load_config
|
from batdetect2.data.predictions import OutputFormatConfig
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||||
|
from batdetect2.evaluate.config import (
|
||||||
|
EvaluationConfig,
|
||||||
|
get_default_eval_config,
|
||||||
|
)
|
||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.models.config import BackboneConfig
|
from batdetect2.models.config import BackboneConfig
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
@ -17,6 +21,7 @@ from batdetect2.train.config import TrainingConfig
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"BatDetect2Config",
|
"BatDetect2Config",
|
||||||
"load_full_config",
|
"load_full_config",
|
||||||
|
"validate_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -24,7 +29,9 @@ class BatDetect2Config(BaseConfig):
|
|||||||
config_version: Literal["v1"] = "v1"
|
config_version: Literal["v1"] = "v1"
|
||||||
|
|
||||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
evaluation: EvaluationConfig = Field(
|
||||||
|
default_factory=get_default_eval_config
|
||||||
|
)
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
preprocess: PreprocessingConfig = Field(
|
preprocess: PreprocessingConfig = Field(
|
||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
@ -33,6 +40,14 @@ class BatDetect2Config(BaseConfig):
|
|||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
|
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
||||||
|
if config is None:
|
||||||
|
return BatDetect2Config()
|
||||||
|
|
||||||
|
return BatDetect2Config.model_validate(config)
|
||||||
|
|
||||||
|
|
||||||
def load_full_config(
|
def load_full_config(
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class BaseConfig(BaseModel):
|
|||||||
and serialization capabilities.
|
and serialization capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
def to_yaml_string(
|
def to_yaml_string(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -77,6 +77,15 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
||||||
return tuple(self._config_types.values())
|
return tuple(self._config_types.values())
|
||||||
|
|
||||||
|
def get_config_type(self, name: str) -> Type[BaseModel]:
|
||||||
|
try:
|
||||||
|
return self._config_types[name]
|
||||||
|
except KeyError as err:
|
||||||
|
raise ValueError(
|
||||||
|
f"No config type with name '{name}' is registered. "
|
||||||
|
f"Existing config types: {list(self._config_types.keys())}"
|
||||||
|
) from err
|
||||||
|
|
||||||
def build(
|
def build(
|
||||||
self,
|
self,
|
||||||
config: BaseModel,
|
config: BaseModel,
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from batdetect2.data.annotations import (
|
from batdetect2.data.annotations import (
|
||||||
AnnotatedDataset,
|
AnnotatedDataset,
|
||||||
|
AnnotationFormats,
|
||||||
AOEFAnnotations,
|
AOEFAnnotations,
|
||||||
BatDetect2FilesAnnotations,
|
BatDetect2FilesAnnotations,
|
||||||
BatDetect2MergedAnnotations,
|
BatDetect2MergedAnnotations,
|
||||||
@ -11,6 +12,14 @@ from batdetect2.data.datasets import (
|
|||||||
load_dataset_config,
|
load_dataset_config,
|
||||||
load_dataset_from_config,
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
|
from batdetect2.data.predictions import (
|
||||||
|
BatDetect2OutputConfig,
|
||||||
|
OutputFormatConfig,
|
||||||
|
RawOutputConfig,
|
||||||
|
SoundEventOutputConfig,
|
||||||
|
build_output_formatter,
|
||||||
|
get_output_formatter,
|
||||||
|
)
|
||||||
from batdetect2.data.summary import (
|
from batdetect2.data.summary import (
|
||||||
compute_class_summary,
|
compute_class_summary,
|
||||||
extract_recordings_df,
|
extract_recordings_df,
|
||||||
@ -20,12 +29,19 @@ from batdetect2.data.summary import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
"AOEFAnnotations",
|
"AOEFAnnotations",
|
||||||
"AnnotatedDataset",
|
"AnnotatedDataset",
|
||||||
|
"AnnotationFormats",
|
||||||
"BatDetect2FilesAnnotations",
|
"BatDetect2FilesAnnotations",
|
||||||
"BatDetect2MergedAnnotations",
|
"BatDetect2MergedAnnotations",
|
||||||
|
"BatDetect2OutputConfig",
|
||||||
"DatasetConfig",
|
"DatasetConfig",
|
||||||
|
"OutputFormatConfig",
|
||||||
|
"RawOutputConfig",
|
||||||
|
"SoundEventOutputConfig",
|
||||||
|
"build_output_formatter",
|
||||||
"compute_class_summary",
|
"compute_class_summary",
|
||||||
"extract_recordings_df",
|
"extract_recordings_df",
|
||||||
"extract_sound_events_df",
|
"extract_sound_events_df",
|
||||||
|
"get_output_formatter",
|
||||||
"load_annotated_dataset",
|
"load_annotated_dataset",
|
||||||
"load_dataset",
|
"load_dataset",
|
||||||
"load_dataset_config",
|
"load_dataset_config",
|
||||||
|
|||||||
@ -13,7 +13,6 @@ format-specific loading function to retrieve the annotations as a standard
|
|||||||
`soundevent.data.AnnotationSet`.
|
`soundevent.data.AnnotationSet`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, Optional, Union
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -64,7 +63,7 @@ source configuration represents.
|
|||||||
|
|
||||||
def load_annotated_dataset(
|
def load_annotated_dataset(
|
||||||
dataset: AnnotatedDataset,
|
dataset: AnnotatedDataset,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load annotations for a single data source based on its configuration.
|
"""Load annotations for a single data source based on its configuration.
|
||||||
|
|
||||||
@ -97,6 +96,7 @@ def load_annotated_dataset(
|
|||||||
known format-specific loading functions implemented in the dispatch
|
known format-specific loading functions implemented in the dispatch
|
||||||
logic.
|
logic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(dataset, AOEFAnnotations):
|
if isinstance(dataset, AOEFAnnotations):
|
||||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||||
|
|
||||||
|
|||||||
@ -84,7 +84,7 @@ class AOEFAnnotations(AnnotatedDataset):
|
|||||||
|
|
||||||
def load_aoef_annotated_dataset(
|
def load_aoef_annotated_dataset(
|
||||||
dataset: AOEFAnnotations,
|
dataset: AOEFAnnotations,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> data.AnnotationSet:
|
) -> data.AnnotationSet:
|
||||||
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
||||||
|
|
||||||
|
|||||||
@ -19,7 +19,7 @@ The core components are:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -53,7 +53,7 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
Dataset = List[data.ClipAnnotation]
|
Dataset = Sequence[data.ClipAnnotation]
|
||||||
"""Type alias for a loaded dataset representation.
|
"""Type alias for a loaded dataset representation.
|
||||||
|
|
||||||
Represents an entire dataset *after loading* as a flat Python list containing
|
Represents an entire dataset *after loading* as a flat Python list containing
|
||||||
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
|
|||||||
|
|
||||||
def load_dataset(
|
def load_dataset(
|
||||||
config: DatasetConfig,
|
config: DatasetConfig,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
@ -168,7 +168,7 @@ def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
|||||||
def load_dataset_from_config(
|
def load_dataset_from_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[data.PathLike] = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load dataset annotation metadata from a configuration file.
|
"""Load dataset annotation metadata from a configuration file.
|
||||||
|
|
||||||
@ -250,6 +250,6 @@ def save_dataset(
|
|||||||
annotation_set = data.AnnotationSet(
|
annotation_set = data.AnnotationSet(
|
||||||
name=name,
|
name=name,
|
||||||
description=description,
|
description=description,
|
||||||
clip_annotations=dataset,
|
clip_annotations=list(dataset),
|
||||||
)
|
)
|
||||||
io.save(annotation_set, path, audio_dir=audio_dir)
|
io.save(annotation_set, path, audio_dir=audio_dir)
|
||||||
|
|||||||
58
src/batdetect2/data/predictions/__init__.py
Normal file
58
src/batdetect2/data/predictions/__init__.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Annotated, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.data.predictions.base import (
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
prediction_formatters,
|
||||||
|
)
|
||||||
|
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
||||||
|
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||||
|
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
||||||
|
from batdetect2.typing import TargetProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_output_formatter",
|
||||||
|
"get_output_formatter",
|
||||||
|
"BatDetect2OutputConfig",
|
||||||
|
"RawOutputConfig",
|
||||||
|
"SoundEventOutputConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
OutputFormatConfig = Annotated[
|
||||||
|
Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig],
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_output_formatter(
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
config: Optional[OutputFormatConfig] = None,
|
||||||
|
) -> OutputFormatterProtocol:
|
||||||
|
"""Construct the final output formatter."""
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
|
||||||
|
config = config or RawOutputConfig()
|
||||||
|
|
||||||
|
targets = targets or build_targets()
|
||||||
|
return prediction_formatters.build(config, targets)
|
||||||
|
|
||||||
|
|
||||||
|
def get_output_formatter(
|
||||||
|
name: str,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
config: Optional[OutputFormatConfig] = None,
|
||||||
|
) -> OutputFormatterProtocol:
|
||||||
|
"""Get the output formatter by name."""
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
config_class = prediction_formatters.get_config_type(name)
|
||||||
|
config = config_class() # type: ignore
|
||||||
|
|
||||||
|
if config.name != name: # type: ignore
|
||||||
|
raise ValueError(
|
||||||
|
f"Config name {config.name} does not match formatter name {name}" # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return build_output_formatter(targets, config)
|
||||||
29
src/batdetect2/data/predictions/base.py
Normal file
29
src/batdetect2/data/predictions/base.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.core import Registry
|
||||||
|
from batdetect2.typing import (
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||||
|
path = Path(path)
|
||||||
|
audio_dir = Path(audio_dir)
|
||||||
|
|
||||||
|
if path.is_absolute():
|
||||||
|
if not path.is_relative_to(audio_dir):
|
||||||
|
raise ValueError(
|
||||||
|
f"Audio file {path} is not in audio_dir {audio_dir}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return path.relative_to(audio_dir)
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||||
|
Registry(name="output_formatter")
|
||||||
|
)
|
||||||
227
src/batdetect2/data/predictions/batdetect2.py
Normal file
227
src/batdetect2/data/predictions/batdetect2.py
Normal file
@ -0,0 +1,227 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Literal, Optional, Sequence, TypedDict
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.data.predictions.base import (
|
||||||
|
make_path_relative,
|
||||||
|
prediction_formatters,
|
||||||
|
)
|
||||||
|
from batdetect2.targets import terms
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from typing import NotRequired # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
from typing_extensions import NotRequired
|
||||||
|
|
||||||
|
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||||
|
|
||||||
|
|
||||||
|
class Annotation(DictWithClass):
|
||||||
|
"""Format of annotations.
|
||||||
|
|
||||||
|
This is the format of a single annotation as expected by the
|
||||||
|
annotation tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
start_time: float
|
||||||
|
"""Start time in seconds."""
|
||||||
|
|
||||||
|
end_time: float
|
||||||
|
"""End time in seconds."""
|
||||||
|
|
||||||
|
low_freq: float
|
||||||
|
"""Low frequency in Hz."""
|
||||||
|
|
||||||
|
high_freq: float
|
||||||
|
"""High frequency in Hz."""
|
||||||
|
|
||||||
|
class_prob: float
|
||||||
|
"""Probability of class assignment."""
|
||||||
|
|
||||||
|
det_prob: float
|
||||||
|
"""Probability of detection."""
|
||||||
|
|
||||||
|
individual: str
|
||||||
|
"""Individual ID."""
|
||||||
|
|
||||||
|
event: str
|
||||||
|
"""Type of detected event."""
|
||||||
|
|
||||||
|
|
||||||
|
class FileAnnotation(TypedDict):
|
||||||
|
"""Format of results.
|
||||||
|
|
||||||
|
This is the format of the results expected by the annotation tool.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
"""File ID."""
|
||||||
|
|
||||||
|
annotated: bool
|
||||||
|
"""Whether file has been annotated."""
|
||||||
|
|
||||||
|
duration: float
|
||||||
|
"""Duration of audio file."""
|
||||||
|
|
||||||
|
issues: bool
|
||||||
|
"""Whether file has issues."""
|
||||||
|
|
||||||
|
time_exp: float
|
||||||
|
"""Time expansion factor."""
|
||||||
|
|
||||||
|
class_name: str
|
||||||
|
"""Class predicted at file level."""
|
||||||
|
|
||||||
|
notes: str
|
||||||
|
"""Notes of file."""
|
||||||
|
|
||||||
|
annotation: List[Annotation]
|
||||||
|
"""List of annotations."""
|
||||||
|
|
||||||
|
file_path: NotRequired[str]
|
||||||
|
"""Path to file."""
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2OutputConfig(BaseConfig):
|
||||||
|
name: Literal["batdetect2"] = "batdetect2"
|
||||||
|
|
||||||
|
event_name: str = "Echolocation"
|
||||||
|
|
||||||
|
annotation_note: str = "Automatically generated."
|
||||||
|
|
||||||
|
|
||||||
|
class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
event_name: str,
|
||||||
|
annotation_note: str,
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.event_name = event_name
|
||||||
|
self.annotation_note = annotation_note
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self, predictions: Sequence[BatDetect2Prediction]
|
||||||
|
) -> List[FileAnnotation]:
|
||||||
|
return [
|
||||||
|
self.format_prediction(prediction) for prediction in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[FileAnnotation],
|
||||||
|
path: data.PathLike,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> None:
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.is_dir():
|
||||||
|
path.mkdir(parents=True)
|
||||||
|
|
||||||
|
for prediction in predictions:
|
||||||
|
pred_path = path / (prediction["id"] + ".json")
|
||||||
|
|
||||||
|
if audio_dir is not None and "file_path" in prediction:
|
||||||
|
prediction["file_path"] = str(
|
||||||
|
make_path_relative(
|
||||||
|
prediction["file_path"],
|
||||||
|
audio_dir,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
pred_path.write_text(json.dumps(prediction))
|
||||||
|
|
||||||
|
def load(self, path: data.PathLike) -> List[FileAnnotation]:
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
files = list(path.glob("*.json"))
|
||||||
|
|
||||||
|
if not files:
|
||||||
|
return []
|
||||||
|
|
||||||
|
return [
|
||||||
|
json.loads(file.read_text()) for file in files if file.is_file()
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
||||||
|
"""Get class of recording from annotations."""
|
||||||
|
|
||||||
|
if not annotations:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
||||||
|
return highest_scoring["class"]
|
||||||
|
|
||||||
|
def format_prediction(
|
||||||
|
self, prediction: BatDetect2Prediction
|
||||||
|
) -> FileAnnotation:
|
||||||
|
recording = prediction.clip.recording
|
||||||
|
|
||||||
|
annotations = [
|
||||||
|
self.format_sound_event_prediction(pred)
|
||||||
|
for pred in prediction.predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
return FileAnnotation(
|
||||||
|
id=recording.path.name,
|
||||||
|
file_path=str(recording.path),
|
||||||
|
annotated=False,
|
||||||
|
duration=recording.duration,
|
||||||
|
issues=False,
|
||||||
|
time_exp=recording.time_expansion,
|
||||||
|
class_name=self.get_recording_class(annotations),
|
||||||
|
notes=self.annotation_note,
|
||||||
|
annotation=annotations,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_class_name(self, class_index: int) -> str:
|
||||||
|
class_name = self.targets.class_names[class_index]
|
||||||
|
tags = self.targets.decode_class(class_name)
|
||||||
|
return data.find_tag_value(
|
||||||
|
tags,
|
||||||
|
term=terms.generic_class,
|
||||||
|
default=class_name,
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
def format_sound_event_prediction(
|
||||||
|
self, prediction: RawPrediction
|
||||||
|
) -> Annotation:
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
|
prediction.geometry
|
||||||
|
)
|
||||||
|
|
||||||
|
top_class_index = int(np.argmax(prediction.class_scores))
|
||||||
|
top_class_score = float(prediction.class_scores[top_class_index])
|
||||||
|
top_class = self.get_class_name(top_class_index)
|
||||||
|
return Annotation(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
low_freq=low_freq,
|
||||||
|
high_freq=high_freq,
|
||||||
|
class_prob=top_class_score,
|
||||||
|
det_prob=float(prediction.detection_score),
|
||||||
|
individual="",
|
||||||
|
event=self.event_name,
|
||||||
|
**{"class": top_class},
|
||||||
|
)
|
||||||
|
|
||||||
|
@prediction_formatters.register(BatDetect2OutputConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
|
||||||
|
return BatDetect2Formatter(
|
||||||
|
targets,
|
||||||
|
event_name=config.event_name,
|
||||||
|
annotation_note=config.annotation_note,
|
||||||
|
)
|
||||||
246
src/batdetect2/data/predictions/raw.py
Normal file
246
src/batdetect2/data/predictions/raw.py
Normal file
@ -0,0 +1,246 @@
|
|||||||
|
from collections import defaultdict
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Literal, Optional, Sequence
|
||||||
|
from uuid import UUID, uuid4
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.data.predictions.base import (
|
||||||
|
make_path_relative,
|
||||||
|
prediction_formatters,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RawOutputConfig(BaseConfig):
|
||||||
|
name: Literal["raw"] = "raw"
|
||||||
|
|
||||||
|
include_class_scores: bool = True
|
||||||
|
include_features: bool = True
|
||||||
|
include_geometry: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
include_class_scores: bool = True,
|
||||||
|
include_features: bool = True,
|
||||||
|
include_geometry: bool = True,
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.include_class_scores = include_class_scores
|
||||||
|
self.include_features = include_features
|
||||||
|
self.include_geometry = include_geometry
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
return list(predictions)
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
path: data.PathLike,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> None:
|
||||||
|
num_features = 0
|
||||||
|
|
||||||
|
tree = xr.DataTree()
|
||||||
|
|
||||||
|
for prediction in predictions:
|
||||||
|
clip = prediction.clip
|
||||||
|
recording = clip.recording
|
||||||
|
|
||||||
|
if audio_dir is not None:
|
||||||
|
recording = recording.model_copy(
|
||||||
|
update=dict(
|
||||||
|
path=make_path_relative(recording.path, audio_dir)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_data = defaultdict(list)
|
||||||
|
|
||||||
|
for pred in prediction.predictions:
|
||||||
|
detection_id = str(uuid4())
|
||||||
|
|
||||||
|
clip_data["detection_id"].append(detection_id)
|
||||||
|
clip_data["detection_score"].append(pred.detection_score)
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||||
|
pred.geometry
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_data["start_time"].append(start_time)
|
||||||
|
clip_data["end_time"].append(end_time)
|
||||||
|
clip_data["low_freq"].append(low_freq)
|
||||||
|
clip_data["high_freq"].append(high_freq)
|
||||||
|
|
||||||
|
clip_data["geometry"].append(pred.geometry.model_dump_json())
|
||||||
|
|
||||||
|
top_class_index = int(np.argmax(pred.class_scores))
|
||||||
|
top_class_score = float(pred.class_scores[top_class_index])
|
||||||
|
top_class = self.targets.class_names[top_class_index]
|
||||||
|
|
||||||
|
clip_data["top_class"].append(top_class)
|
||||||
|
clip_data["top_class_score"].append(top_class_score)
|
||||||
|
|
||||||
|
clip_data["class_scores"].append(pred.class_scores)
|
||||||
|
clip_data["features"].append(pred.features)
|
||||||
|
|
||||||
|
num_features = len(pred.features)
|
||||||
|
|
||||||
|
data_vars = {
|
||||||
|
"score": (["detection"], clip_data["detection_score"]),
|
||||||
|
"start_time": (["detection"], clip_data["start_time"]),
|
||||||
|
"end_time": (["detection"], clip_data["end_time"]),
|
||||||
|
"low_freq": (["detection"], clip_data["low_freq"]),
|
||||||
|
"high_freq": (["detection"], clip_data["high_freq"]),
|
||||||
|
"top_class": (["detection"], clip_data["top_class"]),
|
||||||
|
"top_class_score": (
|
||||||
|
["detection"],
|
||||||
|
clip_data["top_class_score"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
coords = {
|
||||||
|
"detection": ("detection", clip_data["detection_id"]),
|
||||||
|
"clip_start": clip.start_time,
|
||||||
|
"clip_end": clip.end_time,
|
||||||
|
"clip_id": str(clip.uuid),
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.include_class_scores:
|
||||||
|
data_vars["class_scores"] = (
|
||||||
|
["detection", "classes"],
|
||||||
|
clip_data["class_scores"],
|
||||||
|
)
|
||||||
|
coords["classes"] = ("classes", self.targets.class_names)
|
||||||
|
|
||||||
|
if self.include_features:
|
||||||
|
data_vars["features"] = (
|
||||||
|
["detection", "feature"],
|
||||||
|
clip_data["features"],
|
||||||
|
)
|
||||||
|
coords["feature"] = ("feature", np.arange(num_features))
|
||||||
|
|
||||||
|
if self.include_geometry:
|
||||||
|
data_vars["geometry"] = (["detection"], clip_data["geometry"])
|
||||||
|
|
||||||
|
dataset = xr.Dataset(
|
||||||
|
data_vars=data_vars,
|
||||||
|
coords=coords,
|
||||||
|
attrs={
|
||||||
|
"recording": recording.model_dump_json(exclude_none=True),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tree = tree.assign(
|
||||||
|
{
|
||||||
|
str(clip.uuid): xr.DataTree(
|
||||||
|
dataset=dataset,
|
||||||
|
name=str(clip.uuid),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.suffix == ".nc":
|
||||||
|
path = Path(path).with_suffix(".nc")
|
||||||
|
|
||||||
|
tree.to_netcdf(path)
|
||||||
|
|
||||||
|
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
root = xr.load_datatree(path)
|
||||||
|
|
||||||
|
predictions: List[BatDetect2Prediction] = []
|
||||||
|
|
||||||
|
for _, clip_data in root.items():
|
||||||
|
recording = data.Recording.model_validate_json(
|
||||||
|
clip_data.attrs["recording"]
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_id = clip_data.clip_id.item()
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
uuid=UUID(clip_id),
|
||||||
|
start_time=clip_data.clip_start,
|
||||||
|
end_time=clip_data.clip_end,
|
||||||
|
)
|
||||||
|
|
||||||
|
sound_events = []
|
||||||
|
|
||||||
|
for detection in clip_data.detection:
|
||||||
|
score = clip_data.score.sel(detection=detection).item()
|
||||||
|
|
||||||
|
if "geometry" in clip_data:
|
||||||
|
geometry = data.geometry_validate(
|
||||||
|
clip_data.geometry.sel(detection=detection).item()
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
start_time = clip_data.start_time.sel(detection=detection)
|
||||||
|
end_time = clip_data.end_time.sel(detection=detection)
|
||||||
|
low_freq = clip_data.low_freq.sel(detection=detection)
|
||||||
|
high_freq = clip_data.high_freq.sel(detection=detection)
|
||||||
|
geometry = data.BoundingBox(
|
||||||
|
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||||
|
)
|
||||||
|
|
||||||
|
if "class_scores" in clip_data:
|
||||||
|
class_scores = clip_data.class_scores.sel(
|
||||||
|
detection=detection
|
||||||
|
).data
|
||||||
|
else:
|
||||||
|
class_scores = np.zeros(len(self.targets.class_names))
|
||||||
|
class_index = self.targets.class_names.index(
|
||||||
|
clip_data.top_class.sel(detection=detection).item()
|
||||||
|
)
|
||||||
|
class_scores[class_index] = clip_data.top_class_score.sel(
|
||||||
|
detection=detection
|
||||||
|
).item()
|
||||||
|
|
||||||
|
if "features" in clip_data:
|
||||||
|
features = clip_data.features.sel(detection=detection).data
|
||||||
|
else:
|
||||||
|
features = np.zeros(0)
|
||||||
|
|
||||||
|
sound_events.append(
|
||||||
|
RawPrediction(
|
||||||
|
geometry=geometry,
|
||||||
|
detection_score=score,
|
||||||
|
class_scores=class_scores,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions.append(
|
||||||
|
BatDetect2Prediction(
|
||||||
|
clip=clip,
|
||||||
|
predictions=sound_events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
@prediction_formatters.register(RawOutputConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
||||||
|
return RawFormatter(
|
||||||
|
targets,
|
||||||
|
include_class_scores=config.include_class_scores,
|
||||||
|
include_features=config.include_features,
|
||||||
|
include_geometry=config.include_geometry,
|
||||||
|
)
|
||||||
131
src/batdetect2/data/predictions/soundevent.py
Normal file
131
src/batdetect2/data/predictions/soundevent.py
Normal file
@ -0,0 +1,131 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Literal, Optional, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundevent import data, io
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.data.predictions.base import (
|
||||||
|
prediction_formatters,
|
||||||
|
)
|
||||||
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
RawPrediction,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SoundEventOutputConfig(BaseConfig):
|
||||||
|
name: Literal["soundevent"] = "soundevent"
|
||||||
|
top_k: Optional[int] = 1
|
||||||
|
min_score: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
top_k: Optional[int] = 1,
|
||||||
|
min_score: Optional[float] = 0,
|
||||||
|
):
|
||||||
|
self.targets = targets
|
||||||
|
self.top_k = top_k
|
||||||
|
self.min_score = min_score
|
||||||
|
|
||||||
|
def format(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
|
) -> List[data.ClipPrediction]:
|
||||||
|
return [
|
||||||
|
self.format_prediction(prediction) for prediction in predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[data.ClipPrediction],
|
||||||
|
path: data.PathLike,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> None:
|
||||||
|
run = data.PredictionSet(clip_predictions=list(predictions))
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.suffix == ".json":
|
||||||
|
path = Path(path).with_suffix(".json")
|
||||||
|
|
||||||
|
io.save(run, path, audio_dir=audio_dir)
|
||||||
|
|
||||||
|
def load(self, path: data.PathLike) -> List[data.ClipPrediction]:
|
||||||
|
path = Path(path)
|
||||||
|
run = io.load(path, type="prediction_set")
|
||||||
|
return run.clip_predictions
|
||||||
|
|
||||||
|
def format_prediction(
|
||||||
|
self,
|
||||||
|
prediction: BatDetect2Prediction,
|
||||||
|
) -> data.ClipPrediction:
|
||||||
|
recording = prediction.clip.recording
|
||||||
|
return data.ClipPrediction(
|
||||||
|
clip=prediction.clip,
|
||||||
|
sound_events=[
|
||||||
|
self.format_sound_event_prediction(pred, recording)
|
||||||
|
for pred in prediction.predictions
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def format_sound_event_prediction(
|
||||||
|
self,
|
||||||
|
prediction: RawPrediction,
|
||||||
|
recording: data.Recording,
|
||||||
|
) -> data.SoundEventPrediction:
|
||||||
|
return data.SoundEventPrediction(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=prediction.geometry,
|
||||||
|
),
|
||||||
|
score=prediction.detection_score,
|
||||||
|
tags=self.get_sound_event_tags(prediction),
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_sound_event_tags(
|
||||||
|
self, prediction: RawPrediction
|
||||||
|
) -> List[data.PredictedTag]:
|
||||||
|
sorted_indices = np.argsort(prediction.class_scores)[::-1]
|
||||||
|
|
||||||
|
tags = [
|
||||||
|
data.PredictedTag(
|
||||||
|
tag=tag,
|
||||||
|
score=prediction.detection_score,
|
||||||
|
)
|
||||||
|
for tag in self.targets.detection_class_tags
|
||||||
|
]
|
||||||
|
|
||||||
|
top_k = self.top_k or len(sorted_indices)
|
||||||
|
|
||||||
|
for ind in sorted_indices[:top_k]:
|
||||||
|
score = float(prediction.class_scores[ind])
|
||||||
|
|
||||||
|
if self.min_score is not None and score < self.min_score:
|
||||||
|
break
|
||||||
|
|
||||||
|
class_name = self.targets.class_names[ind]
|
||||||
|
class_tags = self.targets.decode_class(class_name)
|
||||||
|
tags.extend(
|
||||||
|
data.PredictedTag(
|
||||||
|
tag=tag,
|
||||||
|
score=score,
|
||||||
|
)
|
||||||
|
for tag in class_tags
|
||||||
|
)
|
||||||
|
|
||||||
|
return tags
|
||||||
|
|
||||||
|
@prediction_formatters.register(SoundEventOutputConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
|
||||||
|
return SoundEventOutputFormatter(
|
||||||
|
targets,
|
||||||
|
top_k=config.top_k,
|
||||||
|
min_score=config.min_score,
|
||||||
|
)
|
||||||
@ -159,6 +159,7 @@ def compute_class_summary(
|
|||||||
exclude_generic=False,
|
exclude_generic=False,
|
||||||
exclude_non_target=True,
|
exclude_non_target=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
recordings = extract_recordings_df(dataset)
|
recordings = extract_recordings_df(dataset)
|
||||||
|
|
||||||
num_calls = (
|
num_calls = (
|
||||||
|
|||||||
@ -27,6 +27,28 @@ class EvaluationConfig(BaseConfig):
|
|||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def get_default_eval_config() -> EvaluationConfig:
|
||||||
|
return EvaluationConfig.model_validate(
|
||||||
|
{
|
||||||
|
"tasks": [
|
||||||
|
{
|
||||||
|
"name": "sound_event_detection",
|
||||||
|
"plots": [
|
||||||
|
{"name": "pr_curve"},
|
||||||
|
{"name": "score_distribution"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "sound_event_classification",
|
||||||
|
"plots": [
|
||||||
|
{"name": "pr_curve"},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional, Sequence
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -12,11 +12,13 @@ from batdetect2.logging import build_logger
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
|
OutputFormatterProtocol,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
@ -31,11 +33,12 @@ def evaluate(
|
|||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
|
formatter: Optional["OutputFormatterProtocol"] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
):
|
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
config = config or BatDetect2Config()
|
||||||
@ -66,4 +69,12 @@ def evaluate(
|
|||||||
)
|
)
|
||||||
module = EvaluationModule(model, evaluator)
|
module = EvaluationModule(model, evaluator)
|
||||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||||
return trainer.test(module, loader)
|
metrics = trainer.test(module, loader)
|
||||||
|
|
||||||
|
if formatter is not None and logger.log_dir is not None:
|
||||||
|
formatter.save(
|
||||||
|
module.predictions,
|
||||||
|
path=Path(logger.log_dir) / "predictions",
|
||||||
|
)
|
||||||
|
|
||||||
|
return metrics, module.predictions # type: ignore
|
||||||
|
|||||||
@ -6,7 +6,8 @@ from soundevent import data
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.evaluate.tasks import build_task
|
from batdetect2.evaluate.tasks import build_task
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
@ -26,7 +27,7 @@ class Evaluator:
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
) -> List[Any]:
|
) -> List[Any]:
|
||||||
return [
|
return [
|
||||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||||
|
|||||||
@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import EvaluatorProtocol
|
from batdetect2.typing import EvaluatorProtocol
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
@ -24,7 +24,7 @@ class EvaluationModule(LightningModule):
|
|||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
|
||||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self.predictions: List[List[RawPrediction]] = []
|
self.predictions: List[BatDetect2Prediction] = []
|
||||||
|
|
||||||
def test_step(self, batch: TestExample, batch_idx: int):
|
def test_step(self, batch: TestExample, batch_idx: int):
|
||||||
dataset = self.get_dataset()
|
dataset = self.get_dataset()
|
||||||
@ -39,11 +39,16 @@ class EvaluationModule(LightningModule):
|
|||||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
)
|
)
|
||||||
predictions = [
|
predictions = [
|
||||||
to_raw_predictions(
|
BatDetect2Prediction(
|
||||||
|
clip=clip_annotation.clip,
|
||||||
|
predictions=to_raw_predictions(
|
||||||
clip_dets.numpy(),
|
clip_dets.numpy(),
|
||||||
targets=self.evaluator.targets,
|
targets=self.evaluator.targets,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for clip_annotation, clip_dets in zip(
|
||||||
|
clip_annotations, clip_detections
|
||||||
)
|
)
|
||||||
for clip_dets in clip_detections
|
|
||||||
]
|
]
|
||||||
|
|
||||||
self.clip_annotations.extend(clip_annotations)
|
self.clip_annotations.extend(clip_annotations)
|
||||||
|
|||||||
@ -23,7 +23,7 @@ from batdetect2.evaluate.match import (
|
|||||||
build_matcher,
|
build_matcher,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -99,7 +99,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
) -> List[T_Output]:
|
) -> List[T_Output]:
|
||||||
return [
|
return [
|
||||||
self.evaluate_clip(clip_annotation, preds)
|
self.evaluate_clip(clip_annotation, preds)
|
||||||
@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
predictions: Sequence[RawPrediction],
|
prediction: BatDetect2Prediction,
|
||||||
) -> T_Output: ...
|
) -> T_Output: ...
|
||||||
|
|
||||||
def include_sound_event_annotation(
|
def include_sound_event_annotation(
|
||||||
|
|||||||
@ -1,7 +1,6 @@
|
|||||||
from typing import (
|
from typing import (
|
||||||
List,
|
List,
|
||||||
Literal,
|
Literal,
|
||||||
Sequence,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -23,7 +22,7 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class ClassificationTaskConfig(BaseTaskConfig):
|
class ClassificationTaskConfig(BaseTaskConfig):
|
||||||
@ -49,12 +48,14 @@ class ClassificationTask(BaseTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
predictions: Sequence[RawPrediction],
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
preds = [
|
preds = [
|
||||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
pred
|
||||||
|
for pred in prediction.predictions
|
||||||
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
|
|
||||||
all_gts = [
|
all_gts = [
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Literal, Sequence
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||||
@ -37,7 +38,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
predictions: Sequence[RawPrediction],
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
|||||||
gt_classes.add(class_name)
|
gt_classes.add(class_name)
|
||||||
|
|
||||||
pred_scores = defaultdict(float)
|
pred_scores = defaultdict(float)
|
||||||
for pred in predictions:
|
for pred in prediction.predictions:
|
||||||
if not self.include_prediction(pred, clip):
|
if not self.include_prediction(pred, clip):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Literal, Sequence
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||||
@ -36,7 +37,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
predictions: Sequence[RawPrediction],
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -46,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pred_score = 0
|
pred_score = 0
|
||||||
for pred in predictions:
|
for pred in prediction.predictions:
|
||||||
if not self.include_prediction(pred, clip):
|
if not self.include_prediction(pred, clip):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Literal, Sequence
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class DetectionTaskConfig(BaseTaskConfig):
|
class DetectionTaskConfig(BaseTaskConfig):
|
||||||
@ -35,7 +36,7 @@ class DetectionTask(BaseTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
predictions: Sequence[RawPrediction],
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -45,7 +46,9 @@ class DetectionTask(BaseTask[ClipEval]):
|
|||||||
if self.include_sound_event_annotation(sound_event, clip)
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
]
|
]
|
||||||
preds = [
|
preds = [
|
||||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
pred
|
||||||
|
for pred in prediction.predictions
|
||||||
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
scores = [pred.detection_score for pred in preds]
|
scores = [pred.detection_score for pred in preds]
|
||||||
|
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Literal, Sequence
|
from typing import List, Literal
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
|||||||
BaseTaskConfig,
|
BaseTaskConfig,
|
||||||
tasks_registry,
|
tasks_registry,
|
||||||
)
|
)
|
||||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
from batdetect2.typing import TargetProtocol
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||||
@ -35,7 +36,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
def evaluate_clip(
|
def evaluate_clip(
|
||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
predictions: Sequence[RawPrediction],
|
prediction: BatDetect2Prediction,
|
||||||
) -> ClipEval:
|
) -> ClipEval:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
@ -45,7 +46,9 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
|||||||
if self.include_sound_event_annotation(sound_event, clip)
|
if self.include_sound_event_annotation(sound_event, clip)
|
||||||
]
|
]
|
||||||
preds = [
|
preds = [
|
||||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
pred
|
||||||
|
for pred in prediction.predictions
|
||||||
|
if self.include_prediction(pred, clip)
|
||||||
]
|
]
|
||||||
# Take the highest score for each prediction
|
# Take the highest score for each prediction
|
||||||
scores = [pred.class_scores.max() for pred in preds]
|
scores = [pred.class_scores.max() for pred in preds]
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
import io
|
import io
|
||||||
|
import sys
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -32,6 +33,20 @@ from batdetect2.core.configs import BaseConfig
|
|||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
|
|
||||||
|
|
||||||
|
def enable_logging(level: int):
|
||||||
|
logger.remove()
|
||||||
|
|
||||||
|
if level == 0:
|
||||||
|
log_level = "WARNING"
|
||||||
|
elif level == 1:
|
||||||
|
log_level = "INFO"
|
||||||
|
else:
|
||||||
|
log_level = "DEBUG"
|
||||||
|
|
||||||
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
logger.enable("batdetect2")
|
||||||
|
|
||||||
|
|
||||||
class BaseLoggerConfig(BaseConfig):
|
class BaseLoggerConfig(BaseConfig):
|
||||||
log_dir: Path = DEFAULT_LOGS_DIR
|
log_dir: Path = DEFAULT_LOGS_DIR
|
||||||
experiment_name: Optional[str] = None
|
experiment_name: Optional[str] = None
|
||||||
|
|||||||
@ -30,10 +30,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import Backbone, build_backbone
|
||||||
Backbone,
|
|
||||||
build_backbone,
|
|
||||||
)
|
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
@ -62,16 +59,13 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.typing import (
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.typing.models import DetectionModel
|
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
|
DetectionModel,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BBoxHead",
|
"BBoxHead",
|
||||||
@ -133,6 +127,10 @@ def build_model(
|
|||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.postprocess import build_postprocessor
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
|
||||||
config = config or BackboneConfig()
|
config = config or BackboneConfig()
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
preprocessor = preprocessor or build_preprocessor()
|
preprocessor = preprocessor or build_preprocessor()
|
||||||
|
|||||||
@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially
|
|||||||
intermediate features.
|
intermediate features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -30,6 +30,55 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def to_xarray(
|
||||||
|
array: Union[torch.Tensor, np.ndarray],
|
||||||
|
start_time: float,
|
||||||
|
end_time: float,
|
||||||
|
min_freq: float = MIN_FREQ,
|
||||||
|
max_freq: float = MAX_FREQ,
|
||||||
|
name: str = "xarray",
|
||||||
|
extra_dims: Optional[List[str]] = None,
|
||||||
|
extra_coords: Optional[Dict[str, np.ndarray]] = None,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if isinstance(array, torch.Tensor):
|
||||||
|
array = array.detach().cpu().numpy()
|
||||||
|
|
||||||
|
extra_ndims = array.ndim - 2
|
||||||
|
|
||||||
|
if extra_ndims < 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Input array must have at least 2 dimensions, "
|
||||||
|
f"got shape {array.shape}"
|
||||||
|
)
|
||||||
|
|
||||||
|
width = array.shape[-1]
|
||||||
|
height = array.shape[-2]
|
||||||
|
|
||||||
|
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||||
|
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||||
|
|
||||||
|
if extra_dims is None:
|
||||||
|
extra_dims = [f"dim_{i}" for i in range(extra_ndims)]
|
||||||
|
|
||||||
|
if extra_coords is None:
|
||||||
|
extra_coords = {}
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=array,
|
||||||
|
dims=[
|
||||||
|
*extra_dims,
|
||||||
|
Dimensions.frequency.value,
|
||||||
|
Dimensions.time.value,
|
||||||
|
],
|
||||||
|
coords={
|
||||||
|
**extra_coords,
|
||||||
|
Dimensions.frequency.value: freqs,
|
||||||
|
Dimensions.time.value: times,
|
||||||
|
},
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def map_detection_to_clip(
|
def map_detection_to_clip(
|
||||||
detections: ClipDetectionsTensor,
|
detections: ClipDetectionsTensor,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
|
|||||||
@ -10,9 +10,9 @@ from batdetect2.postprocess import to_raw_predictions
|
|||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
|
BatDetect2Prediction,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
RawPrediction,
|
|
||||||
TrainExample,
|
TrainExample,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ class ValidationMetrics(Callback):
|
|||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
self._predictions: List[List[RawPrediction]] = []
|
self._predictions: List[BatDetect2Prediction] = []
|
||||||
|
|
||||||
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
|
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
@ -100,8 +100,15 @@ class ValidationMetrics(Callback):
|
|||||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
)
|
)
|
||||||
predictions = [
|
predictions = [
|
||||||
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
|
BatDetect2Prediction(
|
||||||
for clip_dets in clip_detections
|
clip=clip_annotation.clip,
|
||||||
|
predictions=to_raw_predictions(
|
||||||
|
clip_dets.numpy(), targets=model.targets
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for clip_annotation, clip_dets in zip(
|
||||||
|
clip_annotations, clip_detections
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
self._clip_annotations.extend(clip_annotations)
|
self._clip_annotations.extend(clip_annotations)
|
||||||
|
|||||||
@ -6,11 +6,10 @@ from soundevent.data import PathLike
|
|||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.plotting.clips import build_preprocessor
|
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.targets.targets import build_targets
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.typing import ModelOutput, TrainExample
|
from batdetect2.typing import ModelOutput, TrainExample
|
||||||
|
|
||||||
@ -27,20 +26,20 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: "BatDetect2Config",
|
config: Optional[dict] = None,
|
||||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
|
||||||
learning_rate: float = 0.001,
|
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
model: Optional[Model] = None,
|
model: Optional[Model] = None,
|
||||||
loss: Optional[torch.nn.Module] = None,
|
loss: Optional[torch.nn.Module] = None,
|
||||||
):
|
):
|
||||||
|
from batdetect2.config import validate_config
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.save_hyperparameters(logger=False)
|
self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
self.input_samplerate = input_samplerate
|
self.config = validate_config(config)
|
||||||
self.config = config
|
self.input_samplerate = self.config.audio.samplerate
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = self.config.train.optimizer.learning_rate
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
if loss is None:
|
if loss is None:
|
||||||
@ -104,14 +103,7 @@ def load_model_from_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional[dict] = None,
|
||||||
t_max: int = 200,
|
t_max: int = 200,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
from batdetect2.config import BatDetect2Config
|
return TrainingModule(config=config, t_max=t_max)
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
|
||||||
return TrainingModule(
|
|
||||||
config=config,
|
|
||||||
learning_rate=config.train.optimizer.learning_rate,
|
|
||||||
t_max=t_max,
|
|
||||||
)
|
|
||||||
|
|||||||
@ -97,7 +97,7 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
config,
|
config.model_dump(mode="json"),
|
||||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from batdetect2.typing.data import OutputFormatterProtocol
|
||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing.evaluate import (
|
||||||
AffinityFunction,
|
AffinityFunction,
|
||||||
ClipMatches,
|
ClipMatches,
|
||||||
@ -10,6 +11,7 @@ from batdetect2.typing.evaluate import (
|
|||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
|
ClipDetectionsTensor,
|
||||||
GeometryDecoder,
|
GeometryDecoder,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
@ -43,8 +45,9 @@ __all__ = [
|
|||||||
"Augmentation",
|
"Augmentation",
|
||||||
"BackboneModel",
|
"BackboneModel",
|
||||||
"BatDetect2Prediction",
|
"BatDetect2Prediction",
|
||||||
"ClipMatches",
|
"ClipDetectionsTensor",
|
||||||
"ClipLabeller",
|
"ClipLabeller",
|
||||||
|
"ClipMatches",
|
||||||
"ClipperProtocol",
|
"ClipperProtocol",
|
||||||
"DetectionModel",
|
"DetectionModel",
|
||||||
"EvaluatorProtocol",
|
"EvaluatorProtocol",
|
||||||
@ -56,6 +59,7 @@ __all__ = [
|
|||||||
"MatcherProtocol",
|
"MatcherProtocol",
|
||||||
"MetricsProtocol",
|
"MetricsProtocol",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
|
"OutputFormatterProtocol",
|
||||||
"PlotterProtocol",
|
"PlotterProtocol",
|
||||||
"Position",
|
"Position",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
|
|||||||
26
src/batdetect2/typing/data.py
Normal file
26
src/batdetect2/typing/data.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from typing import Generic, List, Optional, Protocol, Sequence, TypeVar
|
||||||
|
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"OutputFormatterProtocol",
|
||||||
|
]
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||||
|
def format(
|
||||||
|
self, predictions: Sequence[BatDetect2Prediction]
|
||||||
|
) -> List[T]: ...
|
||||||
|
|
||||||
|
def save(
|
||||||
|
self,
|
||||||
|
predictions: Sequence[T],
|
||||||
|
path: PathLike,
|
||||||
|
audio_dir: Optional[PathLike] = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
def load(self, path: PathLike) -> List[T]: ...
|
||||||
@ -14,7 +14,7 @@ from typing import (
|
|||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -115,7 +115,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
predictions: Sequence[Sequence[RawPrediction]],
|
predictions: Sequence[BatDetect2Prediction],
|
||||||
) -> EvaluationOutput: ...
|
) -> EvaluationOutput: ...
|
||||||
|
|
||||||
def compute_metrics(
|
def compute_metrics(
|
||||||
|
|||||||
@ -4,16 +4,15 @@ import lightning as L
|
|||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.models import build_model
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
from batdetect2.train import TrainingModule
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
|
|
||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
model = build_model()
|
config = BatDetect2Config()
|
||||||
config = FullTrainingConfig()
|
return build_training_module(config=config.model_dump())
|
||||||
return build_training_module(model, config=config)
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user