diff --git a/batdetect2/train/clips.py b/batdetect2/train/clips.py index c1d39e8..ec13b8e 100644 --- a/batdetect2/train/clips.py +++ b/batdetect2/train/clips.py @@ -11,7 +11,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.513 DEFAULT_MAX_EMPTY_CLIP = 0.1 -class ClipperConfig(BaseConfig): +class ClipingConfig(BaseConfig): duration: float = DEFAULT_TRAIN_CLIP_DURATION random: bool = True max_empty: float = DEFAULT_MAX_EMPTY_CLIP @@ -69,8 +69,8 @@ class Clipper(ClipperProtocol): ) -def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol: - config = config or ClipperConfig() +def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol: + config = config or ClipingConfig() return Clipper( duration=config.duration, max_empty=config.max_empty, diff --git a/batdetect2/train/config.py b/batdetect2/train/config.py index 5663611..e2ab8ec 100644 --- a/batdetect2/train/config.py +++ b/batdetect2/train/config.py @@ -4,6 +4,11 @@ from pydantic import Field from soundevent.data import PathLike from batdetect2.configs import BaseConfig, load_config +from batdetect2.train.augmentations import ( + DEFAULT_AUGMENTATION_CONFIG, + AugmentationsConfig, +) +from batdetect2.train.clips import ClipingConfig from batdetect2.train.losses import LossConfig __all__ = [ @@ -20,9 +25,17 @@ class OptimizerConfig(BaseConfig): class TrainingConfig(BaseConfig): batch_size: int = 32 + loss: LossConfig = Field(default_factory=LossConfig) + optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) + augmentations: AugmentationsConfig = Field( + default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG + ) + + cliping: ClipingConfig = Field(default_factory=ClipingConfig) + def load_train_config( path: PathLike, diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index 4f76051..9c0ea82 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -4,15 +4,10 @@ from typing import List, Optional, Sequence, Tuple import numpy as np import torch import xarray as xr -from pydantic import Field from soundevent import data from torch.utils.data import Dataset -from batdetect2.configs import BaseConfig -from batdetect2.train.augmentations import ( - Augmentation, - AugmentationsConfig, -) +from batdetect2.train.augmentations import Augmentation from batdetect2.train.types import ClipperProtocol, TrainExample __all__ = [ @@ -20,19 +15,6 @@ __all__ = [ ] -class SubclipConfig(BaseConfig): - duration: Optional[float] = None - width: int = 512 - random: bool = False - - -class DatasetConfig(BaseConfig): - subclip: SubclipConfig = Field(default_factory=SubclipConfig) - augmentation: AugmentationsConfig = Field( - default_factory=AugmentationsConfig - ) - - class LabeledDataset(Dataset): def __init__( self,