Update example config

This commit is contained in:
mbsantiago 2025-09-28 16:22:21 +01:00
parent c9f0c5c431
commit 30159d64a9
4 changed files with 85 additions and 51 deletions

View File

@ -138,36 +138,49 @@ train:
name: csv
validation:
metrics:
- name: detection_ap
- name: classification_ap
plots:
- name: example_gallery
- name: example_clip
- name: detection_pr_curve
- name: classification_pr_curves
- name: detection_roc_curve
- name: classification_roc_curves
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: sound_event_classification
metrics:
- name: average_precision
evaluation:
match_strategy:
name: start_time_match
distance_threshold: 0.01
metrics:
- name: classification_ap
- name: detection_ap
- name: detection_roc_auc
- name: classification_roc_auc
- name: top_class_ap
- name: classification_balanced_accuracy
- name: clip_multiclass_ap
- name: clip_multiclass_roc_auc
- name: clip_detection_ap
- name: clip_detection_roc_auc
plots:
- name: example_gallery
- name: example_clip
- name: detection_pr_curve
- name: classification_pr_curves
- name: detection_roc_curve
- name: classification_roc_curves
tasks:
- name: sound_event_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: score_distribution
- name: example_detection
- name: sound_event_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: top_class_detection
metrics:
- name: average_precision
plots:
- name: pr_curve
- name: confusion_matrix
- name: example_classification
- name: clip_detection
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve
- name: score_distribution
- name: clip_classification
metrics:
- name: average_precision
- name: roc_auc
plots:
- name: pr_curve
- name: roc_curve

View File

@ -4,8 +4,7 @@ from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.tables import FullEvaluationTable
from batdetect2.logging import get_image_logger, get_table_logger
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipMatches, EvaluatorProtocol
@ -54,16 +53,6 @@ class EvaluationModule(LightningModule):
def on_test_epoch_end(self):
self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations)
self.log_table(self.clip_evaluations)
def log_table(self, evaluated_clips: Sequence[ClipMatches]):
table_logger = get_table_logger(self.logger) # type: ignore
if table_logger is None:
return
df = FullEvaluationTable()(evaluated_clips)
table_logger("full_evaluation", df, 0)
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
plotter = get_image_logger(self.logger) # type: ignore

View File

@ -105,6 +105,14 @@ class MixAudio(torch.nn.Module):
samplerate: int,
source: Optional[AudioSource],
):
if source is None:
warnings.warn(
"Mix audio augmentation ('mix_audio') requires an "
"'example_source' callable to be provided.",
stacklevel=2,
)
return lambda wav, clip_annotation: (wav, clip_annotation)
return MixAudio(
example_source=source,
min_weight=config.min_weight,
@ -197,7 +205,9 @@ class AddEcho(torch.nn.Module):
@audio_augmentations.register(AddEchoConfig)
@staticmethod
def from_config(
config: AddEchoConfig, samplerate: int, source: AudioSource
config: AddEchoConfig,
samplerate: int,
source: Optional[AudioSource],
):
return AddEcho(
samplerate=samplerate,
@ -662,6 +672,30 @@ def build_audio_augmentations(
return AugmentationSequence(augmentations)
def build_spectrogram_augmentations(
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
) -> Optional[Augmentation]:
if not steps:
return None
augmentations = []
for step_config in steps:
augmentation = spec_augmentations.build(step_config)
if augmentation is None:
continue
augmentations.append(
MaybeApply(
augmentation=augmentation,
probability=step_config.probability,
)
)
return AugmentationSequence(augmentations)
def build_augmentations(
samplerate: int,
config: Optional[AugmentationsConfig] = None,
@ -675,16 +709,14 @@ def build_augmentations(
lambda: config.to_yaml_string(),
)
audio_augmentation = build_augmentation_sequence(
samplerate,
audio_augmentation = build_audio_augmentations(
steps=config.audio,
samplerate=samplerate,
audio_source=audio_source,
)
spectrogram_augmentation = build_augmentation_sequence(
samplerate,
steps=config.audio,
audio_source=audio_source,
spectrogram_augmentation = build_spectrogram_augmentations(
steps=config.spectrogram,
)
return audio_augmentation, spectrogram_augmentation

View File

@ -146,7 +146,7 @@ def build_trainer_callbacks(
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="total_loss/val",
monitor="classification/mean_average_precision",
),
ValidationMetrics(evaluator),
]