Change train to use full config

This commit is contained in:
mbsantiago 2025-06-26 16:02:41 -06:00
parent 6d91153a56
commit 587742b41e
5 changed files with 213 additions and 307 deletions

View File

@ -2,13 +2,13 @@ from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.preprocess import preprocess
from batdetect2.cli.train import train
from batdetect2.cli.train import train_detector
__all__ = [
"cli",
"detect",
"data",
"train",
"train_detector",
"preprocess",
]

View File

@ -5,236 +5,53 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
from batdetect2.train import (
FullTrainingConfig,
load_full_training_config,
train,
)
from batdetect2.models import build_model
from batdetect2.models.backbones import load_backbone_config
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import train
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import list_preprocessed_files
__all__ = [
"train_command",
]
DEFAULT_CONFIG_FILE = Path("config.yaml")
@cli.command(name="train")
@click.option(
"--train-examples",
type=click.Path(exists=True),
required=True,
)
@click.option("--val-examples", type=click.Path(exists=True))
@click.option(
"--model-path",
type=click.Path(exists=True),
)
@click.option(
"--train-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--train-config-field",
type=str,
default="train",
)
@click.option(
"--preprocess-config",
type=click.Path(exists=True),
help=(
"Path to the preprocessing configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--preprocess-config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
default="preprocess",
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
default="targets",
)
@click.option(
"--postprocess-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--postprocess-config-field",
type=str,
default="postprocess",
)
@click.option(
"--model-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--model-config-field",
type=str,
default="model",
)
@click.option(
"--train-workers",
type=int,
default=0,
)
@click.option(
"--val-workers",
type=int,
default=0,
)
@click.option("--train-dir", type=click.Path(exists=True), required=True)
@click.option("--val-dir", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int, default=0)
@click.option("--val-workers", type=int, default=0)
def train_command(
train_examples: Path,
val_examples: Optional[Path] = None,
train_dir: Path,
val_dir: Optional[Path] = None,
model_path: Optional[Path] = None,
train_config: Path = DEFAULT_CONFIG_FILE,
train_config_field: str = "train",
preprocess_config: Path = DEFAULT_CONFIG_FILE,
preprocess_config_field: str = "preprocess",
target_config: Path = DEFAULT_CONFIG_FILE,
target_config_field: str = "targets",
postprocess_config: Path = DEFAULT_CONFIG_FILE,
postprocess_config_field: str = "postprocess",
model_config: Path = DEFAULT_CONFIG_FILE,
model_config_field: str = "model",
config: Optional[Path] = None,
config_field: Optional[str] = None,
train_workers: int = 0,
val_workers: int = 0,
):
logger.info("Starting training!")
try:
target_config_loaded = load_target_config(
path=target_config,
field=target_config_field,
)
targets = build_targets(config=target_config_loaded)
logger.debug(
"Loaded targets info from config file {path}", path=target_config
)
except IOError:
logger.debug(
"Could not load target info from config file, using default"
)
targets = build_targets()
try:
preprocess_config_loaded = load_preprocessing_config(
path=preprocess_config,
field=preprocess_config_field,
)
preprocessor = build_preprocessor(preprocess_config_loaded)
logger.debug(
"Loaded preprocessor from config file {path}", path=target_config
conf = (
load_full_training_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
)
except IOError:
logger.debug(
"Could not load preprocessor from config file, using default"
)
preprocessor = build_preprocessor()
try:
model_config_loaded = load_backbone_config(
path=model_config, field=model_config_field
)
model = build_model(
num_classes=len(targets.class_names),
config=model_config_loaded,
)
except IOError:
model = build_model(num_classes=len(targets.class_names))
try:
postprocess_config_loaded = load_postprocess_config(
path=postprocess_config,
field=postprocess_config_field,
)
postprocessor = build_postprocessor(
targets=targets,
config=postprocess_config_loaded,
)
logger.debug(
"Loaded postprocessor from file {path}", path=postprocess_config
)
except IOError:
logger.debug(
"Could not load postprocessor config from file. Using default"
)
postprocessor = build_postprocessor(targets=targets)
try:
train_config_loaded = load_train_config(
path=train_config, field=train_config_field
)
logger.debug(
"Loaded training config from file {path}",
path=train_config,
)
except IOError:
train_config_loaded = TrainingConfig()
logger.debug("Could not load training config from file. Using default")
train_files = list_preprocessed_files(train_examples)
val_files = (
None if val_examples is None else list_preprocessed_files(val_examples)
train_examples = list_preprocessed_files(train_dir)
val_examples = (
list_preprocessed_files(val_dir) if val_dir is not None else None
)
return train(
detector=model,
train_examples=train_files, # type: ignore
val_examples=val_files, # type: ignore
train(
train_examples=train_examples,
val_examples=val_examples,
config=conf,
model_path=model_path,
preprocessor=preprocessor,
postprocessor=postprocessor,
targets=targets,
config=train_config_loaded,
callbacks=[
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names,
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
],
train_workers=train_workers,
val_workers=val_workers,
)

View File

@ -15,7 +15,7 @@ from batdetect2.train.augmentations import (
)
from batdetect2.train.clips import build_clipper, select_subclip
from batdetect2.train.config import (
TrainerConfig,
PLTrainerConfig,
TrainingConfig,
load_train_config,
)
@ -39,8 +39,14 @@ from batdetect2.train.preprocess import (
preprocess_annotations,
)
from batdetect2.train.train import (
FullTrainingConfig,
build_train_dataset,
build_train_loader,
build_trainer,
build_training_module,
build_val_dataset,
build_val_loader,
load_full_training_config,
train,
)
@ -50,14 +56,15 @@ __all__ = [
"DetectionLossConfig",
"EchoAugmentationConfig",
"FrequencyMaskAugmentationConfig",
"FullTrainingConfig",
"LabeledDataset",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
"RandomExampleSource",
"SizeLossConfig",
"TimeMaskAugmentationConfig",
"TrainExample",
"TrainerConfig",
"TrainingConfig",
"VolumeAugmentationConfig",
"WarpAugmentationConfig",
@ -67,9 +74,14 @@ __all__ = [
"build_clipper",
"build_loss",
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_training_module",
"build_val_dataset",
"build_val_loader",
"generate_train_example",
"list_preprocessed_files",
"load_full_training_config",
"load_label_config",
"load_train_config",
"mask_frequency",
@ -79,6 +91,5 @@ __all__ = [
"scale_volume",
"select_subclip",
"train",
"train",
"warp_spectrogram",
]

View File

@ -13,18 +13,12 @@ from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
from batdetect2.train.losses import LossConfig
__all__ = [
"OptimizerConfig",
"TrainingConfig",
"load_train_config",
]
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainerConfig(BaseConfig):
class PLTrainerConfig(BaseConfig):
accelerator: str = "auto"
accumulate_grad_batches: int = 1
deterministic: bool = True
@ -45,15 +39,16 @@ class TrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None
class TrainingConfig(BaseConfig):
class TrainingConfig(PLTrainerConfig):
batch_size: int = 8
learning_rate: float = 1e-3
t_max: int = 100
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)

View File

@ -1,20 +1,28 @@
from collections.abc import Sequence
from typing import List, Optional
from lightning import Trainer
from lightning.pytorch.callbacks import Callback
from pydantic import Field
from soundevent import data
from torch.utils.data import DataLoader
from batdetect2.models.types import DetectionModel
from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import build_targets
from batdetect2.targets.types import TargetProtocol
from batdetect2.train.augmentations import (
build_augmentations,
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import BackboneConfig, build_model
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
from batdetect2.preprocess import (
PreprocessingConfig,
PreprocessorProtocol,
build_preprocessor,
)
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import (
@ -27,93 +35,70 @@ from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
__all__ = [
"train",
"build_val_dataset",
"FullTrainingConfig",
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_training_module",
"build_val_dataset",
"build_val_loader",
"load_full_training_config",
"train",
]
class FullTrainingConfig(BaseConfig):
"""Full training configuration."""
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
def load_full_training_config(
path: data.PathLike,
field: Optional[str] = None,
) -> FullTrainingConfig:
"""Load the full training configuration."""
return load_config(path, schema=FullTrainingConfig, field=field)
def train(
detector: DetectionModel,
train_examples: List[data.PathLike],
targets: Optional[TargetProtocol] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
postprocessor: Optional[PostprocessorProtocol] = None,
val_examples: Optional[List[data.PathLike]] = None,
config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None,
train_examples: Sequence[data.PathLike],
val_examples: Optional[Sequence[data.PathLike]] = None,
config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None,
train_workers: int = 0,
val_workers: int = 0,
**trainer_kwargs,
) -> None:
config = config or TrainingConfig()
):
conf = config or FullTrainingConfig()
if model_path is None:
if preprocessor is None:
preprocessor = build_preprocessor()
if targets is None:
targets = build_targets()
if postprocessor is None:
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
loss = build_loss(config.loss)
module = TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=config.optimizer.learning_rate,
t_max=config.optimizer.t_max,
)
else:
if model_path is not None:
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(conf)
train_dataset = build_train_dataset(
trainer = build_trainer(conf, targets=module.targets)
train_dataloader = build_train_loader(
train_examples,
preprocessor=module.preprocessor,
config=config,
)
logger = build_logger(config.logger)
if logger and hasattr(logger, 'log_hyperparams'):
logger.log_hyperparams(config.model_dump(exclude_none=True))
trainer = Trainer(
**config.trainer.model_dump(exclude_none=True, exclude={"logger"}),
callbacks=callbacks,
logger=logger,
**trainer_kwargs,
)
train_dataloader = DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
config=conf.train,
num_workers=train_workers,
collate_fn=collate_fn,
)
val_dataloader = None
if val_examples:
val_dataset = build_val_dataset(
val_dataloader = (
build_val_loader(
val_examples,
config=config,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
config=conf.train,
num_workers=val_workers,
collate_fn=collate_fn,
)
if val_examples is not None
else None
)
trainer.fit(
@ -123,8 +108,106 @@ def train(
)
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
preprocessor = build_preprocessor(conf.preprocess)
targets = build_targets(conf.targets)
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
model = build_model(
num_classes=len(targets.class_names),
config=conf.model,
)
loss = build_loss(conf.train.loss)
return TrainingModule(
detector=model,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=conf.train.learning_rate,
t_max=conf.train.t_max,
)
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
]
def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
) -> Trainer:
logger = build_logger(conf.train.logger)
if logger and hasattr(logger, "log_hyperparams"):
logger.log_hyperparams(conf.model_dump(exclude_none=True))
return Trainer(
accelerator=conf.train.accelerator,
logger=logger,
callbacks=build_trainer_callbacks(targets),
)
def build_train_loader(
train_examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol,
config: TrainingConfig,
num_workers: Optional[int] = None,
) -> DataLoader:
train_dataset = build_train_dataset(
train_examples,
preprocessor=preprocessor,
config=config,
)
return DataLoader(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=num_workers or 0,
collate_fn=collate_fn,
)
def build_val_loader(
val_examples: Sequence[data.PathLike],
config: TrainingConfig,
num_workers: Optional[int] = None,
):
val_dataset = build_val_dataset(
val_examples,
config=config,
)
return DataLoader(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=num_workers or 0,
collate_fn=collate_fn,
)
def build_train_dataset(
examples: List[data.PathLike],
examples: Sequence[data.PathLike],
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
) -> LabeledDataset:
@ -133,7 +216,7 @@ def build_train_dataset(
clipper = build_clipper(config.cliping, random=True)
random_example_source = RandomExampleSource(
examples,
list(examples),
clipper=clipper,
)
@ -151,7 +234,7 @@ def build_train_dataset(
def build_val_dataset(
examples: List[data.PathLike],
examples: Sequence[data.PathLike],
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset: