Add checkpoint config

This commit is contained in:
mbsantiago 2025-10-04 15:08:34 +01:00
parent 2d796394f6
commit 160cb6ae30
3 changed files with 66 additions and 47 deletions

View File

@ -0,0 +1,47 @@
from pathlib import Path
from typing import Optional
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from batdetect2.core import BaseConfig
__all__ = [
"CheckpointConfig",
"build_checkpoint_callback",
]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
class CheckpointConfig(BaseConfig):
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
monitor: str = "classification/mean_average_precision"
mode: str = "max"
save_top_k: int = 1
filename: Optional[str] = None
def build_checkpoint_callback(
config: Optional[CheckpointConfig] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Callback:
config = config or CheckpointConfig()
if checkpoint_dir is None:
checkpoint_dir = config.checkpoint_dir
if experiment_name is not None:
checkpoint_dir = checkpoint_dir / experiment_name
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
return ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=config.save_top_k,
monitor=config.monitor,
mode=config.mode,
filename=config.filename,
)

View File

@ -6,6 +6,7 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
from batdetect2.train.checkpoints import CheckpointConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig from batdetect2.train.labels import LabelConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
@ -51,6 +52,7 @@ class TrainingConfig(BaseConfig):
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig) logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig) labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig) validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
def load_train_config( def load_train_config(

View File

@ -1,9 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from lightning import Trainer, seed_everything from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
@ -13,6 +12,7 @@ from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.checkpoints import build_checkpoint_callback
from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
@ -32,8 +32,6 @@ __all__ = [
"train", "train",
] ]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train( def train(
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
@ -104,7 +102,6 @@ def train(
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
config, config,
targets=targets,
evaluator=build_evaluator( evaluator=build_evaluator(
config.train.validation, config.train.validation,
targets=targets, targets=targets,
@ -124,58 +121,29 @@ def train(
logger.info("Training complete.") logger.info("Training complete.")
def build_trainer_callbacks(
targets: "TargetProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> List[Callback]:
if checkpoint_dir is None:
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
if experiment_name is not None:
checkpoint_dir = checkpoint_dir / experiment_name
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = evaluator or build_evaluator(targets=targets)
return [
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="classification/mean_average_precision",
mode="max",
),
ValidationMetrics(evaluator),
]
def build_trainer( def build_trainer(
conf: "BatDetect2Config", config: "BatDetect2Config",
targets: "TargetProtocol", evaluator: "EvaluatorProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = conf.train.trainer trainer_conf = config.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
train_logger = build_logger( train_logger = build_logger(
conf.train.logger, config.train.logger,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
) )
train_logger.log_hyperparams( train_logger.log_hyperparams(
conf.model_dump( config.model_dump(
mode="json", mode="json",
exclude_none=True, exclude_none=True,
) )
@ -184,11 +152,13 @@ def build_trainer(
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
logger=train_logger, logger=train_logger,
callbacks=build_trainer_callbacks( callbacks=[
targets, build_checkpoint_callback(
evaluator=evaluator, config=config.train.checkpoints,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
), ),
ValidationMetrics(evaluator),
],
) )