mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Improved train module
This commit is contained in:
parent
7689580a24
commit
e383a33cbf
@ -0,0 +1,48 @@
|
|||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
AugmentationsConfig,
|
||||||
|
add_echo,
|
||||||
|
augment_example,
|
||||||
|
load_agumentation_config,
|
||||||
|
mask_frequency,
|
||||||
|
mask_time,
|
||||||
|
mix_examples,
|
||||||
|
scale_volume,
|
||||||
|
select_subclip,
|
||||||
|
warp_spectrogram,
|
||||||
|
)
|
||||||
|
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||||
|
from batdetect2.train.dataset import (
|
||||||
|
LabeledDataset,
|
||||||
|
SubclipConfig,
|
||||||
|
TrainExample,
|
||||||
|
)
|
||||||
|
from batdetect2.train.labels import LabelConfig, load_label_config
|
||||||
|
from batdetect2.train.preprocess import preprocess_annotations
|
||||||
|
from batdetect2.train.targets import TargetConfig, load_target_config
|
||||||
|
from batdetect2.train.train import TrainerConfig, load_trainer_config, train
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"AugmentationsConfig",
|
||||||
|
"LabelConfig",
|
||||||
|
"LabeledDataset",
|
||||||
|
"SubclipConfig",
|
||||||
|
"TargetConfig",
|
||||||
|
"TrainExample",
|
||||||
|
"TrainerConfig",
|
||||||
|
"TrainingConfig",
|
||||||
|
"add_echo",
|
||||||
|
"augment_example",
|
||||||
|
"load_agumentation_config",
|
||||||
|
"load_label_config",
|
||||||
|
"load_target_config",
|
||||||
|
"load_train_config",
|
||||||
|
"load_trainer_config",
|
||||||
|
"mask_frequency",
|
||||||
|
"mask_time",
|
||||||
|
"mix_examples",
|
||||||
|
"preprocess_annotations",
|
||||||
|
"scale_volume",
|
||||||
|
"select_subclip",
|
||||||
|
"train",
|
||||||
|
"warp_spectrogram",
|
||||||
|
]
|
@ -3,16 +3,30 @@ from typing import Callable, Optional, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import arrays
|
from soundevent import arrays, data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
from batdetect2.preprocess import PreprocessingConfig, compute_spectrogram
|
||||||
from batdetect2.preprocess.arrays import adjust_width
|
from batdetect2.preprocess.arrays import adjust_width
|
||||||
|
|
||||||
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
Augmentation = Callable[[xr.Dataset], xr.Dataset]
|
||||||
|
|
||||||
|
|
||||||
class AugmentationConfig(BaseConfig):
|
__all__ = [
|
||||||
|
"AugmentationsConfig",
|
||||||
|
"load_agumentation_config",
|
||||||
|
"select_subclip",
|
||||||
|
"mix_examples",
|
||||||
|
"add_echo",
|
||||||
|
"scale_volume",
|
||||||
|
"warp_spectrogram",
|
||||||
|
"mask_time",
|
||||||
|
"mask_frequency",
|
||||||
|
"augment_example",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAugmentationConfig(BaseConfig):
|
||||||
enable: bool = True
|
enable: bool = True
|
||||||
probability: float = 0.2
|
probability: float = 0.2
|
||||||
|
|
||||||
@ -63,7 +77,7 @@ def select_subclip(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MixAugmentationConfig(AugmentationConfig):
|
class MixAugmentationConfig(BaseAugmentationConfig):
|
||||||
min_weight: float = 0.3
|
min_weight: float = 0.3
|
||||||
max_weight: float = 0.7
|
max_weight: float = 0.7
|
||||||
|
|
||||||
@ -133,7 +147,7 @@ def mix_examples(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class EchoAugmentationConfig(AugmentationConfig):
|
class EchoAugmentationConfig(BaseAugmentationConfig):
|
||||||
max_delay: float = 0.005
|
max_delay: float = 0.005
|
||||||
min_weight: float = 0.0
|
min_weight: float = 0.0
|
||||||
max_weight: float = 1.0
|
max_weight: float = 1.0
|
||||||
@ -188,7 +202,7 @@ def add_echo(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class VolumeAugmentationConfig(AugmentationConfig):
|
class VolumeAugmentationConfig(BaseAugmentationConfig):
|
||||||
min_scaling: float = 0.0
|
min_scaling: float = 0.0
|
||||||
max_scaling: float = 2.0
|
max_scaling: float = 2.0
|
||||||
|
|
||||||
@ -206,7 +220,7 @@ def scale_volume(
|
|||||||
return example.assign(spectrogram=example["spectrogram"] * factor)
|
return example.assign(spectrogram=example["spectrogram"] * factor)
|
||||||
|
|
||||||
|
|
||||||
class WarpAugmentationConfig(AugmentationConfig):
|
class WarpAugmentationConfig(BaseAugmentationConfig):
|
||||||
delta: float = 0.04
|
delta: float = 0.04
|
||||||
|
|
||||||
|
|
||||||
@ -294,7 +308,7 @@ def mask_axis(
|
|||||||
return array.where(condition, other=mask_value)
|
return array.where(condition, other=mask_value)
|
||||||
|
|
||||||
|
|
||||||
class TimeMaskAugmentationConfig(AugmentationConfig):
|
class TimeMaskAugmentationConfig(BaseAugmentationConfig):
|
||||||
max_perc: float = 0.05
|
max_perc: float = 0.05
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
|
|
||||||
@ -318,7 +332,7 @@ def mask_time(
|
|||||||
return example.assign(spectrogram=spectrogram)
|
return example.assign(spectrogram=spectrogram)
|
||||||
|
|
||||||
|
|
||||||
class FrequencyMaskAugmentationConfig(AugmentationConfig):
|
class FrequencyMaskAugmentationConfig(BaseAugmentationConfig):
|
||||||
max_perc: float = 0.10
|
max_perc: float = 0.10
|
||||||
max_masks: int = 3
|
max_masks: int = 3
|
||||||
|
|
||||||
@ -361,7 +375,13 @@ class AugmentationsConfig(BaseConfig):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def should_apply(config: AugmentationConfig) -> bool:
|
def load_agumentation_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> AugmentationsConfig:
|
||||||
|
return load_config(path, schema=AugmentationsConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def should_apply(config: BaseAugmentationConfig) -> bool:
|
||||||
if not config.enable:
|
if not config.enable:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -43,28 +43,23 @@ class SubclipConfig(BaseConfig):
|
|||||||
|
|
||||||
class DatasetConfig(BaseConfig):
|
class DatasetConfig(BaseConfig):
|
||||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||||
preprocessing: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
augmentation: AugmentationsConfig = Field(
|
augmentation: AugmentationsConfig = Field(
|
||||||
default_factory=AugmentationsConfig
|
default_factory=AugmentationsConfig
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LabeledDataset(Dataset):
|
class LabeledDataset(Dataset):
|
||||||
config: DatasetConfig
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
filenames: Sequence[PathLike],
|
filenames: Sequence[PathLike],
|
||||||
augment: bool = False,
|
subclip: Optional[SubclipConfig] = None,
|
||||||
subclip: bool = False,
|
augmentation: Optional[AugmentationsConfig] = None,
|
||||||
config: Optional[DatasetConfig] = None,
|
preprocessing: Optional[PreprocessingConfig] = None,
|
||||||
):
|
):
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.augment = augment
|
|
||||||
self.subclip = subclip
|
self.subclip = subclip
|
||||||
self.config = config or DatasetConfig()
|
self.augmentation = augmentation
|
||||||
|
self.preprocessing = preprocessing or PreprocessingConfig()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.filenames)
|
return len(self.filenames)
|
||||||
@ -75,16 +70,16 @@ class LabeledDataset(Dataset):
|
|||||||
if self.subclip:
|
if self.subclip:
|
||||||
dataset = select_subclip(
|
dataset = select_subclip(
|
||||||
dataset,
|
dataset,
|
||||||
duration=self.config.subclip.duration,
|
duration=self.subclip.duration,
|
||||||
width=self.config.subclip.width,
|
width=self.subclip.width,
|
||||||
random=self.config.subclip.random,
|
random=self.subclip.random,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.augment:
|
if self.augmentation:
|
||||||
dataset = augment_example(
|
dataset = augment_example(
|
||||||
dataset,
|
dataset,
|
||||||
self.config.augmentation,
|
self.augmentation,
|
||||||
preprocessing_config=self.config.preprocessing,
|
preprocessing_config=self.preprocessing,
|
||||||
others=self.get_random_example,
|
others=self.get_random_example,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -101,15 +96,15 @@ class LabeledDataset(Dataset):
|
|||||||
cls,
|
cls,
|
||||||
directory: PathLike,
|
directory: PathLike,
|
||||||
extension: str = ".nc",
|
extension: str = ".nc",
|
||||||
config: Optional[DatasetConfig] = None,
|
subclip: Optional[SubclipConfig] = None,
|
||||||
augment: bool = False,
|
augmentation: Optional[AugmentationsConfig] = None,
|
||||||
subclip: bool = False,
|
preprocessing: Optional[PreprocessingConfig] = None,
|
||||||
):
|
):
|
||||||
return cls(
|
return cls(
|
||||||
get_files(directory, extension),
|
get_files(directory, extension),
|
||||||
config=config,
|
|
||||||
augment=augment,
|
|
||||||
subclip=subclip,
|
subclip=subclip,
|
||||||
|
augmentation=augmentation,
|
||||||
|
preprocessing=preprocessing,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_random_example(self) -> xr.Dataset:
|
def get_random_example(self) -> xr.Dataset:
|
||||||
@ -119,9 +114,9 @@ class LabeledDataset(Dataset):
|
|||||||
if self.subclip:
|
if self.subclip:
|
||||||
dataset = select_subclip(
|
dataset = select_subclip(
|
||||||
dataset,
|
dataset,
|
||||||
duration=self.config.subclip.duration,
|
duration=self.subclip.duration,
|
||||||
width=self.config.subclip.width,
|
width=self.subclip.width,
|
||||||
random=self.config.subclip.random,
|
random=self.subclip.random,
|
||||||
)
|
)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
@ -144,7 +139,7 @@ class LabeledDataset(Dataset):
|
|||||||
if not self.subclip:
|
if not self.subclip:
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
width = self.config.subclip.width
|
width = self.subclip.width
|
||||||
return adjust_width(tensor, width)
|
return adjust_width(tensor, width)
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,13 +3,19 @@ from typing import Callable, List, Optional, Sequence, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from pydantic import Field
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import arrays, data, geometry
|
from soundevent import arrays, data, geometry
|
||||||
from soundevent.geometry.operations import Positions
|
from soundevent.geometry.operations import Positions
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
__all__ = ["generate_heatmaps"]
|
__all__ = [
|
||||||
|
"HeatmapsConfig",
|
||||||
|
"LabelConfig",
|
||||||
|
"generate_heatmaps",
|
||||||
|
"load_label_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class HeatmapsConfig(BaseConfig):
|
class HeatmapsConfig(BaseConfig):
|
||||||
@ -19,6 +25,10 @@ class HeatmapsConfig(BaseConfig):
|
|||||||
frequency_scale: float = 1 / 859.375
|
frequency_scale: float = 1 / 859.375
|
||||||
|
|
||||||
|
|
||||||
|
class LabelConfig(BaseConfig):
|
||||||
|
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
sound_events: Sequence[data.SoundEventAnnotation],
|
sound_events: Sequence[data.SoundEventAnnotation],
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
@ -132,3 +142,9 @@ def generate_heatmaps(
|
|||||||
).fillna(0.0)
|
).fillna(0.0)
|
||||||
|
|
||||||
return detection_heatmap, class_heatmap, size_heatmap
|
return detection_heatmap, class_heatmap, size_heatmap
|
||||||
|
|
||||||
|
|
||||||
|
def load_label_config(
|
||||||
|
path: data.PathLike, field: Optional[str] = None
|
||||||
|
) -> LabelConfig:
|
||||||
|
return load_config(path, schema=LabelConfig, field=field)
|
||||||
|
@ -6,9 +6,15 @@ from pydantic import Field
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.typing import ModelOutput
|
from batdetect2.models.typing import ModelOutput
|
||||||
from batdetect2.plot import detection
|
|
||||||
from batdetect2.train.dataset import TrainExample
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"bbox_size_loss",
|
||||||
|
"compute_loss",
|
||||||
|
"focal_loss",
|
||||||
|
"mse_loss",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class SizeLossConfig(BaseConfig):
|
class SizeLossConfig(BaseConfig):
|
||||||
weight: float = 0.1
|
weight: float = 0.1
|
||||||
|
@ -17,7 +17,7 @@ from batdetect2.preprocess import (
|
|||||||
compute_spectrogram,
|
compute_spectrogram,
|
||||||
load_clip_audio,
|
load_clip_audio,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
|
from batdetect2.train.labels import LabelConfig, generate_heatmaps
|
||||||
from batdetect2.train.targets import (
|
from batdetect2.train.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_encoder,
|
build_encoder,
|
||||||
@ -30,6 +30,9 @@ FilenameFn = Callable[[data.ClipAnnotation], str]
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
|
"preprocess_single_annotation",
|
||||||
|
"generate_train_example",
|
||||||
|
"TrainPreprocessingConfig",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -38,15 +41,21 @@ class TrainPreprocessingConfig(BaseConfig):
|
|||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
config: Optional[TrainPreprocessingConfig] = None,
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
|
target_config: Optional[TargetConfig] = None,
|
||||||
|
label_config: Optional[LabelConfig] = None,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Generate a training example."""
|
"""Generate a training example."""
|
||||||
config = config or TrainPreprocessingConfig()
|
config = TrainPreprocessingConfig(
|
||||||
|
preprocessing=preprocessing_config or PreprocessingConfig(),
|
||||||
|
target=target_config or TargetConfig(),
|
||||||
|
labels=label_config or LabelConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
wave = load_clip_audio(
|
wave = load_clip_audio(
|
||||||
clip_annotation.clip,
|
clip_annotation.clip,
|
||||||
@ -78,10 +87,10 @@ def generate_train_example(
|
|||||||
spectrogram,
|
spectrogram,
|
||||||
class_names,
|
class_names,
|
||||||
encoder,
|
encoder,
|
||||||
target_sigma=config.heatmaps.sigma,
|
target_sigma=config.labels.heatmaps.sigma,
|
||||||
position=config.heatmaps.position,
|
position=config.labels.heatmaps.position,
|
||||||
time_scale=config.heatmaps.time_scale,
|
time_scale=config.labels.heatmaps.time_scale,
|
||||||
frequency_scale=config.heatmaps.frequency_scale,
|
frequency_scale=config.labels.heatmaps.frequency_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
dataset = xr.Dataset(
|
||||||
@ -133,14 +142,14 @@ def preprocess_annotations(
|
|||||||
output_dir: PathLike,
|
output_dir: PathLike,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
config: Optional[TrainPreprocessingConfig] = None,
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
|
target_config: Optional[TargetConfig] = None,
|
||||||
|
label_config: Optional[LabelConfig] = None,
|
||||||
max_workers: Optional[int] = None,
|
max_workers: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Preprocess annotations and save to disk."""
|
"""Preprocess annotations and save to disk."""
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
config = config or TrainPreprocessingConfig()
|
|
||||||
|
|
||||||
if not output_dir.is_dir():
|
if not output_dir.is_dir():
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
@ -151,9 +160,11 @@ def preprocess_annotations(
|
|||||||
partial(
|
partial(
|
||||||
preprocess_single_annotation,
|
preprocess_single_annotation,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
config=config,
|
|
||||||
filename_fn=filename_fn,
|
filename_fn=filename_fn,
|
||||||
replace=replace,
|
replace=replace,
|
||||||
|
preprocessing_config=preprocessing_config,
|
||||||
|
target_config=target_config,
|
||||||
|
label_config=label_config,
|
||||||
),
|
),
|
||||||
clip_annotations,
|
clip_annotations,
|
||||||
),
|
),
|
||||||
@ -165,7 +176,9 @@ def preprocess_annotations(
|
|||||||
def preprocess_single_annotation(
|
def preprocess_single_annotation(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
output_dir: PathLike,
|
output_dir: PathLike,
|
||||||
config: TrainPreprocessingConfig,
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
|
target_config: Optional[TargetConfig] = None,
|
||||||
|
label_config: Optional[LabelConfig] = None,
|
||||||
filename_fn: FilenameFn = _get_filename,
|
filename_fn: FilenameFn = _get_filename,
|
||||||
replace: bool = False,
|
replace: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -181,7 +194,12 @@ def preprocess_single_annotation(
|
|||||||
path.unlink()
|
path.unlink()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sample = generate_train_example(clip_annotation, config=config)
|
sample = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
preprocessing_config=preprocessing_config,
|
||||||
|
target_config=target_config,
|
||||||
|
label_config=label_config,
|
||||||
|
)
|
||||||
except Exception as error:
|
except Exception as error:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Failed to process annotation: {clip_annotation.uuid}"
|
f"Failed to process annotation: {clip_annotation.uuid}"
|
||||||
|
@ -9,6 +9,14 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.terms import TagInfo, get_tag_from_info
|
from batdetect2.terms import TagInfo, get_tag_from_info
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TargetConfig",
|
||||||
|
"load_target_config",
|
||||||
|
"build_encoder",
|
||||||
|
"build_decoder",
|
||||||
|
"filter_sound_event",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class ReplaceConfig(BaseConfig):
|
class ReplaceConfig(BaseConfig):
|
||||||
"""Configuration for replacing tags."""
|
"""Configuration for replacing tags."""
|
||||||
|
@ -1,82 +1,68 @@
|
|||||||
from typing import Callable, NamedTuple, Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import torch
|
from lightning import LightningModule
|
||||||
from soundevent import data
|
from lightning.pytorch import Trainer
|
||||||
from torch.optim import Adam
|
from soundevent.data import PathLike
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.data.datasets import ClipAnnotationDataset
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.models.typing import DetectionModel
|
from batdetect2.train.dataset import LabeledDataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"train",
|
||||||
|
"TrainerConfig",
|
||||||
|
"load_trainer_config",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TrainInputs(NamedTuple):
|
class TrainerConfig(BaseConfig):
|
||||||
spec: torch.Tensor
|
accelerator: str = "auto"
|
||||||
detection_heatmap: torch.Tensor
|
accumulate_grad_batches: int = 1
|
||||||
class_heatmap: torch.Tensor
|
deterministic: bool = True
|
||||||
size_heatmap: torch.Tensor
|
check_val_every_n_epoch: int = 1
|
||||||
|
devices: Union[str, int] = "auto"
|
||||||
|
enable_checkpointing: bool = True
|
||||||
|
gradient_clip_val: Optional[float] = None
|
||||||
|
limit_train_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_test_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_val_batches: Optional[Union[int, float]] = None
|
||||||
|
log_every_n_steps: Optional[int] = None
|
||||||
|
max_epochs: Optional[int] = None
|
||||||
|
min_epochs: Optional[int] = 100
|
||||||
|
max_steps: Optional[int] = None
|
||||||
|
min_steps: Optional[int] = None
|
||||||
|
max_time: Optional[str] = None
|
||||||
|
precision: Optional[str] = None
|
||||||
|
reload_dataloaders_every_n_epochs: Optional[int] = None
|
||||||
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
def train_loop(
|
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
||||||
model: DetectionModel,
|
return load_config(path, schema=TrainerConfig, field=field)
|
||||||
train_dataset: ClipAnnotationDataset[TrainInputs],
|
|
||||||
validation_dataset: ClipAnnotationDataset[TrainInputs],
|
|
||||||
device: Optional[torch.device] = None,
|
def train(
|
||||||
num_epochs: int = 100,
|
module: LightningModule,
|
||||||
learning_rate: float = 1e-4,
|
train_dataset: LabeledDataset,
|
||||||
|
trainer_config: Optional[TrainerConfig] = None,
|
||||||
|
dev_run: bool = False,
|
||||||
|
overfit_batches: bool = False,
|
||||||
|
profiler: Optional[str] = None,
|
||||||
):
|
):
|
||||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
trainer_config = trainer_config or TrainerConfig()
|
||||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
trainer = Trainer(
|
||||||
|
**trainer_config.model_dump(
|
||||||
model.to(device)
|
exclude_unset=True,
|
||||||
|
exclude_none=True,
|
||||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
),
|
||||||
scheduler = CosineAnnealingLR(
|
fast_dev_run=dev_run,
|
||||||
optimizer,
|
overfit_batches=overfit_batches,
|
||||||
num_epochs * len(train_loader),
|
profiler=profiler,
|
||||||
)
|
)
|
||||||
|
train_loader = DataLoader(
|
||||||
for epoch in range(num_epochs):
|
train_dataset,
|
||||||
train_loss = train_single_epoch(
|
batch_size=module.config.train.batch_size,
|
||||||
model,
|
shuffle=True,
|
||||||
train_loader,
|
num_workers=7,
|
||||||
optimizer,
|
)
|
||||||
device,
|
trainer.fit(module, train_dataloaders=train_loader)
|
||||||
scheduler,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train_single_epoch(
|
|
||||||
model: DetectionModel,
|
|
||||||
train_loader: DataLoader,
|
|
||||||
optimizer: Adam,
|
|
||||||
device: torch.device,
|
|
||||||
scheduler: CosineAnnealingLR,
|
|
||||||
):
|
|
||||||
model.train()
|
|
||||||
train_loss = tu.AverageMeter()
|
|
||||||
|
|
||||||
for batch in train_loader:
|
|
||||||
optimizer.zero_grad()
|
|
||||||
|
|
||||||
spec = batch.spec.to(device)
|
|
||||||
detection_heatmap = batch.detection_heatmap.to(device)
|
|
||||||
class_heatmap = batch.class_heatmap.to(device)
|
|
||||||
size_heatmap = batch.size_heatmap.to(device)
|
|
||||||
|
|
||||||
outputs = model(spec)
|
|
||||||
|
|
||||||
loss = loss_fun(
|
|
||||||
outputs,
|
|
||||||
gt_det,
|
|
||||||
gt_size,
|
|
||||||
gt_class,
|
|
||||||
det_criterion,
|
|
||||||
params,
|
|
||||||
class_inv_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_loss.update(loss.item(), data.shape[0])
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
scheduler.step()
|
|
||||||
|
Loading…
Reference in New Issue
Block a user