mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Wrote the main train function
This commit is contained in:
parent
4b4c3ecdf5
commit
899f74efd5
@ -69,12 +69,15 @@ class Clipper(ClipperProtocol):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_clipper(config: Optional[ClipingConfig] = None) -> ClipperProtocol:
|
def build_clipper(
|
||||||
|
config: Optional[ClipingConfig] = None,
|
||||||
|
random: Optional[bool] = None,
|
||||||
|
) -> ClipperProtocol:
|
||||||
config = config or ClipingConfig()
|
config = config or ClipingConfig()
|
||||||
return Clipper(
|
return Clipper(
|
||||||
duration=config.duration,
|
duration=config.duration,
|
||||||
max_empty=config.max_empty,
|
max_empty=config.max_empty,
|
||||||
random=config.random,
|
random=config.random if random else False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
@ -23,8 +23,29 @@ class OptimizerConfig(BaseConfig):
|
|||||||
t_max: int = 100
|
t_max: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerConfig(BaseConfig):
|
||||||
|
accelerator: str = "auto"
|
||||||
|
accumulate_grad_batches: int = 1
|
||||||
|
deterministic: bool = True
|
||||||
|
check_val_every_n_epoch: int = 1
|
||||||
|
devices: Union[str, int] = "auto"
|
||||||
|
enable_checkpointing: bool = True
|
||||||
|
gradient_clip_val: Optional[float] = None
|
||||||
|
limit_train_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_test_batches: Optional[Union[int, float]] = None
|
||||||
|
limit_val_batches: Optional[Union[int, float]] = None
|
||||||
|
log_every_n_steps: Optional[int] = None
|
||||||
|
max_epochs: Optional[int] = 200
|
||||||
|
min_epochs: Optional[int] = None
|
||||||
|
max_steps: Optional[int] = None
|
||||||
|
min_steps: Optional[int] = None
|
||||||
|
max_time: Optional[str] = None
|
||||||
|
precision: Optional[str] = None
|
||||||
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(BaseConfig):
|
||||||
batch_size: int = 32
|
batch_size: int = 8
|
||||||
|
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
|
|
||||||
@ -36,9 +57,11 @@ class TrainingConfig(BaseConfig):
|
|||||||
|
|
||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
|
|
||||||
|
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
path: PathLike,
|
path: data.PathLike,
|
||||||
field: Optional[str] = None,
|
field: Optional[str] = None,
|
||||||
) -> TrainingConfig:
|
) -> TrainingConfig:
|
||||||
return load_config(path, schema=TrainingConfig, field=field)
|
return load_config(path, schema=TrainingConfig, field=field)
|
||||||
|
@ -81,7 +81,10 @@ class LabeledDataset(Dataset):
|
|||||||
array: xr.DataArray,
|
array: xr.DataArray,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return torch.tensor(array.values.astype(dtype))
|
return torch.nan_to_num(
|
||||||
|
torch.tensor(array.values.astype(dtype)),
|
||||||
|
nan=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def list_preprocessed_files(
|
def list_preprocessed_files(
|
||||||
@ -91,7 +94,11 @@ def list_preprocessed_files(
|
|||||||
|
|
||||||
|
|
||||||
class RandomExampleSource:
|
class RandomExampleSource:
|
||||||
def __init__(self, filenames: List[str], clipper: ClipperProtocol):
|
def __init__(
|
||||||
|
self,
|
||||||
|
filenames: List[data.PathLike],
|
||||||
|
clipper: ClipperProtocol,
|
||||||
|
):
|
||||||
self.filenames = filenames
|
self.filenames = filenames
|
||||||
self.clipper = clipper
|
self.clipper = clipper
|
||||||
|
|
||||||
|
@ -40,7 +40,9 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
self.save_hyperparameters()
|
# NOTE: Ignore detector and loss from hyperparameter saving
|
||||||
|
# as they are nn.Module and should be saved regardless.
|
||||||
|
self.save_hyperparameters(ignore=["detector", "loss"])
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||||
return self.detector(spec)
|
return self.detector(spec)
|
||||||
|
@ -1,68 +1,112 @@
|
|||||||
from typing import Optional, Union
|
from typing import List, Optional
|
||||||
|
|
||||||
from lightning import LightningModule
|
from lightning import Trainer
|
||||||
from lightning.pytorch import Trainer
|
from soundevent import data
|
||||||
from soundevent.data import PathLike
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.models.types import DetectionModel
|
||||||
from batdetect2.train.dataset import LabeledDataset
|
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
from batdetect2.train.augmentations import (
|
||||||
|
build_augmentations,
|
||||||
|
)
|
||||||
|
from batdetect2.train.clips import build_clipper
|
||||||
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
from batdetect2.train.dataset import LabeledDataset, RandomExampleSource
|
||||||
|
from batdetect2.train.lightning import TrainingModule
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"train",
|
"train",
|
||||||
"TrainerConfig",
|
|
||||||
"load_trainer_config",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TrainerConfig(BaseConfig):
|
|
||||||
accelerator: str = "auto"
|
|
||||||
accumulate_grad_batches: int = 1
|
|
||||||
deterministic: bool = True
|
|
||||||
check_val_every_n_epoch: int = 1
|
|
||||||
devices: Union[str, int] = "auto"
|
|
||||||
enable_checkpointing: bool = True
|
|
||||||
gradient_clip_val: Optional[float] = None
|
|
||||||
limit_train_batches: Optional[Union[int, float]] = None
|
|
||||||
limit_test_batches: Optional[Union[int, float]] = None
|
|
||||||
limit_val_batches: Optional[Union[int, float]] = None
|
|
||||||
log_every_n_steps: Optional[int] = None
|
|
||||||
max_epochs: Optional[int] = None
|
|
||||||
min_epochs: Optional[int] = 100
|
|
||||||
max_steps: Optional[int] = None
|
|
||||||
min_steps: Optional[int] = None
|
|
||||||
max_time: Optional[str] = None
|
|
||||||
precision: Optional[str] = None
|
|
||||||
reload_dataloaders_every_n_epochs: Optional[int] = None
|
|
||||||
val_check_interval: Optional[Union[int, float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
def load_trainer_config(path: PathLike, field: Optional[str] = None):
|
|
||||||
return load_config(path, schema=TrainerConfig, field=field)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
module: LightningModule,
|
detector: DetectionModel,
|
||||||
train_dataset: LabeledDataset,
|
targets: TargetProtocol,
|
||||||
trainer_config: Optional[TrainerConfig] = None,
|
preprocessor: PreprocessorProtocol,
|
||||||
dev_run: bool = False,
|
postprocessor: PostprocessorProtocol,
|
||||||
overfit_batches: bool = False,
|
train_examples: List[data.PathLike],
|
||||||
profiler: Optional[str] = None,
|
val_examples: Optional[List[data.PathLike]] = None,
|
||||||
):
|
config: Optional[TrainingConfig] = None,
|
||||||
trainer_config = trainer_config or TrainerConfig()
|
) -> None:
|
||||||
trainer = Trainer(
|
config = config or TrainingConfig()
|
||||||
**trainer_config.model_dump(
|
|
||||||
exclude_unset=True,
|
train_dataset = build_dataset(
|
||||||
exclude_none=True,
|
train_examples,
|
||||||
),
|
preprocessor,
|
||||||
fast_dev_run=dev_run,
|
config=config,
|
||||||
overfit_batches=overfit_batches,
|
train=True,
|
||||||
profiler=profiler,
|
|
||||||
)
|
)
|
||||||
train_loader = DataLoader(
|
|
||||||
|
loss = build_loss(config.loss)
|
||||||
|
|
||||||
|
module = TrainingModule(
|
||||||
|
detector=detector,
|
||||||
|
loss=loss,
|
||||||
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
learning_rate=config.optimizer.learning_rate,
|
||||||
|
t_max=config.optimizer.t_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer = Trainer(**config.trainer.model_dump())
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=module.config.train.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
num_workers=7,
|
|
||||||
)
|
)
|
||||||
trainer.fit(module, train_dataloaders=train_loader)
|
|
||||||
|
val_dataloader = None
|
||||||
|
if val_examples:
|
||||||
|
val_dataset = build_dataset(
|
||||||
|
val_examples,
|
||||||
|
preprocessor,
|
||||||
|
config=config,
|
||||||
|
train=False,
|
||||||
|
)
|
||||||
|
val_dataloader = DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer.fit(
|
||||||
|
module,
|
||||||
|
train_dataloaders=train_dataloader,
|
||||||
|
val_dataloaders=val_dataloader,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_dataset(
|
||||||
|
examples: List[data.PathLike],
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
config: Optional[TrainingConfig] = None,
|
||||||
|
train: bool = True,
|
||||||
|
):
|
||||||
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
|
clipper = build_clipper(config.cliping, random=train)
|
||||||
|
|
||||||
|
augmentations = None
|
||||||
|
|
||||||
|
if train:
|
||||||
|
random_example_source = RandomExampleSource(
|
||||||
|
examples,
|
||||||
|
clipper=clipper,
|
||||||
|
)
|
||||||
|
|
||||||
|
augmentations = build_augmentations(
|
||||||
|
preprocessor,
|
||||||
|
config=config.augmentations,
|
||||||
|
example_source=random_example_source,
|
||||||
|
)
|
||||||
|
|
||||||
|
return LabeledDataset(
|
||||||
|
examples,
|
||||||
|
clipper=clipper,
|
||||||
|
augmentation=augmentations,
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user