mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Wrote the main train function
This commit is contained in:
parent
4b4c3ecdf5
commit
899f74efd5
@ -69,12 +69,15 @@ class Clipper(ClipperProtocol):
|
||||
)
|
||||
|
||||
|
||||
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
|
||||
def build_clipper(
|
||||
config: Optional[ClipingConfig] = None,
|
||||
random: Optional[bool] = None,
|
||||
) -> ClipperProtocol:
|
||||
config = config or ClipingConfig()
|
||||
return Clipper(
|
||||
duration=config.duration,
|
||||
max_empty=config.max_empty,
|
||||
random=config.random,
|
||||
random=config.random if random else False,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.augmentations import (
|
||||
@ -23,8 +23,29 @@ class OptimizerConfig(BaseConfig):
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainerConfig(BaseConfig):
|
||||
accelerator: str = "auto"
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
check_val_every_n_epoch: int = 1
|
||||
devices: Union[str, int] = "auto"
|
||||
enable_checkpointing: bool = True
|
||||
gradient_clip_val: Optional[float] = None
|
||||
limit_train_batches: Optional[Union[int, float]] = None
|
||||
limit_test_batches: Optional[Union[int, float]] = None
|
||||
limit_val_batches: Optional[Union[int, float]] = None
|
||||
log_every_n_steps: Optional[int] = None
|
||||
max_epochs: Optional[int] = 200
|
||||
min_epochs: Optional[int] = None
|
||||
max_steps: Optional[int] = None
|
||||
min_steps: Optional[int] = None
|
||||
max_time: Optional[str] = None
|
||||
precision: Optional[str] = None
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
batch_size: int = 32
|
||||
batch_size: int = 8
|
||||
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
|
||||
@ -36,9 +57,11 @@ class TrainingConfig(BaseConfig):
|
||||
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
|
||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
path: PathLike,
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> TrainingConfig:
|
||||
return load_config(path, schema=TrainingConfig, field=field)
|
||||
|
@ -81,7 +81,10 @@ class LabeledDataset(Dataset):
|
||||
array: xr.DataArray,
|
||||
dtype=np.float32,
|
||||
) -> torch.Tensor:
|
||||
return torch.tensor(array.values.astype(dtype))
|
||||
return torch.nan_to_num(
|
||||
torch.tensor(array.values.astype(dtype)),
|
||||
nan=0,
|
||||
)
|
||||
|
||||
|
||||
def list_preprocessed_files(
|
||||
@ -91,7 +94,11 @@ def list_preprocessed_files(
|
||||
|
||||
|
||||
class RandomExampleSource:
|
||||
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
|
||||
def __init__(
|
||||
self,
|
||||
filenames: List[data.PathLike],
|
||||
clipper: ClipperProtocol,
|
||||
):
|
||||
self.filenames = filenames
|
||||
self.clipper = clipper
|
||||
|
||||
|
@ -40,7 +40,9 @@ class TrainingModule(L.LightningModule):
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
self.save_hyperparameters()
|
||||
# NOTE: Ignore detector and loss from hyperparameter saving
|
||||
# as they are nn.Module and should be saved regardless.
|
||||
self.save_hyperparameters(ignore=["detector", "loss"])
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.detector(spec)
|
||||
|
@ -1,68 +1,112 @@
|
||||
from typing import Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
from lightning import LightningModule
|
||||
from lightning.pytorch import Trainer
|
||||
from soundevent.data import PathLike
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.train.dataset import LabeledDataset
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.augmentations import (
|
||||
build_augmentations,
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"TrainerConfig",
|
||||
"load_trainer_config",
|
||||
]
|
||||
|
||||
|
||||
class TrainerConfig(BaseConfig):
|
||||
accelerator: str = "auto"
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
check_val_every_n_epoch: int = 1
|
||||
devices: Union[str, int] = "auto"
|
||||
enable_checkpointing: bool = True
|
||||
gradient_clip_val: Optional[float] = None
|
||||
limit_train_batches: Optional[Union[int, float]] = None
|
||||
limit_test_batches: Optional[Union[int, float]] = None
|
||||
limit_val_batches: Optional[Union[int, float]] = None
|
||||
log_every_n_steps: Optional[int] = None
|
||||
max_epochs: Optional[int] = None
|
||||
min_epochs: Optional[int] = 100
|
||||
max_steps: Optional[int] = None
|
||||
min_steps: Optional[int] = None
|
||||
max_time: Optional[str] = None
|
||||
precision: Optional[str] = None
|
||||
reload_dataloaders_every_n_epochs: Optional[int] = None
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
||||
return load_config(path, schema=TrainerConfig, field=field)
|
||||
|
||||
|
||||
def train(
|
||||
module: LightningModule,
|
||||
train_dataset: LabeledDataset,
|
||||
trainer_config: Optional[TrainerConfig] = None,
|
||||
dev_run: bool = False,
|
||||
overfit_batches: bool = False,
|
||||
profiler: Optional[str] = None,
|
||||
):
|
||||
trainer_config = trainer_config or TrainerConfig()
|
||||
trainer = Trainer(
|
||||
**trainer_config.model_dump(
|
||||
exclude_unset=True,
|
||||
exclude_none=True,
|
||||
),
|
||||
fast_dev_run=dev_run,
|
||||
overfit_batches=overfit_batches,
|
||||
profiler=profiler,
|
||||
detector: DetectionModel,
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
train_examples: List[data.PathLike],
|
||||
val_examples: Optional[List[data.PathLike]] = None,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
) -> None:
|
||||
config = config or TrainingConfig()
|
||||
|
||||
train_dataset = build_dataset(
|
||||
train_examples,
|
||||
preprocessor,
|
||||
config=config,
|
||||
train=True,
|
||||
)
|
||||
train_loader = DataLoader(
|
||||
|
||||
loss = build_loss(config.loss)
|
||||
|
||||
module = TrainingModule(
|
||||
detector=detector,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
learning_rate=config.optimizer.learning_rate,
|
||||
t_max=config.optimizer.t_max,
|
||||
)
|
||||
|
||||
trainer = Trainer(**config.trainer.model_dump())
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=module.config.train.batch_size,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=7,
|
||||
)
|
||||
trainer.fit(module, train_dataloaders=train_loader)
|
||||
|
||||
val_dataloader = None
|
||||
if val_examples:
|
||||
val_dataset = build_dataset(
|
||||
val_examples,
|
||||
preprocessor,
|
||||
config=config,
|
||||
train=False,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
module,
|
||||
train_dataloaders=train_dataloader,
|
||||
val_dataloaders=val_dataloader,
|
||||
)
|
||||
|
||||
|
||||
def build_dataset(
|
||||
examples: List[data.PathLike],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
train: bool = True,
|
||||
):
|
||||
config = config or TrainingConfig()
|
||||
|
||||
clipper = build_clipper(config.cliping, random=train)
|
||||
|
||||
augmentations = None
|
||||
|
||||
if train:
|
||||
random_example_source = RandomExampleSource(
|
||||
examples,
|
||||
clipper=clipper,
|
||||
)
|
||||
|
||||
augmentations = build_augmentations(
|
||||
preprocessor,
|
||||
config=config.augmentations,
|
||||
example_source=random_example_source,
|
||||
)
|
||||
|
||||
return LabeledDataset(
|
||||
examples,
|
||||
clipper=clipper,
|
||||
augmentation=augmentations,
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user