Evaluate using Lightning too to handle device changes

This commit is contained in:
mbsantiago 2025-09-18 09:28:21 +01:00
parent 6c25787123
commit e65df81db2
27 changed files with 700 additions and 377 deletions

View File

@ -5,7 +5,7 @@ from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import Evaluator, build_evaluator from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
@ -14,6 +14,7 @@ from batdetect2.train import train
from batdetect2.train.lightning import load_model_from_checkpoint from batdetect2.train.lightning import load_model_from_checkpoint
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
EvaluatorProtocol,
PostprocessorProtocol, PostprocessorProtocol,
PreprocessorProtocol, PreprocessorProtocol,
TargetProtocol, TargetProtocol,
@ -28,7 +29,7 @@ class BatDetect2API:
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
evaluator: Evaluator, evaluator: EvaluatorProtocol,
model: Model, model: Model,
): ):
self.config = config self.config = config
@ -70,6 +71,27 @@ class BatDetect2API:
) )
return self return self
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None,
output_dir: data.PathLike = ".",
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
return evaluate(
self.model,
test_annotations,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
)
@classmethod @classmethod
def from_config(cls, config: BatDetect2Config): def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)
@ -118,8 +140,14 @@ class BatDetect2API:
) )
@classmethod @classmethod
def from_checkpoint(cls, path: data.PathLike): def from_checkpoint(
model, config = load_model_from_checkpoint(path) cls,
path: data.PathLike,
config: Optional[BatDetect2Config] = None,
):
model, stored_config = load_model_from_checkpoint(path)
config = config or stored_config
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)

View File

@ -10,11 +10,17 @@ from batdetect2.cli.base import cli
__all__ = ["evaluate_command"] __all__ = ["evaluate_command"]
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate") @cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True)) @click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True)) @click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--output-dir", type=click.Path()) @click.option("--config", "config_path", type=click.Path())
@click.option("--workers", type=int) @click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int)
@click.option( @click.option(
"-v", "-v",
"--verbose", "--verbose",
@ -24,13 +30,16 @@ __all__ = ["evaluate_command"]
def evaluate_command( def evaluate_command(
model_path: Path, model_path: Path,
test_dataset: Path, test_dataset: Path,
output_dir: Optional[Path] = None, config_path: Optional[Path],
workers: Optional[int] = None, output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0, verbose: int = 0,
): ):
from batdetect2.api.base import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
logger.remove() logger.remove()
if verbose == 0: if verbose == 0:
@ -49,16 +58,16 @@ def evaluate_command(
num_annotations=len(test_annotations), num_annotations=len(test_annotations),
) )
model, train_config = load_model_from_checkpoint(model_path) config = None
if config_path is not None:
config = load_full_config(config_path)
df, results = evaluate( api = BatDetect2API.from_checkpoint(model_path, config=config)
model,
api.evaluate(
test_annotations, test_annotations,
config=train_config, num_workers=num_workers,
num_workers=workers, output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
) )
print(results)
if output_dir:
df.to_csv(output_dir / "results.csv")

View File

@ -2,6 +2,7 @@ import sys
from typing import Generic, Protocol, Type, TypeVar from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import assert_type
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import ParamSpec from typing import ParamSpec
@ -44,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]):
config_cls: Type[T_Config], config_cls: Type[T_Config],
logic_cls: LogicProtocol[T_Config, T_Type, P_Type], logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
) -> None: ) -> None:
"""A decorator factory to register a new item."""
fields = config_cls.model_fields fields = config_cls.model_fields
if "name" not in fields: if "name" not in fields:

View File

@ -1,9 +1,11 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
"load_evaluation_config", "load_evaluation_config",
"evaluate",
"Evaluator", "Evaluator",
"build_evaluator", "build_evaluator",
] ]

View File

@ -11,6 +11,7 @@ from batdetect2.evaluate.metrics import (
MetricConfig, MetricConfig,
) )
from batdetect2.evaluate.plots import PlotConfig from batdetect2.evaluate.plots import PlotConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
@ -28,6 +29,7 @@ class EvaluationConfig(BaseConfig):
] ]
) )
plots: List[PlotConfig] = Field(default_factory=list) plots: List[PlotConfig] = Field(default_factory=list)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def load_evaluation_config( def load_evaluation_config(

View File

@ -0,0 +1,144 @@
from typing import List, NamedTuple, Optional, Sequence
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipperProtocol,
PreprocessorProtocol,
)
__all__ = [
"TestDataset",
"build_test_dataset",
"build_test_loader",
]
class TestExample(NamedTuple):
spec: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class TestDataset(Dataset[TestExample]):
clip_annotations: List[data.ClipAnnotation]
def __init__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = list(clip_annotations)
self.clipper = clipper
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return TestExample(
spec=spectrogram,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader[TestExample]:
logger.info("Building test data loader...")
config = config or TestLoaderConfig()
logger.opt(lazy=True).debug(
"Test data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
test_dataset = build_test_dataset(
clip_annotations,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
) -> TestDataset:
logger.info("Building training dataset...")
config = config or TestLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
return TestDataset(
clip_annotations,
audio_loader=audio_loader,
clipper=clipper,
preprocessor=preprocessor,
)
def _collate_fn(batch: List[TestExample]) -> TestExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TestExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -1,35 +1,44 @@
from typing import List, Optional, Tuple from pathlib import Path
from typing import TYPE_CHECKING, Optional, Sequence
import pandas as pd from lightning import Trainer
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol
from batdetect2.postprocess import get_raw_predictions
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.labels import build_clip_labeler if TYPE_CHECKING:
from batdetect2.train.train import build_val_loader from batdetect2.config import BatDetect2Config
from batdetect2.typing import ClipLabeller, TargetProtocol from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
def evaluate( def evaluate(
model: Model, model: Model,
test_annotations: List[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
targets: Optional[TargetProtocol] = None, targets: Optional["TargetProtocol"] = None,
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional[PreprocessorProtocol] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
labeller: Optional[ClipLabeller] = None, config: Optional["BatDetect2Config"] = None,
config: Optional[EvaluationConfig] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
) -> Tuple[pd.DataFrame, dict]: output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
config = config or EvaluationConfig() experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader() audio_loader = audio_loader or build_audio_loader()
@ -39,60 +48,21 @@ def evaluate(
targets = targets or build_targets() targets = targets or build_targets()
labeller = labeller or build_clip_labeler( loader = build_test_loader(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
loader = build_val_loader(
test_annotations, test_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor, preprocessor=preprocessor,
num_workers=num_workers, num_workers=num_workers,
) )
dataset: ValidationDataset = loader.dataset # type: ignore evaluator = build_evaluator(config=config.evaluation, targets=targets)
clip_annotations = [] logger = build_logger(
predictions = [] config.evaluation.logger,
log_dir=Path(output_dir),
evaluator = build_evaluator(config=config, targets=targets) experiment_name=experiment_name,
run_name=run_name,
for batch in loader: )
outputs = model.detector(batch.spec) module = EvaluationModule(model, evaluator)
trainer = Trainer(logger=logger, enable_checkpointing=False)
clip_annotations = [ return trainer.test(module, loader)
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
predictions = get_raw_predictions(
outputs,
start_times=[
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,
)
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matches = evaluator.evaluate(clip_annotations, predictions)
df = extract_matches_dataframe(matches)
metrics = [
DetectionAP(),
ClassificationAP(class_names=targets.class_names),
]
results = {
name: value
for metric in metrics
for name, value in metric(matches).items()
}
return df, results

View File

@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing.evaluate import ( from batdetect2.typing.evaluate import (
ClipEvaluation, ClipEvaluation,
EvaluatorProtocol,
MatcherProtocol, MatcherProtocol,
MetricsProtocol, MetricsProtocol,
PlotterProtocol, PlotterProtocol,
@ -135,7 +136,7 @@ def build_evaluator(
matcher: Optional[MatcherProtocol] = None, matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None, plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None, metrics: Optional[List[MetricsProtocol]] = None,
) -> Evaluator: ) -> EvaluatorProtocol:
config = config or EvaluationConfig() config = config or EvaluationConfig()
targets = targets or build_targets() targets = targets or build_targets()
matcher = matcher or build_matcher(config.match_strategy) matcher = matcher or build_matcher(config.match_strategy)

View File

@ -0,0 +1,86 @@
from typing import Sequence
from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.tables import FullEvaluationTable
from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
class EvaluationModule(LightningModule):
def __init__(
self,
model: Model,
evaluator: EvaluatorProtocol,
):
super().__init__()
self.model = model
self.evaluator = evaluator
self.clip_evaluations = []
def test_step(self, batch: TestExample):
dataset = self.get_dataset()
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
)
for clip_dets in clip_detections
]
self.clip_evaluations.extend(
self.evaluator.evaluate(clip_annotations, predictions)
)
def on_test_epoch_start(self):
self.clip_evaluations = []
def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations)
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None:
return
df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
plotter = get_image_logger(self.logger) # type: ignore
if plotter is None:
return
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics)
def get_dataset(self) -> TestDataset:
dataloaders = self.trainer.test_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, TestDataset)
return dataset

View File

@ -15,10 +15,8 @@ import numpy as np
from pydantic import Field from pydantic import Field
from sklearn import metrics, preprocessing from sklearn import metrics, preprocessing
from batdetect2.core.configs import BaseConfig from batdetect2.core import BaseConfig, Registry
from batdetect2.core.registries import Registry from batdetect2.typing import ClipEvaluation, MetricsProtocol
from batdetect2.typing import MetricsProtocol
from batdetect2.typing.evaluate import ClipEvaluation
__all__ = ["DetectionAP", "ClassificationAP"] __all__ = ["DetectionAP", "ClassificationAP"]
@ -375,14 +373,96 @@ metrics_registry.register(
) )
class ClipAPConfig(BaseConfig): class ClipDetectionAPConfig(BaseConfig):
name: Literal["clip_ap"] = "clip_ap" name: Literal["clip_detection_ap"] = "clip_detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class ClipDetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {"clip_detection_ap": self.metric(y_true, y_score)}
@classmethod
def from_config(
cls,
config: ClipDetectionAPConfig,
class_names: List[str],
):
return cls(implementation=config.ap_implementation)
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
class ClipDetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
}
@classmethod
def from_config(
cls,
config: ClipDetectionROCAUCConfig,
class_names: List[str],
):
return cls()
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
class ClipMulticlassAPConfig(BaseConfig):
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
ap_implementation: APImplementation = "pascal_voc" ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None include: Optional[List[str]] = None
exclude: Optional[List[str]] = None exclude: Optional[List[str]] = None
class ClipAP(MetricsProtocol): class ClipMulticlassAP(MetricsProtocol):
def __init__( def __init__(
self, self,
class_names: List[str], class_names: List[str],
@ -454,15 +534,17 @@ class ClipAP(MetricsProtocol):
[value for value in class_scores.values() if value != 0] [value for value in class_scores.values() if value != 0]
) )
return { return {
"clip_mAP": float(mean_ap), "clip_multiclass_mAP": float(mean_ap),
**{ **{
f"clip_AP/{class_name}": class_scores[class_name] f"clip_multiclass_AP/{class_name}": class_scores[class_name]
for class_name in self.selected for class_name in self.selected
}, },
} }
@classmethod @classmethod
def from_config(cls, config: ClipAPConfig, class_names: List[str]): def from_config(
cls, config: ClipMulticlassAPConfig, class_names: List[str]
):
return cls( return cls(
implementation=config.ap_implementation, implementation=config.ap_implementation,
include=config.include, include=config.include,
@ -471,16 +553,16 @@ class ClipAP(MetricsProtocol):
) )
metrics_registry.register(ClipAPConfig, ClipAP) metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
class ClipROCAUCConfig(BaseConfig): class ClipMulticlassROCAUCConfig(BaseConfig):
name: Literal["clip_roc_auc"] = "clip_roc_auc" name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
include: Optional[List[str]] = None include: Optional[List[str]] = None
exclude: Optional[List[str]] = None exclude: Optional[List[str]] = None
class ClipROCAUC(MetricsProtocol): class ClipMulticlassROCAUC(MetricsProtocol):
def __init__( def __init__(
self, self,
class_names: List[str], class_names: List[str],
@ -548,9 +630,11 @@ class ClipROCAUC(MetricsProtocol):
[value for value in class_scores.values() if value != 0] [value for value in class_scores.values() if value != 0]
) )
return { return {
"clip_macro_ROC_AUC": float(mean_roc_auc), "clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
**{ **{
f"clip_ROC_AUC/{class_name}": class_scores[class_name] f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected for class_name in self.selected
}, },
} }
@ -558,7 +642,7 @@ class ClipROCAUC(MetricsProtocol):
@classmethod @classmethod
def from_config( def from_config(
cls, cls,
config: ClipROCAUCConfig, config: ClipMulticlassROCAUCConfig,
class_names: List[str], class_names: List[str],
): ):
return cls( return cls(
@ -568,7 +652,7 @@ class ClipROCAUC(MetricsProtocol):
) )
metrics_registry.register(ClipROCAUCConfig, ClipROCAUC) metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
MetricConfig = Annotated[ MetricConfig = Annotated[
Union[ Union[
@ -578,8 +662,10 @@ MetricConfig = Annotated[
ClassificationROCAUCConfig, ClassificationROCAUCConfig,
TopClassAPConfig, TopClassAPConfig,
ClassificationBalancedAccuracyConfig, ClassificationBalancedAccuracyConfig,
ClipAPConfig, ClipDetectionAPConfig,
ClipROCAUCConfig, ClipDetectionROCAUCConfig,
ClipMulticlassAPConfig,
ClipMulticlassROCAUCConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -10,19 +10,18 @@ from pydantic import Field
from sklearn import metrics from sklearn import metrics
from sklearn.preprocessing import label_binarize from sklearn.preprocessing import label_binarize
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core.configs import BaseConfig from batdetect2.core import BaseConfig, Registry
from batdetect2.core.registries import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.plotting.gallery import plot_match_gallery from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.matches import plot_matches from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing.evaluate import ( from batdetect2.typing import (
AudioLoader,
ClipEvaluation, ClipEvaluation,
MatchEvaluation, MatchEvaluation,
PlotterProtocol, PlotterProtocol,
PreprocessorProtocol,
) )
from batdetect2.typing.preprocess import AudioLoader
__all__ = [ __all__ = [
"build_plotter", "build_plotter",
@ -431,6 +430,62 @@ class ClassificationROCCurves(PlotterProtocol):
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves) plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
class ConfusionMatrixConfig(BaseConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
background_class: str = "noise"
class ConfusionMatrix(PlotterProtocol):
def __init__(self, background_class: str, class_names: List[str]):
self.background_class = background_class
self.class_names = class_names
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else self.background_class
)
top_class = match.pred_class
y_pred.append(
top_class
if top_class is not None
else self.background_class
)
display = metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=[*self.class_names, self.background_class],
)
yield "confusion_matrix", display.figure_
@classmethod
def from_config(
cls,
config: ConfusionMatrixConfig,
class_names: List[str],
):
return cls(
background_class=config.background_class,
class_names=class_names,
)
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
PlotConfig = Annotated[ PlotConfig = Annotated[
Union[ Union[
ExampleGalleryConfig, ExampleGalleryConfig,
@ -439,6 +494,7 @@ PlotConfig = Annotated[
ClassificationPRCurvesConfig, ClassificationPRCurvesConfig,
DetectionROCCurveConfig, DetectionROCCurveConfig,
ClassificationROCCurvesConfig, ClassificationROCCurvesConfig,
ConfusionMatrixConfig,
], ],
Field(discriminator="name"), Field(discriminator="name"),
] ]

View File

@ -1,18 +1,49 @@
from typing import List from typing import Annotated, Callable, Literal, Sequence, Union
import pandas as pd import pandas as pd
from pydantic import Field
from soundevent.geometry import compute_bounds from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import ClipEvaluation from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame: tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
"evaluation_table"
)
class FullEvaluationTableConfig(BaseConfig):
name: Literal["full_evaluation"] = "full_evaluation"
class FullEvaluationTable:
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations)
@classmethod
def from_config(cls, config: FullEvaluationTableConfig):
return cls()
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
def extract_matches_dataframe(
clip_evaluations: Sequence[ClipEvaluation],
) -> pd.DataFrame:
data = [] data = []
for clip_evaluation in clip_evaluations: for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches: for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None pred_start_time = pred_low_freq = pred_end_time = (
pred_high_freq
) = None
sound_event_annotation = match.sound_event_annotation sound_event_annotation = match.sound_event_annotation
@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
) )
if match.pred_geometry is not None: if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = ( (
compute_bounds(match.pred_geometry) pred_start_time,
) pred_low_freq,
pred_end_time,
pred_high_freq,
) = compute_bounds(match.pred_geometry)
data.append( data.append(
{ {
@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
df = pd.DataFrame(data) df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df return df
EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name")
]
def build_table_generator(
config: EvaluationTableConfig,
) -> EvaluationTableGenerator:
return tables_registry.build(config)

View File

@ -67,7 +67,7 @@ from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.typing.models import DetectionModel from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
DetectionsTensor, ClipDetectionsTensor,
PostprocessorProtocol, PostprocessorProtocol,
) )
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
@ -121,7 +121,7 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.targets = targets
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]: def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
spec = self.preprocessor(wav) spec = self.preprocessor(wav)
outputs = self.detector(spec) outputs = self.detector(spec)
return self.postprocessor(outputs) return self.postprocessor(outputs)

View File

@ -8,13 +8,10 @@ from batdetect2.postprocess.decoding import (
convert_raw_predictions_to_clip_prediction, convert_raw_predictions_to_clip_prediction,
to_raw_predictions, to_raw_predictions,
) )
from batdetect2.postprocess.nms import ( from batdetect2.postprocess.nms import non_max_suppression
non_max_suppression,
)
from batdetect2.postprocess.postprocessor import ( from batdetect2.postprocess.postprocessor import (
Postprocessor, Postprocessor,
build_postprocessor, build_postprocessor,
get_raw_predictions,
) )
__all__ = [ __all__ = [
@ -25,5 +22,4 @@ __all__ = [
"to_raw_predictions", "to_raw_predictions",
"load_postprocess_config", "load_postprocess_config",
"non_max_suppression", "non_max_suppression",
"get_raw_predictions",
] ]

View File

@ -6,7 +6,7 @@ import numpy as np
from soundevent import data from soundevent import data
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
DetectionsArray, ClipDetectionsArray,
RawPrediction, RawPrediction,
) )
from batdetect2.typing.targets import TargetProtocol from batdetect2.typing.targets import TargetProtocol
@ -28,7 +28,7 @@ decoding.
def to_raw_predictions( def to_raw_predictions(
detections: DetectionsArray, detections: ClipDetectionsArray,
targets: TargetProtocol, targets: TargetProtocol,
) -> List[RawPrediction]: ) -> List[RawPrediction]:
predictions = [] predictions = []

View File

@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`. all extracted information into a structured `xarray.Dataset`.
""" """
from typing import List, Optional, Tuple, Union from typing import List, Optional
import torch import torch
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression from batdetect2.typing.postprocess import ClipDetectionsTensor
from batdetect2.typing.postprocess import (
DetectionsTensor,
ModelOutput,
)
__all__ = [ __all__ = [
"extract_prediction_tensor", "extract_detection_peaks",
] ]
def extract_prediction_tensor( def extract_detection_peaks(
output: ModelOutput, detection_heatmap: torch.Tensor,
size_heatmap: torch.Tensor,
feature_heatmap: torch.Tensor,
classification_heatmap: torch.Tensor,
max_detections: int = 200, max_detections: int = 200,
threshold: Optional[float] = None, threshold: Optional[float] = None,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE, ) -> List[ClipDetectionsTensor]:
) -> List[DetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=nms_kernel_size,
)
height = detection_heatmap.shape[-2] height = detection_heatmap.shape[-2]
width = detection_heatmap.shape[-1] width = detection_heatmap.shape[-1]
@ -53,9 +46,9 @@ def extract_prediction_tensor(
freqs = freqs.flatten().to(detection_heatmap.device) freqs = freqs.flatten().to(detection_heatmap.device)
times = times.flatten().to(detection_heatmap.device) times = times.flatten().to(detection_heatmap.device)
output_size_preds = output.size_preds.detach() output_size_preds = size_heatmap.detach()
output_features = output.features.detach() output_features = feature_heatmap.detach()
output_class_probs = output.class_probs.detach() output_class_probs = classification_heatmap.detach()
predictions = [] predictions = []
for idx, item in enumerate(detection_heatmap): for idx, item in enumerate(detection_heatmap):
@ -65,23 +58,25 @@ def extract_prediction_tensor(
detection_scores = item.take(indices) detection_scores = item.take(indices)
detection_freqs = freqs.take(indices) detection_freqs = freqs.take(indices)
detection_times = times.take(indices) detection_times = times.take(indices)
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output_class_probs[
idx, :, detection_freqs, detection_times
].T
if threshold is not None: if threshold is not None:
mask = detection_scores >= threshold mask = detection_scores >= threshold
detection_scores = detection_scores[mask] detection_scores = detection_scores[mask]
sizes = sizes[mask]
detection_times = detection_times[mask] detection_times = detection_times[mask]
detection_freqs = detection_freqs[mask] detection_freqs = detection_freqs[mask]
features = features[mask]
class_scores = class_scores[mask] sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output_class_probs[
idx,
:,
detection_freqs,
detection_times,
].T
predictions.append( predictions.append(
DetectionsTensor( ClipDetectionsTensor(
scores=detection_scores, scores=detection_scores,
sizes=sizes, sizes=sizes,
features=features, features=features,

View File

@ -1,29 +1,20 @@
from typing import List, Optional from typing import List, Optional, Tuple, Union
import torch import torch
from loguru import logger from loguru import logger
from soundevent import data
from batdetect2.postprocess.config import ( from batdetect2.postprocess.config import (
PostprocessConfig, PostprocessConfig,
) )
from batdetect2.postprocess.decoding import ( from batdetect2.postprocess.extraction import extract_detection_peaks
DEFAULT_CLASSIFICATION_THRESHOLD, from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.remapping import map_detection_to_clip from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.typing import ModelOutput from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
BatDetect2Prediction, ClipDetectionsTensor,
DetectionsTensor,
PostprocessorProtocol, PostprocessorProtocol,
RawPrediction,
) )
from batdetect2.typing.preprocess import PreprocessorProtocol from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"build_postprocessor", "build_postprocessor",
@ -60,24 +51,43 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
max_freq: float, max_freq: float,
top_k_per_sec: int = 200, top_k_per_sec: int = 200,
detection_threshold: float = 0.01, detection_threshold: float = 0.01,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
): ):
"""Initialize the Postprocessor.""" """Initialize the Postprocessor."""
super().__init__() super().__init__()
self.samplerate = samplerate
self.output_samplerate = samplerate
self.min_freq = min_freq self.min_freq = min_freq
self.max_freq = max_freq self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold self.detection_threshold = detection_threshold
self.nms_kernel_size = nms_kernel_size
def forward(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=self.nms_kernel_size,
)
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1] width = output.detection_probs.shape[-1]
duration = width / self.samplerate duration = width / self.output_samplerate
max_detections = int(self.top_k_per_sec * duration) max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor( detections = extract_detection_peaks(
output, detection_heatmap,
size_heatmap=output.size_preds,
feature_heatmap=output.features,
classification_heatmap=output.class_probs,
max_detections=max_detections, max_detections=max_detections,
threshold=self.detection_threshold, threshold=self.detection_threshold,
) )
if start_times is None:
start_times = [0 for _ in range(len(detections))]
return [ return [
map_detection_to_clip( map_detection_to_clip(
detection, detection,
@ -88,121 +98,3 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
) )
for detection in detections for detection in detections
] ]
def get_detections(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
return detections
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
return [
map_detection_to_clip(
detection,
start_time=start_time,
end_time=start_time + duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, start_time in zip(detections, start_times)
]
def get_raw_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
start_times: Optional[List[float]] = None,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch."""
detections = postprocessor.get_detections(output, start_times)
return [
to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections
]
def get_sound_event_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
clips: List[data.Clip],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -20,7 +20,7 @@ import xarray as xr
from soundevent.arrays import Dimensions from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import DetectionsTensor from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [ __all__ = [
"features_to_xarray", "features_to_xarray",
@ -31,15 +31,15 @@ __all__ = [
def map_detection_to_clip( def map_detection_to_clip(
detections: DetectionsTensor, detections: ClipDetectionsTensor,
start_time: float, start_time: float,
end_time: float, end_time: float,
min_freq: float, min_freq: float,
max_freq: float, max_freq: float,
) -> DetectionsTensor: ) -> ClipDetectionsTensor:
duration = end_time - start_time duration = end_time - start_time
bandwidth = max_freq - min_freq bandwidth = max_freq - min_freq
return DetectionsTensor( return ClipDetectionsTensor(
scores=detections.scores, scores=detections.scores,
sizes=detections.sizes, sizes=detections.sizes,
features=detections.features, features=detections.features,

View File

@ -14,7 +14,6 @@ from batdetect2.train.augmentations import (
scale_volume, scale_volume,
warp_spectrogram, warp_spectrogram,
) )
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import ( from batdetect2.train.config import (
PLTrainerConfig, PLTrainerConfig,
TrainingConfig, TrainingConfig,
@ -61,7 +60,6 @@ __all__ = [
"add_echo", "add_echo",
"build_augmentations", "build_augmentations",
"build_clip_labeler", "build_clip_labeler",
"build_clipper",
"build_loss", "build_loss",
"build_train_dataset", "build_train_dataset",
"build_train_loader", "build_train_loader",
@ -74,7 +72,6 @@ __all__ = [
"mask_time", "mask_time",
"mix_audio", "mix_audio",
"scale_volume", "scale_volume",
"select_subclip",
"train", "train",
"warp_spectrogram", "warp_spectrogram",
] ]

View File

@ -11,9 +11,9 @@ from pydantic import Field
from soundevent import data from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.core.arrays import adjust_width from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.train.clips import get_subclip_annotation
from batdetect2.typing import AudioLoader, Augmentation from batdetect2.typing import AudioLoader, Augmentation
__all__ = [ __all__ = [

View File

@ -5,13 +5,13 @@ from lightning.pytorch.callbacks import Callback
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate import Evaluator from batdetect2.logging import get_image_logger
from batdetect2.postprocess import get_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter
from batdetect2.typing import ( from batdetect2.typing import (
ClipEvaluation, ClipEvaluation,
EvaluatorProtocol,
ModelOutput, ModelOutput,
RawPrediction, RawPrediction,
TrainExample, TrainExample,
@ -19,7 +19,7 @@ from batdetect2.typing import (
class ValidationMetrics(Callback): class ValidationMetrics(Callback):
def __init__(self, evaluator: Evaluator): def __init__(self, evaluator: EvaluatorProtocol):
super().__init__() super().__init__()
self.evaluator = evaluator self.evaluator = evaluator
@ -34,12 +34,12 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, ValidationDataset) assert isinstance(dataset, ValidationDataset)
return dataset return dataset
def plot_examples( def generate_plots(
self, self,
pl_module: LightningModule, pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation], evaluated_clips: List[ClipEvaluation],
): ):
plotter = get_image_plotter(pl_module.logger) # type: ignore plotter = get_image_logger(pl_module.logger) # type: ignore
if plotter is None: if plotter is None:
return return
@ -66,7 +66,7 @@ class ValidationMetrics(Callback):
) )
self.log_metrics(pl_module, clip_evaluations) self.log_metrics(pl_module, clip_evaluations)
self.plot_examples(pl_module, clip_evaluations) self.generate_plots(pl_module, clip_evaluations)
return super().on_validation_epoch_end(trainer, pl_module) return super().on_validation_epoch_end(trainer, pl_module)
@ -88,8 +88,7 @@ class ValidationMetrics(Callback):
batch_idx: int, batch_idx: int,
dataloader_idx: int = 0, dataloader_idx: int = 0,
) -> None: ) -> None:
postprocessor = pl_module.model.postprocessor model = pl_module.model
targets = pl_module.model.targets
dataset = self.get_dataset(trainer) dataset = self.get_dataset(trainer)
clip_annotations = [ clip_annotations = [
@ -97,15 +96,14 @@ class ValidationMetrics(Callback):
for example_idx in batch.idx for example_idx in batch.idx
] ]
predictions = get_raw_predictions( clip_detections = model.postprocessor(
outputs, outputs,
start_times=[ start_times=[ca.clip.start_time for ca in clip_annotations],
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=postprocessor,
) )
predictions = [
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
for clip_dets in clip_detections
]
self._clip_annotations.extend(clip_annotations) self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions) self._predictions.extend(predictions)

View File

@ -5,17 +5,9 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.train.augmentations import ( from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
DEFAULT_AUGMENTATION_CONFIG, from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
AugmentationsConfig,
)
from batdetect2.train.clips import (
ClipConfig,
PaddedClipConfig,
RandomClipConfig,
)
from batdetect2.train.labels import LabelConfig from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
__all__ = [ __all__ = [
@ -45,30 +37,6 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None val_check_interval: Optional[Union[int, float]] = None
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
class OptimizerConfig(BaseConfig): class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3 learning_rate: float = 1e-3
t_max: int = 100 t_max: int = 100
@ -79,9 +47,8 @@ class TrainingConfig(BaseConfig):
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig) val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig) labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig) validation: EvaluationConfig = Field(default_factory=EvaluationConfig)

View File

@ -2,18 +2,21 @@ from typing import List, Optional, Sequence
import torch import torch
from loguru import logger from loguru import logger
from pydantic import Field
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader, Dataset from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
RandomAudioSource, RandomAudioSource,
build_augmentations, build_augmentations,
) )
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
@ -144,6 +147,22 @@ class ValidationDataset(Dataset):
) )
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_train_loader( def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional[AudioLoader] = None,
@ -178,6 +197,14 @@ def build_train_loader(
) )
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_val_loader( def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation], clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None, audio_loader: Optional[AudioLoader] = None,

View File

@ -8,20 +8,21 @@ from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
from batdetect2.train.logging import build_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
ClipLabeller, ClipLabeller,
EvaluatorProtocol,
PreprocessorProtocol, PreprocessorProtocol,
TargetProtocol, TargetProtocol,
) )
@ -122,7 +123,7 @@ def train(
def build_trainer_callbacks( def build_trainer_callbacks(
targets: "TargetProtocol", targets: "TargetProtocol",
evaluator: Optional[Evaluator] = None, evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
@ -142,7 +143,6 @@ def build_trainer_callbacks(
ModelCheckpoint( ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
save_top_k=1, save_top_k=1,
filename="best-{epoch:02d}-{val_loss:.0f}",
monitor="total_loss/val", monitor="total_loss/val",
), ),
ValidationMetrics(evaluator), ValidationMetrics(evaluator),
@ -152,7 +152,7 @@ def build_trainer_callbacks(
def build_trainer( def build_trainer(
conf: "BatDetect2Config", conf: "BatDetect2Config",
targets: "TargetProtocol", targets: "TargetProtocol",
evaluator: Optional[Evaluator] = None, evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,

View File

@ -1,7 +1,9 @@
from batdetect2.typing.evaluate import ( from batdetect2.typing.evaluate import (
ClipEvaluation, ClipEvaluation,
EvaluatorProtocol,
MatchEvaluation, MatchEvaluation,
MetricsProtocol, MetricsProtocol,
PlotterProtocol,
) )
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import ( from batdetect2.typing.postprocess import (
@ -49,6 +51,7 @@ __all__ = [
"MatchEvaluation", "MatchEvaluation",
"MetricsProtocol", "MetricsProtocol",
"ModelOutput", "ModelOutput",
"PlotterProtocol",
"Position", "Position",
"PostprocessorProtocol", "PostprocessorProtocol",
"PreprocessorProtocol", "PreprocessorProtocol",
@ -60,4 +63,5 @@ __all__ = [
"SoundEventFilter", "SoundEventFilter",
"TargetProtocol", "TargetProtocol",
"TrainExample", "TrainExample",
"EvaluatorProtocol",
] ]

View File

@ -14,7 +14,11 @@ from typing import (
from matplotlib.figure import Figure from matplotlib.figure import Figure
from soundevent import data from soundevent import data
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [ __all__ = [
"EvaluatorProtocol",
"MetricsProtocol", "MetricsProtocol",
"MatchEvaluation", "MatchEvaluation",
] ]
@ -107,3 +111,21 @@ class PlotterProtocol(Protocol):
def __call__( def __call__(
self, clip_evaluations: Sequence[ClipEvaluation] self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ... ) -> Iterable[Tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol):
targets: TargetProtocol
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]: ...
def compute_metrics(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]: ...
def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...

View File

@ -12,7 +12,7 @@ system that deal with model predictions.
""" """
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol from typing import List, NamedTuple, Optional, Protocol, Sequence
import numpy as np import numpy as np
import torch import torch
@ -47,15 +47,13 @@ class GeometryDecoder(Protocol):
class RawPrediction(NamedTuple): class RawPrediction(NamedTuple):
"""Intermediate representation of a single detected sound event."""
geometry: data.Geometry geometry: data.Geometry
detection_score: float detection_score: float
class_scores: np.ndarray class_scores: np.ndarray
features: np.ndarray features: np.ndarray
class DetectionsArray(NamedTuple): class ClipDetectionsArray(NamedTuple):
scores: np.ndarray scores: np.ndarray
sizes: np.ndarray sizes: np.ndarray
class_scores: np.ndarray class_scores: np.ndarray
@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple):
features: np.ndarray features: np.ndarray
class DetectionsTensor(NamedTuple): class ClipDetectionsTensor(NamedTuple):
scores: torch.Tensor scores: torch.Tensor
sizes: torch.Tensor sizes: torch.Tensor
class_scores: torch.Tensor class_scores: torch.Tensor
@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple):
frequencies: torch.Tensor frequencies: torch.Tensor
features: torch.Tensor features: torch.Tensor
def numpy(self) -> DetectionsArray: def numpy(self) -> ClipDetectionsArray:
return DetectionsArray( return ClipDetectionsArray(
scores=self.scores.detach().cpu().numpy(), scores=self.scores.detach().cpu().numpy(),
sizes=self.sizes.detach().cpu().numpy(), sizes=self.sizes.detach().cpu().numpy(),
class_scores=self.class_scores.detach().cpu().numpy(), class_scores=self.class_scores.detach().cpu().numpy(),
@ -92,10 +90,8 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline.""" """Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ... def __call__(
def get_detections(
self, self,
output: ModelOutput, output: ModelOutput,
start_times: Optional[List[float]] = None, start_times: Optional[Sequence[float]] = None,
) -> List[DetectionsTensor]: ... ) -> List[ClipDetectionsTensor]: ...