Refine training config

This commit is contained in:
mbsantiago 2025-04-23 23:35:42 +01:00
parent d51e3f8bbd
commit 7dd35d6e3e
3 changed files with 17 additions and 22 deletions

View File

@ -11,7 +11,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1 DEFAULT_MAX_EMPTY_CLIP = 0.1
class ClipperConfig(BaseConfig): class ClipingConfig(BaseConfig):
duration: float = DEFAULT_TRAIN_CLIP_DURATION duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP max_empty: float = DEFAULT_MAX_EMPTY_CLIP
@ -69,8 +69,8 @@ class Clipper(ClipperProtocol):
) )
def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol: def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
config = config or ClipperConfig() config = config or ClipingConfig()
return Clipper( return Clipper(
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,

View File

@ -4,6 +4,11 @@ from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import ClipingConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
__all__ = [ __all__ = [
@ -20,9 +25,17 @@ class OptimizerConfig(BaseConfig):
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
batch_size: int = 32 batch_size: int = 32
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
def load_train_config( def load_train_config(
path: PathLike, path: PathLike,

View File

@ -4,15 +4,10 @@ from typing import List, Optional, Sequence, Tuple
import numpy as np import numpy as np
import torch import torch
import xarray as xr import xarray as xr
from pydantic import Field
from soundevent import data from soundevent import data
from torch.utils.data import Dataset from torch.utils.data import Dataset
from batdetect2.configs import BaseConfig from batdetect2.train.augmentations import Augmentation
from batdetect2.train.augmentations import (
Augmentation,
AugmentationsConfig,
)
from batdetect2.train.types import ClipperProtocol, TrainExample from batdetect2.train.types import ClipperProtocol, TrainExample
__all__ = [ __all__ = [
@ -20,19 +15,6 @@ __all__ = [
] ]
class SubclipConfig(BaseConfig):
duration: Optional[float] = None
width: int = 512
random: bool = False
class DatasetConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
augmentation: AugmentationsConfig = Field(
default_factory=AugmentationsConfig
)
class LabeledDataset(Dataset): class LabeledDataset(Dataset):
def __init__( def __init__(
self, self,