mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Add checkpoint config
This commit is contained in:
parent
2d796394f6
commit
160cb6ae30
47
src/batdetect2/train/checkpoints.py
Normal file
47
src/batdetect2/train/checkpoints.py
Normal file
@ -0,0 +1,47 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"CheckpointConfig",
|
||||
"build_checkpoint_callback",
|
||||
]
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
|
||||
|
||||
class CheckpointConfig(BaseConfig):
|
||||
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
|
||||
monitor: str = "classification/mean_average_precision"
|
||||
mode: str = "max"
|
||||
save_top_k: int = 1
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
def build_checkpoint_callback(
|
||||
config: Optional[CheckpointConfig] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Callback:
|
||||
config = config or CheckpointConfig()
|
||||
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = config.checkpoint_dir
|
||||
|
||||
if experiment_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / experiment_name
|
||||
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
return ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=config.save_top_k,
|
||||
monitor=config.monitor,
|
||||
mode=config.mode,
|
||||
filename=config.filename,
|
||||
)
|
||||
@ -6,6 +6,7 @@ from soundevent import data
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||
from batdetect2.train.checkpoints import CheckpointConfig
|
||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import LabelConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
@ -51,6 +52,7 @@ class TrainingConfig(BaseConfig):
|
||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from lightning import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
@ -13,6 +12,7 @@ from batdetect2.logging import build_logger
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
@ -32,8 +32,6 @@ __all__ = [
|
||||
"train",
|
||||
]
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
|
||||
|
||||
def train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
@ -104,7 +102,6 @@ def train(
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
config,
|
||||
targets=targets,
|
||||
evaluator=build_evaluator(
|
||||
config.train.validation,
|
||||
targets=targets,
|
||||
@ -124,58 +121,29 @@ def train(
|
||||
logger.info("Training complete.")
|
||||
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> List[Callback]:
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
if experiment_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / experiment_name
|
||||
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
evaluator = evaluator or build_evaluator(targets=targets)
|
||||
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=1,
|
||||
monitor="classification/mean_average_precision",
|
||||
mode="max",
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
]
|
||||
|
||||
|
||||
def build_trainer(
|
||||
conf: "BatDetect2Config",
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
config: "BatDetect2Config",
|
||||
evaluator: "EvaluatorProtocol",
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = conf.train.trainer
|
||||
trainer_conf = config.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
conf.train.logger,
|
||||
config.train.logger,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
train_logger.log_hyperparams(
|
||||
conf.model_dump(
|
||||
config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
@ -184,11 +152,13 @@ def build_trainer(
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
logger=train_logger,
|
||||
callbacks=build_trainer_callbacks(
|
||||
targets,
|
||||
evaluator=evaluator,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
),
|
||||
callbacks=[
|
||||
build_checkpoint_callback(
|
||||
config=config.train.checkpoints,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
],
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user