mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Compare commits
3 Commits
6d0a73dda6
...
8366410332
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8366410332 | ||
|
|
aa4ad68958 | ||
|
|
401a3832ce |
@ -7,30 +7,30 @@ authors = [
|
||||
{ "name" = "Santiago Martinez Balvanera", "email" = "santiago.balvanera.20@ucl.ac.uk" },
|
||||
]
|
||||
dependencies = [
|
||||
"cf-xarray>=0.9.0",
|
||||
"click>=8.1.7",
|
||||
"deepmerge>=2.0",
|
||||
"hydra-core>=1.3.2",
|
||||
"librosa>=0.10.1",
|
||||
"lightning[extra]==2.5.0",
|
||||
"loguru>=0.7.3",
|
||||
"matplotlib>=3.7.1",
|
||||
"netcdf4>=1.6.5",
|
||||
"numba>=0.60",
|
||||
"numpy>=1.23.5",
|
||||
"omegaconf>=2.3.0",
|
||||
"onnx>=1.16.0",
|
||||
"pandas>=1.5.3",
|
||||
"pyyaml>=6.0.2",
|
||||
"scikit-learn>=1.2.2",
|
||||
"scipy>=1.10.1",
|
||||
"seaborn>=0.13.2",
|
||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||
"tensorboard>=2.16.2",
|
||||
"torch>=1.13.1,<2.5.0",
|
||||
"torchaudio>=1.13.1,<2.5.0",
|
||||
"torchvision>=0.14.0",
|
||||
"soundevent[audio,geometry,plot]>=2.9.1",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
"cf-xarray>=0.9.0",
|
||||
"onnx>=1.16.0",
|
||||
"lightning[extra]==2.5.0",
|
||||
"tensorboard>=2.16.2",
|
||||
"omegaconf>=2.3.0",
|
||||
"pyyaml>=6.0.2",
|
||||
"hydra-core>=1.3.2",
|
||||
"numba>=0.60",
|
||||
"loguru>=0.7.3",
|
||||
"deepmerge>=2.0",
|
||||
]
|
||||
requires-python = ">=3.9,<3.13"
|
||||
readme = "README.md"
|
||||
@ -89,9 +89,7 @@ dev = [
|
||||
"python-lsp-server>=1.13.0",
|
||||
]
|
||||
dvclive = ["dvclive>=3.48.2"]
|
||||
mlflow = [
|
||||
"mlflow>=3.1.1",
|
||||
]
|
||||
mlflow = ["mlflow>=3.1.1"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
@ -115,9 +113,7 @@ exclude = [
|
||||
"src/batdetect2/detector/",
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
"src/batdetect2/plotting",
|
||||
"src/batdetect2/plot",
|
||||
"src/batdetect2/api",
|
||||
"src/batdetect2/evaluate/legacy",
|
||||
"src/batdetect2/train/legacy",
|
||||
]
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence
|
||||
from typing import Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -9,6 +10,14 @@ from soundevent.audio.files import get_audio_files
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.data import (
|
||||
OutputFormatConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
load_dataset_from_config,
|
||||
)
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.data.predictions.base import OutputFormatterProtocol
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
@ -41,6 +50,7 @@ class BatDetect2API:
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
evaluator: EvaluatorProtocol,
|
||||
formatter: OutputFormatterProtocol,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
@ -50,9 +60,17 @@ class BatDetect2API:
|
||||
self.postprocessor = postprocessor
|
||||
self.evaluator = evaluator
|
||||
self.model = model
|
||||
self.formatter = formatter
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def load_annotations(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
) -> Dataset:
|
||||
return load_dataset_from_config(path, base_dir=base_dir)
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
@ -91,7 +109,8 @@ class BatDetect2API:
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
save_predictions: bool = True,
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
return evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
@ -103,8 +122,41 @@ class BatDetect2API:
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
formatter=self.formatter if save_predictions else None,
|
||||
)
|
||||
|
||||
def evaluate_predictions(
|
||||
self,
|
||||
annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
output_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
annotations,
|
||||
predictions,
|
||||
)
|
||||
|
||||
metrics = self.evaluator.compute_metrics(clip_evals)
|
||||
|
||||
if output_dir is not None:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
metrics_path = output_dir / "metrics.json"
|
||||
metrics_path.write_text(json.dumps(metrics))
|
||||
|
||||
for figure_name, fig in self.evaluator.generate_plots(clip_evals):
|
||||
fig_path = output_dir / figure_name
|
||||
|
||||
if not fig_path.parent.is_dir():
|
||||
fig_path.parent.mkdir(parents=True)
|
||||
|
||||
fig.savefig(fig_path)
|
||||
|
||||
return metrics
|
||||
|
||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||
return self.audio_loader.load_file(path)
|
||||
|
||||
@ -194,8 +246,38 @@ class BatDetect2API:
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
def save_predictions(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
format: Optional[str] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
):
|
||||
formatter = self.formatter
|
||||
|
||||
if format is not None or config is not None:
|
||||
format = format or config.name # type: ignore
|
||||
formatter = get_output_formatter(
|
||||
name=format,
|
||||
targets=self.targets,
|
||||
config=config,
|
||||
)
|
||||
|
||||
outs = formatter.format(predictions)
|
||||
formatter.save(outs, audio_dir=audio_dir, path=path)
|
||||
|
||||
def load_predictions(
|
||||
self,
|
||||
path: data.PathLike,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return self.formatter.load(path)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BatDetect2Config):
|
||||
def from_config(
|
||||
cls,
|
||||
config: BatDetect2Config,
|
||||
):
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
@ -228,6 +310,8 @@ class BatDetect2API:
|
||||
),
|
||||
)
|
||||
|
||||
formatter = build_output_formatter(targets, config=config.output)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
@ -236,6 +320,7 @@ class BatDetect2API:
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -266,6 +351,8 @@ class BatDetect2API:
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
formatter = build_output_formatter(targets, config=config.output)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
@ -274,4 +361,5 @@ class BatDetect2API:
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
formatter=formatter,
|
||||
)
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.logging import enable_logging
|
||||
|
||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||
|
||||
@ -27,22 +26,9 @@ BatDetect2 - Detection and Classification
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def cli(
|
||||
verbose: int = 0,
|
||||
):
|
||||
def cli(verbose: int = 0):
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
logger.remove()
|
||||
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.enable("batdetect2")
|
||||
enable_logging(verbose)
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
||||
@ -43,3 +43,55 @@ def summary(
|
||||
)
|
||||
|
||||
print(f"Number of annotated clips: {len(dataset)}")
|
||||
|
||||
|
||||
@data.command()
|
||||
@click.argument(
|
||||
"dataset_config",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
"--field",
|
||||
type=str,
|
||||
help="If the dataset info is in a nested field please specify here.",
|
||||
)
|
||||
@click.option(
|
||||
"--output",
|
||||
type=click.Path(exists=False),
|
||||
default="annotations.json",
|
||||
)
|
||||
@click.option(
|
||||
"--base-dir",
|
||||
type=click.Path(exists=True),
|
||||
help="The base directory to which all recording and annotations paths are relative to.",
|
||||
)
|
||||
def convert(
|
||||
dataset_config: Path,
|
||||
field: Optional[str] = None,
|
||||
output: Path = Path("annotations.json"),
|
||||
base_dir: Optional[Path] = None,
|
||||
):
|
||||
"""Convert a dataset config file to soundevent format."""
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.data import load_dataset, load_dataset_config
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
|
||||
config = load_dataset_config(
|
||||
dataset_config,
|
||||
field=field,
|
||||
)
|
||||
|
||||
dataset = load_dataset(
|
||||
config,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
annotation_set = data.AnnotationSet(
|
||||
clip_annotations=list(dataset),
|
||||
name=config.name,
|
||||
description=config.description,
|
||||
)
|
||||
|
||||
io.save(annotation_set, output)
|
||||
|
||||
@ -4,9 +4,13 @@ from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.configs import load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.data.predictions import OutputFormatConfig
|
||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||
from batdetect2.evaluate.config import (
|
||||
EvaluationConfig,
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
@ -17,6 +21,7 @@ from batdetect2.train.config import TrainingConfig
|
||||
__all__ = [
|
||||
"BatDetect2Config",
|
||||
"load_full_config",
|
||||
"validate_config",
|
||||
]
|
||||
|
||||
|
||||
@ -24,7 +29,9 @@ class BatDetect2Config(BaseConfig):
|
||||
config_version: Literal["v1"] = "v1"
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
evaluation: EvaluationConfig = Field(
|
||||
default_factory=get_default_eval_config
|
||||
)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
@ -33,6 +40,14 @@ class BatDetect2Config(BaseConfig):
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||
|
||||
|
||||
def validate_config(config: Optional[dict]) -> BatDetect2Config:
|
||||
if config is None:
|
||||
return BatDetect2Config()
|
||||
|
||||
return BatDetect2Config.model_validate(config)
|
||||
|
||||
|
||||
def load_full_config(
|
||||
|
||||
@ -29,7 +29,7 @@ class BaseConfig(BaseModel):
|
||||
and serialization capabilities.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
def to_yaml_string(
|
||||
self,
|
||||
|
||||
@ -77,6 +77,15 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
def get_config_types(self) -> Tuple[Type[BaseModel], ...]:
|
||||
return tuple(self._config_types.values())
|
||||
|
||||
def get_config_type(self, name: str) -> Type[BaseModel]:
|
||||
try:
|
||||
return self._config_types[name]
|
||||
except KeyError as err:
|
||||
raise ValueError(
|
||||
f"No config type with name '{name}' is registered. "
|
||||
f"Existing config types: {list(self._config_types.keys())}"
|
||||
) from err
|
||||
|
||||
def build(
|
||||
self,
|
||||
config: BaseModel,
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
from batdetect2.data.annotations import (
|
||||
AnnotatedDataset,
|
||||
AnnotationFormats,
|
||||
AOEFAnnotations,
|
||||
BatDetect2FilesAnnotations,
|
||||
BatDetect2MergedAnnotations,
|
||||
@ -11,6 +12,14 @@ from batdetect2.data.datasets import (
|
||||
load_dataset_config,
|
||||
load_dataset_from_config,
|
||||
)
|
||||
from batdetect2.data.predictions import (
|
||||
BatDetect2OutputConfig,
|
||||
OutputFormatConfig,
|
||||
RawOutputConfig,
|
||||
SoundEventOutputConfig,
|
||||
build_output_formatter,
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.data.summary import (
|
||||
compute_class_summary,
|
||||
extract_recordings_df,
|
||||
@ -20,12 +29,19 @@ from batdetect2.data.summary import (
|
||||
__all__ = [
|
||||
"AOEFAnnotations",
|
||||
"AnnotatedDataset",
|
||||
"AnnotationFormats",
|
||||
"BatDetect2FilesAnnotations",
|
||||
"BatDetect2MergedAnnotations",
|
||||
"BatDetect2OutputConfig",
|
||||
"DatasetConfig",
|
||||
"OutputFormatConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
"build_output_formatter",
|
||||
"compute_class_summary",
|
||||
"extract_recordings_df",
|
||||
"extract_sound_events_df",
|
||||
"get_output_formatter",
|
||||
"load_annotated_dataset",
|
||||
"load_dataset",
|
||||
"load_dataset_config",
|
||||
|
||||
@ -13,7 +13,6 @@ format-specific loading function to retrieve the annotations as a standard
|
||||
`soundevent.data.AnnotationSet`.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
@ -64,7 +63,7 @@ source configuration represents.
|
||||
|
||||
def load_annotated_dataset(
|
||||
dataset: AnnotatedDataset,
|
||||
base_dir: Optional[Path] = None,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations for a single data source based on its configuration.
|
||||
|
||||
@ -97,6 +96,7 @@ def load_annotated_dataset(
|
||||
known format-specific loading functions implemented in the dispatch
|
||||
logic.
|
||||
"""
|
||||
|
||||
if isinstance(dataset, AOEFAnnotations):
|
||||
return load_aoef_annotated_dataset(dataset, base_dir=base_dir)
|
||||
|
||||
|
||||
@ -84,7 +84,7 @@ class AOEFAnnotations(AnnotatedDataset):
|
||||
|
||||
def load_aoef_annotated_dataset(
|
||||
dataset: AOEFAnnotations,
|
||||
base_dir: Optional[Path] = None,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
) -> data.AnnotationSet:
|
||||
"""Load annotations from an AnnotationSet or AnnotationProject file.
|
||||
|
||||
|
||||
@ -19,7 +19,7 @@ The core components are:
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
@ -53,7 +53,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
Dataset = List[data.ClipAnnotation]
|
||||
Dataset = Sequence[data.ClipAnnotation]
|
||||
"""Type alias for a loaded dataset representation.
|
||||
|
||||
Represents an entire dataset *after loading* as a flat Python list containing
|
||||
@ -77,7 +77,7 @@ class DatasetConfig(BaseConfig):
|
||||
|
||||
def load_dataset(
|
||||
config: DatasetConfig,
|
||||
base_dir: Optional[Path] = None,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
clip_annotations = []
|
||||
@ -168,7 +168,7 @@ def load_dataset_config(path: data.PathLike, field: Optional[str] = None):
|
||||
def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
base_dir: Optional[Path] = None,
|
||||
base_dir: Optional[data.PathLike] = None,
|
||||
) -> Dataset:
|
||||
"""Load dataset annotation metadata from a configuration file.
|
||||
|
||||
@ -250,6 +250,6 @@ def save_dataset(
|
||||
annotation_set = data.AnnotationSet(
|
||||
name=name,
|
||||
description=description,
|
||||
clip_annotations=dataset,
|
||||
clip_annotations=list(dataset),
|
||||
)
|
||||
io.save(annotation_set, path, audio_dir=audio_dir)
|
||||
|
||||
58
src/batdetect2/data/predictions/__init__.py
Normal file
58
src/batdetect2/data/predictions/__init__.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Annotated, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.data.predictions.base import (
|
||||
OutputFormatterProtocol,
|
||||
prediction_formatters,
|
||||
)
|
||||
from batdetect2.data.predictions.batdetect2 import BatDetect2OutputConfig
|
||||
from batdetect2.data.predictions.raw import RawOutputConfig
|
||||
from batdetect2.data.predictions.soundevent import SoundEventOutputConfig
|
||||
from batdetect2.typing import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_output_formatter",
|
||||
"get_output_formatter",
|
||||
"BatDetect2OutputConfig",
|
||||
"RawOutputConfig",
|
||||
"SoundEventOutputConfig",
|
||||
]
|
||||
|
||||
|
||||
OutputFormatConfig = Annotated[
|
||||
Union[BatDetect2OutputConfig, SoundEventOutputConfig, RawOutputConfig],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_output_formatter(
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
) -> OutputFormatterProtocol:
|
||||
"""Construct the final output formatter."""
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
config = config or RawOutputConfig()
|
||||
|
||||
targets = targets or build_targets()
|
||||
return prediction_formatters.build(config, targets)
|
||||
|
||||
|
||||
def get_output_formatter(
|
||||
name: str,
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
config: Optional[OutputFormatConfig] = None,
|
||||
) -> OutputFormatterProtocol:
|
||||
"""Get the output formatter by name."""
|
||||
|
||||
if config is None:
|
||||
config_class = prediction_formatters.get_config_type(name)
|
||||
config = config_class() # type: ignore
|
||||
|
||||
if config.name != name: # type: ignore
|
||||
raise ValueError(
|
||||
f"Config name {config.name} does not match formatter name {name}" # type: ignore
|
||||
)
|
||||
|
||||
return build_output_formatter(targets, config)
|
||||
29
src/batdetect2/data/predictions/base.py
Normal file
29
src/batdetect2/data/predictions/base.py
Normal file
@ -0,0 +1,29 @@
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.core import Registry
|
||||
from batdetect2.typing import (
|
||||
OutputFormatterProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
def make_path_relative(path: PathLike, audio_dir: PathLike) -> Path:
|
||||
path = Path(path)
|
||||
audio_dir = Path(audio_dir)
|
||||
|
||||
if path.is_absolute():
|
||||
if not path.is_relative_to(audio_dir):
|
||||
raise ValueError(
|
||||
f"Audio file {path} is not in audio_dir {audio_dir}"
|
||||
)
|
||||
|
||||
return path.relative_to(audio_dir)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
prediction_formatters: Registry[OutputFormatterProtocol, [TargetProtocol]] = (
|
||||
Registry(name="output_formatter")
|
||||
)
|
||||
227
src/batdetect2/data/predictions/batdetect2.py
Normal file
227
src/batdetect2/data/predictions/batdetect2.py
Normal file
@ -0,0 +1,227 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Sequence, TypedDict
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.data.predictions.base import (
|
||||
make_path_relative,
|
||||
prediction_formatters,
|
||||
)
|
||||
from batdetect2.targets import terms
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
OutputFormatterProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
try:
|
||||
from typing import NotRequired # type: ignore
|
||||
except ImportError:
|
||||
from typing_extensions import NotRequired
|
||||
|
||||
DictWithClass = TypedDict("DictWithClass", {"class": str})
|
||||
|
||||
|
||||
class Annotation(DictWithClass):
|
||||
"""Format of annotations.
|
||||
|
||||
This is the format of a single annotation as expected by the
|
||||
annotation tool.
|
||||
"""
|
||||
|
||||
start_time: float
|
||||
"""Start time in seconds."""
|
||||
|
||||
end_time: float
|
||||
"""End time in seconds."""
|
||||
|
||||
low_freq: float
|
||||
"""Low frequency in Hz."""
|
||||
|
||||
high_freq: float
|
||||
"""High frequency in Hz."""
|
||||
|
||||
class_prob: float
|
||||
"""Probability of class assignment."""
|
||||
|
||||
det_prob: float
|
||||
"""Probability of detection."""
|
||||
|
||||
individual: str
|
||||
"""Individual ID."""
|
||||
|
||||
event: str
|
||||
"""Type of detected event."""
|
||||
|
||||
|
||||
class FileAnnotation(TypedDict):
|
||||
"""Format of results.
|
||||
|
||||
This is the format of the results expected by the annotation tool.
|
||||
"""
|
||||
|
||||
id: str
|
||||
"""File ID."""
|
||||
|
||||
annotated: bool
|
||||
"""Whether file has been annotated."""
|
||||
|
||||
duration: float
|
||||
"""Duration of audio file."""
|
||||
|
||||
issues: bool
|
||||
"""Whether file has issues."""
|
||||
|
||||
time_exp: float
|
||||
"""Time expansion factor."""
|
||||
|
||||
class_name: str
|
||||
"""Class predicted at file level."""
|
||||
|
||||
notes: str
|
||||
"""Notes of file."""
|
||||
|
||||
annotation: List[Annotation]
|
||||
"""List of annotations."""
|
||||
|
||||
file_path: NotRequired[str]
|
||||
"""Path to file."""
|
||||
|
||||
|
||||
class BatDetect2OutputConfig(BaseConfig):
|
||||
name: Literal["batdetect2"] = "batdetect2"
|
||||
|
||||
event_name: str = "Echolocation"
|
||||
|
||||
annotation_note: str = "Automatically generated."
|
||||
|
||||
|
||||
class BatDetect2Formatter(OutputFormatterProtocol[FileAnnotation]):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
event_name: str,
|
||||
annotation_note: str,
|
||||
):
|
||||
self.targets = targets
|
||||
self.event_name = event_name
|
||||
self.annotation_note = annotation_note
|
||||
|
||||
def format(
|
||||
self, predictions: Sequence[BatDetect2Prediction]
|
||||
) -> List[FileAnnotation]:
|
||||
return [
|
||||
self.format_prediction(prediction) for prediction in predictions
|
||||
]
|
||||
|
||||
def save(
|
||||
self,
|
||||
predictions: Sequence[FileAnnotation],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> None:
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_dir():
|
||||
path.mkdir(parents=True)
|
||||
|
||||
for prediction in predictions:
|
||||
pred_path = path / (prediction["id"] + ".json")
|
||||
|
||||
if audio_dir is not None and "file_path" in prediction:
|
||||
prediction["file_path"] = str(
|
||||
make_path_relative(
|
||||
prediction["file_path"],
|
||||
audio_dir,
|
||||
)
|
||||
)
|
||||
|
||||
pred_path.write_text(json.dumps(prediction))
|
||||
|
||||
def load(self, path: data.PathLike) -> List[FileAnnotation]:
|
||||
path = Path(path)
|
||||
|
||||
files = list(path.glob("*.json"))
|
||||
|
||||
if not files:
|
||||
return []
|
||||
|
||||
return [
|
||||
json.loads(file.read_text()) for file in files if file.is_file()
|
||||
]
|
||||
|
||||
def get_recording_class(self, annotations: List[Annotation]) -> str:
|
||||
"""Get class of recording from annotations."""
|
||||
|
||||
if not annotations:
|
||||
return ""
|
||||
|
||||
highest_scoring = max(annotations, key=lambda x: x["class_prob"])
|
||||
return highest_scoring["class"]
|
||||
|
||||
def format_prediction(
|
||||
self, prediction: BatDetect2Prediction
|
||||
) -> FileAnnotation:
|
||||
recording = prediction.clip.recording
|
||||
|
||||
annotations = [
|
||||
self.format_sound_event_prediction(pred)
|
||||
for pred in prediction.predictions
|
||||
]
|
||||
|
||||
return FileAnnotation(
|
||||
id=recording.path.name,
|
||||
file_path=str(recording.path),
|
||||
annotated=False,
|
||||
duration=recording.duration,
|
||||
issues=False,
|
||||
time_exp=recording.time_expansion,
|
||||
class_name=self.get_recording_class(annotations),
|
||||
notes=self.annotation_note,
|
||||
annotation=annotations,
|
||||
)
|
||||
|
||||
def get_class_name(self, class_index: int) -> str:
|
||||
class_name = self.targets.class_names[class_index]
|
||||
tags = self.targets.decode_class(class_name)
|
||||
return data.find_tag_value(
|
||||
tags,
|
||||
term=terms.generic_class,
|
||||
default=class_name,
|
||||
) # type: ignore
|
||||
|
||||
def format_sound_event_prediction(
|
||||
self, prediction: RawPrediction
|
||||
) -> Annotation:
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
prediction.geometry
|
||||
)
|
||||
|
||||
top_class_index = int(np.argmax(prediction.class_scores))
|
||||
top_class_score = float(prediction.class_scores[top_class_index])
|
||||
top_class = self.get_class_name(top_class_index)
|
||||
return Annotation(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
low_freq=low_freq,
|
||||
high_freq=high_freq,
|
||||
class_prob=top_class_score,
|
||||
det_prob=float(prediction.detection_score),
|
||||
individual="",
|
||||
event=self.event_name,
|
||||
**{"class": top_class},
|
||||
)
|
||||
|
||||
@prediction_formatters.register(BatDetect2OutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: BatDetect2OutputConfig, targets: TargetProtocol):
|
||||
return BatDetect2Formatter(
|
||||
targets,
|
||||
event_name=config.event_name,
|
||||
annotation_note=config.annotation_note,
|
||||
)
|
||||
246
src/batdetect2/data/predictions/raw.py
Normal file
246
src/batdetect2/data/predictions/raw.py
Normal file
@ -0,0 +1,246 @@
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Sequence
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.data.predictions.base import (
|
||||
make_path_relative,
|
||||
prediction_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
OutputFormatterProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
class RawOutputConfig(BaseConfig):
|
||||
name: Literal["raw"] = "raw"
|
||||
|
||||
include_class_scores: bool = True
|
||||
include_features: bool = True
|
||||
include_geometry: bool = True
|
||||
|
||||
|
||||
class RawFormatter(OutputFormatterProtocol[BatDetect2Prediction]):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
include_class_scores: bool = True,
|
||||
include_features: bool = True,
|
||||
include_geometry: bool = True,
|
||||
):
|
||||
self.targets = targets
|
||||
self.include_class_scores = include_class_scores
|
||||
self.include_features = include_features
|
||||
self.include_geometry = include_geometry
|
||||
|
||||
def format(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return list(predictions)
|
||||
|
||||
def save(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> None:
|
||||
num_features = 0
|
||||
|
||||
tree = xr.DataTree()
|
||||
|
||||
for prediction in predictions:
|
||||
clip = prediction.clip
|
||||
recording = clip.recording
|
||||
|
||||
if audio_dir is not None:
|
||||
recording = recording.model_copy(
|
||||
update=dict(
|
||||
path=make_path_relative(recording.path, audio_dir)
|
||||
)
|
||||
)
|
||||
|
||||
clip_data = defaultdict(list)
|
||||
|
||||
for pred in prediction.predictions:
|
||||
detection_id = str(uuid4())
|
||||
|
||||
clip_data["detection_id"].append(detection_id)
|
||||
clip_data["detection_score"].append(pred.detection_score)
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(
|
||||
pred.geometry
|
||||
)
|
||||
|
||||
clip_data["start_time"].append(start_time)
|
||||
clip_data["end_time"].append(end_time)
|
||||
clip_data["low_freq"].append(low_freq)
|
||||
clip_data["high_freq"].append(high_freq)
|
||||
|
||||
clip_data["geometry"].append(pred.geometry.model_dump_json())
|
||||
|
||||
top_class_index = int(np.argmax(pred.class_scores))
|
||||
top_class_score = float(pred.class_scores[top_class_index])
|
||||
top_class = self.targets.class_names[top_class_index]
|
||||
|
||||
clip_data["top_class"].append(top_class)
|
||||
clip_data["top_class_score"].append(top_class_score)
|
||||
|
||||
clip_data["class_scores"].append(pred.class_scores)
|
||||
clip_data["features"].append(pred.features)
|
||||
|
||||
num_features = len(pred.features)
|
||||
|
||||
data_vars = {
|
||||
"score": (["detection"], clip_data["detection_score"]),
|
||||
"start_time": (["detection"], clip_data["start_time"]),
|
||||
"end_time": (["detection"], clip_data["end_time"]),
|
||||
"low_freq": (["detection"], clip_data["low_freq"]),
|
||||
"high_freq": (["detection"], clip_data["high_freq"]),
|
||||
"top_class": (["detection"], clip_data["top_class"]),
|
||||
"top_class_score": (
|
||||
["detection"],
|
||||
clip_data["top_class_score"],
|
||||
),
|
||||
}
|
||||
|
||||
coords = {
|
||||
"detection": ("detection", clip_data["detection_id"]),
|
||||
"clip_start": clip.start_time,
|
||||
"clip_end": clip.end_time,
|
||||
"clip_id": str(clip.uuid),
|
||||
}
|
||||
|
||||
if self.include_class_scores:
|
||||
data_vars["class_scores"] = (
|
||||
["detection", "classes"],
|
||||
clip_data["class_scores"],
|
||||
)
|
||||
coords["classes"] = ("classes", self.targets.class_names)
|
||||
|
||||
if self.include_features:
|
||||
data_vars["features"] = (
|
||||
["detection", "feature"],
|
||||
clip_data["features"],
|
||||
)
|
||||
coords["feature"] = ("feature", np.arange(num_features))
|
||||
|
||||
if self.include_geometry:
|
||||
data_vars["geometry"] = (["detection"], clip_data["geometry"])
|
||||
|
||||
dataset = xr.Dataset(
|
||||
data_vars=data_vars,
|
||||
coords=coords,
|
||||
attrs={
|
||||
"recording": recording.model_dump_json(exclude_none=True),
|
||||
},
|
||||
)
|
||||
|
||||
tree = tree.assign(
|
||||
{
|
||||
str(clip.uuid): xr.DataTree(
|
||||
dataset=dataset,
|
||||
name=str(clip.uuid),
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.suffix == ".nc":
|
||||
path = Path(path).with_suffix(".nc")
|
||||
|
||||
tree.to_netcdf(path)
|
||||
|
||||
def load(self, path: data.PathLike) -> List[BatDetect2Prediction]:
|
||||
path = Path(path)
|
||||
|
||||
root = xr.load_datatree(path)
|
||||
|
||||
predictions: List[BatDetect2Prediction] = []
|
||||
|
||||
for _, clip_data in root.items():
|
||||
recording = data.Recording.model_validate_json(
|
||||
clip_data.attrs["recording"]
|
||||
)
|
||||
|
||||
clip_id = clip_data.clip_id.item()
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
uuid=UUID(clip_id),
|
||||
start_time=clip_data.clip_start,
|
||||
end_time=clip_data.clip_end,
|
||||
)
|
||||
|
||||
sound_events = []
|
||||
|
||||
for detection in clip_data.detection:
|
||||
score = clip_data.score.sel(detection=detection).item()
|
||||
|
||||
if "geometry" in clip_data:
|
||||
geometry = data.geometry_validate(
|
||||
clip_data.geometry.sel(detection=detection).item()
|
||||
)
|
||||
else:
|
||||
start_time = clip_data.start_time.sel(detection=detection)
|
||||
end_time = clip_data.end_time.sel(detection=detection)
|
||||
low_freq = clip_data.low_freq.sel(detection=detection)
|
||||
high_freq = clip_data.high_freq.sel(detection=detection)
|
||||
geometry = data.BoundingBox(
|
||||
coordinates=[start_time, low_freq, end_time, high_freq]
|
||||
)
|
||||
|
||||
if "class_scores" in clip_data:
|
||||
class_scores = clip_data.class_scores.sel(
|
||||
detection=detection
|
||||
).data
|
||||
else:
|
||||
class_scores = np.zeros(len(self.targets.class_names))
|
||||
class_index = self.targets.class_names.index(
|
||||
clip_data.top_class.sel(detection=detection).item()
|
||||
)
|
||||
class_scores[class_index] = clip_data.top_class_score.sel(
|
||||
detection=detection
|
||||
).item()
|
||||
|
||||
if "features" in clip_data:
|
||||
features = clip_data.features.sel(detection=detection).data
|
||||
else:
|
||||
features = np.zeros(0)
|
||||
|
||||
sound_events.append(
|
||||
RawPrediction(
|
||||
geometry=geometry,
|
||||
detection_score=score,
|
||||
class_scores=class_scores,
|
||||
features=features,
|
||||
)
|
||||
)
|
||||
|
||||
predictions.append(
|
||||
BatDetect2Prediction(
|
||||
clip=clip,
|
||||
predictions=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
@prediction_formatters.register(RawOutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: RawOutputConfig, targets: TargetProtocol):
|
||||
return RawFormatter(
|
||||
targets,
|
||||
include_class_scores=config.include_class_scores,
|
||||
include_features=config.include_features,
|
||||
include_geometry=config.include_geometry,
|
||||
)
|
||||
131
src/batdetect2/data/predictions/soundevent.py
Normal file
131
src/batdetect2/data/predictions/soundevent.py
Normal file
@ -0,0 +1,131 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Literal, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data, io
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.data.predictions.base import (
|
||||
prediction_formatters,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
OutputFormatterProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
class SoundEventOutputConfig(BaseConfig):
|
||||
name: Literal["soundevent"] = "soundevent"
|
||||
top_k: Optional[int] = 1
|
||||
min_score: Optional[float] = None
|
||||
|
||||
|
||||
class SoundEventOutputFormatter(OutputFormatterProtocol[data.ClipPrediction]):
|
||||
def __init__(
|
||||
self,
|
||||
targets: TargetProtocol,
|
||||
top_k: Optional[int] = 1,
|
||||
min_score: Optional[float] = 0,
|
||||
):
|
||||
self.targets = targets
|
||||
self.top_k = top_k
|
||||
self.min_score = min_score
|
||||
|
||||
def format(
|
||||
self,
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
) -> List[data.ClipPrediction]:
|
||||
return [
|
||||
self.format_prediction(prediction) for prediction in predictions
|
||||
]
|
||||
|
||||
def save(
|
||||
self,
|
||||
predictions: Sequence[data.ClipPrediction],
|
||||
path: data.PathLike,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
) -> None:
|
||||
run = data.PredictionSet(clip_predictions=list(predictions))
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.suffix == ".json":
|
||||
path = Path(path).with_suffix(".json")
|
||||
|
||||
io.save(run, path, audio_dir=audio_dir)
|
||||
|
||||
def load(self, path: data.PathLike) -> List[data.ClipPrediction]:
|
||||
path = Path(path)
|
||||
run = io.load(path, type="prediction_set")
|
||||
return run.clip_predictions
|
||||
|
||||
def format_prediction(
|
||||
self,
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> data.ClipPrediction:
|
||||
recording = prediction.clip.recording
|
||||
return data.ClipPrediction(
|
||||
clip=prediction.clip,
|
||||
sound_events=[
|
||||
self.format_sound_event_prediction(pred, recording)
|
||||
for pred in prediction.predictions
|
||||
],
|
||||
)
|
||||
|
||||
def format_sound_event_prediction(
|
||||
self,
|
||||
prediction: RawPrediction,
|
||||
recording: data.Recording,
|
||||
) -> data.SoundEventPrediction:
|
||||
return data.SoundEventPrediction(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=prediction.geometry,
|
||||
),
|
||||
score=prediction.detection_score,
|
||||
tags=self.get_sound_event_tags(prediction),
|
||||
)
|
||||
|
||||
def get_sound_event_tags(
|
||||
self, prediction: RawPrediction
|
||||
) -> List[data.PredictedTag]:
|
||||
sorted_indices = np.argsort(prediction.class_scores)[::-1]
|
||||
|
||||
tags = [
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=prediction.detection_score,
|
||||
)
|
||||
for tag in self.targets.detection_class_tags
|
||||
]
|
||||
|
||||
top_k = self.top_k or len(sorted_indices)
|
||||
|
||||
for ind in sorted_indices[:top_k]:
|
||||
score = float(prediction.class_scores[ind])
|
||||
|
||||
if self.min_score is not None and score < self.min_score:
|
||||
break
|
||||
|
||||
class_name = self.targets.class_names[ind]
|
||||
class_tags = self.targets.decode_class(class_name)
|
||||
tags.extend(
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=score,
|
||||
)
|
||||
for tag in class_tags
|
||||
)
|
||||
|
||||
return tags
|
||||
|
||||
@prediction_formatters.register(SoundEventOutputConfig)
|
||||
@staticmethod
|
||||
def from_config(config: SoundEventOutputConfig, targets: TargetProtocol):
|
||||
return SoundEventOutputFormatter(
|
||||
targets,
|
||||
top_k=config.top_k,
|
||||
min_score=config.min_score,
|
||||
)
|
||||
@ -159,6 +159,7 @@ def compute_class_summary(
|
||||
exclude_generic=False,
|
||||
exclude_non_target=True,
|
||||
)
|
||||
|
||||
recordings = extract_recordings_df(dataset)
|
||||
|
||||
num_calls = (
|
||||
|
||||
@ -27,6 +27,28 @@ class EvaluationConfig(BaseConfig):
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def get_default_eval_config() -> EvaluationConfig:
|
||||
return EvaluationConfig.model_validate(
|
||||
{
|
||||
"tasks": [
|
||||
{
|
||||
"name": "sound_event_detection",
|
||||
"plots": [
|
||||
{"name": "pr_curve"},
|
||||
{"name": "score_distribution"},
|
||||
],
|
||||
},
|
||||
{
|
||||
"name": "sound_event_classification",
|
||||
"plots": [
|
||||
{"name": "pr_curve"},
|
||||
],
|
||||
},
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
@ -12,11 +12,13 @@ from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
OutputFormatterProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
@ -31,11 +33,12 @@ def evaluate(
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
formatter: Optional["OutputFormatterProtocol"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
) -> Tuple[Dict[str, float], List[List[RawPrediction]]]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
@ -66,4 +69,12 @@ def evaluate(
|
||||
)
|
||||
module = EvaluationModule(model, evaluator)
|
||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||
return trainer.test(module, loader)
|
||||
metrics = trainer.test(module, loader)
|
||||
|
||||
if formatter is not None and logger.log_dir is not None:
|
||||
formatter.save(
|
||||
module.predictions,
|
||||
path=Path(logger.log_dir) / "predictions",
|
||||
)
|
||||
|
||||
return metrics, module.predictions # type: ignore
|
||||
|
||||
@ -6,7 +6,8 @@ from soundevent import data
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.tasks import build_task
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing import EvaluatorProtocol, RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import EvaluatorProtocol, TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
__all__ = [
|
||||
"Evaluator",
|
||||
@ -26,7 +27,7 @@ class Evaluator:
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
) -> List[Any]:
|
||||
return [
|
||||
task.evaluate(clip_annotations, predictions) for task in self.tasks
|
||||
|
||||
@ -9,7 +9,7 @@ from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
@ -24,7 +24,7 @@ class EvaluationModule(LightningModule):
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.clip_annotations: List[data.ClipAnnotation] = []
|
||||
self.predictions: List[List[RawPrediction]] = []
|
||||
self.predictions: List[BatDetect2Prediction] = []
|
||||
|
||||
def test_step(self, batch: TestExample, batch_idx: int):
|
||||
dataset = self.get_dataset()
|
||||
@ -39,11 +39,16 @@ class EvaluationModule(LightningModule):
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
BatDetect2Prediction(
|
||||
clip=clip_annotation.clip,
|
||||
predictions=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
),
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections
|
||||
)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
self.clip_annotations.extend(clip_annotations)
|
||||
|
||||
@ -23,7 +23,7 @@ from batdetect2.evaluate.match import (
|
||||
build_matcher,
|
||||
)
|
||||
from batdetect2.typing.evaluate import EvaluatorProtocol, MatcherProtocol
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -99,7 +99,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
) -> List[T_Output]:
|
||||
return [
|
||||
self.evaluate_clip(clip_annotation, preds)
|
||||
@ -109,7 +109,7 @@ class BaseTask(EvaluatorProtocol, Generic[T_Output]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> T_Output: ...
|
||||
|
||||
def include_sound_event_annotation(
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from typing import (
|
||||
List,
|
||||
Literal,
|
||||
Sequence,
|
||||
)
|
||||
|
||||
from pydantic import Field
|
||||
@ -23,7 +22,7 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import BatDetect2Prediction, TargetProtocol
|
||||
|
||||
|
||||
class ClassificationTaskConfig(BaseTaskConfig):
|
||||
@ -49,12 +48,14 @@ class ClassificationTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
|
||||
all_gts = [
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from collections import defaultdict
|
||||
from typing import List, Literal, Sequence
|
||||
from typing import List, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class ClipClassificationTaskConfig(BaseTaskConfig):
|
||||
@ -37,7 +38,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -54,7 +55,7 @@ class ClipClassificationTask(BaseTask[ClipEval]):
|
||||
gt_classes.add(class_name)
|
||||
|
||||
pred_scores = defaultdict(float)
|
||||
for pred in predictions:
|
||||
for pred in prediction.predictions:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Literal, Sequence
|
||||
from typing import List, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -18,7 +18,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class ClipDetectionTaskConfig(BaseTaskConfig):
|
||||
@ -36,7 +37,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -46,7 +47,7 @@ class ClipDetectionTask(BaseTask[ClipEval]):
|
||||
)
|
||||
|
||||
pred_score = 0
|
||||
for pred in predictions:
|
||||
for pred in prediction.predictions:
|
||||
if not self.include_prediction(pred, clip):
|
||||
continue
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Literal, Sequence
|
||||
from typing import List, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class DetectionTaskConfig(BaseTaskConfig):
|
||||
@ -35,7 +36,7 @@ class DetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -45,7 +46,9 @@ class DetectionTask(BaseTask[ClipEval]):
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
scores = [pred.detection_score for pred in preds]
|
||||
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import List, Literal, Sequence
|
||||
from typing import List, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -19,7 +19,8 @@ from batdetect2.evaluate.tasks.base import (
|
||||
BaseTaskConfig,
|
||||
tasks_registry,
|
||||
)
|
||||
from batdetect2.typing import RawPrediction, TargetProtocol
|
||||
from batdetect2.typing import TargetProtocol
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class TopClassDetectionTaskConfig(BaseTaskConfig):
|
||||
@ -35,7 +36,7 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
def evaluate_clip(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
predictions: Sequence[RawPrediction],
|
||||
prediction: BatDetect2Prediction,
|
||||
) -> ClipEval:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
@ -45,7 +46,9 @@ class TopClassDetectionTask(BaseTask[ClipEval]):
|
||||
if self.include_sound_event_annotation(sound_event, clip)
|
||||
]
|
||||
preds = [
|
||||
pred for pred in predictions if self.include_prediction(pred, clip)
|
||||
pred
|
||||
for pred in prediction.predictions
|
||||
if self.include_prediction(pred, clip)
|
||||
]
|
||||
# Take the highest score for each prediction
|
||||
scores = [pred.class_scores.max() for pred in preds]
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import io
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@ -32,6 +33,20 @@ from batdetect2.core.configs import BaseConfig
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
|
||||
|
||||
def enable_logging(level: int):
|
||||
logger.remove()
|
||||
|
||||
if level == 0:
|
||||
log_level = "WARNING"
|
||||
elif level == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
logger.enable("batdetect2")
|
||||
|
||||
|
||||
class BaseLoggerConfig(BaseConfig):
|
||||
log_dir: Path = DEFAULT_LOGS_DIR
|
||||
experiment_name: Optional[str] = None
|
||||
|
||||
@ -30,10 +30,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.models.backbones import (
|
||||
Backbone,
|
||||
build_backbone,
|
||||
)
|
||||
from batdetect2.models.backbones import Backbone, build_backbone
|
||||
from batdetect2.models.blocks import (
|
||||
ConvConfig,
|
||||
FreqCoordConvDownConfig,
|
||||
@ -62,16 +59,13 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.models import DetectionModel
|
||||
from batdetect2.typing.postprocess import (
|
||||
from batdetect2.typing import (
|
||||
ClipDetectionsTensor,
|
||||
DetectionModel,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"BBoxHead",
|
||||
@ -133,6 +127,10 @@ def build_model(
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
):
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
|
||||
config = config or BackboneConfig()
|
||||
targets = targets or build_targets()
|
||||
preprocessor = preprocessor or build_preprocessor()
|
||||
|
||||
@ -12,7 +12,7 @@ classification probability maps, size prediction maps, and potentially
|
||||
intermediate features.
|
||||
"""
|
||||
|
||||
from typing import List
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -30,6 +30,55 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def to_xarray(
|
||||
array: Union[torch.Tensor, np.ndarray],
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float = MIN_FREQ,
|
||||
max_freq: float = MAX_FREQ,
|
||||
name: str = "xarray",
|
||||
extra_dims: Optional[List[str]] = None,
|
||||
extra_coords: Optional[Dict[str, np.ndarray]] = None,
|
||||
) -> xr.DataArray:
|
||||
if isinstance(array, torch.Tensor):
|
||||
array = array.detach().cpu().numpy()
|
||||
|
||||
extra_ndims = array.ndim - 2
|
||||
|
||||
if extra_ndims < 0:
|
||||
raise ValueError(
|
||||
"Input array must have at least 2 dimensions, "
|
||||
f"got shape {array.shape}"
|
||||
)
|
||||
|
||||
width = array.shape[-1]
|
||||
height = array.shape[-2]
|
||||
|
||||
times = np.linspace(start_time, end_time, width, endpoint=False)
|
||||
freqs = np.linspace(min_freq, max_freq, height, endpoint=False)
|
||||
|
||||
if extra_dims is None:
|
||||
extra_dims = [f"dim_{i}" for i in range(extra_ndims)]
|
||||
|
||||
if extra_coords is None:
|
||||
extra_coords = {}
|
||||
|
||||
return xr.DataArray(
|
||||
data=array,
|
||||
dims=[
|
||||
*extra_dims,
|
||||
Dimensions.frequency.value,
|
||||
Dimensions.time.value,
|
||||
],
|
||||
coords={
|
||||
**extra_coords,
|
||||
Dimensions.frequency.value: freqs,
|
||||
Dimensions.time.value: times,
|
||||
},
|
||||
name=name,
|
||||
)
|
||||
|
||||
|
||||
def map_detection_to_clip(
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float,
|
||||
|
||||
@ -10,9 +10,9 @@ from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.typing import (
|
||||
BatDetect2Prediction,
|
||||
EvaluatorProtocol,
|
||||
ModelOutput,
|
||||
RawPrediction,
|
||||
TrainExample,
|
||||
)
|
||||
|
||||
@ -24,7 +24,7 @@ class ValidationMetrics(Callback):
|
||||
self.evaluator = evaluator
|
||||
|
||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||
self._predictions: List[List[RawPrediction]] = []
|
||||
self._predictions: List[BatDetect2Prediction] = []
|
||||
|
||||
def get_dataset(self, trainer: Trainer) -> ValidationDataset:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
@ -100,8 +100,15 @@ class ValidationMetrics(Callback):
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
|
||||
for clip_dets in clip_detections
|
||||
BatDetect2Prediction(
|
||||
clip=clip_annotation.clip,
|
||||
predictions=to_raw_predictions(
|
||||
clip_dets.numpy(), targets=model.targets
|
||||
),
|
||||
)
|
||||
for clip_annotation, clip_dets in zip(
|
||||
clip_annotations, clip_detections
|
||||
)
|
||||
]
|
||||
|
||||
self._clip_annotations.extend(clip_annotations)
|
||||
|
||||
@ -6,11 +6,10 @@ from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.audio import TARGET_SAMPLERATE_HZ
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.plotting.clips import build_preprocessor
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
@ -27,20 +26,20 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: "BatDetect2Config",
|
||||
input_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
learning_rate: float = 0.001,
|
||||
config: Optional[dict] = None,
|
||||
t_max: int = 100,
|
||||
model: Optional[Model] = None,
|
||||
loss: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
from batdetect2.config import validate_config
|
||||
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
self.input_samplerate = input_samplerate
|
||||
self.config = config
|
||||
self.learning_rate = learning_rate
|
||||
self.config = validate_config(config)
|
||||
self.input_samplerate = self.config.audio.samplerate
|
||||
self.learning_rate = self.config.train.optimizer.learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
if loss is None:
|
||||
@ -104,14 +103,7 @@ def load_model_from_checkpoint(
|
||||
|
||||
|
||||
def build_training_module(
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
config: Optional[dict] = None,
|
||||
t_max: int = 200,
|
||||
) -> TrainingModule:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
return TrainingModule(
|
||||
config=config,
|
||||
learning_rate=config.train.optimizer.learning_rate,
|
||||
t_max=t_max,
|
||||
)
|
||||
return TrainingModule(config=config, t_max=t_max)
|
||||
|
||||
@ -97,7 +97,7 @@ def train(
|
||||
)
|
||||
|
||||
module = build_training_module(
|
||||
config,
|
||||
config.model_dump(mode="json"),
|
||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||
)
|
||||
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from batdetect2.typing.data import OutputFormatterProtocol
|
||||
from batdetect2.typing.evaluate import (
|
||||
AffinityFunction,
|
||||
ClipMatches,
|
||||
@ -10,6 +11,7 @@ from batdetect2.typing.evaluate import (
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
ClipDetectionsTensor,
|
||||
GeometryDecoder,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
@ -43,8 +45,9 @@ __all__ = [
|
||||
"Augmentation",
|
||||
"BackboneModel",
|
||||
"BatDetect2Prediction",
|
||||
"ClipMatches",
|
||||
"ClipDetectionsTensor",
|
||||
"ClipLabeller",
|
||||
"ClipMatches",
|
||||
"ClipperProtocol",
|
||||
"DetectionModel",
|
||||
"EvaluatorProtocol",
|
||||
@ -56,6 +59,7 @@ __all__ = [
|
||||
"MatcherProtocol",
|
||||
"MetricsProtocol",
|
||||
"ModelOutput",
|
||||
"OutputFormatterProtocol",
|
||||
"PlotterProtocol",
|
||||
"Position",
|
||||
"PostprocessorProtocol",
|
||||
|
||||
26
src/batdetect2/typing/data.py
Normal file
26
src/batdetect2/typing/data.py
Normal file
@ -0,0 +1,26 @@
|
||||
from typing import Generic, List, Optional, Protocol, Sequence, TypeVar
|
||||
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
__all__ = [
|
||||
"OutputFormatterProtocol",
|
||||
]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class OutputFormatterProtocol(Protocol, Generic[T]):
|
||||
def format(
|
||||
self, predictions: Sequence[BatDetect2Prediction]
|
||||
) -> List[T]: ...
|
||||
|
||||
def save(
|
||||
self,
|
||||
predictions: Sequence[T],
|
||||
path: PathLike,
|
||||
audio_dir: Optional[PathLike] = None,
|
||||
) -> None: ...
|
||||
|
||||
def load(self, path: PathLike) -> List[T]: ...
|
||||
@ -14,7 +14,7 @@ from typing import (
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction, RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
@ -115,7 +115,7 @@ class EvaluatorProtocol(Protocol, Generic[EvaluationOutput]):
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
predictions: Sequence[BatDetect2Prediction],
|
||||
) -> EvaluationOutput: ...
|
||||
|
||||
def compute_metrics(
|
||||
|
||||
@ -4,16 +4,15 @@ import lightning as L
|
||||
import torch
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.train import FullTrainingConfig, TrainingModule
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.train import TrainingModule
|
||||
from batdetect2.train.train import build_training_module
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
|
||||
def build_default_module():
|
||||
model = build_model()
|
||||
config = FullTrainingConfig()
|
||||
return build_training_module(model, config=config)
|
||||
config = BatDetect2Config()
|
||||
return build_training_module(config=config.model_dump())
|
||||
|
||||
|
||||
def test_can_initialize_default_module():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user