mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Add checkpoint config
This commit is contained in:
parent
2d796394f6
commit
160cb6ae30
47
src/batdetect2/train/checkpoints.py
Normal file
47
src/batdetect2/train/checkpoints.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
|
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"CheckpointConfig",
|
||||||
|
"build_checkpoint_callback",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointConfig(BaseConfig):
|
||||||
|
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
|
||||||
|
monitor: str = "classification/mean_average_precision"
|
||||||
|
mode: str = "max"
|
||||||
|
save_top_k: int = 1
|
||||||
|
filename: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
def build_checkpoint_callback(
|
||||||
|
config: Optional[CheckpointConfig] = None,
|
||||||
|
checkpoint_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
) -> Callback:
|
||||||
|
config = config or CheckpointConfig()
|
||||||
|
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
checkpoint_dir = config.checkpoint_dir
|
||||||
|
|
||||||
|
if experiment_name is not None:
|
||||||
|
checkpoint_dir = checkpoint_dir / experiment_name
|
||||||
|
|
||||||
|
if run_name is not None:
|
||||||
|
checkpoint_dir = checkpoint_dir / run_name
|
||||||
|
|
||||||
|
return ModelCheckpoint(
|
||||||
|
dirpath=str(checkpoint_dir),
|
||||||
|
save_top_k=config.save_top_k,
|
||||||
|
monitor=config.monitor,
|
||||||
|
mode=config.mode,
|
||||||
|
filename=config.filename,
|
||||||
|
)
|
||||||
@ -6,6 +6,7 @@ from soundevent import data
|
|||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||||
|
from batdetect2.train.checkpoints import CheckpointConfig
|
||||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||||
from batdetect2.train.labels import LabelConfig
|
from batdetect2.train.labels import LabelConfig
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
@ -51,6 +52,7 @@ class TrainingConfig(BaseConfig):
|
|||||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
|
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
|
|||||||
@ -1,9 +1,8 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from lightning import Trainer, seed_everything
|
from lightning import Trainer, seed_everything
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -13,6 +12,7 @@ from batdetect2.logging import build_logger
|
|||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
|
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
@ -32,8 +32,6 @@ __all__ = [
|
|||||||
"train",
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
@ -104,7 +102,6 @@ def train(
|
|||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
config,
|
config,
|
||||||
targets=targets,
|
|
||||||
evaluator=build_evaluator(
|
evaluator=build_evaluator(
|
||||||
config.train.validation,
|
config.train.validation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -124,58 +121,29 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(
|
|
||||||
targets: "TargetProtocol",
|
|
||||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
|
||||||
checkpoint_dir: Optional[Path] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> List[Callback]:
|
|
||||||
if checkpoint_dir is None:
|
|
||||||
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
|
|
||||||
|
|
||||||
if experiment_name is not None:
|
|
||||||
checkpoint_dir = checkpoint_dir / experiment_name
|
|
||||||
|
|
||||||
if run_name is not None:
|
|
||||||
checkpoint_dir = checkpoint_dir / run_name
|
|
||||||
|
|
||||||
evaluator = evaluator or build_evaluator(targets=targets)
|
|
||||||
|
|
||||||
return [
|
|
||||||
ModelCheckpoint(
|
|
||||||
dirpath=str(checkpoint_dir),
|
|
||||||
save_top_k=1,
|
|
||||||
monitor="classification/mean_average_precision",
|
|
||||||
mode="max",
|
|
||||||
),
|
|
||||||
ValidationMetrics(evaluator),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: "BatDetect2Config",
|
config: "BatDetect2Config",
|
||||||
targets: "TargetProtocol",
|
evaluator: "EvaluatorProtocol",
|
||||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.train.trainer
|
trainer_conf = config.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
train_logger = build_logger(
|
train_logger = build_logger(
|
||||||
conf.train.logger,
|
config.train.logger,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_logger.log_hyperparams(
|
train_logger.log_hyperparams(
|
||||||
conf.model_dump(
|
config.model_dump(
|
||||||
mode="json",
|
mode="json",
|
||||||
exclude_none=True,
|
exclude_none=True,
|
||||||
)
|
)
|
||||||
@ -184,11 +152,13 @@ def build_trainer(
|
|||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
callbacks=build_trainer_callbacks(
|
callbacks=[
|
||||||
targets,
|
build_checkpoint_callback(
|
||||||
evaluator=evaluator,
|
config=config.train.checkpoints,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
),
|
),
|
||||||
|
ValidationMetrics(evaluator),
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user