diff --git a/example_data/config.yaml b/example_data/config.yaml index 8cbf4e5..7d908af 100644 --- a/example_data/config.yaml +++ b/example_data/config.yaml @@ -138,36 +138,49 @@ train: name: csv validation: - metrics: - - name: detection_ap - - name: classification_ap - plots: - - name: example_gallery - - name: example_clip - - name: detection_pr_curve - - name: classification_pr_curves - - name: detection_roc_curve - - name: classification_roc_curves + tasks: + - name: sound_event_detection + metrics: + - name: average_precision + - name: sound_event_classification + metrics: + - name: average_precision evaluation: - match_strategy: - name: start_time_match - distance_threshold: 0.01 - metrics: - - name: classification_ap - - name: detection_ap - - name: detection_roc_auc - - name: classification_roc_auc - - name: top_class_ap - - name: classification_balanced_accuracy - - name: clip_multiclass_ap - - name: clip_multiclass_roc_auc - - name: clip_detection_ap - - name: clip_detection_roc_auc - plots: - - name: example_gallery - - name: example_clip - - name: detection_pr_curve - - name: classification_pr_curves - - name: detection_roc_curve - - name: classification_roc_curves + tasks: + - name: sound_event_detection + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: score_distribution + - name: example_detection + - name: sound_event_classification + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: top_class_detection + metrics: + - name: average_precision + plots: + - name: pr_curve + - name: confusion_matrix + - name: example_classification + - name: clip_detection + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: roc_curve + - name: score_distribution + - name: clip_classification + metrics: + - name: average_precision + - name: roc_auc + plots: + - name: pr_curve + - name: roc_curve diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index 621c869..ccca917 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -4,8 +4,7 @@ from lightning import LightningModule from torch.utils.data import DataLoader from batdetect2.evaluate.dataset import TestDataset, TestExample -from batdetect2.evaluate.tables import FullEvaluationTable -from batdetect2.logging import get_image_logger, get_table_logger +from batdetect2.logging import get_image_logger from batdetect2.models import Model from batdetect2.postprocess import to_raw_predictions from batdetect2.typing import ClipMatches, EvaluatorProtocol @@ -54,16 +53,6 @@ class EvaluationModule(LightningModule): def on_test_epoch_end(self): self.log_metrics(self.clip_evaluations) self.plot_examples(self.clip_evaluations) - self.log_table(self.clip_evaluations) - - def log_table(self, evaluated_clips: Sequence[ClipMatches]): - table_logger = get_table_logger(self.logger) # type: ignore - - if table_logger is None: - return - - df = FullEvaluationTable()(evaluated_clips) - table_logger("full_evaluation", df, 0) def plot_examples(self, evaluated_clips: Sequence[ClipMatches]): plotter = get_image_logger(self.logger) # type: ignore diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 7139899..36cba2c 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -105,6 +105,14 @@ class MixAudio(torch.nn.Module): samplerate: int, source: Optional[AudioSource], ): + if source is None: + warnings.warn( + "Mix audio augmentation ('mix_audio') requires an " + "'example_source' callable to be provided.", + stacklevel=2, + ) + return lambda wav, clip_annotation: (wav, clip_annotation) + return MixAudio( example_source=source, min_weight=config.min_weight, @@ -197,7 +205,9 @@ class AddEcho(torch.nn.Module): @audio_augmentations.register(AddEchoConfig) @staticmethod def from_config( - config: AddEchoConfig, samplerate: int, source: AudioSource + config: AddEchoConfig, + samplerate: int, + source: Optional[AudioSource], ): return AddEcho( samplerate=samplerate, @@ -662,6 +672,30 @@ def build_audio_augmentations( return AugmentationSequence(augmentations) +def build_spectrogram_augmentations( + steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None, +) -> Optional[Augmentation]: + if not steps: + return None + + augmentations = [] + + for step_config in steps: + augmentation = spec_augmentations.build(step_config) + + if augmentation is None: + continue + + augmentations.append( + MaybeApply( + augmentation=augmentation, + probability=step_config.probability, + ) + ) + + return AugmentationSequence(augmentations) + + def build_augmentations( samplerate: int, config: Optional[AugmentationsConfig] = None, @@ -675,16 +709,14 @@ def build_augmentations( lambda: config.to_yaml_string(), ) - audio_augmentation = build_augmentation_sequence( - samplerate, + audio_augmentation = build_audio_augmentations( steps=config.audio, + samplerate=samplerate, audio_source=audio_source, ) - spectrogram_augmentation = build_augmentation_sequence( - samplerate, - steps=config.audio, - audio_source=audio_source, + spectrogram_augmentation = build_spectrogram_augmentations( + steps=config.spectrogram, ) return audio_augmentation, spectrogram_augmentation diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index ff030fe..aed25cd 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -146,7 +146,7 @@ def build_trainer_callbacks( ModelCheckpoint( dirpath=str(checkpoint_dir), save_top_k=1, - monitor="total_loss/val", + monitor="classification/mean_average_precision", ), ValidationMetrics(evaluator), ]