mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 09:29:33 +01:00
Compare commits
No commits in common. "24dcc47e73b1c578d4957bc60534c80707e1f0e5" and "2d796394f69e5ada5fa36ae5a1a867c538447786" have entirely different histories.
24dcc47e73
...
2d796394f6
@ -1,4 +1,3 @@
|
|||||||
from batdetect2.train.checkpoints import DEFAULT_CHECKPOINT_DIR
|
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
load_train_config,
|
load_train_config,
|
||||||
@ -7,7 +6,7 @@ from batdetect2.train.lightning import (
|
|||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
)
|
)
|
||||||
from batdetect2.train.train import build_trainer, train
|
from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DEFAULT_CHECKPOINT_DIR",
|
"DEFAULT_CHECKPOINT_DIR",
|
||||||
|
|||||||
@ -1,47 +0,0 @@
|
|||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
|
||||||
|
|
||||||
from batdetect2.core import BaseConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"CheckpointConfig",
|
|
||||||
"build_checkpoint_callback",
|
|
||||||
]
|
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointConfig(BaseConfig):
|
|
||||||
checkpoint_dir: Path = DEFAULT_CHECKPOINT_DIR
|
|
||||||
monitor: str = "classification/mean_average_precision"
|
|
||||||
mode: str = "max"
|
|
||||||
save_top_k: int = 1
|
|
||||||
filename: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
def build_checkpoint_callback(
|
|
||||||
config: Optional[CheckpointConfig] = None,
|
|
||||||
checkpoint_dir: Optional[Path] = None,
|
|
||||||
experiment_name: Optional[str] = None,
|
|
||||||
run_name: Optional[str] = None,
|
|
||||||
) -> Callback:
|
|
||||||
config = config or CheckpointConfig()
|
|
||||||
|
|
||||||
if checkpoint_dir is None:
|
|
||||||
checkpoint_dir = config.checkpoint_dir
|
|
||||||
|
|
||||||
if experiment_name is not None:
|
|
||||||
checkpoint_dir = checkpoint_dir / experiment_name
|
|
||||||
|
|
||||||
if run_name is not None:
|
|
||||||
checkpoint_dir = checkpoint_dir / run_name
|
|
||||||
|
|
||||||
return ModelCheckpoint(
|
|
||||||
dirpath=str(checkpoint_dir),
|
|
||||||
save_top_k=config.save_top_k,
|
|
||||||
monitor=config.monitor,
|
|
||||||
mode=config.mode,
|
|
||||||
filename=config.filename,
|
|
||||||
)
|
|
||||||
@ -6,7 +6,6 @@ from soundevent import data
|
|||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||||
from batdetect2.train.checkpoints import CheckpointConfig
|
|
||||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||||
from batdetect2.train.labels import LabelConfig
|
from batdetect2.train.labels import LabelConfig
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
@ -52,7 +51,6 @@ class TrainingConfig(BaseConfig):
|
|||||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||||
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
from lightning import Trainer, seed_everything
|
from lightning import Trainer, seed_everything
|
||||||
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -12,7 +13,6 @@ from batdetect2.logging import build_logger
|
|||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.checkpoints import build_checkpoint_callback
|
|
||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
@ -32,6 +32,8 @@ __all__ = [
|
|||||||
"train",
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
@ -102,6 +104,7 @@ def train(
|
|||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
config,
|
config,
|
||||||
|
targets=targets,
|
||||||
evaluator=build_evaluator(
|
evaluator=build_evaluator(
|
||||||
config.train.validation,
|
config.train.validation,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
@ -121,29 +124,58 @@ def train(
|
|||||||
logger.info("Training complete.")
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
|
def build_trainer_callbacks(
|
||||||
|
targets: "TargetProtocol",
|
||||||
|
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||||
|
checkpoint_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
) -> List[Callback]:
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
if experiment_name is not None:
|
||||||
|
checkpoint_dir = checkpoint_dir / experiment_name
|
||||||
|
|
||||||
|
if run_name is not None:
|
||||||
|
checkpoint_dir = checkpoint_dir / run_name
|
||||||
|
|
||||||
|
evaluator = evaluator or build_evaluator(targets=targets)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ModelCheckpoint(
|
||||||
|
dirpath=str(checkpoint_dir),
|
||||||
|
save_top_k=1,
|
||||||
|
monitor="classification/mean_average_precision",
|
||||||
|
mode="max",
|
||||||
|
),
|
||||||
|
ValidationMetrics(evaluator),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
config: "BatDetect2Config",
|
conf: "BatDetect2Config",
|
||||||
evaluator: "EvaluatorProtocol",
|
targets: "TargetProtocol",
|
||||||
|
evaluator: Optional["EvaluatorProtocol"] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = config.train.trainer
|
trainer_conf = conf.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
train_logger = build_logger(
|
train_logger = build_logger(
|
||||||
config.train.logger,
|
conf.train.logger,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
train_logger.log_hyperparams(
|
train_logger.log_hyperparams(
|
||||||
config.model_dump(
|
conf.model_dump(
|
||||||
mode="json",
|
mode="json",
|
||||||
exclude_none=True,
|
exclude_none=True,
|
||||||
)
|
)
|
||||||
@ -152,13 +184,11 @@ def build_trainer(
|
|||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
logger=train_logger,
|
logger=train_logger,
|
||||||
callbacks=[
|
callbacks=build_trainer_callbacks(
|
||||||
build_checkpoint_callback(
|
targets,
|
||||||
config=config.train.checkpoints,
|
evaluator=evaluator,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
),
|
),
|
||||||
ValidationMetrics(evaluator),
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user