mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Evaluate using Lightning too to handle device changes
This commit is contained in:
parent
6c25787123
commit
e65df81db2
@ -5,7 +5,7 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.evaluate import Evaluator, build_evaluator
|
from batdetect2.evaluate import build_evaluator, evaluate
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
@ -14,6 +14,7 @@ from batdetect2.train import train
|
|||||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
|
EvaluatorProtocol,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
@ -28,7 +29,7 @@ class BatDetect2API:
|
|||||||
audio_loader: AudioLoader,
|
audio_loader: AudioLoader,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
evaluator: Evaluator,
|
evaluator: EvaluatorProtocol,
|
||||||
model: Model,
|
model: Model,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
@ -70,6 +71,27 @@ class BatDetect2API:
|
|||||||
)
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
output_dir: data.PathLike = ".",
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
return evaluate(
|
||||||
|
self.model,
|
||||||
|
test_annotations,
|
||||||
|
targets=self.targets,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
config=self.config,
|
||||||
|
num_workers=num_workers,
|
||||||
|
output_dir=output_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: BatDetect2Config):
|
def from_config(cls, config: BatDetect2Config):
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
@ -118,8 +140,14 @@ class BatDetect2API:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_checkpoint(cls, path: data.PathLike):
|
def from_checkpoint(
|
||||||
model, config = load_model_from_checkpoint(path)
|
cls,
|
||||||
|
path: data.PathLike,
|
||||||
|
config: Optional[BatDetect2Config] = None,
|
||||||
|
):
|
||||||
|
model, stored_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
|
config = config or stored_config
|
||||||
|
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
|
|||||||
@ -10,11 +10,17 @@ from batdetect2.cli.base import cli
|
|||||||
__all__ = ["evaluate_command"]
|
__all__ = ["evaluate_command"]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate")
|
@cli.command(name="evaluate")
|
||||||
@click.argument("model-path", type=click.Path(exists=True))
|
@click.argument("model-path", type=click.Path(exists=True))
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--output-dir", type=click.Path())
|
@click.option("--config", "config_path", type=click.Path())
|
||||||
@click.option("--workers", type=int)
|
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||||
|
@click.option("--experiment-name", type=str)
|
||||||
|
@click.option("--run-name", type=str)
|
||||||
|
@click.option("--workers", "num_workers", type=int)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
@ -24,13 +30,16 @@ __all__ = ["evaluate_command"]
|
|||||||
def evaluate_command(
|
def evaluate_command(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
output_dir: Optional[Path] = None,
|
config_path: Optional[Path],
|
||||||
workers: Optional[int] = None,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
|
from batdetect2.api.base import BatDetect2API
|
||||||
|
from batdetect2.config import load_full_config
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.evaluate.evaluate import evaluate
|
|
||||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
|
||||||
|
|
||||||
logger.remove()
|
logger.remove()
|
||||||
if verbose == 0:
|
if verbose == 0:
|
||||||
@ -49,16 +58,16 @@ def evaluate_command(
|
|||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
)
|
)
|
||||||
|
|
||||||
model, train_config = load_model_from_checkpoint(model_path)
|
config = None
|
||||||
|
if config_path is not None:
|
||||||
|
config = load_full_config(config_path)
|
||||||
|
|
||||||
df, results = evaluate(
|
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||||
model,
|
|
||||||
|
api.evaluate(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
config=train_config,
|
num_workers=num_workers,
|
||||||
num_workers=workers,
|
output_dir=output_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
print(results)
|
|
||||||
|
|
||||||
if output_dir:
|
|
||||||
df.to_csv(output_dir / "results.csv")
|
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import sys
|
|||||||
from typing import Generic, Protocol, Type, TypeVar
|
from typing import Generic, Protocol, Type, TypeVar
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing_extensions import assert_type
|
||||||
|
|
||||||
if sys.version_info >= (3, 10):
|
if sys.version_info >= (3, 10):
|
||||||
from typing import ParamSpec
|
from typing import ParamSpec
|
||||||
@ -44,7 +45,6 @@ class Registry(Generic[T_Type, P_Type]):
|
|||||||
config_cls: Type[T_Config],
|
config_cls: Type[T_Config],
|
||||||
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
logic_cls: LogicProtocol[T_Config, T_Type, P_Type],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A decorator factory to register a new item."""
|
|
||||||
fields = config_cls.model_fields
|
fields = config_cls.model_fields
|
||||||
|
|
||||||
if "name" not in fields:
|
if "name" not in fields:
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
|
from batdetect2.evaluate.evaluate import evaluate
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
|
"evaluate",
|
||||||
"Evaluator",
|
"Evaluator",
|
||||||
"build_evaluator",
|
"build_evaluator",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from batdetect2.evaluate.metrics import (
|
|||||||
MetricConfig,
|
MetricConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.evaluate.plots import PlotConfig
|
from batdetect2.evaluate.plots import PlotConfig
|
||||||
|
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"EvaluationConfig",
|
"EvaluationConfig",
|
||||||
@ -28,6 +29,7 @@ class EvaluationConfig(BaseConfig):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
plots: List[PlotConfig] = Field(default_factory=list)
|
plots: List[PlotConfig] = Field(default_factory=list)
|
||||||
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_evaluation_config(
|
def load_evaluation_config(
|
||||||
|
|||||||
144
src/batdetect2/evaluate/dataset.py
Normal file
144
src/batdetect2/evaluate/dataset.py
Normal file
@ -0,0 +1,144 @@
|
|||||||
|
from typing import List, NamedTuple, Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||||
|
from batdetect2.audio.clips import PaddedClipConfig
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.core.arrays import adjust_width
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
ClipperProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TestDataset",
|
||||||
|
"build_test_dataset",
|
||||||
|
"build_test_loader",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class TestExample(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
idx: torch.Tensor
|
||||||
|
start_time: torch.Tensor
|
||||||
|
end_time: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataset(Dataset[TestExample]):
|
||||||
|
clip_annotations: List[data.ClipAnnotation]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
clipper: Optional[ClipperProtocol] = None,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
):
|
||||||
|
self.clip_annotations = list(clip_annotations)
|
||||||
|
self.clipper = clipper
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.audio_dir = audio_dir
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> TestExample:
|
||||||
|
clip_annotation = self.clip_annotations[idx]
|
||||||
|
|
||||||
|
if self.clipper is not None:
|
||||||
|
clip_annotation = self.clipper(clip_annotation)
|
||||||
|
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
|
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||||
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
|
return TestExample(
|
||||||
|
spec=spectrogram,
|
||||||
|
idx=torch.tensor(idx),
|
||||||
|
start_time=torch.tensor(clip.start_time),
|
||||||
|
end_time=torch.tensor(clip.end_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
|
clipping_strategy: ClipConfig = Field(
|
||||||
|
default_factory=lambda: PaddedClipConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_loader(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[TestLoaderConfig] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> DataLoader[TestExample]:
|
||||||
|
logger.info("Building test data loader...")
|
||||||
|
config = config or TestLoaderConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Test data loader config: \n{config}",
|
||||||
|
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
test_dataset = build_test_dataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_workers = num_workers or config.num_workers
|
||||||
|
return DataLoader(
|
||||||
|
test_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_test_dataset(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[TestLoaderConfig] = None,
|
||||||
|
) -> TestDataset:
|
||||||
|
logger.info("Building training dataset...")
|
||||||
|
config = config or TestLoaderConfig()
|
||||||
|
|
||||||
|
clipper = build_clipper(config=config.clipping_strategy)
|
||||||
|
|
||||||
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
return TestDataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
clipper=clipper,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_fn(batch: List[TestExample]) -> TestExample:
|
||||||
|
max_width = max(item.spec.shape[-1] for item in batch)
|
||||||
|
return TestExample(
|
||||||
|
spec=torch.stack(
|
||||||
|
[adjust_width(item.spec, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
idx=torch.stack([item.idx for item in batch]),
|
||||||
|
start_time=torch.stack([item.start_time for item in batch]),
|
||||||
|
end_time=torch.stack([item.end_time for item in batch]),
|
||||||
|
)
|
||||||
@ -1,35 +1,44 @@
|
|||||||
from typing import List, Optional, Tuple
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Optional, Sequence
|
||||||
|
|
||||||
import pandas as pd
|
from lightning import Trainer
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.dataset import build_test_loader
|
||||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.evaluate.metrics import ClassificationAP, DetectionAP
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.plotting.clips import AudioLoader, PreprocessorProtocol
|
|
||||||
from batdetect2.postprocess import get_raw_predictions
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
if TYPE_CHECKING:
|
||||||
from batdetect2.train.train import build_val_loader
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.typing import ClipLabeller, TargetProtocol
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
model: Model,
|
model: Model,
|
||||||
test_annotations: List[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: Optional[TargetProtocol] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
labeller: Optional[ClipLabeller] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
config: Optional[EvaluationConfig] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
) -> Tuple[pd.DataFrame, dict]:
|
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
|
||||||
config = config or EvaluationConfig()
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
):
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
|
config = config or BatDetect2Config()
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader()
|
audio_loader = audio_loader or build_audio_loader()
|
||||||
|
|
||||||
@ -39,60 +48,21 @@ def evaluate(
|
|||||||
|
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
|
|
||||||
labeller = labeller or build_clip_labeler(
|
loader = build_test_loader(
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
loader = build_val_loader(
|
|
||||||
test_annotations,
|
test_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset: ValidationDataset = loader.dataset # type: ignore
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
|
|
||||||
clip_annotations = []
|
logger = build_logger(
|
||||||
predictions = []
|
config.evaluation.logger,
|
||||||
|
log_dir=Path(output_dir),
|
||||||
evaluator = build_evaluator(config=config, targets=targets)
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
for batch in loader:
|
)
|
||||||
outputs = model.detector(batch.spec)
|
module = EvaluationModule(model, evaluator)
|
||||||
|
trainer = Trainer(logger=logger, enable_checkpointing=False)
|
||||||
clip_annotations = [
|
return trainer.test(module, loader)
|
||||||
dataset.clip_annotations[int(example_idx)]
|
|
||||||
for example_idx in batch.idx
|
|
||||||
]
|
|
||||||
|
|
||||||
predictions = get_raw_predictions(
|
|
||||||
outputs,
|
|
||||||
start_times=[
|
|
||||||
clip_annotation.clip.start_time
|
|
||||||
for clip_annotation in clip_annotations
|
|
||||||
],
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=model.postprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
clip_annotations.extend(clip_annotations)
|
|
||||||
predictions.extend(predictions)
|
|
||||||
|
|
||||||
matches = evaluator.evaluate(clip_annotations, predictions)
|
|
||||||
df = extract_matches_dataframe(matches)
|
|
||||||
|
|
||||||
metrics = [
|
|
||||||
DetectionAP(),
|
|
||||||
ClassificationAP(class_names=targets.class_names),
|
|
||||||
]
|
|
||||||
|
|
||||||
results = {
|
|
||||||
name: value
|
|
||||||
for metric in metrics
|
|
||||||
for name, value in metric(matches).items()
|
|
||||||
}
|
|
||||||
|
|
||||||
return df, results
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from batdetect2.evaluate.plots import build_plotter
|
|||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing.evaluate import (
|
||||||
ClipEvaluation,
|
ClipEvaluation,
|
||||||
|
EvaluatorProtocol,
|
||||||
MatcherProtocol,
|
MatcherProtocol,
|
||||||
MetricsProtocol,
|
MetricsProtocol,
|
||||||
PlotterProtocol,
|
PlotterProtocol,
|
||||||
@ -135,7 +136,7 @@ def build_evaluator(
|
|||||||
matcher: Optional[MatcherProtocol] = None,
|
matcher: Optional[MatcherProtocol] = None,
|
||||||
plots: Optional[List[PlotterProtocol]] = None,
|
plots: Optional[List[PlotterProtocol]] = None,
|
||||||
metrics: Optional[List[MetricsProtocol]] = None,
|
metrics: Optional[List[MetricsProtocol]] = None,
|
||||||
) -> Evaluator:
|
) -> EvaluatorProtocol:
|
||||||
config = config or EvaluationConfig()
|
config = config or EvaluationConfig()
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets()
|
||||||
matcher = matcher or build_matcher(config.match_strategy)
|
matcher = matcher or build_matcher(config.match_strategy)
|
||||||
|
|||||||
86
src/batdetect2/evaluate/lightning.py
Normal file
86
src/batdetect2/evaluate/lightning.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from lightning import LightningModule
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
|
from batdetect2.evaluate.tables import FullEvaluationTable
|
||||||
|
from batdetect2.logging import get_image_logger, get_table_logger
|
||||||
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
|
from batdetect2.typing import ClipEvaluation, EvaluatorProtocol
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluationModule(LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Model,
|
||||||
|
evaluator: EvaluatorProtocol,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.model = model
|
||||||
|
self.evaluator = evaluator
|
||||||
|
|
||||||
|
self.clip_evaluations = []
|
||||||
|
|
||||||
|
def test_step(self, batch: TestExample):
|
||||||
|
dataset = self.get_dataset()
|
||||||
|
clip_annotations = [
|
||||||
|
dataset.clip_annotations[int(example_idx)]
|
||||||
|
for example_idx in batch.idx
|
||||||
|
]
|
||||||
|
|
||||||
|
outputs = self.model.detector(batch.spec)
|
||||||
|
clip_detections = self.model.postprocessor(
|
||||||
|
outputs,
|
||||||
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
|
)
|
||||||
|
predictions = [
|
||||||
|
to_raw_predictions(
|
||||||
|
clip_dets.numpy(),
|
||||||
|
targets=self.evaluator.targets,
|
||||||
|
)
|
||||||
|
for clip_dets in clip_detections
|
||||||
|
]
|
||||||
|
|
||||||
|
self.clip_evaluations.extend(
|
||||||
|
self.evaluator.evaluate(clip_annotations, predictions)
|
||||||
|
)
|
||||||
|
|
||||||
|
def on_test_epoch_start(self):
|
||||||
|
self.clip_evaluations = []
|
||||||
|
|
||||||
|
def on_test_epoch_end(self):
|
||||||
|
self.log_metrics(self.clip_evaluations)
|
||||||
|
self.plot_examples(self.clip_evaluations)
|
||||||
|
self.log_table(self.clip_evaluations)
|
||||||
|
|
||||||
|
def log_table(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||||
|
table_logger = get_table_logger(self.logger) # type: ignore
|
||||||
|
|
||||||
|
if table_logger is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
df = FullEvaluationTable()(evaluated_clips)
|
||||||
|
table_logger("full_evaluation", df, 0)
|
||||||
|
|
||||||
|
def plot_examples(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||||
|
plotter = get_image_logger(self.logger) # type: ignore
|
||||||
|
|
||||||
|
if plotter is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||||
|
plotter(figure_name, fig, self.global_step)
|
||||||
|
|
||||||
|
def log_metrics(self, evaluated_clips: Sequence[ClipEvaluation]):
|
||||||
|
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||||
|
self.log_dict(metrics)
|
||||||
|
|
||||||
|
def get_dataset(self) -> TestDataset:
|
||||||
|
dataloaders = self.trainer.test_dataloaders
|
||||||
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
dataset = dataloaders.dataset
|
||||||
|
assert isinstance(dataset, TestDataset)
|
||||||
|
return dataset
|
||||||
@ -15,10 +15,8 @@ import numpy as np
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from sklearn import metrics, preprocessing
|
from sklearn import metrics, preprocessing
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.core.registries import Registry
|
from batdetect2.typing import ClipEvaluation, MetricsProtocol
|
||||||
from batdetect2.typing import MetricsProtocol
|
|
||||||
from batdetect2.typing.evaluate import ClipEvaluation
|
|
||||||
|
|
||||||
__all__ = ["DetectionAP", "ClassificationAP"]
|
__all__ = ["DetectionAP", "ClassificationAP"]
|
||||||
|
|
||||||
@ -375,14 +373,96 @@ metrics_registry.register(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ClipAPConfig(BaseConfig):
|
class ClipDetectionAPConfig(BaseConfig):
|
||||||
name: Literal["clip_ap"] = "clip_ap"
|
name: Literal["clip_detection_ap"] = "clip_detection_ap"
|
||||||
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionAP(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
implementation: APImplementation,
|
||||||
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_det = []
|
||||||
|
clip_scores = []
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
clip_det.append(match.gt_det)
|
||||||
|
clip_scores.append(match.pred_score)
|
||||||
|
|
||||||
|
y_true.append(any(clip_det))
|
||||||
|
y_score.append(max(clip_scores or [0]))
|
||||||
|
|
||||||
|
return {"clip_detection_ap": self.metric(y_true, y_score)}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipDetectionAPConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(implementation=config.ap_implementation)
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipDetectionAPConfig, ClipDetectionAP)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUCConfig(BaseConfig):
|
||||||
|
name: Literal["clip_detection_roc_auc"] = "clip_detection_roc_auc"
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDetectionROCAUC(MetricsProtocol):
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
y_true = []
|
||||||
|
y_score = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
clip_det = []
|
||||||
|
clip_scores = []
|
||||||
|
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
clip_det.append(match.gt_det)
|
||||||
|
clip_scores.append(match.pred_score)
|
||||||
|
|
||||||
|
y_true.append(any(clip_det))
|
||||||
|
y_score.append(max(clip_scores or [0]))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"clip_detection_ap": float(metrics.roc_auc_score(y_true, y_score))
|
||||||
|
}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ClipDetectionROCAUCConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
metrics_registry.register(ClipDetectionROCAUCConfig, ClipDetectionROCAUC)
|
||||||
|
|
||||||
|
|
||||||
|
class ClipMulticlassAPConfig(BaseConfig):
|
||||||
|
name: Literal["clip_multiclass_ap"] = "clip_multiclass_ap"
|
||||||
ap_implementation: APImplementation = "pascal_voc"
|
ap_implementation: APImplementation = "pascal_voc"
|
||||||
include: Optional[List[str]] = None
|
include: Optional[List[str]] = None
|
||||||
exclude: Optional[List[str]] = None
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ClipAP(MetricsProtocol):
|
class ClipMulticlassAP(MetricsProtocol):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
@ -454,15 +534,17 @@ class ClipAP(MetricsProtocol):
|
|||||||
[value for value in class_scores.values() if value != 0]
|
[value for value in class_scores.values() if value != 0]
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"clip_mAP": float(mean_ap),
|
"clip_multiclass_mAP": float(mean_ap),
|
||||||
**{
|
**{
|
||||||
f"clip_AP/{class_name}": class_scores[class_name]
|
f"clip_multiclass_AP/{class_name}": class_scores[class_name]
|
||||||
for class_name in self.selected
|
for class_name in self.selected
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: ClipAPConfig, class_names: List[str]):
|
def from_config(
|
||||||
|
cls, config: ClipMulticlassAPConfig, class_names: List[str]
|
||||||
|
):
|
||||||
return cls(
|
return cls(
|
||||||
implementation=config.ap_implementation,
|
implementation=config.ap_implementation,
|
||||||
include=config.include,
|
include=config.include,
|
||||||
@ -471,16 +553,16 @@ class ClipAP(MetricsProtocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClipAPConfig, ClipAP)
|
metrics_registry.register(ClipMulticlassAPConfig, ClipMulticlassAP)
|
||||||
|
|
||||||
|
|
||||||
class ClipROCAUCConfig(BaseConfig):
|
class ClipMulticlassROCAUCConfig(BaseConfig):
|
||||||
name: Literal["clip_roc_auc"] = "clip_roc_auc"
|
name: Literal["clip_multiclass_roc_auc"] = "clip_multiclass_roc_auc"
|
||||||
include: Optional[List[str]] = None
|
include: Optional[List[str]] = None
|
||||||
exclude: Optional[List[str]] = None
|
exclude: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class ClipROCAUC(MetricsProtocol):
|
class ClipMulticlassROCAUC(MetricsProtocol):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
@ -548,9 +630,11 @@ class ClipROCAUC(MetricsProtocol):
|
|||||||
[value for value in class_scores.values() if value != 0]
|
[value for value in class_scores.values() if value != 0]
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"clip_macro_ROC_AUC": float(mean_roc_auc),
|
"clip_multiclass_macro_ROC_AUC": float(mean_roc_auc),
|
||||||
**{
|
**{
|
||||||
f"clip_ROC_AUC/{class_name}": class_scores[class_name]
|
f"clip_multiclass_ROC_AUC/{class_name}": class_scores[
|
||||||
|
class_name
|
||||||
|
]
|
||||||
for class_name in self.selected
|
for class_name in self.selected
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -558,7 +642,7 @@ class ClipROCAUC(MetricsProtocol):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: ClipROCAUCConfig,
|
config: ClipMulticlassROCAUCConfig,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
@ -568,7 +652,7 @@ class ClipROCAUC(MetricsProtocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(ClipROCAUCConfig, ClipROCAUC)
|
metrics_registry.register(ClipMulticlassROCAUCConfig, ClipMulticlassROCAUC)
|
||||||
|
|
||||||
MetricConfig = Annotated[
|
MetricConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
@ -578,8 +662,10 @@ MetricConfig = Annotated[
|
|||||||
ClassificationROCAUCConfig,
|
ClassificationROCAUCConfig,
|
||||||
TopClassAPConfig,
|
TopClassAPConfig,
|
||||||
ClassificationBalancedAccuracyConfig,
|
ClassificationBalancedAccuracyConfig,
|
||||||
ClipAPConfig,
|
ClipDetectionAPConfig,
|
||||||
ClipROCAUCConfig,
|
ClipDetectionROCAUCConfig,
|
||||||
|
ClipMulticlassAPConfig,
|
||||||
|
ClipMulticlassROCAUCConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -10,19 +10,18 @@ from pydantic import Field
|
|||||||
from sklearn import metrics
|
from sklearn import metrics
|
||||||
from sklearn.preprocessing import label_binarize
|
from sklearn.preprocessing import label_binarize
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core import BaseConfig, Registry
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
|
||||||
from batdetect2.plotting.gallery import plot_match_gallery
|
from batdetect2.plotting.gallery import plot_match_gallery
|
||||||
from batdetect2.plotting.matches import plot_matches
|
from batdetect2.plotting.matches import plot_matches
|
||||||
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
ClipEvaluation,
|
ClipEvaluation,
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
PlotterProtocol,
|
PlotterProtocol,
|
||||||
|
PreprocessorProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import AudioLoader
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_plotter",
|
"build_plotter",
|
||||||
@ -431,6 +430,62 @@ class ClassificationROCCurves(PlotterProtocol):
|
|||||||
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
plots_registry.register(ClassificationROCCurvesConfig, ClassificationROCCurves)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrixConfig(BaseConfig):
|
||||||
|
name: Literal["confusion_matrix"] = "confusion_matrix"
|
||||||
|
background_class: str = "noise"
|
||||||
|
|
||||||
|
|
||||||
|
class ConfusionMatrix(PlotterProtocol):
|
||||||
|
def __init__(self, background_class: str, class_names: List[str]):
|
||||||
|
self.background_class = background_class
|
||||||
|
self.class_names = class_names
|
||||||
|
|
||||||
|
def __call__(self, clip_evaluations: Sequence[ClipEvaluation]):
|
||||||
|
y_true = []
|
||||||
|
y_pred = []
|
||||||
|
|
||||||
|
for clip_eval in clip_evaluations:
|
||||||
|
for match in clip_eval.matches:
|
||||||
|
# Ignore generic unclassified targets
|
||||||
|
if match.gt_det and match.gt_class is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
y_true.append(
|
||||||
|
match.gt_class
|
||||||
|
if match.gt_class is not None
|
||||||
|
else self.background_class
|
||||||
|
)
|
||||||
|
|
||||||
|
top_class = match.pred_class
|
||||||
|
y_pred.append(
|
||||||
|
top_class
|
||||||
|
if top_class is not None
|
||||||
|
else self.background_class
|
||||||
|
)
|
||||||
|
|
||||||
|
display = metrics.ConfusionMatrixDisplay.from_predictions(
|
||||||
|
y_true,
|
||||||
|
y_pred,
|
||||||
|
labels=[*self.class_names, self.background_class],
|
||||||
|
)
|
||||||
|
|
||||||
|
yield "confusion_matrix", display.figure_
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(
|
||||||
|
cls,
|
||||||
|
config: ConfusionMatrixConfig,
|
||||||
|
class_names: List[str],
|
||||||
|
):
|
||||||
|
return cls(
|
||||||
|
background_class=config.background_class,
|
||||||
|
class_names=class_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
plots_registry.register(ConfusionMatrixConfig, ConfusionMatrix)
|
||||||
|
|
||||||
|
|
||||||
PlotConfig = Annotated[
|
PlotConfig = Annotated[
|
||||||
Union[
|
Union[
|
||||||
ExampleGalleryConfig,
|
ExampleGalleryConfig,
|
||||||
@ -439,6 +494,7 @@ PlotConfig = Annotated[
|
|||||||
ClassificationPRCurvesConfig,
|
ClassificationPRCurvesConfig,
|
||||||
DetectionROCCurveConfig,
|
DetectionROCCurveConfig,
|
||||||
ClassificationROCCurvesConfig,
|
ClassificationROCCurvesConfig,
|
||||||
|
ConfusionMatrixConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="name"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1,18 +1,49 @@
|
|||||||
from typing import List
|
from typing import Annotated, Callable, Literal, Sequence, Union
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
from pydantic import Field
|
||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.typing.evaluate import ClipEvaluation
|
from batdetect2.core import BaseConfig, Registry
|
||||||
|
from batdetect2.typing import ClipEvaluation
|
||||||
|
|
||||||
|
EvaluationTableGenerator = Callable[[Sequence[ClipEvaluation]], pd.DataFrame]
|
||||||
|
|
||||||
|
|
||||||
def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.DataFrame:
|
tables_registry: Registry[EvaluationTableGenerator, []] = Registry(
|
||||||
|
"evaluation_table"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FullEvaluationTableConfig(BaseConfig):
|
||||||
|
name: Literal["full_evaluation"] = "full_evaluation"
|
||||||
|
|
||||||
|
|
||||||
|
class FullEvaluationTable:
|
||||||
|
def __call__(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
return extract_matches_dataframe(clip_evaluations)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_config(cls, config: FullEvaluationTableConfig):
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
|
||||||
|
tables_registry.register(FullEvaluationTableConfig, FullEvaluationTable)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_matches_dataframe(
|
||||||
|
clip_evaluations: Sequence[ClipEvaluation],
|
||||||
|
) -> pd.DataFrame:
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
for clip_evaluation in clip_evaluations:
|
for clip_evaluation in clip_evaluations:
|
||||||
for match in clip_evaluation.matches:
|
for match in clip_evaluation.matches:
|
||||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
pred_start_time = pred_low_freq = pred_end_time = (
|
||||||
|
pred_high_freq
|
||||||
|
) = None
|
||||||
|
|
||||||
sound_event_annotation = match.sound_event_annotation
|
sound_event_annotation = match.sound_event_annotation
|
||||||
|
|
||||||
@ -24,9 +55,12 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
|||||||
)
|
)
|
||||||
|
|
||||||
if match.pred_geometry is not None:
|
if match.pred_geometry is not None:
|
||||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
(
|
||||||
compute_bounds(match.pred_geometry)
|
pred_start_time,
|
||||||
)
|
pred_low_freq,
|
||||||
|
pred_end_time,
|
||||||
|
pred_high_freq,
|
||||||
|
) = compute_bounds(match.pred_geometry)
|
||||||
|
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
@ -61,3 +95,14 @@ def extract_matches_dataframe(clip_evaluations: List[ClipEvaluation]) -> pd.Data
|
|||||||
df = pd.DataFrame(data)
|
df = pd.DataFrame(data)
|
||||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||||
return df
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
EvaluationTableConfig = Annotated[
|
||||||
|
Union[FullEvaluationTableConfig,], Field(discriminator="name")
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_table_generator(
|
||||||
|
config: EvaluationTableConfig,
|
||||||
|
) -> EvaluationTableGenerator:
|
||||||
|
return tables_registry.build(config)
|
||||||
@ -67,7 +67,7 @@ from batdetect2.preprocess import build_preprocessor
|
|||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.typing.models import DetectionModel
|
from batdetect2.typing.models import DetectionModel
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
DetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
@ -121,7 +121,7 @@ class Model(torch.nn.Module):
|
|||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> List[ClipDetectionsTensor]:
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
outputs = self.detector(spec)
|
outputs = self.detector(spec)
|
||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|||||||
@ -8,13 +8,10 @@ from batdetect2.postprocess.decoding import (
|
|||||||
convert_raw_predictions_to_clip_prediction,
|
convert_raw_predictions_to_clip_prediction,
|
||||||
to_raw_predictions,
|
to_raw_predictions,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.nms import (
|
from batdetect2.postprocess.nms import non_max_suppression
|
||||||
non_max_suppression,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess.postprocessor import (
|
from batdetect2.postprocess.postprocessor import (
|
||||||
Postprocessor,
|
Postprocessor,
|
||||||
build_postprocessor,
|
build_postprocessor,
|
||||||
get_raw_predictions,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -25,5 +22,4 @@ __all__ = [
|
|||||||
"to_raw_predictions",
|
"to_raw_predictions",
|
||||||
"load_postprocess_config",
|
"load_postprocess_config",
|
||||||
"non_max_suppression",
|
"non_max_suppression",
|
||||||
"get_raw_predictions",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import numpy as np
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
DetectionsArray,
|
ClipDetectionsArray,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
@ -28,7 +28,7 @@ decoding.
|
|||||||
|
|
||||||
|
|
||||||
def to_raw_predictions(
|
def to_raw_predictions(
|
||||||
detections: DetectionsArray,
|
detections: ClipDetectionsArray,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> List[RawPrediction]:
|
) -> List[RawPrediction]:
|
||||||
predictions = []
|
predictions = []
|
||||||
|
|||||||
@ -15,32 +15,25 @@ precise time-frequency location of each detection. The final output aggregates
|
|||||||
all extracted information into a structured `xarray.Dataset`.
|
all extracted information into a structured `xarray.Dataset`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||||
from batdetect2.typing.postprocess import (
|
|
||||||
DetectionsTensor,
|
|
||||||
ModelOutput,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"extract_prediction_tensor",
|
"extract_detection_peaks",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def extract_prediction_tensor(
|
def extract_detection_peaks(
|
||||||
output: ModelOutput,
|
detection_heatmap: torch.Tensor,
|
||||||
|
size_heatmap: torch.Tensor,
|
||||||
|
feature_heatmap: torch.Tensor,
|
||||||
|
classification_heatmap: torch.Tensor,
|
||||||
max_detections: int = 200,
|
max_detections: int = 200,
|
||||||
threshold: Optional[float] = None,
|
threshold: Optional[float] = None,
|
||||||
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
) -> List[ClipDetectionsTensor]:
|
||||||
) -> List[DetectionsTensor]:
|
|
||||||
detection_heatmap = non_max_suppression(
|
|
||||||
output.detection_probs.detach(),
|
|
||||||
kernel_size=nms_kernel_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
height = detection_heatmap.shape[-2]
|
height = detection_heatmap.shape[-2]
|
||||||
width = detection_heatmap.shape[-1]
|
width = detection_heatmap.shape[-1]
|
||||||
|
|
||||||
@ -53,9 +46,9 @@ def extract_prediction_tensor(
|
|||||||
freqs = freqs.flatten().to(detection_heatmap.device)
|
freqs = freqs.flatten().to(detection_heatmap.device)
|
||||||
times = times.flatten().to(detection_heatmap.device)
|
times = times.flatten().to(detection_heatmap.device)
|
||||||
|
|
||||||
output_size_preds = output.size_preds.detach()
|
output_size_preds = size_heatmap.detach()
|
||||||
output_features = output.features.detach()
|
output_features = feature_heatmap.detach()
|
||||||
output_class_probs = output.class_probs.detach()
|
output_class_probs = classification_heatmap.detach()
|
||||||
|
|
||||||
predictions = []
|
predictions = []
|
||||||
for idx, item in enumerate(detection_heatmap):
|
for idx, item in enumerate(detection_heatmap):
|
||||||
@ -65,23 +58,25 @@ def extract_prediction_tensor(
|
|||||||
detection_scores = item.take(indices)
|
detection_scores = item.take(indices)
|
||||||
detection_freqs = freqs.take(indices)
|
detection_freqs = freqs.take(indices)
|
||||||
detection_times = times.take(indices)
|
detection_times = times.take(indices)
|
||||||
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
|
||||||
features = output_features[idx, :, detection_freqs, detection_times].T
|
|
||||||
class_scores = output_class_probs[
|
|
||||||
idx, :, detection_freqs, detection_times
|
|
||||||
].T
|
|
||||||
|
|
||||||
if threshold is not None:
|
if threshold is not None:
|
||||||
mask = detection_scores >= threshold
|
mask = detection_scores >= threshold
|
||||||
|
|
||||||
detection_scores = detection_scores[mask]
|
detection_scores = detection_scores[mask]
|
||||||
sizes = sizes[mask]
|
|
||||||
detection_times = detection_times[mask]
|
detection_times = detection_times[mask]
|
||||||
detection_freqs = detection_freqs[mask]
|
detection_freqs = detection_freqs[mask]
|
||||||
features = features[mask]
|
|
||||||
class_scores = class_scores[mask]
|
sizes = output_size_preds[idx, :, detection_freqs, detection_times].T
|
||||||
|
features = output_features[idx, :, detection_freqs, detection_times].T
|
||||||
|
class_scores = output_class_probs[
|
||||||
|
idx,
|
||||||
|
:,
|
||||||
|
detection_freqs,
|
||||||
|
detection_times,
|
||||||
|
].T
|
||||||
|
|
||||||
predictions.append(
|
predictions.append(
|
||||||
DetectionsTensor(
|
ClipDetectionsTensor(
|
||||||
scores=detection_scores,
|
scores=detection_scores,
|
||||||
sizes=sizes,
|
sizes=sizes,
|
||||||
features=features,
|
features=features,
|
||||||
|
|||||||
@ -1,29 +1,20 @@
|
|||||||
from typing import List, Optional
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.postprocess.config import (
|
from batdetect2.postprocess.config import (
|
||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess.decoding import (
|
from batdetect2.postprocess.extraction import extract_detection_peaks
|
||||||
DEFAULT_CLASSIFICATION_THRESHOLD,
|
from batdetect2.postprocess.nms import NMS_KERNEL_SIZE, non_max_suppression
|
||||||
convert_raw_prediction_to_sound_event_prediction,
|
|
||||||
convert_raw_predictions_to_clip_prediction,
|
|
||||||
to_raw_predictions,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess.extraction import extract_prediction_tensor
|
|
||||||
from batdetect2.postprocess.remapping import map_detection_to_clip
|
from batdetect2.postprocess.remapping import map_detection_to_clip
|
||||||
from batdetect2.typing import ModelOutput
|
from batdetect2.typing import ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
BatDetect2Prediction,
|
ClipDetectionsTensor,
|
||||||
DetectionsTensor,
|
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
RawPrediction,
|
|
||||||
)
|
)
|
||||||
from batdetect2.typing.preprocess import PreprocessorProtocol
|
from batdetect2.typing.preprocess import PreprocessorProtocol
|
||||||
from batdetect2.typing.targets import TargetProtocol
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_postprocessor",
|
"build_postprocessor",
|
||||||
@ -60,24 +51,43 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
max_freq: float,
|
max_freq: float,
|
||||||
top_k_per_sec: int = 200,
|
top_k_per_sec: int = 200,
|
||||||
detection_threshold: float = 0.01,
|
detection_threshold: float = 0.01,
|
||||||
|
nms_kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
):
|
):
|
||||||
"""Initialize the Postprocessor."""
|
"""Initialize the Postprocessor."""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.samplerate = samplerate
|
|
||||||
|
self.output_samplerate = samplerate
|
||||||
self.min_freq = min_freq
|
self.min_freq = min_freq
|
||||||
self.max_freq = max_freq
|
self.max_freq = max_freq
|
||||||
self.top_k_per_sec = top_k_per_sec
|
self.top_k_per_sec = top_k_per_sec
|
||||||
self.detection_threshold = detection_threshold
|
self.detection_threshold = detection_threshold
|
||||||
|
self.nms_kernel_size = nms_kernel_size
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
output: ModelOutput,
|
||||||
|
start_times: Optional[List[float]] = None,
|
||||||
|
) -> List[ClipDetectionsTensor]:
|
||||||
|
detection_heatmap = non_max_suppression(
|
||||||
|
output.detection_probs.detach(),
|
||||||
|
kernel_size=self.nms_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, output: ModelOutput) -> List[DetectionsTensor]:
|
|
||||||
width = output.detection_probs.shape[-1]
|
width = output.detection_probs.shape[-1]
|
||||||
duration = width / self.samplerate
|
duration = width / self.output_samplerate
|
||||||
max_detections = int(self.top_k_per_sec * duration)
|
max_detections = int(self.top_k_per_sec * duration)
|
||||||
detections = extract_prediction_tensor(
|
detections = extract_detection_peaks(
|
||||||
output,
|
detection_heatmap,
|
||||||
|
size_heatmap=output.size_preds,
|
||||||
|
feature_heatmap=output.features,
|
||||||
|
classification_heatmap=output.class_probs,
|
||||||
max_detections=max_detections,
|
max_detections=max_detections,
|
||||||
threshold=self.detection_threshold,
|
threshold=self.detection_threshold,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if start_times is None:
|
||||||
|
start_times = [0 for _ in range(len(detections))]
|
||||||
|
|
||||||
return [
|
return [
|
||||||
map_detection_to_clip(
|
map_detection_to_clip(
|
||||||
detection,
|
detection,
|
||||||
@ -88,121 +98,3 @@ class Postprocessor(torch.nn.Module, PostprocessorProtocol):
|
|||||||
)
|
)
|
||||||
for detection in detections
|
for detection in detections
|
||||||
]
|
]
|
||||||
|
|
||||||
def get_detections(
|
|
||||||
self,
|
|
||||||
output: ModelOutput,
|
|
||||||
start_times: Optional[List[float]] = None,
|
|
||||||
) -> List[DetectionsTensor]:
|
|
||||||
width = output.detection_probs.shape[-1]
|
|
||||||
duration = width / self.samplerate
|
|
||||||
max_detections = int(self.top_k_per_sec * duration)
|
|
||||||
|
|
||||||
detections = extract_prediction_tensor(
|
|
||||||
output,
|
|
||||||
max_detections=max_detections,
|
|
||||||
threshold=self.detection_threshold,
|
|
||||||
)
|
|
||||||
|
|
||||||
if start_times is None:
|
|
||||||
return detections
|
|
||||||
|
|
||||||
width = output.detection_probs.shape[-1]
|
|
||||||
duration = width / self.samplerate
|
|
||||||
return [
|
|
||||||
map_detection_to_clip(
|
|
||||||
detection,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=start_time + duration,
|
|
||||||
min_freq=self.min_freq,
|
|
||||||
max_freq=self.max_freq,
|
|
||||||
)
|
|
||||||
for detection, start_time in zip(detections, start_times)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_raw_predictions(
|
|
||||||
output: ModelOutput,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
start_times: Optional[List[float]] = None,
|
|
||||||
) -> List[List[RawPrediction]]:
|
|
||||||
"""Extract intermediate RawPrediction objects for a batch."""
|
|
||||||
detections = postprocessor.get_detections(output, start_times)
|
|
||||||
return [
|
|
||||||
to_raw_predictions(detection.numpy(), targets=targets)
|
|
||||||
for detection in detections
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_sound_event_predictions(
|
|
||||||
output: ModelOutput,
|
|
||||||
targets: TargetProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
) -> List[List[BatDetect2Prediction]]:
|
|
||||||
raw_predictions = get_raw_predictions(
|
|
||||||
output,
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
start_times=[clip.start_time for clip in clips],
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
[
|
|
||||||
BatDetect2Prediction(
|
|
||||||
raw=raw,
|
|
||||||
sound_event_prediction=convert_raw_prediction_to_sound_event_prediction(
|
|
||||||
raw,
|
|
||||||
recording=clip.recording,
|
|
||||||
targets=targets,
|
|
||||||
classification_threshold=classification_threshold,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
for raw in predictions
|
|
||||||
]
|
|
||||||
for predictions, clip in zip(raw_predictions, clips)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def get_predictions(
|
|
||||||
output: ModelOutput,
|
|
||||||
clips: List[data.Clip],
|
|
||||||
targets: TargetProtocol,
|
|
||||||
postprocessor: PostprocessorProtocol,
|
|
||||||
classification_threshold: float = DEFAULT_CLASSIFICATION_THRESHOLD,
|
|
||||||
) -> List[data.ClipPrediction]:
|
|
||||||
"""Perform the full postprocessing pipeline for a batch.
|
|
||||||
|
|
||||||
Takes raw model output and corresponding clips, applies the entire
|
|
||||||
configured chain (NMS, remapping, extraction, geometry recovery, class
|
|
||||||
decoding), producing final `soundevent.data.ClipPrediction` objects.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
output : ModelOutput
|
|
||||||
Raw output from the neural network model for a batch.
|
|
||||||
clips : List[data.Clip]
|
|
||||||
List of `soundevent.data.Clip` objects corresponding to the batch.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
List[data.ClipPrediction]
|
|
||||||
List containing one `ClipPrediction` object for each input clip,
|
|
||||||
populated with `SoundEventPrediction` objects.
|
|
||||||
"""
|
|
||||||
raw_predictions = get_raw_predictions(
|
|
||||||
output,
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
start_times=[clip.start_time for clip in clips],
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
convert_raw_predictions_to_clip_prediction(
|
|
||||||
prediction,
|
|
||||||
clip,
|
|
||||||
targets=targets,
|
|
||||||
classification_threshold=classification_threshold,
|
|
||||||
)
|
|
||||||
for prediction, clip in zip(raw_predictions, clips)
|
|
||||||
]
|
|
||||||
|
|||||||
@ -20,7 +20,7 @@ import xarray as xr
|
|||||||
from soundevent.arrays import Dimensions
|
from soundevent.arrays import Dimensions
|
||||||
|
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.typing.postprocess import DetectionsTensor
|
from batdetect2.typing.postprocess import ClipDetectionsTensor
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"features_to_xarray",
|
"features_to_xarray",
|
||||||
@ -31,15 +31,15 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
def map_detection_to_clip(
|
def map_detection_to_clip(
|
||||||
detections: DetectionsTensor,
|
detections: ClipDetectionsTensor,
|
||||||
start_time: float,
|
start_time: float,
|
||||||
end_time: float,
|
end_time: float,
|
||||||
min_freq: float,
|
min_freq: float,
|
||||||
max_freq: float,
|
max_freq: float,
|
||||||
) -> DetectionsTensor:
|
) -> ClipDetectionsTensor:
|
||||||
duration = end_time - start_time
|
duration = end_time - start_time
|
||||||
bandwidth = max_freq - min_freq
|
bandwidth = max_freq - min_freq
|
||||||
return DetectionsTensor(
|
return ClipDetectionsTensor(
|
||||||
scores=detections.scores,
|
scores=detections.scores,
|
||||||
sizes=detections.sizes,
|
sizes=detections.sizes,
|
||||||
features=detections.features,
|
features=detections.features,
|
||||||
|
|||||||
@ -14,7 +14,6 @@ from batdetect2.train.augmentations import (
|
|||||||
scale_volume,
|
scale_volume,
|
||||||
warp_spectrogram,
|
warp_spectrogram,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper, select_subclip
|
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
PLTrainerConfig,
|
PLTrainerConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
@ -61,7 +60,6 @@ __all__ = [
|
|||||||
"add_echo",
|
"add_echo",
|
||||||
"build_augmentations",
|
"build_augmentations",
|
||||||
"build_clip_labeler",
|
"build_clip_labeler",
|
||||||
"build_clipper",
|
|
||||||
"build_loss",
|
"build_loss",
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
"build_train_loader",
|
"build_train_loader",
|
||||||
@ -74,7 +72,6 @@ __all__ = [
|
|||||||
"mask_time",
|
"mask_time",
|
||||||
"mix_audio",
|
"mix_audio",
|
||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
|
||||||
"train",
|
"train",
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -11,9 +11,9 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from soundevent.geometry import scale_geometry, shift_geometry
|
from soundevent.geometry import scale_geometry, shift_geometry
|
||||||
|
|
||||||
|
from batdetect2.audio.clips import get_subclip_annotation
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.train.clips import get_subclip_annotation
|
|
||||||
from batdetect2.typing import AudioLoader, Augmentation
|
from batdetect2.typing import AudioLoader, Augmentation
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -5,13 +5,13 @@ from lightning.pytorch.callbacks import Callback
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate import Evaluator
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.postprocess import get_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import get_image_plotter
|
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipEvaluation,
|
ClipEvaluation,
|
||||||
|
EvaluatorProtocol,
|
||||||
ModelOutput,
|
ModelOutput,
|
||||||
RawPrediction,
|
RawPrediction,
|
||||||
TrainExample,
|
TrainExample,
|
||||||
@ -19,7 +19,7 @@ from batdetect2.typing import (
|
|||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self, evaluator: Evaluator):
|
def __init__(self, evaluator: EvaluatorProtocol):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
@ -34,12 +34,12 @@ class ValidationMetrics(Callback):
|
|||||||
assert isinstance(dataset, ValidationDataset)
|
assert isinstance(dataset, ValidationDataset)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def plot_examples(
|
def generate_plots(
|
||||||
self,
|
self,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
evaluated_clips: List[ClipEvaluation],
|
evaluated_clips: List[ClipEvaluation],
|
||||||
):
|
):
|
||||||
plotter = get_image_plotter(pl_module.logger) # type: ignore
|
plotter = get_image_logger(pl_module.logger) # type: ignore
|
||||||
|
|
||||||
if plotter is None:
|
if plotter is None:
|
||||||
return
|
return
|
||||||
@ -66,7 +66,7 @@ class ValidationMetrics(Callback):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.log_metrics(pl_module, clip_evaluations)
|
self.log_metrics(pl_module, clip_evaluations)
|
||||||
self.plot_examples(pl_module, clip_evaluations)
|
self.generate_plots(pl_module, clip_evaluations)
|
||||||
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
@ -88,8 +88,7 @@ class ValidationMetrics(Callback):
|
|||||||
batch_idx: int,
|
batch_idx: int,
|
||||||
dataloader_idx: int = 0,
|
dataloader_idx: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
postprocessor = pl_module.model.postprocessor
|
model = pl_module.model
|
||||||
targets = pl_module.model.targets
|
|
||||||
dataset = self.get_dataset(trainer)
|
dataset = self.get_dataset(trainer)
|
||||||
|
|
||||||
clip_annotations = [
|
clip_annotations = [
|
||||||
@ -97,15 +96,14 @@ class ValidationMetrics(Callback):
|
|||||||
for example_idx in batch.idx
|
for example_idx in batch.idx
|
||||||
]
|
]
|
||||||
|
|
||||||
predictions = get_raw_predictions(
|
clip_detections = model.postprocessor(
|
||||||
outputs,
|
outputs,
|
||||||
start_times=[
|
start_times=[ca.clip.start_time for ca in clip_annotations],
|
||||||
clip_annotation.clip.start_time
|
|
||||||
for clip_annotation in clip_annotations
|
|
||||||
],
|
|
||||||
targets=targets,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
)
|
)
|
||||||
|
predictions = [
|
||||||
|
to_raw_predictions(clip_dets.numpy(), targets=model.targets)
|
||||||
|
for clip_dets in clip_detections
|
||||||
|
]
|
||||||
|
|
||||||
self._clip_annotations.extend(clip_annotations)
|
self._clip_annotations.extend(clip_annotations)
|
||||||
self._predictions.extend(predictions)
|
self._predictions.extend(predictions)
|
||||||
|
|||||||
@ -5,17 +5,9 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||||
AugmentationsConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.train.clips import (
|
|
||||||
ClipConfig,
|
|
||||||
PaddedClipConfig,
|
|
||||||
RandomClipConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.train.labels import LabelConfig
|
from batdetect2.train.labels import LabelConfig
|
||||||
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -45,30 +37,6 @@ class PLTrainerConfig(BaseConfig):
|
|||||||
val_check_interval: Optional[Union[int, float]] = None
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
class ValLoaderConfig(BaseConfig):
|
|
||||||
num_workers: int = 0
|
|
||||||
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
|
||||||
default_factory=lambda: PaddedClipConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainLoaderConfig(BaseConfig):
|
|
||||||
num_workers: int = 0
|
|
||||||
|
|
||||||
batch_size: int = 8
|
|
||||||
|
|
||||||
shuffle: bool = False
|
|
||||||
|
|
||||||
augmentations: AugmentationsConfig = Field(
|
|
||||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
|
||||||
default_factory=lambda: PaddedClipConfig()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(BaseConfig):
|
class OptimizerConfig(BaseConfig):
|
||||||
learning_rate: float = 1e-3
|
learning_rate: float = 1e-3
|
||||||
t_max: int = 100
|
t_max: int = 100
|
||||||
@ -79,9 +47,8 @@ class TrainingConfig(BaseConfig):
|
|||||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
|
|
||||||
|
|||||||
@ -2,18 +2,21 @@ from typing import List, Optional, Sequence
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader, Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import ClipConfig, build_audio_loader, build_clipper
|
||||||
|
from batdetect2.audio.clips import PaddedClipConfig
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.core.arrays import adjust_width
|
from batdetect2.core.arrays import adjust_width
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
|
AugmentationsConfig,
|
||||||
RandomAudioSource,
|
RandomAudioSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper
|
|
||||||
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
@ -144,6 +147,22 @@ class ValidationDataset(Dataset):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
|
|
||||||
|
batch_size: int = 8
|
||||||
|
|
||||||
|
shuffle: bool = False
|
||||||
|
|
||||||
|
augmentations: AugmentationsConfig = Field(
|
||||||
|
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG.model_copy()
|
||||||
|
)
|
||||||
|
|
||||||
|
clipping_strategy: ClipConfig = Field(
|
||||||
|
default_factory=lambda: PaddedClipConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_train_loader(
|
def build_train_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
@ -178,6 +197,14 @@ def build_train_loader(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ValLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
|
|
||||||
|
clipping_strategy: ClipConfig = Field(
|
||||||
|
default_factory=lambda: PaddedClipConfig()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_val_loader(
|
def build_val_loader(
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
|||||||
@ -8,20 +8,21 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
from batdetect2.train.logging import build_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
ClipLabeller,
|
ClipLabeller,
|
||||||
|
EvaluatorProtocol,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
@ -122,7 +123,7 @@ def train(
|
|||||||
|
|
||||||
def build_trainer_callbacks(
|
def build_trainer_callbacks(
|
||||||
targets: "TargetProtocol",
|
targets: "TargetProtocol",
|
||||||
evaluator: Optional[Evaluator] = None,
|
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
@ -142,7 +143,6 @@ def build_trainer_callbacks(
|
|||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
filename="best-{epoch:02d}-{val_loss:.0f}",
|
|
||||||
monitor="total_loss/val",
|
monitor="total_loss/val",
|
||||||
),
|
),
|
||||||
ValidationMetrics(evaluator),
|
ValidationMetrics(evaluator),
|
||||||
@ -152,7 +152,7 @@ def build_trainer_callbacks(
|
|||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: "BatDetect2Config",
|
conf: "BatDetect2Config",
|
||||||
targets: "TargetProtocol",
|
targets: "TargetProtocol",
|
||||||
evaluator: Optional[Evaluator] = None,
|
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
|
|||||||
@ -1,7 +1,9 @@
|
|||||||
from batdetect2.typing.evaluate import (
|
from batdetect2.typing.evaluate import (
|
||||||
ClipEvaluation,
|
ClipEvaluation,
|
||||||
|
EvaluatorProtocol,
|
||||||
MatchEvaluation,
|
MatchEvaluation,
|
||||||
MetricsProtocol,
|
MetricsProtocol,
|
||||||
|
PlotterProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
from batdetect2.typing.models import BackboneModel, DetectionModel, ModelOutput
|
||||||
from batdetect2.typing.postprocess import (
|
from batdetect2.typing.postprocess import (
|
||||||
@ -49,6 +51,7 @@ __all__ = [
|
|||||||
"MatchEvaluation",
|
"MatchEvaluation",
|
||||||
"MetricsProtocol",
|
"MetricsProtocol",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
|
"PlotterProtocol",
|
||||||
"Position",
|
"Position",
|
||||||
"PostprocessorProtocol",
|
"PostprocessorProtocol",
|
||||||
"PreprocessorProtocol",
|
"PreprocessorProtocol",
|
||||||
@ -60,4 +63,5 @@ __all__ = [
|
|||||||
"SoundEventFilter",
|
"SoundEventFilter",
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
|
"EvaluatorProtocol",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -14,7 +14,11 @@ from typing import (
|
|||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.typing.postprocess import RawPrediction
|
||||||
|
from batdetect2.typing.targets import TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"EvaluatorProtocol",
|
||||||
"MetricsProtocol",
|
"MetricsProtocol",
|
||||||
"MatchEvaluation",
|
"MatchEvaluation",
|
||||||
]
|
]
|
||||||
@ -107,3 +111,21 @@ class PlotterProtocol(Protocol):
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
) -> Iterable[Tuple[str, Figure]]: ...
|
) -> Iterable[Tuple[str, Figure]]: ...
|
||||||
|
|
||||||
|
|
||||||
|
class EvaluatorProtocol(Protocol):
|
||||||
|
targets: TargetProtocol
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
self,
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
predictions: Sequence[Sequence[RawPrediction]],
|
||||||
|
) -> List[ClipEvaluation]: ...
|
||||||
|
|
||||||
|
def compute_metrics(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Dict[str, float]: ...
|
||||||
|
|
||||||
|
def generate_plots(
|
||||||
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
|
) -> Iterable[Tuple[str, Figure]]: ...
|
||||||
|
|||||||
@ -12,7 +12,7 @@ system that deal with model predictions.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, NamedTuple, Optional, Protocol
|
from typing import List, NamedTuple, Optional, Protocol, Sequence
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -47,15 +47,13 @@ class GeometryDecoder(Protocol):
|
|||||||
|
|
||||||
|
|
||||||
class RawPrediction(NamedTuple):
|
class RawPrediction(NamedTuple):
|
||||||
"""Intermediate representation of a single detected sound event."""
|
|
||||||
|
|
||||||
geometry: data.Geometry
|
geometry: data.Geometry
|
||||||
detection_score: float
|
detection_score: float
|
||||||
class_scores: np.ndarray
|
class_scores: np.ndarray
|
||||||
features: np.ndarray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
class DetectionsArray(NamedTuple):
|
class ClipDetectionsArray(NamedTuple):
|
||||||
scores: np.ndarray
|
scores: np.ndarray
|
||||||
sizes: np.ndarray
|
sizes: np.ndarray
|
||||||
class_scores: np.ndarray
|
class_scores: np.ndarray
|
||||||
@ -64,7 +62,7 @@ class DetectionsArray(NamedTuple):
|
|||||||
features: np.ndarray
|
features: np.ndarray
|
||||||
|
|
||||||
|
|
||||||
class DetectionsTensor(NamedTuple):
|
class ClipDetectionsTensor(NamedTuple):
|
||||||
scores: torch.Tensor
|
scores: torch.Tensor
|
||||||
sizes: torch.Tensor
|
sizes: torch.Tensor
|
||||||
class_scores: torch.Tensor
|
class_scores: torch.Tensor
|
||||||
@ -72,8 +70,8 @@ class DetectionsTensor(NamedTuple):
|
|||||||
frequencies: torch.Tensor
|
frequencies: torch.Tensor
|
||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
|
|
||||||
def numpy(self) -> DetectionsArray:
|
def numpy(self) -> ClipDetectionsArray:
|
||||||
return DetectionsArray(
|
return ClipDetectionsArray(
|
||||||
scores=self.scores.detach().cpu().numpy(),
|
scores=self.scores.detach().cpu().numpy(),
|
||||||
sizes=self.sizes.detach().cpu().numpy(),
|
sizes=self.sizes.detach().cpu().numpy(),
|
||||||
class_scores=self.class_scores.detach().cpu().numpy(),
|
class_scores=self.class_scores.detach().cpu().numpy(),
|
||||||
@ -92,10 +90,8 @@ class BatDetect2Prediction:
|
|||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
"""Protocol defining the interface for the full postprocessing pipeline."""
|
"""Protocol defining the interface for the full postprocessing pipeline."""
|
||||||
|
|
||||||
def __call__(self, output: ModelOutput) -> List[DetectionsTensor]: ...
|
def __call__(
|
||||||
|
|
||||||
def get_detections(
|
|
||||||
self,
|
self,
|
||||||
output: ModelOutput,
|
output: ModelOutput,
|
||||||
start_times: Optional[List[float]] = None,
|
start_times: Optional[Sequence[float]] = None,
|
||||||
) -> List[DetectionsTensor]: ...
|
) -> List[ClipDetectionsTensor]: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user