mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-30 07:02:01 +02:00
154 lines
4.2 KiB
Python
154 lines
4.2 KiB
Python
from typing import List, Optional
|
|
|
|
from lightning import Trainer
|
|
from lightning.pytorch.callbacks import Callback
|
|
from soundevent import data
|
|
from torch.utils.data import DataLoader
|
|
|
|
from batdetect2.models.types import DetectionModel
|
|
from batdetect2.postprocess import build_postprocessor
|
|
from batdetect2.postprocess.types import PostprocessorProtocol
|
|
from batdetect2.preprocess import build_preprocessor
|
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
|
from batdetect2.targets import build_targets
|
|
from batdetect2.targets.types import TargetProtocol
|
|
from batdetect2.train.augmentations import (
|
|
build_augmentations,
|
|
)
|
|
from batdetect2.train.clips import build_clipper
|
|
from batdetect2.train.config import TrainingConfig
|
|
from batdetect2.train.dataset import (
|
|
LabeledDataset,
|
|
RandomExampleSource,
|
|
collate_fn,
|
|
)
|
|
from batdetect2.train.lightning import TrainingModule
|
|
from batdetect2.train.losses import build_loss
|
|
|
|
__all__ = [
|
|
"train",
|
|
"build_val_dataset",
|
|
"build_train_dataset",
|
|
]
|
|
|
|
|
|
def train(
|
|
detector: DetectionModel,
|
|
train_examples: List[data.PathLike],
|
|
targets: Optional[TargetProtocol] = None,
|
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
postprocessor: Optional[PostprocessorProtocol] = None,
|
|
val_examples: Optional[List[data.PathLike]] = None,
|
|
config: Optional[TrainingConfig] = None,
|
|
callbacks: Optional[List[Callback]] = None,
|
|
model_path: Optional[data.PathLike] = None,
|
|
train_workers: int = 0,
|
|
val_workers: int = 0,
|
|
**trainer_kwargs,
|
|
) -> None:
|
|
config = config or TrainingConfig()
|
|
if model_path is None:
|
|
if preprocessor is None:
|
|
preprocessor = build_preprocessor()
|
|
|
|
if targets is None:
|
|
targets = build_targets()
|
|
|
|
if postprocessor is None:
|
|
postprocessor = build_postprocessor(
|
|
targets,
|
|
min_freq=preprocessor.min_freq,
|
|
max_freq=preprocessor.max_freq,
|
|
)
|
|
|
|
loss = build_loss(config.loss)
|
|
|
|
module = TrainingModule(
|
|
detector=detector,
|
|
loss=loss,
|
|
targets=targets,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
learning_rate=config.optimizer.learning_rate,
|
|
t_max=config.optimizer.t_max,
|
|
)
|
|
else:
|
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
|
|
|
train_dataset = build_train_dataset(
|
|
train_examples,
|
|
preprocessor=module.preprocessor,
|
|
config=config,
|
|
)
|
|
|
|
trainer = Trainer(
|
|
**config.trainer.model_dump(exclude_none=True),
|
|
callbacks=callbacks,
|
|
**trainer_kwargs,
|
|
)
|
|
|
|
train_dataloader = DataLoader(
|
|
train_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=True,
|
|
num_workers=train_workers,
|
|
collate_fn=collate_fn,
|
|
)
|
|
|
|
val_dataloader = None
|
|
if val_examples:
|
|
val_dataset = build_val_dataset(
|
|
val_examples,
|
|
config=config,
|
|
)
|
|
val_dataloader = DataLoader(
|
|
val_dataset,
|
|
batch_size=config.batch_size,
|
|
shuffle=False,
|
|
num_workers=val_workers,
|
|
collate_fn=collate_fn,
|
|
)
|
|
|
|
trainer.fit(
|
|
module,
|
|
train_dataloaders=train_dataloader,
|
|
val_dataloaders=val_dataloader,
|
|
)
|
|
|
|
|
|
def build_train_dataset(
|
|
examples: List[data.PathLike],
|
|
preprocessor: PreprocessorProtocol,
|
|
config: Optional[TrainingConfig] = None,
|
|
) -> LabeledDataset:
|
|
config = config or TrainingConfig()
|
|
|
|
clipper = build_clipper(config.cliping, random=True)
|
|
|
|
random_example_source = RandomExampleSource(
|
|
examples,
|
|
clipper=clipper,
|
|
)
|
|
|
|
augmentations = build_augmentations(
|
|
preprocessor,
|
|
config=config.augmentations,
|
|
example_source=random_example_source,
|
|
)
|
|
|
|
return LabeledDataset(
|
|
examples,
|
|
clipper=clipper,
|
|
augmentation=augmentations,
|
|
)
|
|
|
|
|
|
def build_val_dataset(
|
|
examples: List[data.PathLike],
|
|
config: Optional[TrainingConfig] = None,
|
|
train: bool = True,
|
|
) -> LabeledDataset:
|
|
config = config or TrainingConfig()
|
|
clipper = build_clipper(config.cliping, random=train)
|
|
return LabeledDataset(examples, clipper=clipper)
|