mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Update example config
This commit is contained in:
parent
c9f0c5c431
commit
30159d64a9
@ -138,36 +138,49 @@ train:
|
|||||||
name: csv
|
name: csv
|
||||||
|
|
||||||
validation:
|
validation:
|
||||||
metrics:
|
tasks:
|
||||||
- name: detection_ap
|
- name: sound_event_detection
|
||||||
- name: classification_ap
|
metrics:
|
||||||
plots:
|
- name: average_precision
|
||||||
- name: example_gallery
|
- name: sound_event_classification
|
||||||
- name: example_clip
|
metrics:
|
||||||
- name: detection_pr_curve
|
- name: average_precision
|
||||||
- name: classification_pr_curves
|
|
||||||
- name: detection_roc_curve
|
|
||||||
- name: classification_roc_curves
|
|
||||||
|
|
||||||
evaluation:
|
evaluation:
|
||||||
match_strategy:
|
tasks:
|
||||||
name: start_time_match
|
- name: sound_event_detection
|
||||||
distance_threshold: 0.01
|
metrics:
|
||||||
metrics:
|
- name: average_precision
|
||||||
- name: classification_ap
|
- name: roc_auc
|
||||||
- name: detection_ap
|
plots:
|
||||||
- name: detection_roc_auc
|
- name: pr_curve
|
||||||
- name: classification_roc_auc
|
- name: score_distribution
|
||||||
- name: top_class_ap
|
- name: example_detection
|
||||||
- name: classification_balanced_accuracy
|
- name: sound_event_classification
|
||||||
- name: clip_multiclass_ap
|
metrics:
|
||||||
- name: clip_multiclass_roc_auc
|
- name: average_precision
|
||||||
- name: clip_detection_ap
|
- name: roc_auc
|
||||||
- name: clip_detection_roc_auc
|
plots:
|
||||||
plots:
|
- name: pr_curve
|
||||||
- name: example_gallery
|
- name: top_class_detection
|
||||||
- name: example_clip
|
metrics:
|
||||||
- name: detection_pr_curve
|
- name: average_precision
|
||||||
- name: classification_pr_curves
|
plots:
|
||||||
- name: detection_roc_curve
|
- name: pr_curve
|
||||||
- name: classification_roc_curves
|
- name: confusion_matrix
|
||||||
|
- name: example_classification
|
||||||
|
- name: clip_detection
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
|
- name: score_distribution
|
||||||
|
- name: clip_classification
|
||||||
|
metrics:
|
||||||
|
- name: average_precision
|
||||||
|
- name: roc_auc
|
||||||
|
plots:
|
||||||
|
- name: pr_curve
|
||||||
|
- name: roc_curve
|
||||||
|
|||||||
@ -4,8 +4,7 @@ from lightning import LightningModule
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
from batdetect2.evaluate.tables import FullEvaluationTable
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.logging import get_image_logger, get_table_logger
|
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import ClipMatches, EvaluatorProtocol
|
from batdetect2.typing import ClipMatches, EvaluatorProtocol
|
||||||
@ -54,16 +53,6 @@ class EvaluationModule(LightningModule):
|
|||||||
def on_test_epoch_end(self):
|
def on_test_epoch_end(self):
|
||||||
self.log_metrics(self.clip_evaluations)
|
self.log_metrics(self.clip_evaluations)
|
||||||
self.plot_examples(self.clip_evaluations)
|
self.plot_examples(self.clip_evaluations)
|
||||||
self.log_table(self.clip_evaluations)
|
|
||||||
|
|
||||||
def log_table(self, evaluated_clips: Sequence[ClipMatches]):
|
|
||||||
table_logger = get_table_logger(self.logger) # type: ignore
|
|
||||||
|
|
||||||
if table_logger is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
df = FullEvaluationTable()(evaluated_clips)
|
|
||||||
table_logger("full_evaluation", df, 0)
|
|
||||||
|
|
||||||
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
|
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
|
||||||
plotter = get_image_logger(self.logger) # type: ignore
|
plotter = get_image_logger(self.logger) # type: ignore
|
||||||
|
|||||||
@ -105,6 +105,14 @@ class MixAudio(torch.nn.Module):
|
|||||||
samplerate: int,
|
samplerate: int,
|
||||||
source: Optional[AudioSource],
|
source: Optional[AudioSource],
|
||||||
):
|
):
|
||||||
|
if source is None:
|
||||||
|
warnings.warn(
|
||||||
|
"Mix audio augmentation ('mix_audio') requires an "
|
||||||
|
"'example_source' callable to be provided.",
|
||||||
|
stacklevel=2,
|
||||||
|
)
|
||||||
|
return lambda wav, clip_annotation: (wav, clip_annotation)
|
||||||
|
|
||||||
return MixAudio(
|
return MixAudio(
|
||||||
example_source=source,
|
example_source=source,
|
||||||
min_weight=config.min_weight,
|
min_weight=config.min_weight,
|
||||||
@ -197,7 +205,9 @@ class AddEcho(torch.nn.Module):
|
|||||||
@audio_augmentations.register(AddEchoConfig)
|
@audio_augmentations.register(AddEchoConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
config: AddEchoConfig, samplerate: int, source: AudioSource
|
config: AddEchoConfig,
|
||||||
|
samplerate: int,
|
||||||
|
source: Optional[AudioSource],
|
||||||
):
|
):
|
||||||
return AddEcho(
|
return AddEcho(
|
||||||
samplerate=samplerate,
|
samplerate=samplerate,
|
||||||
@ -662,6 +672,30 @@ def build_audio_augmentations(
|
|||||||
return AugmentationSequence(augmentations)
|
return AugmentationSequence(augmentations)
|
||||||
|
|
||||||
|
|
||||||
|
def build_spectrogram_augmentations(
|
||||||
|
steps: Optional[Sequence[SpectrogramAugmentationConfig]] = None,
|
||||||
|
) -> Optional[Augmentation]:
|
||||||
|
if not steps:
|
||||||
|
return None
|
||||||
|
|
||||||
|
augmentations = []
|
||||||
|
|
||||||
|
for step_config in steps:
|
||||||
|
augmentation = spec_augmentations.build(step_config)
|
||||||
|
|
||||||
|
if augmentation is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
augmentations.append(
|
||||||
|
MaybeApply(
|
||||||
|
augmentation=augmentation,
|
||||||
|
probability=step_config.probability,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return AugmentationSequence(augmentations)
|
||||||
|
|
||||||
|
|
||||||
def build_augmentations(
|
def build_augmentations(
|
||||||
samplerate: int,
|
samplerate: int,
|
||||||
config: Optional[AugmentationsConfig] = None,
|
config: Optional[AugmentationsConfig] = None,
|
||||||
@ -675,16 +709,14 @@ def build_augmentations(
|
|||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_augmentation = build_augmentation_sequence(
|
audio_augmentation = build_audio_augmentations(
|
||||||
samplerate,
|
|
||||||
steps=config.audio,
|
steps=config.audio,
|
||||||
|
samplerate=samplerate,
|
||||||
audio_source=audio_source,
|
audio_source=audio_source,
|
||||||
)
|
)
|
||||||
|
|
||||||
spectrogram_augmentation = build_augmentation_sequence(
|
spectrogram_augmentation = build_spectrogram_augmentations(
|
||||||
samplerate,
|
steps=config.spectrogram,
|
||||||
steps=config.audio,
|
|
||||||
audio_source=audio_source,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return audio_augmentation, spectrogram_augmentation
|
return audio_augmentation, spectrogram_augmentation
|
||||||
|
|||||||
@ -146,7 +146,7 @@ def build_trainer_callbacks(
|
|||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath=str(checkpoint_dir),
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
monitor="total_loss/val",
|
monitor="classification/mean_average_precision",
|
||||||
),
|
),
|
||||||
ValidationMetrics(evaluator),
|
ValidationMetrics(evaluator),
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user