mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Evaluate using Lightning too to handle device changes
This commit is contained in:
parent
6c25787123
commit
e65df81db2
@ -5,7 +5,7 @@ from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.evaluate import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate import build_evaluator, evaluate
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
@ -14,6 +14,7 @@ from batdetect2.train import train
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
EvaluatorProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
@ -28,7 +29,7 @@ class BatDetect2API:
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
evaluator: Evaluator,
|
||||
evaluator: EvaluatorProtocol,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
@ -70,6 +71,27 @@ class BatDetect2API:
|
||||
)
|
||||
return self
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
num_workers: Optional[int] = None,
|
||||
output_dir: data.PathLike = ".",
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
return evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BatDetect2Config):
|
||||
targets = build_targets(config=config.targets)
|
||||
@ -118,8 +140,14 @@ class BatDetect2API:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: data.PathLike):
|
||||
model, config = load_model_from_checkpoint(path)
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: Optional[BatDetect2Config] = None,
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
|
||||
config = config or stored_config
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
|
||||
@ -10,11 +10,17 @@ from batdetect2.cli.base import cli
|
||||
__all__ = ["evaluate_command"]
|
||||
|
||||
|
||||
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
|
||||
|
||||
@cli.command(name="evaluate")
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--output-dir", type=click.Path())
|
||||
@click.option("--workers", type=int)
|
||||
@click.option("--config", "config_path", type=click.Path())
|
||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--workers", "num_workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
@ -24,13 +30,16 @@ __all__ = ["evaluate_command"]
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
workers: Optional[int] = None,
|
||||
config_path: Optional[Path],
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: Optional[int] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api.base import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
@ -49,16 +58,16 @@ def evaluate_command(
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
model, train_config = load_model_from_checkpoint(model_path)
|
||||
config = None
|
||||
if config_path is not None:
|
||||
config = load_full_config(config_path)
|
||||
|
||||
df, results = evaluate(
|
||||
model,
|
||||
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||
|
||||
api.evaluate(
|
||||
test_annotations,
|
||||
config=train_config,
|
||||
num_workers=workers,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
||||
if output_dir:
|
||||
df.to_csv(output_dir / "results.csv")
|
||||
|
||||
@ -2,6 +2,7 @@ import sys
|
||||
from typing import Generic, Protocol, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import assert_type
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import ParamSpec
|
||||
@ -44,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]):
|
||||
config_cls: Type[T_Config],
|
||||
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
||||
) -> None:
|
||||
"""A decorator factory to register a new item."""
|
||||
fields = config_cls.model_fields
|
||||
|
||||
if "name" not in fields:
|
||||
|
||||
@ -1,9 +1,11 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
"evaluate",
|
||||
"Evaluator",
|
||||
"build_evaluator",
|
||||
]
|
||||
|
||||
@ -11,6 +11,7 @@ from batdetect2.evaluate.metrics import (
|
||||
MetricConfig,
|
||||
)
|
||||
from batdetect2.evaluate.plots import PlotConfig
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
@ -28,6 +29,7 @@ class EvaluationConfig(BaseConfig):
|
||||
]
|
||||
)
|
||||
plots: List[PlotConfig] = Field(default_factory=list)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
|
||||
144
src/batdetect2/evaluate/dataset.py
Normal file
144
src/batdetect2/evaluate/dataset.py
Normal file
@ -0,0 +1,144 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipperProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TestDataset",
|
||||
"build_test_dataset",
|
||||
"build_test_loader",
|
||||
]
|
||||
|
||||
|
||||
class TestExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: torch.Tensor
|
||||
end_time: torch.Tensor
|
||||
|
||||
|
||||
class TestDataset(Dataset[TestExample]):
|
||||
clip_annotations: List[data.ClipAnnotation]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
clipper: Optional[ClipperProtocol] = None,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
self.clip_annotations = list(clip_annotations)
|
||||
self.clipper = clipper
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.audio_dir = audio_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx: int) -> TestExample:
|
||||
clip_annotation = self.clip_annotations[idx]
|
||||
|
||||
if self.clipper is not None:
|
||||
clip_annotation = self.clipper(clip_annotation)
|
||||
|
||||
clip = clip_annotation.clip
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return TestExample(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
|
||||
class TestLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
def build_test_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader[TestExample]:
|
||||
logger.info("Building test data loader...")
|
||||
config = config or TestLoaderConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Test data loader config: \n{config}",
|
||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
test_dataset = build_test_dataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_test_dataset(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[TestLoaderConfig] = None,
|
||||
) -> TestDataset:
|
||||
logger.info("Building training dataset...")
|
||||
config = config or TestLoaderConfig()
|
||||
|
||||
clipper = build_clipper(config=config.clipping_strategy)
|
||||
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
return TestDataset(
|
||||
clip_annotations,
|
||||
audio_loader=audio_loader,
|
||||
clipper=clipper,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[TestExample]) -> TestExample:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return TestExample(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
@ -1,35 +1,44 @@
|
||||
from typing import List, Optional, Tuple
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||
from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.train import build_val_loader
|
||||
from batdetect2.typing import ClipLabeller, TargetProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: Model,
|
||||
test_annotations: List[data.ClipAnnotation],
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
labeller: Optional[ClipLabeller] = None,
|
||||
config: Optional[EvaluationConfig] = None,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> Tuple[pd.DataFrame, dict]:
|
||||
config = config or EvaluationConfig()
|
||||
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
@ -39,60 +48,21 @@ def evaluate(
|
||||
|
||||
targets = targets or build_targets()
|
||||
|
||||
labeller = labeller or build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
loader = build_val_loader(
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
dataset: ValidationDataset = loader.dataset # type: ignore
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
|
||||
clip_annotations = []
|
||||
predictions = []
|
||||
|
||||
evaluator = build_evaluator(config=config, targets=targets)
|
||||
|
||||
for batch in loader:
|
||||
outputs = model.detector(batch.spec)
|
||||
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
start_times=[
|
||||
clip_annotation.clip.start_time
|
||||
for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=model.postprocessor,
|
||||
)
|
||||
|
||||
clip_annotations.extend(clip_annotations)
|
||||
predictions.extend(predictions)
|
||||
|
||||
matches = evaluator.evaluate(clip_annotations, predictions)
|
||||
df = extract_matches_dataframe(matches)
|
||||
|
||||
metrics = [
|
||||
DetectionAP(),
|
||||
ClassificationAP(class_names=targets.class_names),
|
||||
]
|
||||
|
||||
results = {
|
||||
name: value
|
||||
for metric in metrics
|
||||
for name, value in metric(matches).items()
|
||||
}
|
||||
|
||||
return df, results
|
||||
logger = build_logger(
|
||||
config.evaluation.logger,
|
||||
log_dir=Path(output_dir),
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
module = EvaluationModule(model, evaluator)
|
||||
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||
return trainer.test(module, loader)
|
||||
|
||||
@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
MatcherProtocol,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
@ -135,7 +136,7 @@ def build_evaluator(
|
||||
matcher: Optional[MatcherProtocol] = None,
|
||||
plots: Optional[List[PlotterProtocol]] = None,
|
||||
metrics: Optional[List[MetricsProtocol]] = None,
|
||||
) -> Evaluator:
|
||||
) -> EvaluatorProtocol:
|
||||
config = config or EvaluationConfig()
|
||||
targets = targets or build_targets()
|
||||
matcher = matcher or build_matcher(config.match_strategy)
|
||||
|
||||
86
src/batdetect2/evaluate/lightning.py
Normal file
86
src/batdetect2/evaluate/lightning.py
Normal file
@ -0,0 +1,86 @@
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.tables import FullEvaluationTable
|
||||
from batdetect2.logging import get_image_logger, get_table_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
evaluator: EvaluatorProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.model = model
|
||||
self.evaluator = evaluator
|
||||
|
||||
self.clip_evaluations = []
|
||||
|
||||
def test_step(self, batch: TestExample):
|
||||
dataset = self.get_dataset()
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.evaluator.targets,
|
||||
)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
self.clip_evaluations.extend(
|
||||
self.evaluator.evaluate(clip_annotations, predictions)
|
||||
)
|
||||
|
||||
def on_test_epoch_start(self):
|
||||
self.clip_evaluations = []
|
||||
|
||||
def on_test_epoch_end(self):
|
||||
self.log_metrics(self.clip_evaluations)
|
||||
self.plot_examples(self.clip_evaluations)
|
||||
self.log_table(self.clip_evaluations)
|
||||
|
||||
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
table_logger = get_table_logger(self.logger) # type: ignore
|
||||
|
||||
if table_logger is None:
|
||||
return
|
||||
|
||||
df = FullEvaluationTable()(evaluated_clips)
|
||||
table_logger("full_evaluation", df, 0)
|
||||
|
||||
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
plotter = get_image_logger(self.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
|
||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||
plotter(figure_name, fig, self.global_step)
|
||||
|
||||
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||
self.log_dict(metrics)
|
||||
|
||||
def get_dataset(self) -> TestDataset:
|
||||
dataloaders = self.trainer.test_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, TestDataset)
|
||||
return dataset
|
||||
@ -15,10 +15,8 @@ import numpy as np
|
||||
from pydantic import Field
|
||||
from sklearn import metrics, preprocessing
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.typing import MetricsProtocol
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipEvaluation, MetricsProtocol
|
||||
|
||||
__all__ = ["DetectionAP", "ClassificationAP"]
|
||||
|
||||
@ -375,14 +373,96 @@ metrics_registry.register(
|
||||
)
|
||||
|
||||
|
||||
class ClipAPConfig(BaseConfig):
|
||||
name: Literal["clip_ap"] = "clip_ap"
|
||||
class ClipDetectionAPConfig(BaseConfig):
|
||||
name: Literal["clip_detection_ap"] = "clip_detection_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
|
||||
|
||||
class ClipDetectionAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
implementation: APImplementation,
|
||||
):
|
||||
self.implementation = implementation
|
||||
self.metric = _ap_impl_mapping[self.implementation]
|
||||
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_det = []
|
||||
clip_scores = []
|
||||
|
||||
for match in clip_eval.matches:
|
||||
clip_det.append(match.gt_det)
|
||||
clip_scores.append(match.pred_score)
|
||||
|
||||
y_true.append(any(clip_det))
|
||||
y_score.append(max(clip_scores or [0]))
|
||||
|
||||
return {"clip_detection_ap": self.metric(y_true, y_score)}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipDetectionAPConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(implementation=config.ap_implementation)
|
||||
|
||||
|
||||
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
|
||||
|
||||
|
||||
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
|
||||
|
||||
|
||||
class ClipDetectionROCAUC(MetricsProtocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]:
|
||||
y_true = []
|
||||
y_score = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
clip_det = []
|
||||
clip_scores = []
|
||||
|
||||
for match in clip_eval.matches:
|
||||
clip_det.append(match.gt_det)
|
||||
clip_scores.append(match.pred_score)
|
||||
|
||||
y_true.append(any(clip_det))
|
||||
y_score.append(max(clip_scores or [0]))
|
||||
|
||||
return {
|
||||
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipDetectionROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls()
|
||||
|
||||
|
||||
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
|
||||
|
||||
|
||||
class ClipMulticlassAPConfig(BaseConfig):
|
||||
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
|
||||
ap_implementation: APImplementation = "pascal_voc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipAP(MetricsProtocol):
|
||||
class ClipMulticlassAP(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
@ -454,15 +534,17 @@ class ClipAP(MetricsProtocol):
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_mAP": float(mean_ap),
|
||||
"clip_multiclass_mAP": float(mean_ap),
|
||||
**{
|
||||
f"clip_AP/{class_name}": class_scores[class_name]
|
||||
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: ClipAPConfig, class_names: List[str]):
|
||||
def from_config(
|
||||
cls, config: ClipMulticlassAPConfig, class_names: List[str]
|
||||
):
|
||||
return cls(
|
||||
implementation=config.ap_implementation,
|
||||
include=config.include,
|
||||
@ -471,16 +553,16 @@ class ClipAP(MetricsProtocol):
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipAPConfig, ClipAP)
|
||||
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
|
||||
|
||||
|
||||
class ClipROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_roc_auc"] = "clip_roc_auc"
|
||||
class ClipMulticlassROCAUCConfig(BaseConfig):
|
||||
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
|
||||
include: Optional[List[str]] = None
|
||||
exclude: Optional[List[str]] = None
|
||||
|
||||
|
||||
class ClipROCAUC(MetricsProtocol):
|
||||
class ClipMulticlassROCAUC(MetricsProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
class_names: List[str],
|
||||
@ -548,9 +630,11 @@ class ClipROCAUC(MetricsProtocol):
|
||||
[value for value in class_scores.values() if value != 0]
|
||||
)
|
||||
return {
|
||||
"clip_macro_ROC_AUC": float(mean_roc_auc),
|
||||
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
|
||||
**{
|
||||
f"clip_ROC_AUC/{class_name}": class_scores[class_name]
|
||||
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
|
||||
class_name
|
||||
]
|
||||
for class_name in self.selected
|
||||
},
|
||||
}
|
||||
@ -558,7 +642,7 @@ class ClipROCAUC(MetricsProtocol):
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ClipROCAUCConfig,
|
||||
config: ClipMulticlassROCAUCConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
@ -568,7 +652,7 @@ class ClipROCAUC(MetricsProtocol):
|
||||
)
|
||||
|
||||
|
||||
metrics_registry.register(ClipROCAUCConfig, ClipROCAUC)
|
||||
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
|
||||
|
||||
MetricConfig = Annotated[
|
||||
Union[
|
||||
@ -578,8 +662,10 @@ MetricConfig = Annotated[
|
||||
ClassificationROCAUCConfig,
|
||||
TopClassAPConfig,
|
||||
ClassificationBalancedAccuracyConfig,
|
||||
ClipAPConfig,
|
||||
ClipROCAUCConfig,
|
||||
ClipDetectionAPConfig,
|
||||
ClipDetectionROCAUCConfig,
|
||||
ClipMulticlassAPConfig,
|
||||
ClipMulticlassROCAUCConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -10,19 +10,18 @@ from pydantic import Field
|
||||
from sklearn import metrics
|
||||
from sklearn.preprocessing import label_binarize
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.plotting.gallery import plot_match_gallery
|
||||
from batdetect2.plotting.matches import plot_matches
|
||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||
from batdetect2.typing.evaluate import (
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipEvaluation,
|
||||
MatchEvaluation,
|
||||
PlotterProtocol,
|
||||
PreprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import AudioLoader
|
||||
|
||||
__all__ = [
|
||||
"build_plotter",
|
||||
@ -431,6 +430,62 @@ class ClassificationROCCurves(PlotterProtocol):
|
||||
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
||||
|
||||
|
||||
class ConfusionMatrixConfig(BaseConfig):
|
||||
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||
background_class: str = "noise"
|
||||
|
||||
|
||||
class ConfusionMatrix(PlotterProtocol):
|
||||
def __init__(self, background_class: str, class_names: List[str]):
|
||||
self.background_class = background_class
|
||||
self.class_names = class_names
|
||||
|
||||
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||
y_true = []
|
||||
y_pred = []
|
||||
|
||||
for clip_eval in clip_evaluations:
|
||||
for match in clip_eval.matches:
|
||||
# Ignore generic unclassified targets
|
||||
if match.gt_det and match.gt_class is None:
|
||||
continue
|
||||
|
||||
y_true.append(
|
||||
match.gt_class
|
||||
if match.gt_class is not None
|
||||
else self.background_class
|
||||
)
|
||||
|
||||
top_class = match.pred_class
|
||||
y_pred.append(
|
||||
top_class
|
||||
if top_class is not None
|
||||
else self.background_class
|
||||
)
|
||||
|
||||
display = metrics.ConfusionMatrixDisplay.from_predictions(
|
||||
y_true,
|
||||
y_pred,
|
||||
labels=[*self.class_names, self.background_class],
|
||||
)
|
||||
|
||||
yield "confusion_matrix", display.figure_
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
cls,
|
||||
config: ConfusionMatrixConfig,
|
||||
class_names: List[str],
|
||||
):
|
||||
return cls(
|
||||
background_class=config.background_class,
|
||||
class_names=class_names,
|
||||
)
|
||||
|
||||
|
||||
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
|
||||
|
||||
|
||||
PlotConfig = Annotated[
|
||||
Union[
|
||||
ExampleGalleryConfig,
|
||||
@ -439,6 +494,7 @@ PlotConfig = Annotated[
|
||||
ClassificationPRCurvesConfig,
|
||||
DetectionROCCurveConfig,
|
||||
ClassificationROCCurvesConfig,
|
||||
ConfusionMatrixConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
@ -1,18 +1,49 @@
|
||||
from typing import List
|
||||
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||
|
||||
import pandas as pd
|
||||
from pydantic import Field
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.typing.evaluate import ClipEvaluation
|
||||
from batdetect2.core import BaseConfig, Registry
|
||||
from batdetect2.typing import ClipEvaluation
|
||||
|
||||
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
|
||||
|
||||
|
||||
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
|
||||
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||
"evaluation_table"
|
||||
)
|
||||
|
||||
|
||||
class FullEvaluationTableConfig(BaseConfig):
|
||||
name: Literal["full_evaluation"] = "full_evaluation"
|
||||
|
||||
|
||||
class FullEvaluationTable:
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> pd.DataFrame:
|
||||
return extract_matches_dataframe(clip_evaluations)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: FullEvaluationTableConfig):
|
||||
return cls()
|
||||
|
||||
|
||||
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
|
||||
|
||||
|
||||
def extract_matches_dataframe(
|
||||
clip_evaluations: Sequence[ClipEvaluation],
|
||||
) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
for clip_evaluation in clip_evaluations:
|
||||
for match in clip_evaluation.matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = (
|
||||
pred_high_freq
|
||||
) = None
|
||||
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
|
||||
@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||
compute_bounds(match.pred_geometry)
|
||||
)
|
||||
(
|
||||
pred_start_time,
|
||||
pred_low_freq,
|
||||
pred_end_time,
|
||||
pred_high_freq,
|
||||
) = compute_bounds(match.pred_geometry)
|
||||
|
||||
data.append(
|
||||
{
|
||||
@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||
return df
|
||||
|
||||
|
||||
EvaluationTableConfig = Annotated[
|
||||
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||
]
|
||||
|
||||
|
||||
def build_table_generator(
|
||||
config: EvaluationTableConfig,
|
||||
) -> EvaluationTableGenerator:
|
||||
return tables_registry.build(config)
|
||||
@ -67,7 +67,7 @@ from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.typing.models import DetectionModel
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsTensor,
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
@ -121,7 +121,7 @@ class Model(torch.nn.Module):
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||
spec = self.preprocessor(wav)
|
||||
outputs = self.detector(spec)
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
@ -8,13 +8,10 @@ from batdetect2.postprocess.decoding import (
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
to_raw_predictions,
|
||||
)
|
||||
from batdetect2.postprocess.nms import (
|
||||
non_max_suppression,
|
||||
)
|
||||
from batdetect2.postprocess.nms import non_max_suppression
|
||||
from batdetect2.postprocess.postprocessor import (
|
||||
Postprocessor,
|
||||
build_postprocessor,
|
||||
get_raw_predictions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@ -25,5 +22,4 @@ __all__ = [
|
||||
"to_raw_predictions",
|
||||
"load_postprocess_config",
|
||||
"non_max_suppression",
|
||||
"get_raw_predictions",
|
||||
]
|
||||
|
||||
@ -6,7 +6,7 @@ import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsArray,
|
||||
ClipDetectionsArray,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
@ -28,7 +28,7 @@ decoding.
|
||||
|
||||
|
||||
def to_raw_predictions(
|
||||
detections: DetectionsArray,
|
||||
detections: ClipDetectionsArray,
|
||||
targets: TargetProtocol,
|
||||
) -> List[RawPrediction]:
|
||||
predictions = []
|
||||
|
||||
@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates
|
||||
all extracted information into a structured `xarray.Dataset`.
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.typing.postprocess import (
|
||||
DetectionsTensor,
|
||||
ModelOutput,
|
||||
)
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"extract_prediction_tensor",
|
||||
"extract_detection_peaks",
|
||||
]
|
||||
|
||||
|
||||
def extract_prediction_tensor(
|
||||
output: ModelOutput,
|
||||
def extract_detection_peaks(
|
||||
detection_heatmap: torch.Tensor,
|
||||
size_heatmap: torch.Tensor,
|
||||
feature_heatmap: torch.Tensor,
|
||||
classification_heatmap: torch.Tensor,
|
||||
max_detections: int = 200,
|
||||
threshold: Optional[float] = None,
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> List[DetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=nms_kernel_size,
|
||||
)
|
||||
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
height = detection_heatmap.shape[-2]
|
||||
width = detection_heatmap.shape[-1]
|
||||
|
||||
@ -53,9 +46,9 @@ def extract_prediction_tensor(
|
||||
freqs = freqs.flatten().to(detection_heatmap.device)
|
||||
times = times.flatten().to(detection_heatmap.device)
|
||||
|
||||
output_size_preds = output.size_preds.detach()
|
||||
output_features = output.features.detach()
|
||||
output_class_probs = output.class_probs.detach()
|
||||
output_size_preds = size_heatmap.detach()
|
||||
output_features = feature_heatmap.detach()
|
||||
output_class_probs = classification_heatmap.detach()
|
||||
|
||||
predictions = []
|
||||
for idx, item in enumerate(detection_heatmap):
|
||||
@ -65,23 +58,25 @@ def extract_prediction_tensor(
|
||||
detection_scores = item.take(indices)
|
||||
detection_freqs = freqs.take(indices)
|
||||
detection_times = times.take(indices)
|
||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output_class_probs[
|
||||
idx, :, detection_freqs, detection_times
|
||||
].T
|
||||
|
||||
if threshold is not None:
|
||||
mask = detection_scores >= threshold
|
||||
|
||||
detection_scores = detection_scores[mask]
|
||||
sizes = sizes[mask]
|
||||
detection_times = detection_times[mask]
|
||||
detection_freqs = detection_freqs[mask]
|
||||
features = features[mask]
|
||||
class_scores = class_scores[mask]
|
||||
|
||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||
class_scores = output_class_probs[
|
||||
idx,
|
||||
:,
|
||||
detection_freqs,
|
||||
detection_times,
|
||||
].T
|
||||
|
||||
predictions.append(
|
||||
DetectionsTensor(
|
||||
ClipDetectionsTensor(
|
||||
scores=detection_scores,
|
||||
sizes=sizes,
|
||||
features=features,
|
||||
|
||||
@ -1,29 +1,20 @@
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.postprocess.config import (
|
||||
PostprocessConfig,
|
||||
)
|
||||
from batdetect2.postprocess.decoding import (
|
||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
convert_raw_prediction_to_sound_event_prediction,
|
||||
convert_raw_predictions_to_clip_prediction,
|
||||
to_raw_predictions,
|
||||
)
|
||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
||||
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||
from batdetect2.typing import ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
BatDetect2Prediction,
|
||||
DetectionsTensor,
|
||||
ClipDetectionsTensor,
|
||||
PostprocessorProtocol,
|
||||
RawPrediction,
|
||||
)
|
||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"build_postprocessor",
|
||||
@ -60,24 +51,43 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
max_freq: float,
|
||||
top_k_per_sec: int = 200,
|
||||
detection_threshold: float = 0.01,
|
||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
):
|
||||
"""Initialize the Postprocessor."""
|
||||
super().__init__()
|
||||
self.samplerate = samplerate
|
||||
|
||||
self.output_samplerate = samplerate
|
||||
self.min_freq = min_freq
|
||||
self.max_freq = max_freq
|
||||
self.top_k_per_sec = top_k_per_sec
|
||||
self.detection_threshold = detection_threshold
|
||||
self.nms_kernel_size = nms_kernel_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[ClipDetectionsTensor]:
|
||||
detection_heatmap = non_max_suppression(
|
||||
output.detection_probs.detach(),
|
||||
kernel_size=self.nms_kernel_size,
|
||||
)
|
||||
|
||||
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
duration = width / self.output_samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
detections = extract_detection_peaks(
|
||||
detection_heatmap,
|
||||
size_heatmap=output.size_preds,
|
||||
feature_heatmap=output.features,
|
||||
classification_heatmap=output.class_probs,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
start_times = [0 for _ in range(len(detections))]
|
||||
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
@ -88,121 +98,3 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
||||
)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
def get_detections(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]:
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
max_detections = int(self.top_k_per_sec * duration)
|
||||
|
||||
detections = extract_prediction_tensor(
|
||||
output,
|
||||
max_detections=max_detections,
|
||||
threshold=self.detection_threshold,
|
||||
)
|
||||
|
||||
if start_times is None:
|
||||
return detections
|
||||
|
||||
width = output.detection_probs.shape[-1]
|
||||
duration = width / self.samplerate
|
||||
return [
|
||||
map_detection_to_clip(
|
||||
detection,
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
min_freq=self.min_freq,
|
||||
max_freq=self.max_freq,
|
||||
)
|
||||
for detection, start_time in zip(detections, start_times)
|
||||
]
|
||||
|
||||
|
||||
def get_raw_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
"""Extract intermediate RawPrediction objects for a batch."""
|
||||
detections = postprocessor.get_detections(output, start_times)
|
||||
return [
|
||||
to_raw_predictions(detection.numpy(), targets=targets)
|
||||
for detection in detections
|
||||
]
|
||||
|
||||
|
||||
def get_sound_event_predictions(
|
||||
output: ModelOutput,
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
clips: List[data.Clip],
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[List[BatDetect2Prediction]]:
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
[
|
||||
BatDetect2Prediction(
|
||||
raw=raw,
|
||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
||||
raw,
|
||||
recording=clip.recording,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
),
|
||||
)
|
||||
for raw in predictions
|
||||
]
|
||||
for predictions, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
|
||||
def get_predictions(
|
||||
output: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
targets: TargetProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Perform the full postprocessing pipeline for a batch.
|
||||
|
||||
Takes raw model output and corresponding clips, applies the entire
|
||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
output : ModelOutput
|
||||
Raw output from the neural network model for a batch.
|
||||
clips : List[data.Clip]
|
||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
||||
|
||||
Returns
|
||||
-------
|
||||
List[data.ClipPrediction]
|
||||
List containing one `ClipPrediction` object for each input clip,
|
||||
populated with `SoundEventPrediction` objects.
|
||||
"""
|
||||
raw_predictions = get_raw_predictions(
|
||||
output,
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
return [
|
||||
convert_raw_predictions_to_clip_prediction(
|
||||
prediction,
|
||||
clip,
|
||||
targets=targets,
|
||||
classification_threshold=classification_threshold,
|
||||
)
|
||||
for prediction, clip in zip(raw_predictions, clips)
|
||||
]
|
||||
|
||||
@ -20,7 +20,7 @@ import xarray as xr
|
||||
from soundevent.arrays import Dimensions
|
||||
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.typing.postprocess import DetectionsTensor
|
||||
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||
|
||||
__all__ = [
|
||||
"features_to_xarray",
|
||||
@ -31,15 +31,15 @@ __all__ = [
|
||||
|
||||
|
||||
def map_detection_to_clip(
|
||||
detections: DetectionsTensor,
|
||||
detections: ClipDetectionsTensor,
|
||||
start_time: float,
|
||||
end_time: float,
|
||||
min_freq: float,
|
||||
max_freq: float,
|
||||
) -> DetectionsTensor:
|
||||
) -> ClipDetectionsTensor:
|
||||
duration = end_time - start_time
|
||||
bandwidth = max_freq - min_freq
|
||||
return DetectionsTensor(
|
||||
return ClipDetectionsTensor(
|
||||
scores=detections.scores,
|
||||
sizes=detections.sizes,
|
||||
features=detections.features,
|
||||
|
||||
@ -14,7 +14,6 @@ from batdetect2.train.augmentations import (
|
||||
scale_volume,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import (
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
@ -61,7 +60,6 @@ __all__ = [
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
"build_clip_labeler",
|
||||
"build_clipper",
|
||||
"build_loss",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
@ -74,7 +72,6 @@ __all__ = [
|
||||
"mask_time",
|
||||
"mix_audio",
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
]
|
||||
|
||||
@ -11,9 +11,9 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import scale_geometry, shift_geometry
|
||||
|
||||
from batdetect2.audio.clips import get_subclip_annotation
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.train.clips import get_subclip_annotation
|
||||
from batdetect2.typing import AudioLoader, Augmentation
|
||||
|
||||
__all__ = [
|
||||
|
||||
@ -5,13 +5,13 @@ from lightning.pytorch.callbacks import Callback
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate import Evaluator
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import get_image_plotter
|
||||
from batdetect2.typing import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
ModelOutput,
|
||||
RawPrediction,
|
||||
TrainExample,
|
||||
@ -19,7 +19,7 @@ from batdetect2.typing import (
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self, evaluator: Evaluator):
|
||||
def __init__(self, evaluator: EvaluatorProtocol):
|
||||
super().__init__()
|
||||
|
||||
self.evaluator = evaluator
|
||||
@ -34,12 +34,12 @@ class ValidationMetrics(Callback):
|
||||
assert isinstance(dataset, ValidationDataset)
|
||||
return dataset
|
||||
|
||||
def plot_examples(
|
||||
def generate_plots(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
evaluated_clips: List[ClipEvaluation],
|
||||
):
|
||||
plotter = get_image_plotter(pl_module.logger) # type: ignore
|
||||
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
return
|
||||
@ -66,7 +66,7 @@ class ValidationMetrics(Callback):
|
||||
)
|
||||
|
||||
self.log_metrics(pl_module, clip_evaluations)
|
||||
self.plot_examples(pl_module, clip_evaluations)
|
||||
self.generate_plots(pl_module, clip_evaluations)
|
||||
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
@ -88,8 +88,7 @@ class ValidationMetrics(Callback):
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> None:
|
||||
postprocessor = pl_module.model.postprocessor
|
||||
targets = pl_module.model.targets
|
||||
model = pl_module.model
|
||||
dataset = self.get_dataset(trainer)
|
||||
|
||||
clip_annotations = [
|
||||
@ -97,15 +96,14 @@ class ValidationMetrics(Callback):
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
clip_detections = model.postprocessor(
|
||||
outputs,
|
||||
start_times=[
|
||||
clip_annotation.clip.start_time
|
||||
for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=postprocessor,
|
||||
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||
)
|
||||
predictions = [
|
||||
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
self._clip_annotations.extend(clip_annotations)
|
||||
self._predictions.extend(predictions)
|
||||
|
||||
@ -5,17 +5,9 @@ from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
)
|
||||
from batdetect2.train.clips import (
|
||||
ClipConfig,
|
||||
PaddedClipConfig,
|
||||
RandomClipConfig,
|
||||
)
|
||||
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import LabelConfig
|
||||
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
@ -45,30 +37,6 @@ class PLTrainerConfig(BaseConfig):
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
class ValLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
class TrainLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
batch_size: int = 8
|
||||
|
||||
shuffle: bool = False
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||
)
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
@ -79,9 +47,8 @@ class TrainingConfig(BaseConfig):
|
||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
|
||||
|
||||
@ -2,18 +2,21 @@ from typing import List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||
from batdetect2.audio.clips import PaddedClipConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
RandomAudioSource,
|
||||
build_augmentations,
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
@ -144,6 +147,22 @@ class ValidationDataset(Dataset):
|
||||
)
|
||||
|
||||
|
||||
class TrainLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
batch_size: int = 8
|
||||
|
||||
shuffle: bool = False
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||
)
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
def build_train_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
@ -178,6 +197,14 @@ def build_train_loader(
|
||||
)
|
||||
|
||||
|
||||
class ValLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
|
||||
|
||||
def build_val_loader(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
|
||||
@ -8,20 +8,21 @@ from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
from batdetect2.train.logging import build_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
ClipLabeller,
|
||||
EvaluatorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
@ -122,7 +123,7 @@ def train(
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional[Evaluator] = None,
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
@ -142,7 +143,6 @@ def build_trainer_callbacks(
|
||||
ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=1,
|
||||
filename="best-{epoch:02d}-{val_loss:.0f}",
|
||||
monitor="total_loss/val",
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
@ -152,7 +152,7 @@ def build_trainer_callbacks(
|
||||
def build_trainer(
|
||||
conf: "BatDetect2Config",
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional[Evaluator] = None,
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
from batdetect2.typing.evaluate import (
|
||||
ClipEvaluation,
|
||||
EvaluatorProtocol,
|
||||
MatchEvaluation,
|
||||
MetricsProtocol,
|
||||
PlotterProtocol,
|
||||
)
|
||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||
from batdetect2.typing.postprocess import (
|
||||
@ -49,6 +51,7 @@ __all__ = [
|
||||
"MatchEvaluation",
|
||||
"MetricsProtocol",
|
||||
"ModelOutput",
|
||||
"PlotterProtocol",
|
||||
"Position",
|
||||
"PostprocessorProtocol",
|
||||
"PreprocessorProtocol",
|
||||
@ -60,4 +63,5 @@ __all__ = [
|
||||
"SoundEventFilter",
|
||||
"TargetProtocol",
|
||||
"TrainExample",
|
||||
"EvaluatorProtocol",
|
||||
]
|
||||
|
||||
@ -14,7 +14,11 @@ from typing import (
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
from batdetect2.typing.targets import TargetProtocol
|
||||
|
||||
__all__ = [
|
||||
"EvaluatorProtocol",
|
||||
"MetricsProtocol",
|
||||
"MatchEvaluation",
|
||||
]
|
||||
@ -107,3 +111,21 @@ class PlotterProtocol(Protocol):
|
||||
def __call__(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
|
||||
class EvaluatorProtocol(Protocol):
|
||||
targets: TargetProtocol
|
||||
|
||||
def evaluate(
|
||||
self,
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
predictions: Sequence[Sequence[RawPrediction]],
|
||||
) -> List[ClipEvaluation]: ...
|
||||
|
||||
def compute_metrics(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Dict[str, float]: ...
|
||||
|
||||
def generate_plots(
|
||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||
) -> Iterable[Tuple[str, Figure]]: ...
|
||||
|
||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, NamedTuple, Optional, Protocol
|
||||
from typing import List, NamedTuple, Optional, Protocol, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -47,15 +47,13 @@ class GeometryDecoder(Protocol):
|
||||
|
||||
|
||||
class RawPrediction(NamedTuple):
|
||||
"""Intermediate representation of a single detected sound event."""
|
||||
|
||||
geometry: data.Geometry
|
||||
detection_score: float
|
||||
class_scores: np.ndarray
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class DetectionsArray(NamedTuple):
|
||||
class ClipDetectionsArray(NamedTuple):
|
||||
scores: np.ndarray
|
||||
sizes: np.ndarray
|
||||
class_scores: np.ndarray
|
||||
@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple):
|
||||
features: np.ndarray
|
||||
|
||||
|
||||
class DetectionsTensor(NamedTuple):
|
||||
class ClipDetectionsTensor(NamedTuple):
|
||||
scores: torch.Tensor
|
||||
sizes: torch.Tensor
|
||||
class_scores: torch.Tensor
|
||||
@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple):
|
||||
frequencies: torch.Tensor
|
||||
features: torch.Tensor
|
||||
|
||||
def numpy(self) -> DetectionsArray:
|
||||
return DetectionsArray(
|
||||
def numpy(self) -> ClipDetectionsArray:
|
||||
return ClipDetectionsArray(
|
||||
scores=self.scores.detach().cpu().numpy(),
|
||||
sizes=self.sizes.detach().cpu().numpy(),
|
||||
class_scores=self.class_scores.detach().cpu().numpy(),
|
||||
@ -92,10 +90,8 @@ class BatDetect2Prediction:
|
||||
class PostprocessorProtocol(Protocol):
|
||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||
|
||||
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
|
||||
|
||||
def get_detections(
|
||||
def __call__(
|
||||
self,
|
||||
output: ModelOutput,
|
||||
start_times: Optional[List[float]] = None,
|
||||
) -> List[DetectionsTensor]: ...
|
||||
start_times: Optional[Sequence[float]] = None,
|
||||
) -> List[ClipDetectionsTensor]: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user