mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
No commits in common. "24dcc47e73b1c578d4957bc60534c80707e1f0e5" and "2d796394f69e5ada5fa36ae5a1a867c538447786" have entirely different histories.
24dcc47e73
...
2d796394f6
@ -1,4 +1,3 @@
|
||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
||||
from batdetect2.train.config import (
|
||||
TrainingConfig,
|
||||
load_train_config,
|
||||
@ -7,7 +6,7 @@ from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
)
|
||||
from batdetect2.train.train import build_trainer, train
|
||||
from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train
|
||||
|
||||
__all__ = [
|
||||
"DEFAULT_CHECKPOINT_DIR",
|
||||
|
||||
@ -1,47 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
|
||||
from batdetect2.core import BaseConfig
|
||||
|
||||
__all__ = [
|
||||
"CheckpointConfig",
|
||||
"build_checkpoint_callback",
|
||||
]
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
|
||||
|
||||
class CheckpointConfig(BaseConfig):
|
||||
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
|
||||
monitor: str = "classification/mean_average_precision"
|
||||
mode: str = "max"
|
||||
save_top_k: int = 1
|
||||
filename: Optional[str] = None
|
||||
|
||||
|
||||
def build_checkpoint_callback(
|
||||
config: Optional[CheckpointConfig] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Callback:
|
||||
config = config or CheckpointConfig()
|
||||
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = config.checkpoint_dir
|
||||
|
||||
if experiment_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / experiment_name
|
||||
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
return ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=config.save_top_k,
|
||||
monitor=config.monitor,
|
||||
mode=config.mode,
|
||||
filename=config.filename,
|
||||
)
|
||||
@ -6,7 +6,6 @@ from soundevent import data
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||
from batdetect2.train.checkpoints import CheckpointConfig
|
||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import LabelConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
@ -52,7 +51,6 @@ class TrainingConfig(BaseConfig):
|
||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
|
||||
|
||||
|
||||
def load_train_config(
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from lightning import Trainer, seed_everything
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
@ -12,7 +13,6 @@ from batdetect2.logging import build_logger
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
@ -32,6 +32,8 @@ __all__ = [
|
||||
"train",
|
||||
]
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
|
||||
|
||||
def train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
@ -102,6 +104,7 @@ def train(
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
config,
|
||||
targets=targets,
|
||||
evaluator=build_evaluator(
|
||||
config.train.validation,
|
||||
targets=targets,
|
||||
@ -121,29 +124,58 @@ def train(
|
||||
logger.info("Training complete.")
|
||||
|
||||
|
||||
def build_trainer_callbacks(
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> List[Callback]:
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
|
||||
|
||||
if experiment_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / experiment_name
|
||||
|
||||
if run_name is not None:
|
||||
checkpoint_dir = checkpoint_dir / run_name
|
||||
|
||||
evaluator = evaluator or build_evaluator(targets=targets)
|
||||
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=1,
|
||||
monitor="classification/mean_average_precision",
|
||||
mode="max",
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
]
|
||||
|
||||
|
||||
def build_trainer(
|
||||
config: "BatDetect2Config",
|
||||
evaluator: "EvaluatorProtocol",
|
||||
conf: "BatDetect2Config",
|
||||
targets: "TargetProtocol",
|
||||
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = config.train.trainer
|
||||
trainer_conf = conf.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
config.train.logger,
|
||||
conf.train.logger,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
train_logger.log_hyperparams(
|
||||
config.model_dump(
|
||||
conf.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
@ -152,13 +184,11 @@ def build_trainer(
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
logger=train_logger,
|
||||
callbacks=[
|
||||
build_checkpoint_callback(
|
||||
config=config.train.checkpoints,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
),
|
||||
ValidationMetrics(evaluator),
|
||||
],
|
||||
callbacks=build_trainer_callbacks(
|
||||
targets,
|
||||
evaluator=evaluator,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
),
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user