Refine training config

This commit is contained in:
mbsantiago 2025-04-23 23:35:42 +01:00
parent d51e3f8bbd
commit 7dd35d6e3e
3 changed files with 17 additions and 22 deletions

View File

@ -11,7 +11,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.513
DEFAULT_MAX_EMPTY_CLIP = 0.1
class ClipperConfig(BaseConfig):
class ClipingConfig(BaseConfig):
duration: float = DEFAULT_TRAIN_CLIP_DURATION
random: bool = True
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
@ -69,8 +69,8 @@ class Clipper(ClipperProtocol):
)
def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol:
config = config or ClipperConfig()
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
config = config or ClipingConfig()
return Clipper(
duration=config.duration,
max_empty=config.max_empty,

View File

@ -4,6 +4,11 @@ from pydantic import Field
from soundevent.data import PathLike
from batdetect2.configs import BaseConfig, load_config
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
)
from batdetect2.train.clips import ClipingConfig
from batdetect2.train.losses import LossConfig
__all__ = [
@ -20,9 +25,17 @@ class OptimizerConfig(BaseConfig):
class TrainingConfig(BaseConfig):
batch_size: int = 32
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
def load_train_config(
path: PathLike,

View File

@ -4,15 +4,10 @@ from typing import List, Optional, Sequence, Tuple
import numpy as np
import torch
import xarray as xr
from pydantic import Field
from soundevent import data
from torch.utils.data import Dataset
from batdetect2.configs import BaseConfig
from batdetect2.train.augmentations import (
Augmentation,
AugmentationsConfig,
)
from batdetect2.train.augmentations import Augmentation
from batdetect2.train.types import ClipperProtocol, TrainExample
__all__ = [
@ -20,19 +15,6 @@ __all__ = [
]
class SubclipConfig(BaseConfig):
duration: Optional[float] = None
width: int = 512
random: bool = False
class DatasetConfig(BaseConfig):
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
augmentation: AugmentationsConfig = Field(
default_factory=AugmentationsConfig
)
class LabeledDataset(Dataset):
def __init__(
self,