Compare commits

..

2 Commits

Author SHA1 Message Date
mbsantiago
24dcc47e73 Fix import of DEFAULT_CHECKPOINT_DIR 2025-10-04 15:16:11 +01:00
mbsantiago
160cb6ae30 Add checkpoint config 2025-10-04 15:08:34 +01:00
4 changed files with 68 additions and 48 deletions

View File

@ -1,3 +1,4 @@
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
from batdetect2.train.config import ( from batdetect2.train.config import (
TrainingConfig, TrainingConfig,
load_train_config, load_train_config,
@ -6,7 +7,7 @@ from batdetect2.train.lightning import (
TrainingModule, TrainingModule,
load_model_from_checkpoint, load_model_from_checkpoint,
) )
from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train from batdetect2.train.train import build_trainer, train
__all__ = [ __all__ = [
"DEFAULT_CHECKPOINT_DIR", "DEFAULT_CHECKPOINT_DIR",

View File

@ -0,0 +1,47 @@
from pathlib import Path
from typing import Optional
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from batdetect2.core import BaseConfig
__all__ = [
"CheckpointConfig",
"build_checkpoint_callback",
]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
class CheckpointConfig(BaseConfig):
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
monitor: str = "classification/mean_average_precision"
mode: str = "max"
save_top_k: int = 1
filename: Optional[str] = None
def build_checkpoint_callback(
config: Optional[CheckpointConfig] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Callback:
config = config or CheckpointConfig()
if checkpoint_dir is None:
checkpoint_dir = config.checkpoint_dir
if experiment_name is not None:
checkpoint_dir = checkpoint_dir / experiment_name
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
return ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=config.save_top_k,
monitor=config.monitor,
mode=config.mode,
filename=config.filename,
)

View File

@ -6,6 +6,7 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
from batdetect2.train.checkpoints import CheckpointConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig from batdetect2.train.labels import LabelConfig
from batdetect2.train.losses import LossConfig from batdetect2.train.losses import LossConfig
@ -51,6 +52,7 @@ class TrainingConfig(BaseConfig):
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig) logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig) labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig) validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
def load_train_config( def load_train_config(

View File

@ -1,9 +1,8 @@
from collections.abc import Sequence from collections.abc import Sequence
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from lightning import Trainer, seed_everything from lightning import Trainer, seed_everything
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
@ -13,6 +12,7 @@ from batdetect2.logging import build_logger
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.checkpoints import build_checkpoint_callback
from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module from batdetect2.train.lightning import build_training_module
@ -32,8 +32,6 @@ __all__ = [
"train", "train",
] ]
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train( def train(
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
@ -104,7 +102,6 @@ def train(
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
config, config,
targets=targets,
evaluator=build_evaluator( evaluator=build_evaluator(
config.train.validation, config.train.validation,
targets=targets, targets=targets,
@ -124,58 +121,29 @@ def train(
logger.info("Training complete.") logger.info("Training complete.")
def build_trainer_callbacks(
targets: "TargetProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> List[Callback]:
if checkpoint_dir is None:
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
if experiment_name is not None:
checkpoint_dir = checkpoint_dir / experiment_name
if run_name is not None:
checkpoint_dir = checkpoint_dir / run_name
evaluator = evaluator or build_evaluator(targets=targets)
return [
ModelCheckpoint(
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="classification/mean_average_precision",
mode="max",
),
ValidationMetrics(evaluator),
]
def build_trainer( def build_trainer(
conf: "BatDetect2Config", config: "BatDetect2Config",
targets: "TargetProtocol", evaluator: "EvaluatorProtocol",
evaluator: Optional["EvaluatorProtocol"] = None,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = conf.train.trainer trainer_conf = config.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
train_logger = build_logger( train_logger = build_logger(
conf.train.logger, config.train.logger,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
) )
train_logger.log_hyperparams( train_logger.log_hyperparams(
conf.model_dump( config.model_dump(
mode="json", mode="json",
exclude_none=True, exclude_none=True,
) )
@ -184,11 +152,13 @@ def build_trainer(
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
logger=train_logger, logger=train_logger,
callbacks=build_trainer_callbacks( callbacks=[
targets, build_checkpoint_callback(
evaluator=evaluator, config=config.train.checkpoints,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
), ),
ValidationMetrics(evaluator),
],
) )