Compare commits

...

2 Commits

Author SHA1 Message Date
mbsantiago
60e922d565 Use pascal voc map computation by default 2025-09-16 10:56:37 +01:00
mbsantiago
704b28292b Cleaning train module 2025-09-15 16:50:08 +01:00
6 changed files with 304 additions and 236 deletions

View File

@ -1,4 +1,14 @@
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
from collections.abc import Callable, Mapping
from typing import (
Annotated,
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Union,
)
import numpy as np
from pydantic import Field
@ -16,11 +26,61 @@ __all__ = ["DetectionAP", "ClassificationAP"]
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
class DetectionAPConfig(BaseConfig):
name: Literal["detection_ap"] = "detection_ap"
implementation: AveragePrecisionImplementation = "pascal_voc"
def pascal_voc_average_precision(y_true, y_score) -> float:
y_true = np.array(y_true)
y_score = np.array(y_score)
sort_ind = np.argsort(y_score)[::-1]
y_true_sorted = y_true[sort_ind]
num_positives = y_true.sum()
false_pos_c = np.cumsum(1 - y_true_sorted)
true_pos_c = np.cumsum(y_true_sorted)
recall = true_pos_c / num_positives
precision = true_pos_c / np.maximum(
true_pos_c + false_pos_c,
np.finfo(np.float64).eps,
)
precision[np.isnan(precision)] = 0
recall[np.isnan(recall)] = 0
# pascal 12 way
mprec = np.hstack((0, precision, 0))
mrec = np.hstack((0, recall, 1))
for ii in range(mprec.shape[0] - 2, -1, -1):
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
return ave_prec
_ap_impl_mapping: Mapping[
AveragePrecisionImplementation, Callable[[Any, Any], float]
] = {
"sklearn": metrics.average_precision_score,
"pascal_voc": pascal_voc_average_precision,
}
class DetectionAP(MetricsProtocol):
def __init__(
self,
implementation: AveragePrecisionImplementation = "pascal_voc",
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
def __call__(
self, clip_evaluations: Sequence[ClipEvaluation]
) -> Dict[str, float]:
@ -31,12 +91,12 @@ class DetectionAP(MetricsProtocol):
for match in clip_eval.matches
]
)
score = float(metrics.average_precision_score(y_true, y_score))
score = float(self.metric(y_true, y_score))
return {"detection_AP": score}
@classmethod
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
return cls()
return cls(implementation=config.implementation)
metrics_registry.register(DetectionAPConfig, DetectionAP)
@ -52,9 +112,12 @@ class ClassificationAP(MetricsProtocol):
def __init__(
self,
class_names: List[str],
implementation: AveragePrecisionImplementation = "pascal_voc",
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
):
self.implementation = implementation
self.metric = _ap_impl_mapping[self.implementation]
self.class_names = class_names
self.selected = class_names
@ -107,10 +170,7 @@ class ClassificationAP(MetricsProtocol):
for class_index, class_name in enumerate(self.class_names):
y_true_class = y_true[:, class_index]
y_pred_class = y_pred[:, class_index]
class_ap = metrics.average_precision_score(
y_true_class,
y_pred_class,
)
class_ap = self.metric(y_true_class, y_pred_class)
class_scores[class_name] = float(class_ap)
mean_ap = np.mean(

View File

@ -22,7 +22,14 @@ from batdetect2.train.config import (
load_full_training_config,
load_train_config,
)
from batdetect2.train.dataset import TrainingDataset
from batdetect2.train.dataset import (
TrainingDataset,
ValidationDataset,
build_train_dataset,
build_train_loader,
build_val_dataset,
build_val_loader,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.losses import (
@ -33,14 +40,7 @@ from batdetect2.train.losses import (
SizeLossConfig,
build_loss,
)
from batdetect2.train.train import (
build_train_dataset,
build_train_loader,
build_trainer,
build_val_dataset,
build_val_loader,
train,
)
from batdetect2.train.train import build_trainer, train
__all__ = [
"AugmentationsConfig",
@ -49,7 +49,6 @@ __all__ = [
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"TrainingDataset",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
@ -57,7 +56,9 @@ __all__ = [
"SizeLossConfig",
"TimeMaskAugmentationConfig",
"TrainingConfig",
"TrainingDataset",
"TrainingModule",
"ValidationDataset",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
"add_echo",

View File

@ -72,13 +72,16 @@ class TrainLoaderConfig(BaseConfig):
)
class TrainingConfig(BaseConfig):
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainingConfig(BaseConfig):
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
loss: LossConfig = Field(default_factory=LossConfig)
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)

View File

@ -1,18 +1,31 @@
from typing import Optional, Sequence, Tuple
from typing import List, Optional, Sequence
import torch
from loguru import logger
from soundevent import data
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, Dataset
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,
)
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import build_clip_labeler
from batdetect2.typing import ClipperProtocol, TrainExample
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
from batdetect2.typing.train import (
Augmentation,
ClipLabeller,
)
from batdetect2.typing.train import Augmentation, ClipLabeller
from batdetect2.utils.arrays import adjust_width
__all__ = [
"TrainingDataset",
"ValidationDataset",
"build_val_loader",
"build_train_loader",
"build_train_dataset",
"build_val_dataset",
]
@ -124,3 +137,174 @@ class ValidationDataset(Dataset):
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainLoaderConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
train_dataset = build_train_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
config = config or ValLoaderConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
val_dataset = build_val_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
random_example_source = RandomAudioSource(
clip_annotations,
audio_loader=audio_loader,
)
if config.augmentations.enabled:
audio_augmentation, spectrogram_augmentation = build_augmentations(
samplerate=preprocessor.input_samplerate,
config=config.augmentations,
audio_source=random_example_source,
)
else:
logger.debug("No augmentations configured for training dataset.")
audio_augmentation = None
spectrogram_augmentation = None
return TrainingDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
clipper=clipper,
preprocessor=preprocessor,
audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation,
)
def build_val_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or ValLoaderConfig()
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -77,3 +77,15 @@ def load_model_from_checkpoint(
) -> Tuple[Model, FullTrainingConfig]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config
def build_training_module(
config: Optional[FullTrainingConfig] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
return TrainingModule(
config=config,
learning_rate=config.train.optimizer.learning_rate,
t_max=t_max,
)

View File

@ -2,47 +2,31 @@ from collections.abc import Sequence
from pathlib import Path
from typing import List, Optional
import torch
from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,
)
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import (
FullTrainingConfig,
TrainLoaderConfig,
ValLoaderConfig,
)
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.lightning import TrainingModule, build_training_module
from batdetect2.train.logging import build_logger
from batdetect2.typing import (
PreprocessorProtocol,
TargetProtocol,
TrainExample,
)
from batdetect2.typing.preprocess import AudioLoader
from batdetect2.typing.train import ClipLabeller
from batdetect2.utils.arrays import adjust_width
__all__ = [
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_val_dataset",
"build_val_loader",
"train",
]
@ -52,6 +36,11 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
trainer: Optional[Trainer] = None,
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None,
@ -67,13 +56,15 @@ def train(
config = config or FullTrainingConfig()
targets = build_targets(config.targets)
targets = targets or build_targets(config.targets)
preprocessor = build_preprocessor(config.preprocess)
preprocessor = preprocessor or build_preprocessor(config.preprocess)
audio_loader = build_audio_loader(config=config.preprocess.audio)
audio_loader = audio_loader or build_audio_loader(
config=config.preprocess.audio
)
labeller = build_clip_labeler(
labeller = labeller or build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
@ -108,10 +99,10 @@ def train(
else:
module = build_training_module(
config,
t_max=config.train.t_max * len(train_dataloader),
t_max=config.train.optimizer.t_max * len(train_dataloader),
)
trainer = build_trainer(
trainer = trainer or build_trainer(
config,
targets=targets,
checkpoint_dir=checkpoint_dir,
@ -129,21 +120,9 @@ def train(
logger.info("Training complete.")
def build_training_module(
config: Optional[FullTrainingConfig] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
return TrainingModule(
config=config,
learning_rate=config.train.learning_rate,
t_max=t_max,
)
def build_trainer_callbacks(
targets: TargetProtocol,
config: EvaluationConfig,
config: FullTrainingConfig,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
@ -157,7 +136,7 @@ def build_trainer_callbacks(
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = build_evaluator(config=config, targets=targets)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
return [
ModelCheckpoint(
@ -202,180 +181,9 @@ def build_trainer(
logger=train_logger,
callbacks=build_trainer_callbacks(
targets,
config=conf.evaluation,
config=conf,
checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name,
run_name=run_name,
),
)
def build_train_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader:
config = config or TrainLoaderConfig()
logger.info("Building training data loader...")
logger.opt(lazy=True).debug(
"Training data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
train_dataset = build_train_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=config.shuffle,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_val_loader(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
config = config or ValLoaderConfig()
logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}",
config=lambda: config.to_yaml_string(exclude_none=True),
)
val_dataset = build_val_dataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=num_workers,
collate_fn=_collate_fn,
)
def build_train_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[TrainLoaderConfig] = None,
) -> TrainingDataset:
logger.info("Building training dataset...")
config = config or TrainLoaderConfig()
clipper = build_clipper(config=config.clipping_strategy)
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
random_example_source = RandomAudioSource(
clip_annotations,
audio_loader=audio_loader,
)
if config.augmentations.enabled:
audio_augmentation, spectrogram_augmentation = build_augmentations(
samplerate=preprocessor.input_samplerate,
config=config.augmentations,
audio_source=random_example_source,
)
else:
logger.debug("No augmentations configured for training dataset.")
audio_augmentation = None
spectrogram_augmentation = None
return TrainingDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
clipper=clipper,
preprocessor=preprocessor,
audio_augmentation=audio_augmentation,
spectrogram_augmentation=spectrogram_augmentation,
)
def build_val_dataset(
clip_annotations: Sequence[data.ClipAnnotation],
audio_loader: Optional[AudioLoader] = None,
labeller: Optional[ClipLabeller] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[ValLoaderConfig] = None,
) -> ValidationDataset:
logger.info("Building validation dataset...")
config = config or ValLoaderConfig()
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
if labeller is None:
labeller = build_clip_labeler(
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
clipper = build_clipper(config.clipping_strategy)
return ValidationDataset(
clip_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
clipper=clipper,
)
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
max_width = max(item.spec.shape[-1] for item in batch)
return TrainExample(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
detection_heatmap=torch.stack(
[adjust_width(item.detection_heatmap, max_width) for item in batch]
),
size_heatmap=torch.stack(
[adjust_width(item.size_heatmap, max_width) for item in batch]
),
class_heatmap=torch.stack(
[adjust_width(item.class_heatmap, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)