mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Compare commits
2 Commits
e752e96b93
...
60e922d565
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
60e922d565 | ||
|
|
704b28292b |
@ -1,4 +1,14 @@
|
|||||||
from typing import Annotated, Dict, List, Literal, Optional, Sequence, Union
|
from collections.abc import Callable, Mapping
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
@ -16,11 +26,61 @@ __all__ = ["DetectionAP", "ClassificationAP"]
|
|||||||
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
metrics_registry: Registry[MetricsProtocol, [List[str]]] = Registry("metric")
|
||||||
|
|
||||||
|
|
||||||
|
AveragePrecisionImplementation = Literal["sklearn", "pascal_voc"]
|
||||||
|
|
||||||
|
|
||||||
class DetectionAPConfig(BaseConfig):
|
class DetectionAPConfig(BaseConfig):
|
||||||
name: Literal["detection_ap"] = "detection_ap"
|
name: Literal["detection_ap"] = "detection_ap"
|
||||||
|
implementation: AveragePrecisionImplementation = "pascal_voc"
|
||||||
|
|
||||||
|
|
||||||
|
def pascal_voc_average_precision(y_true, y_score) -> float:
|
||||||
|
y_true = np.array(y_true)
|
||||||
|
y_score = np.array(y_score)
|
||||||
|
|
||||||
|
sort_ind = np.argsort(y_score)[::-1]
|
||||||
|
y_true_sorted = y_true[sort_ind]
|
||||||
|
|
||||||
|
num_positives = y_true.sum()
|
||||||
|
false_pos_c = np.cumsum(1 - y_true_sorted)
|
||||||
|
true_pos_c = np.cumsum(y_true_sorted)
|
||||||
|
|
||||||
|
recall = true_pos_c / num_positives
|
||||||
|
precision = true_pos_c / np.maximum(
|
||||||
|
true_pos_c + false_pos_c,
|
||||||
|
np.finfo(np.float64).eps,
|
||||||
|
)
|
||||||
|
|
||||||
|
precision[np.isnan(precision)] = 0
|
||||||
|
recall[np.isnan(recall)] = 0
|
||||||
|
|
||||||
|
# pascal 12 way
|
||||||
|
mprec = np.hstack((0, precision, 0))
|
||||||
|
mrec = np.hstack((0, recall, 1))
|
||||||
|
for ii in range(mprec.shape[0] - 2, -1, -1):
|
||||||
|
mprec[ii] = np.maximum(mprec[ii], mprec[ii + 1])
|
||||||
|
inds = np.where(np.not_equal(mrec[1:], mrec[:-1]))[0] + 1
|
||||||
|
ave_prec = ((mrec[inds] - mrec[inds - 1]) * mprec[inds]).sum()
|
||||||
|
|
||||||
|
return ave_prec
|
||||||
|
|
||||||
|
|
||||||
|
_ap_impl_mapping: Mapping[
|
||||||
|
AveragePrecisionImplementation, Callable[[Any, Any], float]
|
||||||
|
] = {
|
||||||
|
"sklearn": metrics.average_precision_score,
|
||||||
|
"pascal_voc": pascal_voc_average_precision,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class DetectionAP(MetricsProtocol):
|
class DetectionAP(MetricsProtocol):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
implementation: AveragePrecisionImplementation = "pascal_voc",
|
||||||
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, clip_evaluations: Sequence[ClipEvaluation]
|
self, clip_evaluations: Sequence[ClipEvaluation]
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
@ -31,12 +91,12 @@ class DetectionAP(MetricsProtocol):
|
|||||||
for match in clip_eval.matches
|
for match in clip_eval.matches
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
score = float(metrics.average_precision_score(y_true, y_score))
|
score = float(self.metric(y_true, y_score))
|
||||||
return {"detection_AP": score}
|
return {"detection_AP": score}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
def from_config(cls, config: DetectionAPConfig, class_names: List[str]):
|
||||||
return cls()
|
return cls(implementation=config.implementation)
|
||||||
|
|
||||||
|
|
||||||
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
metrics_registry.register(DetectionAPConfig, DetectionAP)
|
||||||
@ -52,9 +112,12 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
class_names: List[str],
|
class_names: List[str],
|
||||||
|
implementation: AveragePrecisionImplementation = "pascal_voc",
|
||||||
include: Optional[List[str]] = None,
|
include: Optional[List[str]] = None,
|
||||||
exclude: Optional[List[str]] = None,
|
exclude: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
|
self.implementation = implementation
|
||||||
|
self.metric = _ap_impl_mapping[self.implementation]
|
||||||
self.class_names = class_names
|
self.class_names = class_names
|
||||||
|
|
||||||
self.selected = class_names
|
self.selected = class_names
|
||||||
@ -107,10 +170,7 @@ class ClassificationAP(MetricsProtocol):
|
|||||||
for class_index, class_name in enumerate(self.class_names):
|
for class_index, class_name in enumerate(self.class_names):
|
||||||
y_true_class = y_true[:, class_index]
|
y_true_class = y_true[:, class_index]
|
||||||
y_pred_class = y_pred[:, class_index]
|
y_pred_class = y_pred[:, class_index]
|
||||||
class_ap = metrics.average_precision_score(
|
class_ap = self.metric(y_true_class, y_pred_class)
|
||||||
y_true_class,
|
|
||||||
y_pred_class,
|
|
||||||
)
|
|
||||||
class_scores[class_name] = float(class_ap)
|
class_scores[class_name] = float(class_ap)
|
||||||
|
|
||||||
mean_ap = np.mean(
|
mean_ap = np.mean(
|
||||||
|
|||||||
@ -22,7 +22,14 @@ from batdetect2.train.config import (
|
|||||||
load_full_training_config,
|
load_full_training_config,
|
||||||
load_train_config,
|
load_train_config,
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import TrainingDataset
|
from batdetect2.train.dataset import (
|
||||||
|
TrainingDataset,
|
||||||
|
ValidationDataset,
|
||||||
|
build_train_dataset,
|
||||||
|
build_train_loader,
|
||||||
|
build_val_dataset,
|
||||||
|
build_val_loader,
|
||||||
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.losses import (
|
from batdetect2.train.losses import (
|
||||||
@ -33,14 +40,7 @@ from batdetect2.train.losses import (
|
|||||||
SizeLossConfig,
|
SizeLossConfig,
|
||||||
build_loss,
|
build_loss,
|
||||||
)
|
)
|
||||||
from batdetect2.train.train import (
|
from batdetect2.train.train import build_trainer, train
|
||||||
build_train_dataset,
|
|
||||||
build_train_loader,
|
|
||||||
build_trainer,
|
|
||||||
build_val_dataset,
|
|
||||||
build_val_loader,
|
|
||||||
train,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"AugmentationsConfig",
|
||||||
@ -49,7 +49,6 @@ __all__ = [
|
|||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
"FullTrainingConfig",
|
"FullTrainingConfig",
|
||||||
"TrainingDataset",
|
|
||||||
"LossConfig",
|
"LossConfig",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
"PLTrainerConfig",
|
"PLTrainerConfig",
|
||||||
@ -57,7 +56,9 @@ __all__ = [
|
|||||||
"SizeLossConfig",
|
"SizeLossConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
|
"TrainingDataset",
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
|
"ValidationDataset",
|
||||||
"VolumeAugmentationConfig",
|
"VolumeAugmentationConfig",
|
||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
"add_echo",
|
"add_echo",
|
||||||
|
|||||||
@ -72,13 +72,16 @@ class TrainLoaderConfig(BaseConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
class OptimizerConfig(BaseConfig):
|
||||||
learning_rate: float = 1e-3
|
learning_rate: float = 1e-3
|
||||||
t_max: int = 100
|
t_max: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingConfig(BaseConfig):
|
||||||
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
train_loader: TrainLoaderConfig = Field(default_factory=TrainLoaderConfig)
|
||||||
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
val_loader: ValLoaderConfig = Field(default_factory=ValLoaderConfig)
|
||||||
|
|
||||||
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
cliping: RandomClipConfig = Field(default_factory=RandomClipConfig)
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
|
|||||||
@ -1,18 +1,31 @@
|
|||||||
from typing import Optional, Sequence, Tuple
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from batdetect2.plotting.clips import build_audio_loader
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
RandomAudioSource,
|
||||||
|
build_augmentations,
|
||||||
|
)
|
||||||
|
from batdetect2.train.clips import build_clipper
|
||||||
|
from batdetect2.train.config import TrainLoaderConfig, ValLoaderConfig
|
||||||
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.typing import ClipperProtocol, TrainExample
|
from batdetect2.typing import ClipperProtocol, TrainExample
|
||||||
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
from batdetect2.typing.preprocess import AudioLoader, PreprocessorProtocol
|
||||||
from batdetect2.typing.train import (
|
from batdetect2.typing.train import Augmentation, ClipLabeller
|
||||||
Augmentation,
|
from batdetect2.utils.arrays import adjust_width
|
||||||
ClipLabeller,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingDataset",
|
"TrainingDataset",
|
||||||
|
"ValidationDataset",
|
||||||
|
"build_val_loader",
|
||||||
|
"build_train_loader",
|
||||||
|
"build_train_dataset",
|
||||||
|
"build_val_dataset",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -124,3 +137,174 @@ class ValidationDataset(Dataset):
|
|||||||
start_time=torch.tensor(clip.start_time),
|
start_time=torch.tensor(clip.start_time),
|
||||||
end_time=torch.tensor(clip.end_time),
|
end_time=torch.tensor(clip.end_time),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_train_loader(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
labeller: Optional[ClipLabeller] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[TrainLoaderConfig] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
config = config or TrainLoaderConfig()
|
||||||
|
|
||||||
|
logger.info("Building training data loader...")
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Training data loader config: \n{config}",
|
||||||
|
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
train_dataset = build_train_dataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_workers = num_workers or config.num_workers
|
||||||
|
return DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=config.shuffle,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_val_loader(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
labeller: Optional[ClipLabeller] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[ValLoaderConfig] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
logger.info("Building validation data loader...")
|
||||||
|
config = config or ValLoaderConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Validation data loader config: \n{config}",
|
||||||
|
config=lambda: config.to_yaml_string(exclude_none=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
val_dataset = build_val_dataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_workers = num_workers or config.num_workers
|
||||||
|
return DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers,
|
||||||
|
collate_fn=_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_train_dataset(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
labeller: Optional[ClipLabeller] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[TrainLoaderConfig] = None,
|
||||||
|
) -> TrainingDataset:
|
||||||
|
logger.info("Building training dataset...")
|
||||||
|
config = config or TrainLoaderConfig()
|
||||||
|
|
||||||
|
clipper = build_clipper(config=config.clipping_strategy)
|
||||||
|
|
||||||
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
if labeller is None:
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
random_example_source = RandomAudioSource(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.augmentations.enabled:
|
||||||
|
audio_augmentation, spectrogram_augmentation = build_augmentations(
|
||||||
|
samplerate=preprocessor.input_samplerate,
|
||||||
|
config=config.augmentations,
|
||||||
|
audio_source=random_example_source,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No augmentations configured for training dataset.")
|
||||||
|
audio_augmentation = None
|
||||||
|
spectrogram_augmentation = None
|
||||||
|
|
||||||
|
return TrainingDataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
|
clipper=clipper,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
audio_augmentation=audio_augmentation,
|
||||||
|
spectrogram_augmentation=spectrogram_augmentation,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_val_dataset(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
labeller: Optional[ClipLabeller] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[ValLoaderConfig] = None,
|
||||||
|
) -> ValidationDataset:
|
||||||
|
logger.info("Building validation dataset...")
|
||||||
|
config = config or ValLoaderConfig()
|
||||||
|
|
||||||
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
if labeller is None:
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
clipper = build_clipper(config.clipping_strategy)
|
||||||
|
return ValidationDataset(
|
||||||
|
clip_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
clipper=clipper,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
|
||||||
|
max_width = max(item.spec.shape[-1] for item in batch)
|
||||||
|
return TrainExample(
|
||||||
|
spec=torch.stack(
|
||||||
|
[adjust_width(item.spec, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
detection_heatmap=torch.stack(
|
||||||
|
[adjust_width(item.detection_heatmap, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
size_heatmap=torch.stack(
|
||||||
|
[adjust_width(item.size_heatmap, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
class_heatmap=torch.stack(
|
||||||
|
[adjust_width(item.class_heatmap, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
idx=torch.stack([item.idx for item in batch]),
|
||||||
|
start_time=torch.stack([item.start_time for item in batch]),
|
||||||
|
end_time=torch.stack([item.end_time for item in batch]),
|
||||||
|
)
|
||||||
|
|||||||
@ -77,3 +77,15 @@ def load_model_from_checkpoint(
|
|||||||
) -> Tuple[Model, FullTrainingConfig]:
|
) -> Tuple[Model, FullTrainingConfig]:
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
return module.model, module.config
|
return module.model, module.config
|
||||||
|
|
||||||
|
|
||||||
|
def build_training_module(
|
||||||
|
config: Optional[FullTrainingConfig] = None,
|
||||||
|
t_max: int = 200,
|
||||||
|
) -> TrainingModule:
|
||||||
|
config = config or FullTrainingConfig()
|
||||||
|
return TrainingModule(
|
||||||
|
config=config,
|
||||||
|
learning_rate=config.train.optimizer.learning_rate,
|
||||||
|
t_max=t_max,
|
||||||
|
)
|
||||||
|
|||||||
@ -2,47 +2,31 @@ from collections.abc import Sequence
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
|
||||||
from lightning import Trainer, seed_everything
|
from lightning import Trainer, seed_everything
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
|
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
from batdetect2.plotting.clips import PreprocessorProtocol, build_audio_loader
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.augmentations import (
|
|
||||||
RandomAudioSource,
|
|
||||||
build_augmentations,
|
|
||||||
)
|
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
FullTrainingConfig,
|
FullTrainingConfig,
|
||||||
TrainLoaderConfig,
|
|
||||||
ValLoaderConfig,
|
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule, build_training_module
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
PreprocessorProtocol,
|
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
TrainExample,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.typing.preprocess import AudioLoader
|
||||||
from batdetect2.typing.train import ClipLabeller
|
from batdetect2.typing.train import ClipLabeller
|
||||||
from batdetect2.utils.arrays import adjust_width
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"build_train_dataset",
|
|
||||||
"build_train_loader",
|
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"build_val_dataset",
|
|
||||||
"build_val_loader",
|
|
||||||
"train",
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -52,6 +36,11 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
|||||||
def train(
|
def train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||||
|
trainer: Optional[Trainer] = None,
|
||||||
|
targets: Optional[TargetProtocol] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
labeller: Optional[ClipLabeller] = None,
|
||||||
config: Optional[FullTrainingConfig] = None,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
@ -67,13 +56,15 @@ def train(
|
|||||||
|
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
targets = build_targets(config.targets)
|
targets = targets or build_targets(config.targets)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(config.preprocess)
|
preprocessor = preprocessor or build_preprocessor(config.preprocess)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
audio_loader = audio_loader or build_audio_loader(
|
||||||
|
config=config.preprocess.audio
|
||||||
|
)
|
||||||
|
|
||||||
labeller = build_clip_labeler(
|
labeller = labeller or build_clip_labeler(
|
||||||
targets,
|
targets,
|
||||||
min_freq=preprocessor.min_freq,
|
min_freq=preprocessor.min_freq,
|
||||||
max_freq=preprocessor.max_freq,
|
max_freq=preprocessor.max_freq,
|
||||||
@ -108,10 +99,10 @@ def train(
|
|||||||
else:
|
else:
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
config,
|
config,
|
||||||
t_max=config.train.t_max * len(train_dataloader),
|
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
config,
|
config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
@ -129,21 +120,9 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
|
||||||
config: Optional[FullTrainingConfig] = None,
|
|
||||||
t_max: int = 200,
|
|
||||||
) -> TrainingModule:
|
|
||||||
config = config or FullTrainingConfig()
|
|
||||||
return TrainingModule(
|
|
||||||
config=config,
|
|
||||||
learning_rate=config.train.learning_rate,
|
|
||||||
t_max=t_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(
|
def build_trainer_callbacks(
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
config: EvaluationConfig,
|
config: FullTrainingConfig,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
@ -157,7 +136,7 @@ def build_trainer_callbacks(
|
|||||||
if run_name is not None:
|
if run_name is not None:
|
||||||
checkpoint_dir = checkpoint_dir / run_name
|
checkpoint_dir = checkpoint_dir / run_name
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config, targets=targets)
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
@ -202,180 +181,9 @@ def build_trainer(
|
|||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
callbacks=build_trainer_callbacks(
|
callbacks=build_trainer_callbacks(
|
||||||
targets,
|
targets,
|
||||||
config=conf.evaluation,
|
config=conf,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_train_loader(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
labeller: Optional[ClipLabeller] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
config: Optional[TrainLoaderConfig] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
) -> DataLoader:
|
|
||||||
config = config or TrainLoaderConfig()
|
|
||||||
|
|
||||||
logger.info("Building training data loader...")
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Training data loader config: \n{config}",
|
|
||||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataset = build_train_dataset(
|
|
||||||
clip_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
labeller=labeller,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=config.shuffle,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=_collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_val_loader(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
labeller: Optional[ClipLabeller] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
config: Optional[ValLoaderConfig] = None,
|
|
||||||
num_workers: Optional[int] = None,
|
|
||||||
):
|
|
||||||
logger.info("Building validation data loader...")
|
|
||||||
config = config or ValLoaderConfig()
|
|
||||||
logger.opt(lazy=True).debug(
|
|
||||||
"Validation data loader config: \n{config}",
|
|
||||||
config=lambda: config.to_yaml_string(exclude_none=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
val_dataset = build_val_dataset(
|
|
||||||
clip_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
labeller=labeller,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
config=config,
|
|
||||||
)
|
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=1,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=num_workers,
|
|
||||||
collate_fn=_collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_train_dataset(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
labeller: Optional[ClipLabeller] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
config: Optional[TrainLoaderConfig] = None,
|
|
||||||
) -> TrainingDataset:
|
|
||||||
logger.info("Building training dataset...")
|
|
||||||
config = config or TrainLoaderConfig()
|
|
||||||
|
|
||||||
clipper = build_clipper(config=config.clipping_strategy)
|
|
||||||
|
|
||||||
if audio_loader is None:
|
|
||||||
audio_loader = build_audio_loader()
|
|
||||||
|
|
||||||
if preprocessor is None:
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
if labeller is None:
|
|
||||||
labeller = build_clip_labeler(
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
random_example_source = RandomAudioSource(
|
|
||||||
clip_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.augmentations.enabled:
|
|
||||||
audio_augmentation, spectrogram_augmentation = build_augmentations(
|
|
||||||
samplerate=preprocessor.input_samplerate,
|
|
||||||
config=config.augmentations,
|
|
||||||
audio_source=random_example_source,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug("No augmentations configured for training dataset.")
|
|
||||||
audio_augmentation = None
|
|
||||||
spectrogram_augmentation = None
|
|
||||||
|
|
||||||
return TrainingDataset(
|
|
||||||
clip_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
labeller=labeller,
|
|
||||||
clipper=clipper,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
audio_augmentation=audio_augmentation,
|
|
||||||
spectrogram_augmentation=spectrogram_augmentation,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_val_dataset(
|
|
||||||
clip_annotations: Sequence[data.ClipAnnotation],
|
|
||||||
audio_loader: Optional[AudioLoader] = None,
|
|
||||||
labeller: Optional[ClipLabeller] = None,
|
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
config: Optional[ValLoaderConfig] = None,
|
|
||||||
) -> ValidationDataset:
|
|
||||||
logger.info("Building validation dataset...")
|
|
||||||
config = config or ValLoaderConfig()
|
|
||||||
|
|
||||||
if audio_loader is None:
|
|
||||||
audio_loader = build_audio_loader()
|
|
||||||
|
|
||||||
if preprocessor is None:
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
if labeller is None:
|
|
||||||
labeller = build_clip_labeler(
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
clipper = build_clipper(config.clipping_strategy)
|
|
||||||
return ValidationDataset(
|
|
||||||
clip_annotations,
|
|
||||||
audio_loader=audio_loader,
|
|
||||||
labeller=labeller,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
clipper=clipper,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _collate_fn(batch: List[TrainExample]) -> TrainExample:
|
|
||||||
max_width = max(item.spec.shape[-1] for item in batch)
|
|
||||||
return TrainExample(
|
|
||||||
spec=torch.stack(
|
|
||||||
[adjust_width(item.spec, max_width) for item in batch]
|
|
||||||
),
|
|
||||||
detection_heatmap=torch.stack(
|
|
||||||
[adjust_width(item.detection_heatmap, max_width) for item in batch]
|
|
||||||
),
|
|
||||||
size_heatmap=torch.stack(
|
|
||||||
[adjust_width(item.size_heatmap, max_width) for item in batch]
|
|
||||||
),
|
|
||||||
class_heatmap=torch.stack(
|
|
||||||
[adjust_width(item.class_heatmap, max_width) for item in batch]
|
|
||||||
),
|
|
||||||
idx=torch.stack([item.idx for item in batch]),
|
|
||||||
start_time=torch.stack([item.start_time for item in batch]),
|
|
||||||
end_time=torch.stack([item.end_time for item in batch]),
|
|
||||||
)
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user