Change train to use full config

This commit is contained in:
mbsantiago 2025-06-26 16:02:41 -06:00
parent 6d91153a56
commit 587742b41e
5 changed files with 213 additions and 307 deletions

View File

@ -2,13 +2,13 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect from batdetect2.cli.compat import detect
from batdetect2.cli.data import data from batdetect2.cli.data import data
from batdetect2.cli.preprocess import preprocess from batdetect2.cli.preprocess import preprocess
from batdetect2.cli.train import train from batdetect2.cli.train import train_detector
__all__ = [ __all__ = [
"cli", "cli",
"detect", "detect",
"data", "data",
"train", "train_detector",
"preprocess", "preprocess",
] ]

View File

@ -5,236 +5,53 @@ import click
from loguru import logger from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.evaluate.metrics import ( from batdetect2.train import (
ClassificationAccuracy, FullTrainingConfig,
ClassificationMeanAveragePrecision, load_full_training_config,
DetectionAveragePrecision, train,
) )
from batdetect2.models import build_model
from batdetect2.models.backbones import load_backbone_config
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import train
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import list_preprocessed_files from batdetect2.train.dataset import list_preprocessed_files
__all__ = [ __all__ = [
"train_command", "train_command",
] ]
DEFAULT_CONFIG_FILE = Path("config.yaml")
@cli.command(name="train") @cli.command(name="train")
@click.option( @click.option("--train-dir", type=click.Path(exists=True), required=True)
"--train-examples", @click.option("--val-dir", type=click.Path(exists=True))
type=click.Path(exists=True), @click.option("--model-path", type=click.Path(exists=True))
required=True, @click.option("--config", type=click.Path(exists=True))
) @click.option("--config-field", type=str)
@click.option("--val-examples", type=click.Path(exists=True)) @click.option("--train-workers", type=int, default=0)
@click.option( @click.option("--val-workers", type=int, default=0)
"--model-path",
type=click.Path(exists=True),
)
@click.option(
"--train-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--train-config-field",
type=str,
default="train",
)
@click.option(
"--preprocess-config",
type=click.Path(exists=True),
help=(
"Path to the preprocessing configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--preprocess-config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
default="preprocess",
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
default="targets",
)
@click.option(
"--postprocess-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--postprocess-config-field",
type=str,
default="postprocess",
)
@click.option(
"--model-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--model-config-field",
type=str,
default="model",
)
@click.option(
"--train-workers",
type=int,
default=0,
)
@click.option(
"--val-workers",
type=int,
default=0,
)
def train_command( def train_command(
train_examples: Path, train_dir: Path,
val_examples: Optional[Path] = None, val_dir: Optional[Path] = None,
model_path: Optional[Path] = None, model_path: Optional[Path] = None,
train_config: Path = DEFAULT_CONFIG_FILE, config: Optional[Path] = None,
train_config_field: str = "train", config_field: Optional[str] = None,
preprocess_config: Path = DEFAULT_CONFIG_FILE,
preprocess_config_field: str = "preprocess",
target_config: Path = DEFAULT_CONFIG_FILE,
target_config_field: str = "targets",
postprocess_config: Path = DEFAULT_CONFIG_FILE,
postprocess_config_field: str = "postprocess",
model_config: Path = DEFAULT_CONFIG_FILE,
model_config_field: str = "model",
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
): ):
logger.info("Starting training!") logger.info("Starting training!")
try: conf = (
target_config_loaded = load_target_config( load_full_training_config(config, field=config_field)
path=target_config, if config is not None
field=target_config_field, else FullTrainingConfig()
)
targets = build_targets(config=target_config_loaded)
logger.debug(
"Loaded targets info from config file {path}", path=target_config
)
except IOError:
logger.debug(
"Could not load target info from config file, using default"
)
targets = build_targets()
try:
preprocess_config_loaded = load_preprocessing_config(
path=preprocess_config,
field=preprocess_config_field,
)
preprocessor = build_preprocessor(preprocess_config_loaded)
logger.debug(
"Loaded preprocessor from config file {path}", path=target_config
)
except IOError:
logger.debug(
"Could not load preprocessor from config file, using default"
)
preprocessor = build_preprocessor()
try:
model_config_loaded = load_backbone_config(
path=model_config, field=model_config_field
)
model = build_model(
num_classes=len(targets.class_names),
config=model_config_loaded,
)
except IOError:
model = build_model(num_classes=len(targets.class_names))
try:
postprocess_config_loaded = load_postprocess_config(
path=postprocess_config,
field=postprocess_config_field,
)
postprocessor = build_postprocessor(
targets=targets,
config=postprocess_config_loaded,
)
logger.debug(
"Loaded postprocessor from file {path}", path=postprocess_config
)
except IOError:
logger.debug(
"Could not load postprocessor config from file. Using default"
)
postprocessor = build_postprocessor(targets=targets)
try:
train_config_loaded = load_train_config(
path=train_config, field=train_config_field
)
logger.debug(
"Loaded training config from file {path}",
path=train_config,
)
except IOError:
train_config_loaded = TrainingConfig()
logger.debug("Could not load training config from file. Using default")
train_files = list_preprocessed_files(train_examples)
val_files = (
None if val_examples is None else list_preprocessed_files(val_examples)
) )
return train( train_examples = list_preprocessed_files(train_dir)
detector=model, val_examples = (
train_examples=train_files, # type: ignore list_preprocessed_files(val_dir) if val_dir is not None else None
val_examples=val_files, # type: ignore )
train(
train_examples=train_examples,
val_examples=val_examples,
config=conf,
model_path=model_path, model_path=model_path,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=targets,
config=train_config_loaded,
callbacks=[
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names,
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
],
train_workers=train_workers, train_workers=train_workers,
val_workers=val_workers, val_workers=val_workers,
) )

View File

@ -15,7 +15,7 @@ from batdetect2.train.augmentations import (
) )
from batdetect2.train.clips import build_clipper, select_subclip from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import ( from batdetect2.train.config import (
TrainerConfig, PLTrainerConfig,
TrainingConfig, TrainingConfig,
load_train_config, load_train_config,
) )
@ -39,8 +39,14 @@ from batdetect2.train.preprocess import (
preprocess_annotations, preprocess_annotations,
) )
from batdetect2.train.train import ( from batdetect2.train.train import (
FullTrainingConfig,
build_train_dataset, build_train_dataset,
build_train_loader,
build_trainer,
build_training_module,
build_val_dataset, build_val_dataset,
build_val_loader,
load_full_training_config,
train, train,
) )
@ -50,14 +56,15 @@ __all__ = [
"DetectionLossConfig", "DetectionLossConfig",
"EchoAugmentationConfig", "EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig", "FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"LabeledDataset", "LabeledDataset",
"LossConfig", "LossConfig",
"LossFunction", "LossFunction",
"PLTrainerConfig",
"RandomExampleSource", "RandomExampleSource",
"SizeLossConfig", "SizeLossConfig",
"TimeMaskAugmentationConfig", "TimeMaskAugmentationConfig",
"TrainExample", "TrainExample",
"TrainerConfig",
"TrainingConfig", "TrainingConfig",
"VolumeAugmentationConfig", "VolumeAugmentationConfig",
"WarpAugmentationConfig", "WarpAugmentationConfig",
@ -67,9 +74,14 @@ __all__ = [
"build_clipper", "build_clipper",
"build_loss", "build_loss",
"build_train_dataset", "build_train_dataset",
"build_train_loader",
"build_trainer",
"build_training_module",
"build_val_dataset", "build_val_dataset",
"build_val_loader",
"generate_train_example", "generate_train_example",
"list_preprocessed_files", "list_preprocessed_files",
"load_full_training_config",
"load_label_config", "load_label_config",
"load_train_config", "load_train_config",
"mask_frequency", "mask_frequency",
@ -79,6 +91,5 @@ __all__ = [
"scale_volume", "scale_volume",
"select_subclip", "select_subclip",
"train", "train",
"train",
"warp_spectrogram", "warp_spectrogram",
] ]

View File

@ -13,18 +13,12 @@ from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
__all__ = [ __all__ = [
"OptimizerConfig",
"TrainingConfig", "TrainingConfig",
"load_train_config", "load_train_config",
] ]
class OptimizerConfig(BaseConfig): class PLTrainerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainerConfig(BaseConfig):
accelerator: str = "auto" accelerator: str = "auto"
accumulate_grad_batches: int = 1 accumulate_grad_batches: int = 1
deterministic: bool = True deterministic: bool = True
@ -45,15 +39,16 @@ class TrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None val_check_interval: Optional[Union[int, float]] = None
class TrainingConfig(BaseConfig): class TrainingConfig(PLTrainerConfig):
batch_size: int = 8 batch_size: int = 8
learning_rate: float = 1e-3
t_max: int = 100
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field( augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
) )
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: TrainerConfig = Field(default_factory=TrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)

View File

@ -1,20 +1,28 @@
from collections.abc import Sequence
from typing import List, Optional from typing import List, Optional
from lightning import Trainer from lightning import Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from pydantic import Field
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.models.types import DetectionModel from batdetect2.configs import BaseConfig, load_config
from batdetect2.postprocess import build_postprocessor from batdetect2.evaluate.metrics import (
from batdetect2.postprocess.types import PostprocessorProtocol ClassificationAccuracy,
from batdetect2.preprocess import build_preprocessor ClassificationMeanAveragePrecision,
from batdetect2.preprocess.types import PreprocessorProtocol DetectionAveragePrecision,
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.augmentations import (
build_augmentations,
) )
from batdetect2.models import BackboneConfig, build_model
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import (
PreprocessingConfig,
PreprocessorProtocol,
build_preprocessor,
)
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
@ -27,94 +35,71 @@ from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss from batdetect2.train.losses import build_loss
__all__ = [ __all__ = [
"train", "FullTrainingConfig",
"build_val_dataset",
"build_train_dataset", "build_train_dataset",
"build_train_loader",
"build_trainer",
"build_training_module",
"build_val_dataset",
"build_val_loader",
"load_full_training_config",
"train",
] ]
class FullTrainingConfig(BaseConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)
def train( def train(
detector: DetectionModel, train_examples: Sequence[data.PathLike],
train_examples: List[data.PathLike], val_examples: Optional[Sequence[data.PathLike]] = None,
targets: Optional[TargetProtocol] = None, config: Optional[FullTrainingConfig] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
val_examples: Optional[List[data.PathLike]] = None,
config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None,
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
**trainer_kwargs, ):
) -> None: conf = config or FullTrainingConfig()
config = config or TrainingConfig()
if model_path is None: if model_path is not None:
if preprocessor is None:
preprocessor = build_preprocessor()
if targets is None:
targets = build_targets()
if postprocessor is None:
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
loss = build_loss(config.loss)
module = TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=config.optimizer.learning_rate,
t_max=config.optimizer.t_max,
)
else:
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(conf)
train_dataset = build_train_dataset( trainer = build_trainer(conf, targets=module.targets)
train_dataloader = build_train_loader(
train_examples, train_examples,
preprocessor=module.preprocessor, preprocessor=module.preprocessor,
config=config, config=conf.train,
)
logger = build_logger(config.logger)
if logger and hasattr(logger, 'log_hyperparams'):
logger.log_hyperparams(config.model_dump(exclude_none=True))
trainer = Trainer(
**config.trainer.model_dump(exclude_none=True, exclude={"logger"}),
callbacks=callbacks,
logger=logger,
**trainer_kwargs,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=train_workers, num_workers=train_workers,
collate_fn=collate_fn,
) )
val_dataloader = None val_dataloader = (
if val_examples: build_val_loader(
val_dataset = build_val_dataset(
val_examples, val_examples,
config=config, config=conf.train,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=val_workers, num_workers=val_workers,
collate_fn=collate_fn,
) )
if val_examples is not None
else None
)
trainer.fit( trainer.fit(
module, module,
@ -123,8 +108,106 @@ def train(
) )
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
preprocessor = build_preprocessor(conf.preprocess)
targets = build_targets(conf.targets)
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=conf.model,
)
loss = build_loss(conf.train.loss)
return TrainingModule(
detector=model,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=conf.train.learning_rate,
t_max=conf.train.t_max,
)
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
]
def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
) -> Trainer:
logger = build_logger(conf.train.logger)
if logger and hasattr(logger, "log_hyperparams"):
logger.log_hyperparams(conf.model_dump(exclude_none=True))
return Trainer(
accelerator=conf.train.accelerator,
logger=logger,
callbacks=build_trainer_callbacks(targets),
)
def build_train_loader(
train_examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol,
config: TrainingConfig,
num_workers: Optional[int] = None,
) -> DataLoader:
train_dataset = build_train_dataset(
train_examples,
preprocessor=preprocessor,
config=config,
)
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers or 0,
collate_fn=collate_fn,
)
def build_val_loader(
val_examples: Sequence[data.PathLike],
config: TrainingConfig,
num_workers: Optional[int] = None,
):
val_dataset = build_val_dataset(
val_examples,
config=config,
)
return DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=num_workers or 0,
collate_fn=collate_fn,
)
def build_train_dataset( def build_train_dataset(
examples: List[data.PathLike], examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
) -> LabeledDataset: ) -> LabeledDataset:
@ -133,7 +216,7 @@ def build_train_dataset(
clipper = build_clipper(config.cliping, random=True) clipper = build_clipper(config.cliping, random=True)
random_example_source = RandomExampleSource( random_example_source = RandomExampleSource(
examples, list(examples),
clipper=clipper, clipper=clipper,
) )
@ -151,7 +234,7 @@ def build_train_dataset(
def build_val_dataset( def build_val_dataset(
examples: List[data.PathLike], examples: Sequence[data.PathLike],
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
train: bool = True, train: bool = True,
) -> LabeledDataset: ) -> LabeledDataset: