Update example config

This commit is contained in:
mbsantiago 2025-09-28 16:22:21 +01:00
parent c9f0c5c431
commit 30159d64a9
4 changed files with 85 additions and 51 deletions

View File

@ -138,36 +138,49 @@ train:
name: csv name: csv
validation: validation:
metrics: tasks:
- name: detection_ap - name: sound_event_detection
- name: classification_ap metrics:
plots: - name: average_precision
- name: example_gallery - name: sound_event_classification
- name: example_clip metrics:
- name: detection_pr_curve - name: average_precision
- name: classification_pr_curves
- name: detection_roc_curve
- name: classification_roc_curves
evaluation: evaluation:
match_strategy: tasks:
name: start_time_match - name: sound_event_detection
distance_threshold: 0.01 metrics:
metrics: - name: average_precision
- name: classification_ap - name: roc_auc
- name: detection_ap plots:
- name: detection_roc_auc - name: pr_curve
- name: classification_roc_auc - name: score_distribution
- name: top_class_ap - name: example_detection
- name: classification_balanced_accuracy - name: sound_event_classification
- name: clip_multiclass_ap metrics:
- name: clip_multiclass_roc_auc - name: average_precision
- name: clip_detection_ap - name: roc_auc
- name: clip_detection_roc_auc plots:
plots: - name: pr_curve
- name: example_gallery - name: top_class_detection
- name: example_clip metrics:
- name: detection_pr_curve - name: average_precision
- name: classification_pr_curves plots:
- name: detection_roc_curve - name: pr_curve
- name: classification_roc_curves - name: confusion_matrix
- name: example_classification
- name: clip_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
- name: score_distribution
- name: clip_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve

View File

@ -4,8 +4,7 @@ from lightning import LightningModule
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.tables import FullEvaluationTable from batdetect2.logging import get_image_logger
from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipMatches, EvaluatorProtocol from batdetect2.typing import ClipMatches, EvaluatorProtocol
@ -54,16 +53,6 @@ class EvaluationModule(LightningModule):
def on_test_epoch_end(self): def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations) self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations) self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations)
def log_table(self, evaluated_clips: Sequence[ClipMatches]):
table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None:
return
df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]): def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
plotter = get_image_logger(self.logger) # type: ignore plotter = get_image_logger(self.logger) # type: ignore

View File

@ -105,6 +105,14 @@ class MixAudio(torch.nn.Module):
samplerate: int, samplerate: int,
source: Optional[AudioSource], source: Optional[AudioSource],
): ):
if source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.",
stacklevel=2,
)
return lambda wav, clip_annotation: (wav, clip_annotation)
return MixAudio( return MixAudio(
example_source=source, example_source=source,
min_weight=config.min_weight, min_weight=config.min_weight,
@ -197,7 +205,9 @@ class AddEcho(torch.nn.Module):
@audio_augmentations.register(AddEchoConfig) @audio_augmentations.register(AddEchoConfig)
@staticmethod @staticmethod
def from_config( def from_config(
config: AddEchoConfig, samplerate: int, source: AudioSource config: AddEchoConfig,
samplerate: int,
source: Optional[AudioSource],
): ):
return AddEcho( return AddEcho(
samplerate=samplerate, samplerate=samplerate,
@ -662,6 +672,30 @@ def build_audio_augmentations(
return AugmentationSequence(augmentations) return AugmentationSequence(augmentations)
def build_spectrogram_augmentations(
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
) -> Optional[Augmentation]:
if not steps:
return None
augmentations = []
for step_config in steps:
augmentation = spec_augmentations.build(step_config)
if augmentation is None:
continue
augmentations.append(
MaybeApply(
augmentation=augmentation,
probability=step_config.probability,
)
)
return AugmentationSequence(augmentations)
def build_augmentations( def build_augmentations(
samplerate: int, samplerate: int,
config: Optional[AugmentationsConfig] = None, config: Optional[AugmentationsConfig] = None,
@ -675,16 +709,14 @@ def build_augmentations(
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
audio_augmentation = build_augmentation_sequence( audio_augmentation = build_audio_augmentations(
samplerate,
steps=config.audio, steps=config.audio,
samplerate=samplerate,
audio_source=audio_source, audio_source=audio_source,
) )
spectrogram_augmentation = build_augmentation_sequence( spectrogram_augmentation = build_spectrogram_augmentations(
samplerate, steps=config.spectrogram,
steps=config.audio,
audio_source=audio_source,
) )
return audio_augmentation, spectrogram_augmentation return audio_augmentation, spectrogram_augmentation

View File

@ -146,7 +146,7 @@ def build_trainer_callbacks(
ModelCheckpoint( ModelCheckpoint(
dirpath=str(checkpoint_dir), dirpath=str(checkpoint_dir),
save_top_k=1, save_top_k=1,
monitor="total_loss/val", monitor="classification/mean_average_precision",
), ),
ValidationMetrics(evaluator), ValidationMetrics(evaluator),
] ]