mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Improved train module
This commit is contained in:
parent
7689580a24
commit
e383a33cbf
@ -0,0 +1,48 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
add_echo,
|
||||
augment_example,
|
||||
load_agumentation_config,
|
||||
mask_frequency,
|
||||
mask_time,
|
||||
mix_examples,
|
||||
scale_volume,
|
||||
select_subclip,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
SubclipConfig,
|
||||
TrainExample,
|
||||
)
|
||||
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||
from batdetect2.train.preprocess import preprocess_annotations
|
||||
from batdetect2.train.targets import TargetConfig, load_target_config
|
||||
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"LabelConfig",
|
||||
"LabeledDataset",
|
||||
"SubclipConfig",
|
||||
"TargetConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"add_echo",
|
||||
"augment_example",
|
||||
"load_agumentation_config",
|
||||
"load_label_config",
|
||||
"load_target_config",
|
||||
"load_train_config",
|
||||
"load_trainer_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_examples",
|
||||
"preprocess_annotations",
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
]
|
@ -3,16 +3,30 @@ from typing import Callable, Optional, Union
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import arrays
|
||||
from soundevent import arrays, data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||
from batdetect2.preprocess.arrays import adjust_width
|
||||
|
||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||
|
||||
|
||||
class AugmentationConfig(BaseConfig):
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"load_agumentation_config",
|
||||
"select_subclip",
|
||||
"mix_examples",
|
||||
"add_echo",
|
||||
"scale_volume",
|
||||
"warp_spectrogram",
|
||||
"mask_time",
|
||||
"mask_frequency",
|
||||
"augment_example",
|
||||
]
|
||||
|
||||
|
||||
class BaseAugmentationConfig(BaseConfig):
|
||||
enable: bool = True
|
||||
probability: float = 0.2
|
||||
|
||||
@ -63,7 +77,7 @@ def select_subclip(
|
||||
)
|
||||
|
||||
|
||||
class MixAugmentationConfig(AugmentationConfig):
|
||||
class MixAugmentationConfig(BaseAugmentationConfig):
|
||||
min_weight: float = 0.3
|
||||
max_weight: float = 0.7
|
||||
|
||||
@ -133,7 +147,7 @@ def mix_examples(
|
||||
)
|
||||
|
||||
|
||||
class EchoAugmentationConfig(AugmentationConfig):
|
||||
class EchoAugmentationConfig(BaseAugmentationConfig):
|
||||
max_delay: float = 0.005
|
||||
min_weight: float = 0.0
|
||||
max_weight: float = 1.0
|
||||
@ -188,7 +202,7 @@ def add_echo(
|
||||
)
|
||||
|
||||
|
||||
class VolumeAugmentationConfig(AugmentationConfig):
|
||||
class VolumeAugmentationConfig(BaseAugmentationConfig):
|
||||
min_scaling: float = 0.0
|
||||
max_scaling: float = 2.0
|
||||
|
||||
@ -206,7 +220,7 @@ def scale_volume(
|
||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||
|
||||
|
||||
class WarpAugmentationConfig(AugmentationConfig):
|
||||
class WarpAugmentationConfig(BaseAugmentationConfig):
|
||||
delta: float = 0.04
|
||||
|
||||
|
||||
@ -294,7 +308,7 @@ def mask_axis(
|
||||
return array.where(condition, other=mask_value)
|
||||
|
||||
|
||||
class TimeMaskAugmentationConfig(AugmentationConfig):
|
||||
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
max_perc: float = 0.05
|
||||
max_masks: int = 3
|
||||
|
||||
@ -318,7 +332,7 @@ def mask_time(
|
||||
return example.assign(spectrogram=spectrogram)
|
||||
|
||||
|
||||
class FrequencyMaskAugmentationConfig(AugmentationConfig):
|
||||
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
|
||||
max_perc: float = 0.10
|
||||
max_masks: int = 3
|
||||
|
||||
@ -361,7 +375,13 @@ class AugmentationsConfig(BaseConfig):
|
||||
)
|
||||
|
||||
|
||||
def should_apply(config: AugmentationConfig) -> bool:
|
||||
def load_agumentation_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> AugmentationsConfig:
|
||||
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||
|
||||
|
||||
def should_apply(config: BaseAugmentationConfig) -> bool:
|
||||
if not config.enable:
|
||||
return False
|
||||
|
||||
|
@ -43,28 +43,23 @@ class SubclipConfig(BaseConfig):
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
augmentation: AugmentationsConfig = Field(
|
||||
default_factory=AugmentationsConfig
|
||||
)
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
config: DatasetConfig
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
filenames: Sequence[PathLike],
|
||||
augment: bool = False,
|
||||
subclip: bool = False,
|
||||
config: Optional[DatasetConfig] = None,
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
self.filenames = filenames
|
||||
self.augment = augment
|
||||
self.subclip = subclip
|
||||
self.config = config or DatasetConfig()
|
||||
self.augmentation = augmentation
|
||||
self.preprocessing = preprocessing or PreprocessingConfig()
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
@ -75,16 +70,16 @@ class LabeledDataset(Dataset):
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.config.subclip.duration,
|
||||
width=self.config.subclip.width,
|
||||
random=self.config.subclip.random,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
|
||||
if self.augment:
|
||||
if self.augmentation:
|
||||
dataset = augment_example(
|
||||
dataset,
|
||||
self.config.augmentation,
|
||||
preprocessing_config=self.config.preprocessing,
|
||||
self.augmentation,
|
||||
preprocessing_config=self.preprocessing,
|
||||
others=self.get_random_example,
|
||||
)
|
||||
|
||||
@ -101,15 +96,15 @@ class LabeledDataset(Dataset):
|
||||
cls,
|
||||
directory: PathLike,
|
||||
extension: str = ".nc",
|
||||
config: Optional[DatasetConfig] = None,
|
||||
augment: bool = False,
|
||||
subclip: bool = False,
|
||||
subclip: Optional[SubclipConfig] = None,
|
||||
augmentation: Optional[AugmentationsConfig] = None,
|
||||
preprocessing: Optional[PreprocessingConfig] = None,
|
||||
):
|
||||
return cls(
|
||||
get_files(directory, extension),
|
||||
config=config,
|
||||
augment=augment,
|
||||
subclip=subclip,
|
||||
augmentation=augmentation,
|
||||
preprocessing=preprocessing,
|
||||
)
|
||||
|
||||
def get_random_example(self) -> xr.Dataset:
|
||||
@ -119,9 +114,9 @@ class LabeledDataset(Dataset):
|
||||
if self.subclip:
|
||||
dataset = select_subclip(
|
||||
dataset,
|
||||
duration=self.config.subclip.duration,
|
||||
width=self.config.subclip.width,
|
||||
random=self.config.subclip.random,
|
||||
duration=self.subclip.duration,
|
||||
width=self.subclip.width,
|
||||
random=self.subclip.random,
|
||||
)
|
||||
|
||||
return dataset
|
||||
@ -144,7 +139,7 @@ class LabeledDataset(Dataset):
|
||||
if not self.subclip:
|
||||
return tensor
|
||||
|
||||
width = self.config.subclip.width
|
||||
width = self.subclip.width
|
||||
return adjust_width(tensor, width)
|
||||
|
||||
|
||||
|
@ -3,13 +3,19 @@ from typing import Callable, List, Optional, Sequence, Tuple
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import arrays, data, geometry
|
||||
from soundevent.geometry.operations import Positions
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
|
||||
__all__ = ["generate_heatmaps"]
|
||||
__all__ = [
|
||||
"HeatmapsConfig",
|
||||
"LabelConfig",
|
||||
"generate_heatmaps",
|
||||
"load_label_config",
|
||||
]
|
||||
|
||||
|
||||
class HeatmapsConfig(BaseConfig):
|
||||
@ -19,6 +25,10 @@ class HeatmapsConfig(BaseConfig):
|
||||
frequency_scale: float = 1 / 859.375
|
||||
|
||||
|
||||
class LabelConfig(BaseConfig):
|
||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
sound_events: Sequence[data.SoundEventAnnotation],
|
||||
spec: xr.DataArray,
|
||||
@ -132,3 +142,9 @@ def generate_heatmaps(
|
||||
).fillna(0.0)
|
||||
|
||||
return detection_heatmap, class_heatmap, size_heatmap
|
||||
|
||||
|
||||
def load_label_config(
|
||||
path: data.PathLike, field: Optional[str] = None
|
||||
) -> LabelConfig:
|
||||
return load_config(path, schema=LabelConfig, field=field)
|
||||
|
@ -6,9 +6,15 @@ from pydantic import Field
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.plot import detection
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
__all__ = [
|
||||
"bbox_size_loss",
|
||||
"compute_loss",
|
||||
"focal_loss",
|
||||
"mse_loss",
|
||||
]
|
||||
|
||||
|
||||
class SizeLossConfig(BaseConfig):
|
||||
weight: float = 0.1
|
||||
|
@ -17,7 +17,7 @@ from batdetect2.preprocess import (
|
||||
compute_spectrogram,
|
||||
load_clip_audio,
|
||||
)
|
||||
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
|
||||
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_encoder,
|
||||
@ -30,6 +30,9 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
"preprocess_single_annotation",
|
||||
"generate_train_example",
|
||||
"TrainPreprocessingConfig",
|
||||
]
|
||||
|
||||
|
||||
@ -38,15 +41,21 @@ class TrainPreprocessingConfig(BaseConfig):
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
config: Optional[TrainPreprocessingConfig] = None,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a training example."""
|
||||
config = config or TrainPreprocessingConfig()
|
||||
config = TrainPreprocessingConfig(
|
||||
preprocessing=preprocessing_config or PreprocessingConfig(),
|
||||
target=target_config or TargetConfig(),
|
||||
labels=label_config or LabelConfig(),
|
||||
)
|
||||
|
||||
wave = load_clip_audio(
|
||||
clip_annotation.clip,
|
||||
@ -78,10 +87,10 @@ def generate_train_example(
|
||||
spectrogram,
|
||||
class_names,
|
||||
encoder,
|
||||
target_sigma=config.heatmaps.sigma,
|
||||
position=config.heatmaps.position,
|
||||
time_scale=config.heatmaps.time_scale,
|
||||
frequency_scale=config.heatmaps.frequency_scale,
|
||||
target_sigma=config.labels.heatmaps.sigma,
|
||||
position=config.labels.heatmaps.position,
|
||||
time_scale=config.labels.heatmaps.time_scale,
|
||||
frequency_scale=config.labels.heatmaps.frequency_scale,
|
||||
)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
@ -133,14 +142,14 @@ def preprocess_annotations(
|
||||
output_dir: PathLike,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
config: Optional[TrainPreprocessingConfig] = None,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
max_workers: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
config = config or TrainPreprocessingConfig()
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
@ -151,9 +160,11 @@ def preprocess_annotations(
|
||||
partial(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
config=config,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
@ -165,7 +176,9 @@ def preprocess_annotations(
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
config: TrainPreprocessingConfig,
|
||||
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||
target_config: Optional[TargetConfig] = None,
|
||||
label_config: Optional[LabelConfig] = None,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
) -> None:
|
||||
@ -181,7 +194,12 @@ def preprocess_single_annotation(
|
||||
path.unlink()
|
||||
|
||||
try:
|
||||
sample = generate_train_example(clip_annotation, config=config)
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
preprocessing_config=preprocessing_config,
|
||||
target_config=target_config,
|
||||
label_config=label_config,
|
||||
)
|
||||
except Exception as error:
|
||||
raise RuntimeError(
|
||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||
|
@ -9,6 +9,14 @@ from soundevent import data
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||
|
||||
__all__ = [
|
||||
"TargetConfig",
|
||||
"load_target_config",
|
||||
"build_encoder",
|
||||
"build_decoder",
|
||||
"filter_sound_event",
|
||||
]
|
||||
|
||||
|
||||
class ReplaceConfig(BaseConfig):
|
||||
"""Configuration for replacing tags."""
|
||||
|
@ -1,82 +1,68 @@
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import Trainer
|
||||
from soundevent.data import PathLike
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.data.datasets import ClipAnnotationDataset
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"TrainerConfig",
|
||||
"load_trainer_config",
|
||||
]
|
||||
|
||||
|
||||
class TrainInputs(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
class TrainerConfig(BaseConfig):
|
||||
accelerator: str = "auto"
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
check_val_every_n_epoch: int = 1
|
||||
devices: Union[str, int] = "auto"
|
||||
enable_checkpointing: bool = True
|
||||
gradient_clip_val: Optional[float] = None
|
||||
limit_train_batches: Optional[Union[int, float]] = None
|
||||
limit_test_batches: Optional[Union[int, float]] = None
|
||||
limit_val_batches: Optional[Union[int, float]] = None
|
||||
log_every_n_steps: Optional[int] = None
|
||||
max_epochs: Optional[int] = None
|
||||
min_epochs: Optional[int] = 100
|
||||
max_steps: Optional[int] = None
|
||||
min_steps: Optional[int] = None
|
||||
max_time: Optional[str] = None
|
||||
precision: Optional[str] = None
|
||||
reload_dataloaders_every_n_epochs: Optional[int] = None
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
validation_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
||||
return load_config(path, schema=TrainerConfig, field=field)
|
||||
|
||||
|
||||
def train(
|
||||
module: LightningModule,
|
||||
train_dataset: LabeledDataset,
|
||||
trainer_config: Optional[TrainerConfig] = None,
|
||||
dev_run: bool = False,
|
||||
overfit_batches: bool = False,
|
||||
profiler: Optional[str] = None,
|
||||
):
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
num_epochs * len(train_loader),
|
||||
trainer_config = trainer_config or TrainerConfig()
|
||||
trainer = Trainer(
|
||||
**trainer_config.model_dump(
|
||||
exclude_unset=True,
|
||||
exclude_none=True,
|
||||
),
|
||||
fast_dev_run=dev_run,
|
||||
overfit_batches=overfit_batches,
|
||||
profiler=profiler,
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
device,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
|
||||
def train_single_epoch(
|
||||
model: DetectionModel,
|
||||
train_loader: DataLoader,
|
||||
optimizer: Adam,
|
||||
device: torch.device,
|
||||
scheduler: CosineAnnealingLR,
|
||||
):
|
||||
model.train()
|
||||
train_loss = tu.AverageMeter()
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=module.config.train.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=7,
|
||||
)
|
||||
trainer.fit(module, train_dataloaders=train_loader)
|
||||
|
Loading…
Reference in New Issue
Block a user