mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Update example config
This commit is contained in:
parent
c9f0c5c431
commit
30159d64a9
@ -138,36 +138,49 @@ train:
|
||||
name: csv
|
||||
|
||||
validation:
|
||||
metrics:
|
||||
- name: detection_ap
|
||||
- name: classification_ap
|
||||
plots:
|
||||
- name: example_gallery
|
||||
- name: example_clip
|
||||
- name: detection_pr_curve
|
||||
- name: classification_pr_curves
|
||||
- name: detection_roc_curve
|
||||
- name: classification_roc_curves
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
|
||||
evaluation:
|
||||
match_strategy:
|
||||
name: start_time_match
|
||||
distance_threshold: 0.01
|
||||
metrics:
|
||||
- name: classification_ap
|
||||
- name: detection_ap
|
||||
- name: detection_roc_auc
|
||||
- name: classification_roc_auc
|
||||
- name: top_class_ap
|
||||
- name: classification_balanced_accuracy
|
||||
- name: clip_multiclass_ap
|
||||
- name: clip_multiclass_roc_auc
|
||||
- name: clip_detection_ap
|
||||
- name: clip_detection_roc_auc
|
||||
plots:
|
||||
- name: example_gallery
|
||||
- name: example_clip
|
||||
- name: detection_pr_curve
|
||||
- name: classification_pr_curves
|
||||
- name: detection_roc_curve
|
||||
- name: classification_roc_curves
|
||||
tasks:
|
||||
- name: sound_event_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: score_distribution
|
||||
- name: example_detection
|
||||
- name: sound_event_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: top_class_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: confusion_matrix
|
||||
- name: example_classification
|
||||
- name: clip_detection
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
- name: score_distribution
|
||||
- name: clip_classification
|
||||
metrics:
|
||||
- name: average_precision
|
||||
- name: roc_auc
|
||||
plots:
|
||||
- name: pr_curve
|
||||
- name: roc_curve
|
||||
|
||||
@ -4,8 +4,7 @@ from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.tables import FullEvaluationTable
|
||||
from batdetect2.logging import get_image_logger, get_table_logger
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import ClipMatches, EvaluatorProtocol
|
||||
@ -54,16 +53,6 @@ class EvaluationModule(LightningModule):
|
||||
def on_test_epoch_end(self):
|
||||
self.log_metrics(self.clip_evaluations)
|
||||
self.plot_examples(self.clip_evaluations)
|
||||
self.log_table(self.clip_evaluations)
|
||||
|
||||
def log_table(self, evaluated_clips: Sequence[ClipMatches]):
|
||||
table_logger = get_table_logger(self.logger) # type: ignore
|
||||
|
||||
if table_logger is None:
|
||||
return
|
||||
|
||||
df = FullEvaluationTable()(evaluated_clips)
|
||||
table_logger("full_evaluation", df, 0)
|
||||
|
||||
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
|
||||
plotter = get_image_logger(self.logger) # type: ignore
|
||||
|
||||
@ -105,6 +105,14 @@ class MixAudio(torch.nn.Module):
|
||||
samplerate: int,
|
||||
source: Optional[AudioSource],
|
||||
):
|
||||
if source is None:
|
||||
warnings.warn(
|
||||
"Mix audio augmentation ('mix_audio') requires an "
|
||||
"'example_source' callable to be provided.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return lambda wav, clip_annotation: (wav, clip_annotation)
|
||||
|
||||
return MixAudio(
|
||||
example_source=source,
|
||||
min_weight=config.min_weight,
|
||||
@ -197,7 +205,9 @@ class AddEcho(torch.nn.Module):
|
||||
@audio_augmentations.register(AddEchoConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: AddEchoConfig, samplerate: int, source: AudioSource
|
||||
config: AddEchoConfig,
|
||||
samplerate: int,
|
||||
source: Optional[AudioSource],
|
||||
):
|
||||
return AddEcho(
|
||||
samplerate=samplerate,
|
||||
@ -662,6 +672,30 @@ def build_audio_augmentations(
|
||||
return AugmentationSequence(augmentations)
|
||||
|
||||
|
||||
def build_spectrogram_augmentations(
|
||||
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
|
||||
) -> Optional[Augmentation]:
|
||||
if not steps:
|
||||
return None
|
||||
|
||||
augmentations = []
|
||||
|
||||
for step_config in steps:
|
||||
augmentation = spec_augmentations.build(step_config)
|
||||
|
||||
if augmentation is None:
|
||||
continue
|
||||
|
||||
augmentations.append(
|
||||
MaybeApply(
|
||||
augmentation=augmentation,
|
||||
probability=step_config.probability,
|
||||
)
|
||||
)
|
||||
|
||||
return AugmentationSequence(augmentations)
|
||||
|
||||
|
||||
def build_augmentations(
|
||||
samplerate: int,
|
||||
config: Optional[AugmentationsConfig] = None,
|
||||
@ -675,16 +709,14 @@ def build_augmentations(
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
|
||||
audio_augmentation = build_augmentation_sequence(
|
||||
samplerate,
|
||||
audio_augmentation = build_audio_augmentations(
|
||||
steps=config.audio,
|
||||
samplerate=samplerate,
|
||||
audio_source=audio_source,
|
||||
)
|
||||
|
||||
spectrogram_augmentation = build_augmentation_sequence(
|
||||
samplerate,
|
||||
steps=config.audio,
|
||||
audio_source=audio_source,
|
||||
spectrogram_augmentation = build_spectrogram_augmentations(
|
||||
steps=config.spectrogram,
|
||||
)
|
||||
|
||||
return audio_augmentation, spectrogram_augmentation
|
||||
|
||||
@ -146,7 +146,7 @@ def build_trainer_callbacks(
|
||||
ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=1,
|
||||
monitor="total_loss/val",
|
||||
monitor="classification/mean_average_precision",
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user