mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Use match config for training
This commit is contained in:
parent
8aa2d0cd11
commit
c41551b59c
@ -19,7 +19,7 @@ class ClassExamples:
|
|||||||
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
cross_triggers: List[MatchEvaluation] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
def plot_examples(
|
def plot_example_gallery(
|
||||||
matches: List[MatchEvaluation],
|
matches: List[MatchEvaluation],
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
n_examples: int = 5,
|
n_examples: int = 5,
|
||||||
|
|||||||
@ -179,7 +179,7 @@ def plot_false_positive_match(
|
|||||||
plt.text(
|
plt.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score} ",
|
f"False Positive \nScore: {match.pred_score} \nTop Class: {match.pred_class} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
@ -326,7 +326,7 @@ def plot_true_positive_match(
|
|||||||
plt.text(
|
plt.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
f"True Positive \nClass: {match.gt_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
@ -407,7 +407,7 @@ def plot_cross_trigger_match(
|
|||||||
plt.text(
|
plt.text(
|
||||||
start_time,
|
start_time,
|
||||||
high_freq,
|
high_freq,
|
||||||
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score} \nTop Class Score: {match.pred_class_score} ",
|
f"Cross Trigger \nTrue Class: {match.gt_class} \nPred Class: {match.pred_class} \nDet Score: {match.pred_score:.2f} \nTop Class Score: {match.pred_class_score:.2f} ",
|
||||||
va="top",
|
va="top",
|
||||||
ha="right",
|
ha="right",
|
||||||
color=color,
|
color=color,
|
||||||
|
|||||||
@ -1,14 +1,21 @@
|
|||||||
from typing import List
|
from functools import partial
|
||||||
|
from multiprocessing import Pool
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
from lightning import LightningModule, Trainer
|
from lightning import LightningModule, Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.match import match_sound_events_and_raw_predictions
|
from batdetect2.evaluate.match import (
|
||||||
|
MatchConfig,
|
||||||
|
match_sound_events_and_raw_predictions,
|
||||||
|
)
|
||||||
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
from batdetect2.evaluate.types import MatchEvaluation, MetricsProtocol
|
||||||
from batdetect2.plotting.evaluation import plot_examples
|
from batdetect2.plotting.evaluation import plot_example_gallery
|
||||||
|
from batdetect2.postprocess.types import BatDetect2Prediction
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
@ -16,16 +23,25 @@ from batdetect2.train.types import ModelOutput
|
|||||||
|
|
||||||
|
|
||||||
class ValidationMetrics(Callback):
|
class ValidationMetrics(Callback):
|
||||||
def __init__(self, metrics: List[MetricsProtocol], plot: bool = True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
metrics: List[MetricsProtocol],
|
||||||
|
plot: bool = True,
|
||||||
|
match_config: Optional[MatchConfig] = None,
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if len(metrics) == 0:
|
if len(metrics) == 0:
|
||||||
raise ValueError("At least one metric needs to be provided")
|
raise ValueError("At least one metric needs to be provided")
|
||||||
|
|
||||||
self.matches: List[MatchEvaluation] = []
|
self.match_config = match_config
|
||||||
self.metrics = metrics
|
self.metrics = metrics
|
||||||
self.plot = plot
|
self.plot = plot
|
||||||
|
|
||||||
|
self._matches: List[
|
||||||
|
Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]
|
||||||
|
] = []
|
||||||
|
|
||||||
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
def get_dataset(self, trainer: Trainer) -> LabeledDataset:
|
||||||
dataloaders = trainer.val_dataloaders
|
dataloaders = trainer.val_dataloaders
|
||||||
assert isinstance(dataloaders, DataLoader)
|
assert isinstance(dataloaders, DataLoader)
|
||||||
@ -33,25 +49,33 @@ class ValidationMetrics(Callback):
|
|||||||
assert isinstance(dataset, LabeledDataset)
|
assert isinstance(dataset, LabeledDataset)
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
def plot_examples(self, pl_module: LightningModule):
|
def plot_examples(
|
||||||
|
self,
|
||||||
|
pl_module: LightningModule,
|
||||||
|
matches: List[MatchEvaluation],
|
||||||
|
):
|
||||||
if not isinstance(pl_module.logger, TensorBoardLogger):
|
if not isinstance(pl_module.logger, TensorBoardLogger):
|
||||||
return
|
return
|
||||||
|
|
||||||
for class_name, fig in plot_examples(
|
for class_name, fig in plot_example_gallery(
|
||||||
self.matches,
|
matches,
|
||||||
preprocessor=pl_module.preprocessor,
|
preprocessor=pl_module.preprocessor,
|
||||||
n_examples=5,
|
n_examples=5,
|
||||||
):
|
):
|
||||||
pl_module.logger.experiment.add_figure(
|
pl_module.logger.experiment.add_figure(
|
||||||
f"{class_name}/examples",
|
f"images/{class_name}_examples",
|
||||||
fig,
|
fig,
|
||||||
pl_module.global_step,
|
pl_module.global_step,
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_metrics(self, pl_module: LightningModule):
|
def log_metrics(
|
||||||
|
self,
|
||||||
|
pl_module: LightningModule,
|
||||||
|
matches: List[MatchEvaluation],
|
||||||
|
):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
for metric in self.metrics:
|
for metric in self.metrics:
|
||||||
metrics.update(metric(self.matches).items())
|
metrics.update(metric(matches).items())
|
||||||
|
|
||||||
pl_module.log_dict(metrics)
|
pl_module.log_dict(metrics)
|
||||||
|
|
||||||
@ -60,10 +84,16 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.log_metrics(pl_module)
|
matches = _match_all_collected_examples(
|
||||||
|
self._matches,
|
||||||
|
pl_module.targets,
|
||||||
|
config=self.match_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.log_metrics(pl_module, matches)
|
||||||
|
|
||||||
if self.plot:
|
if self.plot:
|
||||||
self.plot_examples(pl_module)
|
self.plot_examples(pl_module, matches)
|
||||||
|
|
||||||
return super().on_validation_epoch_end(trainer, pl_module)
|
return super().on_validation_epoch_end(trainer, pl_module)
|
||||||
|
|
||||||
@ -72,7 +102,7 @@ class ValidationMetrics(Callback):
|
|||||||
trainer: Trainer,
|
trainer: Trainer,
|
||||||
pl_module: LightningModule,
|
pl_module: LightningModule,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.matches = []
|
self._matches = []
|
||||||
return super().on_validation_epoch_start(trainer, pl_module)
|
return super().on_validation_epoch_start(trainer, pl_module)
|
||||||
|
|
||||||
def on_validation_batch_end( # type: ignore
|
def on_validation_batch_end( # type: ignore
|
||||||
@ -110,13 +140,26 @@ class ValidationMetrics(Callback):
|
|||||||
for clip_annotation, clip_predictions in zip(
|
for clip_annotation, clip_predictions in zip(
|
||||||
clip_annotations, raw_predictions
|
clip_annotations, raw_predictions
|
||||||
):
|
):
|
||||||
self.matches.extend(
|
self._matches.append((clip_annotation, clip_predictions))
|
||||||
match_sound_events_and_raw_predictions(
|
|
||||||
clip_annotation=clip_annotation,
|
|
||||||
raw_predictions=clip_predictions,
|
def _match_all_collected_examples(
|
||||||
targets=pl_module.targets,
|
pre_matches: List[Tuple[data.ClipAnnotation, List[BatDetect2Prediction]]],
|
||||||
)
|
targets: TargetProtocol,
|
||||||
)
|
config: Optional[MatchConfig] = None,
|
||||||
|
) -> List[MatchEvaluation]:
|
||||||
|
logger.info("Matching all annotations and predictions")
|
||||||
|
|
||||||
|
with Pool() as p:
|
||||||
|
matches = p.starmap(
|
||||||
|
partial(
|
||||||
|
match_sound_events_and_raw_predictions,
|
||||||
|
targets=targets,
|
||||||
|
config=config,
|
||||||
|
),
|
||||||
|
pre_matches,
|
||||||
|
)
|
||||||
|
return [match for clip_matches in matches for match in clip_matches]
|
||||||
|
|
||||||
|
|
||||||
def _is_in_subclip(
|
def _is_in_subclip(
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from pydantic import Field
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.models import BackboneConfig
|
from batdetect2.models import BackboneConfig
|
||||||
from batdetect2.postprocess import PostprocessConfig
|
from batdetect2.postprocess import PostprocessConfig
|
||||||
from batdetect2.preprocess import PreprocessingConfig
|
from batdetect2.preprocess import PreprocessingConfig
|
||||||
@ -94,6 +95,7 @@ class FullTrainingConfig(BaseConfig):
|
|||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
evaluation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_full_training_config(
|
def load_full_training_config(
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from loguru import logger
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.evaluate.metrics import (
|
||||||
ClassificationAccuracy,
|
ClassificationAccuracy,
|
||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
@ -82,7 +83,9 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
def build_trainer_callbacks(
|
||||||
|
targets: TargetProtocol, config: EvaluationConfig
|
||||||
|
) -> List[Callback]:
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath="outputs/checkpoints",
|
dirpath="outputs/checkpoints",
|
||||||
@ -96,7 +99,8 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
|||||||
class_names=targets.class_names
|
class_names=targets.class_names
|
||||||
),
|
),
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
]
|
],
|
||||||
|
match_config=config.match,
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -113,7 +117,7 @@ def build_trainer(
|
|||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
logger=build_logger(conf.train.logger),
|
logger=build_logger(conf.train.logger),
|
||||||
callbacks=build_trainer_callbacks(targets),
|
callbacks=build_trainer_callbacks(targets, config=conf.evaluation),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user