mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Refine training config
This commit is contained in:
parent
d51e3f8bbd
commit
7dd35d6e3e
@ -11,7 +11,7 @@ DEFAULT_TRAIN_CLIP_DURATION = 0.513
|
||||
DEFAULT_MAX_EMPTY_CLIP = 0.1
|
||||
|
||||
|
||||
class ClipperConfig(BaseConfig):
|
||||
class ClipingConfig(BaseConfig):
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
random: bool = True
|
||||
max_empty: float = DEFAULT_MAX_EMPTY_CLIP
|
||||
@ -69,8 +69,8 @@ class Clipper(ClipperProtocol):
|
||||
)
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipperConfig] = None) -> ClipperProtocol:
|
||||
config = config or ClipperConfig()
|
||||
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
|
||||
config = config or ClipingConfig()
|
||||
return Clipper(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
|
@ -4,6 +4,11 @@ from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
)
|
||||
from batdetect2.train.clips import ClipingConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
@ -20,9 +25,17 @@ class OptimizerConfig(BaseConfig):
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
batch_size: int = 32
|
||||
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||
)
|
||||
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
path: PathLike,
|
||||
|
@ -4,15 +4,10 @@ from typing import List, Optional, Sequence, Tuple
|
||||
import numpy as np
|
||||
import torch
|
||||
import xarray as xr
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
Augmentation,
|
||||
AugmentationsConfig,
|
||||
)
|
||||
from batdetect2.train.augmentations import Augmentation
|
||||
from batdetect2.train.types import ClipperProtocol, TrainExample
|
||||
|
||||
__all__ = [
|
||||
@ -20,19 +15,6 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class SubclipConfig(BaseConfig):
|
||||
duration: Optional[float] = None
|
||||
width: int = 512
|
||||
random: bool = False
|
||||
|
||||
|
||||
class DatasetConfig(BaseConfig):
|
||||
subclip: SubclipConfig = Field(default_factory=SubclipConfig)
|
||||
augmentation: AugmentationsConfig = Field(
|
||||
default_factory=AugmentationsConfig
|
||||
)
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user