Evaluate using Lightning too to handle device changes

This commit is contained in:
mbsantiago 2025-09-18 09:28:21 +01:00
parent 6c25787123
commit e65df81db2
27 changed files with 700 additions and 377 deletions

View File

@ -5,7 +5,7 @@ from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import Evaluator, build_evaluator
from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
@ -14,6 +14,7 @@ from batdetect2.train import train
from batdetect2.train.lightning import load_model_from_checkpoint
from batdetect2.typing import (
AudioLoader,
EvaluatorProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
@ -28,7 +29,7 @@ class BatDetect2API:
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
evaluator: Evaluator,
evaluator: EvaluatorProtocol,
model: Model,
):
self.config = config
@ -70,6 +71,27 @@ class BatDetect2API:
)
return self
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None,
output_dir: data.PathLike = ".",
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
return evaluate(
self.model,
test_annotations,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
)
@classmethod
def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets)
@ -118,8 +140,14 @@ class BatDetect2API:
)
@classmethod
def from_checkpoint(cls, path: data.PathLike):
model, config = load_model_from_checkpoint(path)
def from_checkpoint(
cls,
path: data.PathLike,
config: Optional[BatDetect2Config] = None,
):
model, stored_config = load_model_from_checkpoint(path)
config = config or stored_config
targets = build_targets(config=config.targets)

View File

@ -10,11 +10,17 @@ from batdetect2.cli.base import cli
__all__ = ["evaluate_command"]
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--output-dir", type=click.Path())
@click.option("--workers", type=int)
@click.option("--config", "config_path", type=click.Path())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int)
@click.option(
"-v",
"--verbose",
@ -24,13 +30,16 @@ __all__ = ["evaluate_command"]
def evaluate_command(
model_path: Path,
test_dataset: Path,
output_dir: Optional[Path] = None,
workers: Optional[int] = None,
config_path: Optional[Path],
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api.base import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
logger.remove()
if verbose == 0:
@ -49,16 +58,16 @@ def evaluate_command(
num_annotations=len(test_annotations),
)
model, train_config = load_model_from_checkpoint(model_path)
config = None
if config_path is not None:
config = load_full_config(config_path)
df, results = evaluate(
model,
api = BatDetect2API.from_checkpoint(model_path, config=config)
api.evaluate(
test_annotations,
config=train_config,
num_workers=workers,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
run_name=run_name,
)
print(results)
if output_dir:
df.to_csv(output_dir / "results.csv")

View File

@ -2,6 +2,7 @@ import sys
from typing import Generic, Protocol, Type, TypeVar
from pydantic import BaseModel
from typing_extensions import assert_type
if sys.version_info >= (3, 10):
from typing import ParamSpec
@ -44,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]):
config_cls: Type[T_Config],
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
) -> None:
"""A decorator factory to register a new item."""
fields = config_cls.model_fields
if "name" not in fields:

View File

@ -1,9 +1,11 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
__all__ = [
"EvaluationConfig",
"load_evaluation_config",
"evaluate",
"Evaluator",
"build_evaluator",
]

View File

@ -11,6 +11,7 @@ from batdetect2.evaluate.metrics import (
MetricConfig,
)
from batdetect2.evaluate.plots import PlotConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [
"EvaluationConfig",
@ -28,6 +29,7 @@ class EvaluationConfig(BaseConfig):
]
)
plots: List[PlotConfig] = Field(default_factory=list)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def load_evaluation_config(

View File

@ -0,0 +1,144 @@
from typing import List, NamedTuple, Optional, Sequence
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import (
AudioLoader,
ClipperProtocol,
PreprocessorProtocol,
)
__all__ = [
"TestDataset",
"build_test_dataset",
"build_test_loader",
]
class TestExample(NamedTuple):
spec: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class TestDataset(Dataset[TestExample]):
clip_annotations: List[data.ClipAnnotation]
def __init__(
self,
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
clipper: Optional[ClipperProtocol] = None,
audio_dir: Optional[data.PathLike] = None,
):
self.clip_annotations = list(clip_annotations)
self.clipper = clipper
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> TestExample:
clip_annotation = self.clip_annotations[idx]
if self.clipper is not None:
clip_annotation = self.clipper(clip_annotation)
clip = clip_annotation.clip
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return TestExample(
spec=spectrogram,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_test_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader[TestExample]:
logger.info("Building test data loader...")
config = config or TestLoaderConfig()
logger.opt(lazy=True).debug(
"Test data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
test_dataset = build_test_dataset(
clip_annotations,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_test_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TestLoaderConfig] = None,
) -> TestDataset:
logger.info("Building training dataset...")
config = config or TestLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
return TestDataset(
clip_annotations,
audio_loader=audio_loader,
clipper=clipper,
preprocessor=preprocessor,
)
def _collate_fn(batch: List[TestExample]) -> TestExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TestExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -1,35 +1,44 @@
from typing import List, Optional, Tuple
from pathlib import Path
from typing import TYPE_CHECKING, Optional, Sequence
import pandas as pd
from lightning import Trainer
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol
from batdetect2.postprocess import get_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.train import build_val_loader
from batdetect2.typing import ClipLabeller, TargetProtocol
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
def evaluate(
model: Model,
test_annotations: List[data.ClipAnnotation],
targets: Optional[TargetProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
labeller: Optional[ClipLabeller] = None,
config: Optional[EvaluationConfig] = None,
test_annotations: Sequence[data.ClipAnnotation],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
) -> Tuple[pd.DataFrame, dict]:
config = config or EvaluationConfig()
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
@ -39,60 +48,21 @@ def evaluate(
targets = targets or build_targets()
labeller = labeller or build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
loader = build_val_loader(
loader = build_test_loader(
test_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
num_workers=num_workers,
)
dataset: ValidationDataset = loader.dataset # type: ignore
evaluator = build_evaluator(config=config.evaluation, targets=targets)
clip_annotations = []
predictions = []
evaluator = build_evaluator(config=config, targets=targets)
for batch in loader:
outputs = model.detector(batch.spec)
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
predictions = get_raw_predictions(
outputs,
start_times=[
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,
)
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matches = evaluator.evaluate(clip_annotations, predictions)
df = extract_matches_dataframe(matches)
metrics = [
DetectionAP(),
ClassificationAP(class_names=targets.class_names),
]
results = {
name: value
for metric in metrics
for name, value in metric(matches).items()
}
return df, results
logger = build_logger(
config.evaluation.logger,
log_dir=Path(output_dir),
experiment_name=experiment_name,
run_name=run_name,
)
module = EvaluationModule(model, evaluator)
trainer = Trainer(logger=logger, enable_checkpointing=False)
return trainer.test(module, loader)

View File

@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter
from batdetect2.targets import build_targets
from batdetect2.typing.evaluate import (
ClipEvaluation,
EvaluatorProtocol,
MatcherProtocol,
MetricsProtocol,
PlotterProtocol,
@ -135,7 +136,7 @@ def build_evaluator(
matcher: Optional[MatcherProtocol] = None,
plots: Optional[List[PlotterProtocol]] = None,
metrics: Optional[List[MetricsProtocol]] = None,
) -> Evaluator:
) -> EvaluatorProtocol:
config = config or EvaluationConfig()
targets = targets or build_targets()
matcher = matcher or build_matcher(config.match_strategy)

View File

@ -0,0 +1,86 @@
from typing import Sequence
from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.tables import FullEvaluationTable
from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
class EvaluationModule(LightningModule):
def __init__(
self,
model: Model,
evaluator: EvaluatorProtocol,
):
super().__init__()
self.model = model
self.evaluator = evaluator
self.clip_evaluations = []
def test_step(self, batch: TestExample):
dataset = self.get_dataset()
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(
clip_dets.numpy(),
targets=self.evaluator.targets,
)
for clip_dets in clip_detections
]
self.clip_evaluations.extend(
self.evaluator.evaluate(clip_annotations, predictions)
)
def on_test_epoch_start(self):
self.clip_evaluations = []
def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations)
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None:
return
df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
plotter = get_image_logger(self.logger) # type: ignore
if plotter is None:
return
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics)
def get_dataset(self) -> TestDataset:
dataloaders = self.trainer.test_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, TestDataset)
return dataset

View File

@ -15,10 +15,8 @@ import numpy as np
from pydantic import Field
from sklearn import metrics, preprocessing
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.typing import MetricsProtocol
from batdetect2.typing.evaluate import ClipEvaluation
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation, MetricsProtocol
__all__ = ["DetectionAP", "ClassificationAP"]
@ -375,14 +373,96 @@ metrics_registry.register(
)
class ClipAPConfig(BaseConfig):
name: Literal["clip_ap"] = "clip_ap"
class ClipDetectionAPConfig(BaseConfig):
name: Literal["clip_detection_ap"] = "clip_detection_ap"
ap_implementation: APImplementation = "pascal_voc"
class ClipDetectionAP(MetricsProtocol):
def __init__(
self,
implementation: APImplementation,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {"clip_detection_ap": self.metric(y_true, y_score)}
@classmethod
def from_config(
cls,
config: ClipDetectionAPConfig,
class_names: List[str],
):
return cls(implementation=config.ap_implementation)
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
class ClipDetectionROCAUCConfig(BaseConfig):
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
class ClipDetectionROCAUC(MetricsProtocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
y_true = []
y_score = []
for clip_eval in clip_evaluations:
clip_det = []
clip_scores = []
for match in clip_eval.matches:
clip_det.append(match.gt_det)
clip_scores.append(match.pred_score)
y_true.append(any(clip_det))
y_score.append(max(clip_scores or [0]))
return {
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
}
@classmethod
def from_config(
cls,
config: ClipDetectionROCAUCConfig,
class_names: List[str],
):
return cls()
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
class ClipMulticlassAPConfig(BaseConfig):
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
ap_implementation: APImplementation = "pascal_voc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipAP(MetricsProtocol):
class ClipMulticlassAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
@ -454,15 +534,17 @@ class ClipAP(MetricsProtocol):
[value for value in class_scores.values() if value != 0]
)
return {
"clip_mAP": float(mean_ap),
"clip_multiclass_mAP": float(mean_ap),
**{
f"clip_AP/{class_name}": class_scores[class_name]
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
for class_name in self.selected
},
}
@classmethod
def from_config(cls, config: ClipAPConfig, class_names: List[str]):
def from_config(
cls, config: ClipMulticlassAPConfig, class_names: List[str]
):
return cls(
implementation=config.ap_implementation,
include=config.include,
@ -471,16 +553,16 @@ class ClipAP(MetricsProtocol):
)
metrics_registry.register(ClipAPConfig, ClipAP)
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
class ClipROCAUCConfig(BaseConfig):
name: Literal["clip_roc_auc"] = "clip_roc_auc"
class ClipMulticlassROCAUCConfig(BaseConfig):
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
include: Optional[List[str]] = None
exclude: Optional[List[str]] = None
class ClipROCAUC(MetricsProtocol):
class ClipMulticlassROCAUC(MetricsProtocol):
def __init__(
self,
class_names: List[str],
@ -548,9 +630,11 @@ class ClipROCAUC(MetricsProtocol):
[value for value in class_scores.values() if value != 0]
)
return {
"clip_macro_ROC_AUC": float(mean_roc_auc),
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
**{
f"clip_ROC_AUC/{class_name}": class_scores[class_name]
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
class_name
]
for class_name in self.selected
},
}
@ -558,7 +642,7 @@ class ClipROCAUC(MetricsProtocol):
@classmethod
def from_config(
cls,
config: ClipROCAUCConfig,
config: ClipMulticlassROCAUCConfig,
class_names: List[str],
):
return cls(
@ -568,7 +652,7 @@ class ClipROCAUC(MetricsProtocol):
)
metrics_registry.register(ClipROCAUCConfig, ClipROCAUC)
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
MetricConfig = Annotated[
Union[
@ -578,8 +662,10 @@ MetricConfig = Annotated[
ClassificationROCAUCConfig,
TopClassAPConfig,
ClassificationBalancedAccuracyConfig,
ClipAPConfig,
ClipROCAUCConfig,
ClipDetectionAPConfig,
ClipDetectionROCAUCConfig,
ClipMulticlassAPConfig,
ClipMulticlassROCAUCConfig,
],
Field(discriminator="name"),
]

View File

@ -10,19 +10,18 @@ from pydantic import Field
from sklearn import metrics
from sklearn.preprocessing import label_binarize
from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.core import BaseConfig, Registry
from batdetect2.plotting.gallery import plot_match_gallery
from batdetect2.plotting.matches import plot_matches
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.typing.evaluate import (
from batdetect2.typing import (
AudioLoader,
ClipEvaluation,
MatchEvaluation,
PlotterProtocol,
PreprocessorProtocol,
)
from batdetect2.typing.preprocess import AudioLoader
__all__ = [
"build_plotter",
@ -431,6 +430,62 @@ class ClassificationROCCurves(PlotterProtocol):
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
class ConfusionMatrixConfig(BaseConfig):
name: Literal["confusion_matrix"] = "confusion_matrix"
background_class: str = "noise"
class ConfusionMatrix(PlotterProtocol):
def __init__(self, background_class: str, class_names: List[str]):
self.background_class = background_class
self.class_names = class_names
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
y_true = []
y_pred = []
for clip_eval in clip_evaluations:
for match in clip_eval.matches:
# Ignore generic unclassified targets
if match.gt_det and match.gt_class is None:
continue
y_true.append(
match.gt_class
if match.gt_class is not None
else self.background_class
)
top_class = match.pred_class
y_pred.append(
top_class
if top_class is not None
else self.background_class
)
display = metrics.ConfusionMatrixDisplay.from_predictions(
y_true,
y_pred,
labels=[*self.class_names, self.background_class],
)
yield "confusion_matrix", display.figure_
@classmethod
def from_config(
cls,
config: ConfusionMatrixConfig,
class_names: List[str],
):
return cls(
background_class=config.background_class,
class_names=class_names,
)
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
PlotConfig = Annotated[
Union[
ExampleGalleryConfig,
@ -439,6 +494,7 @@ PlotConfig = Annotated[
ClassificationPRCurvesConfig,
DetectionROCCurveConfig,
ClassificationROCCurvesConfig,
ConfusionMatrixConfig,
],
Field(discriminator="name"),
]

View File

@ -1,18 +1,49 @@
from typing import List
from typing import Annotated, Callable, Literal, Sequence, Union
import pandas as pd
from pydantic import Field
from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import ClipEvaluation
from batdetect2.core import BaseConfig, Registry
from batdetect2.typing import ClipEvaluation
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
"evaluation_table"
)
class FullEvaluationTableConfig(BaseConfig):
name: Literal["full_evaluation"] = "full_evaluation"
class FullEvaluationTable:
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> pd.DataFrame:
return extract_matches_dataframe(clip_evaluations)
@classmethod
def from_config(cls, config: FullEvaluationTableConfig):
return cls()
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
def extract_matches_dataframe(
clip_evaluations: Sequence[ClipEvaluation],
) -> pd.DataFrame:
data = []
for clip_evaluation in clip_evaluations:
for match in clip_evaluation.matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = (
pred_high_freq
) = None
sound_event_annotation = match.sound_event_annotation
@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
)
if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
compute_bounds(match.pred_geometry)
)
(
pred_start_time,
pred_low_freq,
pred_end_time,
pred_high_freq,
) = compute_bounds(match.pred_geometry)
data.append(
{
@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df
EvaluationTableConfig = Annotated[
Union[FullEvaluationTableConfig,], Field(discriminator="name")
]
def build_table_generator(
config: EvaluationTableConfig,
) -> EvaluationTableGenerator:
return tables_registry.build(config)

View File

@ -67,7 +67,7 @@ from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.typing.models import DetectionModel
from batdetect2.typing.postprocess import (
DetectionsTensor,
ClipDetectionsTensor,
PostprocessorProtocol,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
@ -121,7 +121,7 @@ class Model(torch.nn.Module):
self.postprocessor = postprocessor
self.targets = targets
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
spec = self.preprocessor(wav)
outputs = self.detector(spec)
return self.postprocessor(outputs)

View File

@ -8,13 +8,10 @@ from batdetect2.postprocess.decoding import (
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.nms import (
non_max_suppression,
)
from batdetect2.postprocess.nms import non_max_suppression
from batdetect2.postprocess.postprocessor import (
Postprocessor,
build_postprocessor,
get_raw_predictions,
)
__all__ = [
@ -25,5 +22,4 @@ __all__ = [
"to_raw_predictions",
"load_postprocess_config",
"non_max_suppression",
"get_raw_predictions",
]

View File

@ -6,7 +6,7 @@ import numpy as np
from soundevent import data
from batdetect2.typing.postprocess import (
DetectionsArray,
ClipDetectionsArray,
RawPrediction,
)
from batdetect2.typing.targets import TargetProtocol
@ -28,7 +28,7 @@ decoding.
def to_raw_predictions(
detections: DetectionsArray,
detections: ClipDetectionsArray,
targets: TargetProtocol,
) -> List[RawPrediction]:
predictions = []

View File

@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates
all extracted information into a structured `xarray.Dataset`.
"""
from typing import List, Optional, Tuple, Union
from typing import List, Optional
import torch
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.typing.postprocess import (
DetectionsTensor,
ModelOutput,
)
from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [
"extract_prediction_tensor",
"extract_detection_peaks",
]
def extract_prediction_tensor(
output: ModelOutput,
def extract_detection_peaks(
detection_heatmap: torch.Tensor,
size_heatmap: torch.Tensor,
feature_heatmap: torch.Tensor,
classification_heatmap: torch.Tensor,
max_detections: int = 200,
threshold: Optional[float] = None,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> List[DetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=nms_kernel_size,
)
) -> List[ClipDetectionsTensor]:
height = detection_heatmap.shape[-2]
width = detection_heatmap.shape[-1]
@ -53,9 +46,9 @@ def extract_prediction_tensor(
freqs = freqs.flatten().to(detection_heatmap.device)
times = times.flatten().to(detection_heatmap.device)
output_size_preds = output.size_preds.detach()
output_features = output.features.detach()
output_class_probs = output.class_probs.detach()
output_size_preds = size_heatmap.detach()
output_features = feature_heatmap.detach()
output_class_probs = classification_heatmap.detach()
predictions = []
for idx, item in enumerate(detection_heatmap):
@ -65,23 +58,25 @@ def extract_prediction_tensor(
detection_scores = item.take(indices)
detection_freqs = freqs.take(indices)
detection_times = times.take(indices)
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output_class_probs[
idx, :, detection_freqs, detection_times
].T
if threshold is not None:
mask = detection_scores >= threshold
detection_scores = detection_scores[mask]
sizes = sizes[mask]
detection_times = detection_times[mask]
detection_freqs = detection_freqs[mask]
features = features[mask]
class_scores = class_scores[mask]
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
features = output_features[idx, :, detection_freqs, detection_times].T
class_scores = output_class_probs[
idx,
:,
detection_freqs,
detection_times,
].T
predictions.append(
DetectionsTensor(
ClipDetectionsTensor(
scores=detection_scores,
sizes=sizes,
features=features,

View File

@ -1,29 +1,20 @@
from typing import List, Optional
from typing import List, Optional, Tuple, Union
import torch
from loguru import logger
from soundevent import data
from batdetect2.postprocess.config import (
PostprocessConfig,
)
from batdetect2.postprocess.decoding import (
DEFAULT_CLASSIFICATION_THRESHOLD,
convert_raw_prediction_to_sound_event_prediction,
convert_raw_predictions_to_clip_prediction,
to_raw_predictions,
)
from batdetect2.postprocess.extraction import extract_prediction_tensor
from batdetect2.postprocess.extraction import extract_detection_peaks
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
from batdetect2.postprocess.remapping import map_detection_to_clip
from batdetect2.typing import ModelOutput
from batdetect2.typing.postprocess import (
BatDetect2Prediction,
DetectionsTensor,
ClipDetectionsTensor,
PostprocessorProtocol,
RawPrediction,
)
from batdetect2.typing.preprocess import PreprocessorProtocol
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"build_postprocessor",
@ -60,24 +51,43 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
max_freq: float,
top_k_per_sec: int = 200,
detection_threshold: float = 0.01,
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
):
"""Initialize the Postprocessor."""
super().__init__()
self.samplerate = samplerate
self.output_samplerate = samplerate
self.min_freq = min_freq
self.max_freq = max_freq
self.top_k_per_sec = top_k_per_sec
self.detection_threshold = detection_threshold
self.nms_kernel_size = nms_kernel_size
def forward(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[ClipDetectionsTensor]:
detection_heatmap = non_max_suppression(
output.detection_probs.detach(),
kernel_size=self.nms_kernel_size,
)
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
duration = width / self.output_samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
detections = extract_detection_peaks(
detection_heatmap,
size_heatmap=output.size_preds,
feature_heatmap=output.features,
classification_heatmap=output.class_probs,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
start_times = [0 for _ in range(len(detections))]
return [
map_detection_to_clip(
detection,
@ -88,121 +98,3 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
)
for detection in detections
]
def get_detections(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]:
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
max_detections = int(self.top_k_per_sec * duration)
detections = extract_prediction_tensor(
output,
max_detections=max_detections,
threshold=self.detection_threshold,
)
if start_times is None:
return detections
width = output.detection_probs.shape[-1]
duration = width / self.samplerate
return [
map_detection_to_clip(
detection,
start_time=start_time,
end_time=start_time + duration,
min_freq=self.min_freq,
max_freq=self.max_freq,
)
for detection, start_time in zip(detections, start_times)
]
def get_raw_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
start_times: Optional[List[float]] = None,
) -> List[List[RawPrediction]]:
"""Extract intermediate RawPrediction objects for a batch."""
detections = postprocessor.get_detections(output, start_times)
return [
to_raw_predictions(detection.numpy(), targets=targets)
for detection in detections
]
def get_sound_event_predictions(
output: ModelOutput,
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
clips: List[data.Clip],
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[List[BatDetect2Prediction]]:
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
[
BatDetect2Prediction(
raw=raw,
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
raw,
recording=clip.recording,
targets=targets,
classification_threshold=classification_threshold,
),
)
for raw in predictions
]
for predictions, clip in zip(raw_predictions, clips)
]
def get_predictions(
output: ModelOutput,
clips: List[data.Clip],
targets: TargetProtocol,
postprocessor: PostprocessorProtocol,
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
) -> List[data.ClipPrediction]:
"""Perform the full postprocessing pipeline for a batch.
Takes raw model output and corresponding clips, applies the entire
configured chain (NMS, remapping, extraction, geometry recovery, class
decoding), producing final `soundevent.data.ClipPrediction` objects.
Parameters
----------
output : ModelOutput
Raw output from the neural network model for a batch.
clips : List[data.Clip]
List of `soundevent.data.Clip` objects corresponding to the batch.
Returns
-------
List[data.ClipPrediction]
List containing one `ClipPrediction` object for each input clip,
populated with `SoundEventPrediction` objects.
"""
raw_predictions = get_raw_predictions(
output,
targets=targets,
postprocessor=postprocessor,
start_times=[clip.start_time for clip in clips],
)
return [
convert_raw_predictions_to_clip_prediction(
prediction,
clip,
targets=targets,
classification_threshold=classification_threshold,
)
for prediction, clip in zip(raw_predictions, clips)
]

View File

@ -20,7 +20,7 @@ import xarray as xr
from soundevent.arrays import Dimensions
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.typing.postprocess import DetectionsTensor
from batdetect2.typing.postprocess import ClipDetectionsTensor
__all__ = [
"features_to_xarray",
@ -31,15 +31,15 @@ __all__ = [
def map_detection_to_clip(
detections: DetectionsTensor,
detections: ClipDetectionsTensor,
start_time: float,
end_time: float,
min_freq: float,
max_freq: float,
) -> DetectionsTensor:
) -> ClipDetectionsTensor:
duration = end_time - start_time
bandwidth = max_freq - min_freq
return DetectionsTensor(
return ClipDetectionsTensor(
scores=detections.scores,
sizes=detections.sizes,
features=detections.features,

View File

@ -14,7 +14,6 @@ from batdetect2.train.augmentations import (
scale_volume,
warp_spectrogram,
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import (
PLTrainerConfig,
TrainingConfig,
@ -61,7 +60,6 @@ __all__ = [
"add_echo",
"build_augmentations",
"build_clip_labeler",
"build_clipper",
"build_loss",
"build_train_dataset",
"build_train_loader",
@ -74,7 +72,6 @@ __all__ = [
"mask_time",
"mix_audio",
"scale_volume",
"select_subclip",
"train",
"warp_spectrogram",
]

View File

@ -11,9 +11,9 @@ from pydantic import Field
from soundevent import data
from soundevent.geometry import scale_geometry, shift_geometry
from batdetect2.audio.clips import get_subclip_annotation
from batdetect2.core.arrays import adjust_width
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.train.clips import get_subclip_annotation
from batdetect2.typing import AudioLoader, Augmentation
__all__ = [

View File

@ -5,13 +5,13 @@ from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate import Evaluator
from batdetect2.postprocess import get_raw_predictions
from batdetect2.logging import get_image_logger
from batdetect2.postprocess import to_raw_predictions
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import get_image_plotter
from batdetect2.typing import (
ClipEvaluation,
EvaluatorProtocol,
ModelOutput,
RawPrediction,
TrainExample,
@ -19,7 +19,7 @@ from batdetect2.typing import (
class ValidationMetrics(Callback):
def __init__(self, evaluator: Evaluator):
def __init__(self, evaluator: EvaluatorProtocol):
super().__init__()
self.evaluator = evaluator
@ -34,12 +34,12 @@ class ValidationMetrics(Callback):
assert isinstance(dataset, ValidationDataset)
return dataset
def plot_examples(
def generate_plots(
self,
pl_module: LightningModule,
evaluated_clips: List[ClipEvaluation],
):
plotter = get_image_plotter(pl_module.logger) # type: ignore
plotter = get_image_logger(pl_module.logger) # type: ignore
if plotter is None:
return
@ -66,7 +66,7 @@ class ValidationMetrics(Callback):
)
self.log_metrics(pl_module, clip_evaluations)
self.plot_examples(pl_module, clip_evaluations)
self.generate_plots(pl_module, clip_evaluations)
return super().on_validation_epoch_end(trainer, pl_module)
@ -88,8 +88,7 @@ class ValidationMetrics(Callback):
batch_idx: int,
dataloader_idx: int = 0,
) -> None:
postprocessor = pl_module.model.postprocessor
targets = pl_module.model.targets
model = pl_module.model
dataset = self.get_dataset(trainer)
clip_annotations = [
@ -97,15 +96,14 @@ class ValidationMetrics(Callback):
for example_idx in batch.idx
]
predictions = get_raw_predictions(
clip_detections = model.postprocessor(
outputs,
start_times=[
clip_annotation.clip.start_time
for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=postprocessor,
start_times=[ca.clip.start_time for ca in clip_annotations],
)
predictions = [
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
for clip_dets in clip_detections
]
self._clip_annotations.extend(clip_annotations)
self._predictions.extend(predictions)

View File

@ -5,17 +5,9 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import (
ClipConfig,
PaddedClipConfig,
RandomClipConfig,
)
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig
__all__ = [
@ -45,30 +37,6 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
@ -79,9 +47,8 @@ class TrainingConfig(BaseConfig):
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)

View File

@ -2,18 +2,21 @@ from typing import List, Optional, Sequence
import torch
from loguru import logger
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
from batdetect2.audio.clips import PaddedClipConfig
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
RandomAudioSource,
build_augmentations,
)
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import (
AudioLoader,
@ -144,6 +147,22 @@ class ValidationDataset(Dataset):
)
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
)
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
@ -178,6 +197,14 @@ def build_train_loader(
)
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,

View File

@ -8,20 +8,21 @@ from loguru import logger
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
from batdetect2.train.logging import build_logger
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
ClipLabeller,
EvaluatorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
@ -122,7 +123,7 @@ def train(
def build_trainer_callbacks(
targets: "TargetProtocol",
evaluator: Optional[Evaluator] = None,
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
@ -142,7 +143,6 @@ def build_trainer_callbacks(
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
filename="best-{epoch:02d}-{val_loss:.0f}",
monitor="total_loss/val",
),
ValidationMetrics(evaluator),
@ -152,7 +152,7 @@ def build_trainer_callbacks(
def build_trainer(
conf: "BatDetect2Config",
targets: "TargetProtocol",
evaluator: Optional[Evaluator] = None,
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,

View File

@ -1,7 +1,9 @@
from batdetect2.typing.evaluate import (
ClipEvaluation,
EvaluatorProtocol,
MatchEvaluation,
MetricsProtocol,
PlotterProtocol,
)
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
from batdetect2.typing.postprocess import (
@ -49,6 +51,7 @@ __all__ = [
"MatchEvaluation",
"MetricsProtocol",
"ModelOutput",
"PlotterProtocol",
"Position",
"PostprocessorProtocol",
"PreprocessorProtocol",
@ -60,4 +63,5 @@ __all__ = [
"SoundEventFilter",
"TargetProtocol",
"TrainExample",
"EvaluatorProtocol",
]

View File

@ -14,7 +14,11 @@ from typing import (
from matplotlib.figure import Figure
from soundevent import data
from batdetect2.typing.postprocess import RawPrediction
from batdetect2.typing.targets import TargetProtocol
__all__ = [
"EvaluatorProtocol",
"MetricsProtocol",
"MatchEvaluation",
]
@ -107,3 +111,21 @@ class PlotterProtocol(Protocol):
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...
class EvaluatorProtocol(Protocol):
targets: TargetProtocol
def evaluate(
self,
clip_annotations: Sequence[data.ClipAnnotation],
predictions: Sequence[Sequence[RawPrediction]],
) -> List[ClipEvaluation]: ...
def compute_metrics(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]: ...
def generate_plots(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Iterable[Tuple[str, Figure]]: ...

View File

@ -12,7 +12,7 @@ system that deal with model predictions.
"""
from dataclasses import dataclass
from typing import List, NamedTuple, Optional, Protocol
from typing import List, NamedTuple, Optional, Protocol, Sequence
import numpy as np
import torch
@ -47,15 +47,13 @@ class GeometryDecoder(Protocol):
class RawPrediction(NamedTuple):
"""Intermediate representation of a single detected sound event."""
geometry: data.Geometry
detection_score: float
class_scores: np.ndarray
features: np.ndarray
class DetectionsArray(NamedTuple):
class ClipDetectionsArray(NamedTuple):
scores: np.ndarray
sizes: np.ndarray
class_scores: np.ndarray
@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple):
features: np.ndarray
class DetectionsTensor(NamedTuple):
class ClipDetectionsTensor(NamedTuple):
scores: torch.Tensor
sizes: torch.Tensor
class_scores: torch.Tensor
@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple):
frequencies: torch.Tensor
features: torch.Tensor
def numpy(self) -> DetectionsArray:
return DetectionsArray(
def numpy(self) -> ClipDetectionsArray:
return ClipDetectionsArray(
scores=self.scores.detach().cpu().numpy(),
sizes=self.sizes.detach().cpu().numpy(),
class_scores=self.class_scores.detach().cpu().numpy(),
@ -92,10 +90,8 @@ class BatDetect2Prediction:
class PostprocessorProtocol(Protocol):
"""Protocol defining the interface for the full postprocessing pipeline."""
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
def get_detections(
def __call__(
self,
output: ModelOutput,
start_times: Optional[List[float]] = None,
) -> List[DetectionsTensor]: ...
start_times: Optional[Sequence[float]] = None,
) -> List[ClipDetectionsTensor]: ...