mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
3 Commits
65d13a32b7
...
b997a122f1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b997a122f1 | ||
|
|
c41551b59c | ||
|
|
8aa2d0cd11 |
@ -1,6 +1,7 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||
|
||||
@ -21,4 +22,5 @@ BatDetect2 - Detection and Classification
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
logger.enable("batdetect2")
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
||||
@ -1,9 +1,15 @@
|
||||
from batdetect2.evaluate.config import (
|
||||
EvaluationConfig,
|
||||
load_evaluation_config,
|
||||
)
|
||||
from batdetect2.evaluate.match import (
|
||||
match_predictions_and_annotations,
|
||||
match_sound_events_and_raw_predictions,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"match_sound_events_and_raw_predictions",
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
"match_predictions_and_annotations",
|
||||
"match_sound_events_and_raw_predictions",
|
||||
]
|
||||
|
||||
25
src/batdetect2/evaluate/config.py
Normal file
25
src/batdetect2/evaluate/config.py
Normal file
@ -0,0 +1,25 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.match import DEFAULT_MATCH_CONFIG, MatchConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
"load_evaluation_config",
|
||||
]
|
||||
|
||||
|
||||
class EvaluationConfig(BaseConfig):
|
||||
match: MatchConfig = Field(
|
||||
default_factory=lambda: DEFAULT_MATCH_CONFIG.model_copy(),
|
||||
)
|
||||
|
||||
|
||||
def load_evaluation_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> EvaluationConfig:
|
||||
return load_config(path, schema=EvaluationConfig, field=field)
|
||||
@ -19,7 +19,7 @@ class ClassExamples:
|
||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||
|
||||
|
||||
def plot_examples(
|
||||
def plot_example_gallery(
|
||||
matches: List[MatchEvaluation],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
n_examples: int = 5,
|
||||
|
||||
@ -179,7 +179,7 @@ def plot_false_positive_match(
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
|
||||
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
@ -326,7 +326,7 @@ def plot_true_positive_match(
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
@ -407,7 +407,7 @@ def plot_cross_trigger_match(
|
||||
plt.text(
|
||||
start_time,
|
||||
high_freq,
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||
va="top",
|
||||
ha="right",
|
||||
color=color,
|
||||
|
||||
@ -1,14 +1,21 @@
|
||||
from typing import List
|
||||
from functools import partial
|
||||
from multiprocessing import Pool
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from lightning import LightningModule, Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
||||
from batdetect2.evaluate.match import (
|
||||
MatchConfig,
|
||||
match_sound_events_and_raw_predictions,
|
||||
)
|
||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||
from batdetect2.plotting.evaluation import plot_examples
|
||||
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
@ -16,16 +23,25 @@ from batdetect2.train.types import ModelOutput
|
||||
|
||||
|
||||
class ValidationMetrics(Callback):
|
||||
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
|
||||
def __init__(
|
||||
self,
|
||||
metrics: List[MetricsProtocol],
|
||||
plot: bool = True,
|
||||
match_config: Optional[MatchConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if len(metrics) == 0:
|
||||
raise ValueError("At least one metric needs to be provided")
|
||||
|
||||
self.matches: List[MatchEvaluation] = []
|
||||
self.match_config = match_config
|
||||
self.metrics = metrics
|
||||
self.plot = plot
|
||||
|
||||
self._matches: List[
|
||||
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
|
||||
] = []
|
||||
|
||||
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
||||
dataloaders = trainer.val_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
@ -33,25 +49,33 @@ class ValidationMetrics(Callback):
|
||||
assert isinstance(dataset, LabeledDataset)
|
||||
return dataset
|
||||
|
||||
def plot_examples(self, pl_module: LightningModule):
|
||||
def plot_examples(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
matches: List[MatchEvaluation],
|
||||
):
|
||||
if not isinstance(pl_module.logger, TensorBoardLogger):
|
||||
return
|
||||
|
||||
for class_name, fig in plot_examples(
|
||||
self.matches,
|
||||
for class_name, fig in plot_example_gallery(
|
||||
matches,
|
||||
preprocessor=pl_module.preprocessor,
|
||||
n_examples=5,
|
||||
):
|
||||
pl_module.logger.experiment.add_figure(
|
||||
f"{class_name}/examples",
|
||||
f"images/{class_name}_examples",
|
||||
fig,
|
||||
pl_module.global_step,
|
||||
)
|
||||
|
||||
def log_metrics(self, pl_module: LightningModule):
|
||||
def log_metrics(
|
||||
self,
|
||||
pl_module: LightningModule,
|
||||
matches: List[MatchEvaluation],
|
||||
):
|
||||
metrics = {}
|
||||
for metric in self.metrics:
|
||||
metrics.update(metric(self.matches).items())
|
||||
metrics.update(metric(matches).items())
|
||||
|
||||
pl_module.log_dict(metrics)
|
||||
|
||||
@ -60,10 +84,16 @@ class ValidationMetrics(Callback):
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
self.log_metrics(pl_module)
|
||||
matches = _match_all_collected_examples(
|
||||
self._matches,
|
||||
pl_module.targets,
|
||||
config=self.match_config,
|
||||
)
|
||||
|
||||
self.log_metrics(pl_module, matches)
|
||||
|
||||
if self.plot:
|
||||
self.plot_examples(pl_module)
|
||||
self.plot_examples(pl_module, matches)
|
||||
|
||||
return super().on_validation_epoch_end(trainer, pl_module)
|
||||
|
||||
@ -72,7 +102,7 @@ class ValidationMetrics(Callback):
|
||||
trainer: Trainer,
|
||||
pl_module: LightningModule,
|
||||
) -> None:
|
||||
self.matches = []
|
||||
self._matches = []
|
||||
return super().on_validation_epoch_start(trainer, pl_module)
|
||||
|
||||
def on_validation_batch_end( # type: ignore
|
||||
@ -110,13 +140,26 @@ class ValidationMetrics(Callback):
|
||||
for clip_annotation, clip_predictions in zip(
|
||||
clip_annotations, raw_predictions
|
||||
):
|
||||
self.matches.extend(
|
||||
match_sound_events_and_raw_predictions(
|
||||
clip_annotation=clip_annotation,
|
||||
raw_predictions=clip_predictions,
|
||||
targets=pl_module.targets,
|
||||
)
|
||||
)
|
||||
self._matches.append((clip_annotation, clip_predictions))
|
||||
|
||||
|
||||
def _match_all_collected_examples(
|
||||
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
|
||||
targets: TargetProtocol,
|
||||
config: Optional[MatchConfig] = None,
|
||||
) -> List[MatchEvaluation]:
|
||||
logger.info("Matching all annotations and predictions")
|
||||
|
||||
with Pool() as p:
|
||||
matches = p.starmap(
|
||||
partial(
|
||||
match_sound_events_and_raw_predictions,
|
||||
targets=targets,
|
||||
config=config,
|
||||
),
|
||||
pre_matches,
|
||||
)
|
||||
return [match for clip_matches in matches for match in clip_matches]
|
||||
|
||||
|
||||
def _is_in_subclip(
|
||||
|
||||
@ -4,6 +4,7 @@ from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.models import BackboneConfig
|
||||
from batdetect2.postprocess import PostprocessConfig
|
||||
from batdetect2.preprocess import PreprocessingConfig
|
||||
@ -94,6 +95,7 @@ class FullTrainingConfig(BaseConfig):
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
|
||||
|
||||
def load_full_training_config(
|
||||
|
||||
@ -7,6 +7,7 @@ from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
@ -82,7 +83,9 @@ def train(
|
||||
logger.info("Training complete.")
|
||||
|
||||
|
||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||
def build_trainer_callbacks(
|
||||
targets: TargetProtocol, config: EvaluationConfig
|
||||
) -> List[Callback]:
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath="outputs/checkpoints",
|
||||
@ -96,7 +99,8 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||
class_names=targets.class_names
|
||||
),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
],
|
||||
match_config=config.match,
|
||||
),
|
||||
]
|
||||
|
||||
@ -113,7 +117,7 @@ def build_trainer(
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
logger=build_logger(conf.train.logger),
|
||||
callbacks=build_trainer_callbacks(targets),
|
||||
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user