From 5a974711b09d7983e788db039a258cc834dc7646 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 5 May 2026 14:09:53 +0100 Subject: [PATCH] feat: log training provenance artifacts --- src/batdetect2/api_v2.py | 11 +- src/batdetect2/cli/finetune.py | 41 ++++++-- src/batdetect2/cli/train.py | 24 ++++- src/batdetect2/logging.py | 85 ++++++++++++++- src/batdetect2/train/__init__.py | 14 +++ src/batdetect2/train/logging.py | 164 +++++++++++++++++++++++++++++ src/batdetect2/train/train.py | 59 ++++++++--- tests/test_train/test_lightning.py | 57 ++++++++++ 8 files changed, 426 insertions(+), 29 deletions(-) create mode 100644 src/batdetect2/train/logging.py diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index ce24d95..143dfe9 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -15,7 +15,11 @@ if TYPE_CHECKING: from batdetect2.data import Dataset from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol from batdetect2.inference import InferenceConfig - from batdetect2.logging import AppLoggingConfig, LoggerConfig + from batdetect2.logging import ( + AppLoggingConfig, + LoggerConfig, + LoggingCallback, + ) from batdetect2.models import Model, ModelConfig from batdetect2.outputs import ( OutputFormatConfig, @@ -35,6 +39,7 @@ if TYPE_CHECKING: TargetProtocol, ) from batdetect2.train import TrainingConfig + from batdetect2.train.logging import TrainLoggingContext DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" @@ -106,6 +111,7 @@ class BatDetect2API: audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, logger_config: LoggerConfig | None = None, + logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (), ): from batdetect2.train import run_train @@ -130,6 +136,7 @@ class BatDetect2API: train_config=train_config or self.train_config, audio_config=audio_config or self.audio_config, logger_config=logger_config or self.logging_config.train, + logging_callbacks=logging_callbacks, ) self.model.eval() return self @@ -153,6 +160,7 @@ class BatDetect2API: audio_config: AudioConfig | None = None, train_config: TrainingConfig | None = None, logger_config: LoggerConfig | None = None, + logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (), ) -> "BatDetect2API": """Fine-tune from a checkpoint using a new target definition.""" from batdetect2.evaluate import build_evaluator @@ -231,6 +239,7 @@ class BatDetect2API: audio_config=api.audio_config, train_config=api.train_config, logger_config=logger_config or api.logging_config.train, + logging_callbacks=logging_callbacks, ) api.model.eval() return api diff --git a/src/batdetect2/cli/finetune.py b/src/batdetect2/cli/finetune.py index 45b6efb..467b91d 100644 --- a/src/batdetect2/cli/finetune.py +++ b/src/batdetect2/cli/finetune.py @@ -126,10 +126,14 @@ def finetune_command( """Fine-tune a BatDetect2 checkpoint on a new target definition.""" from batdetect2.api_v2 import BatDetect2API from batdetect2.audio import AudioConfig - from batdetect2.data import load_dataset_from_config + from batdetect2.data import load_dataset, load_dataset_config from batdetect2.logging import AppLoggingConfig from batdetect2.targets import TargetConfig from batdetect2.train import TrainingConfig + from batdetect2.train.logging import ( + DatasetConfigArtifact, + DatasetConfigArtifactLogging, + ) logger.info("Initiating fine-tuning process...") @@ -148,16 +152,34 @@ def finetune_command( else None ) - train_annotations = load_dataset_from_config( - train_dataset, - base_dir=base_dir, + train_dataset_conf = load_dataset_config(train_dataset) + train_annotations = load_dataset(train_dataset_conf, base_dir=base_dir) + + val_dataset_conf = ( + load_dataset_config(val_dataset) if val_dataset else None ) - val_annotations = None - if val_dataset is not None: - val_annotations = load_dataset_from_config( - val_dataset, - base_dir=base_dir, + val_annotations = ( + load_dataset(val_dataset_conf, base_dir=base_dir) + if val_dataset_conf + else None + ) + + logging_callbacks = [ + DatasetConfigArtifactLogging( + train_dataset_config=DatasetConfigArtifact( + filename="train_dataset.yaml", + config=train_dataset_conf, + ), + val_dataset_config=( + DatasetConfigArtifact( + filename="val_dataset.yaml", + config=val_dataset_conf, + ) + if val_dataset_conf + else None + ), ) + ] api = BatDetect2API.from_checkpoint( model_path, @@ -185,4 +207,5 @@ def finetune_command( train_config=train_conf, audio_config=audio_conf, logger_config=logging_conf.train if logging_conf is not None else None, + logging_callbacks=logging_callbacks, ) diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 1b428b7..a23687e 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -145,7 +145,7 @@ def train_command( """ from batdetect2.api_v2 import BatDetect2API from batdetect2.audio import AudioConfig - from batdetect2.data import load_dataset_from_config + from batdetect2.data import load_dataset_config, load_dataset_from_config from batdetect2.evaluate import EvaluationConfig from batdetect2.inference import InferenceConfig from batdetect2.logging import AppLoggingConfig @@ -153,6 +153,10 @@ def train_command( from batdetect2.outputs import OutputsConfig from batdetect2.targets import TargetConfig from batdetect2.train import TrainingConfig + from batdetect2.train.logging import ( + DatasetConfigArtifact, + DatasetConfigArtifactLogging, + ) logger.info("Initiating training process...") @@ -222,6 +226,23 @@ def train_command( logger.info("Configuration and data loaded. Starting training...") + logging_callbacks = [ + DatasetConfigArtifactLogging( + train_dataset_config=DatasetConfigArtifact( + filename="train_dataset.yaml", + config=load_dataset_config(train_dataset), + ), + val_dataset_config=( + DatasetConfigArtifact( + filename="val_dataset.yaml", + config=load_dataset_config(val_dataset), + ) + if val_dataset is not None + else None + ), + ) + ] + if model_path is not None and model_conf is not None: raise click.UsageError( "--model-config cannot be used with --model. " @@ -267,4 +288,5 @@ def train_command( experiment_name=experiment_name, run_name=run_name, seed=seed, + logging_callbacks=logging_callbacks, ) diff --git a/src/batdetect2/logging.py b/src/batdetect2/logging.py index 6d14980..6376ae7 100644 --- a/src/batdetect2/logging.py +++ b/src/batdetect2/logging.py @@ -24,9 +24,7 @@ from batdetect2.core.configs import BaseConfig if TYPE_CHECKING: import numpy as np import pandas as pd - from lightning.pytorch.loggers import ( - Logger, - ) + from lightning.pytorch.loggers import Logger from matplotlib.figure import Figure from soundevent import data @@ -40,11 +38,15 @@ __all__ = [ "DVCLiveConfig", "LoggerConfig", "MLFlowLoggerConfig", + "LoggingCallback", "TensorBoardLoggerConfig", "build_logger", "enable_logging", "get_image_logger", "get_table_logger", + "log_artifact_file", + "log_config_artifact", + "log_csv_artifact", ] @@ -120,6 +122,18 @@ class LoggerBuilder(Protocol, Generic[T]): ) -> Logger: ... +LoggingContext = TypeVar("LoggingContext", contravariant=True) + + +class LoggingCallback(Protocol, Generic[LoggingContext]): + def run( + self, + logger: Logger, + artifact_path: Path, + context: LoggingContext, + ) -> None: ... + + def create_dvclive_logger( config: DVCLiveConfig, log_dir: Path | None = None, @@ -273,6 +287,71 @@ def build_logger( ) +def log_artifact_file( + runtime_logger: Logger, + path: Path, + artifact_path: str = "artifacts", +) -> None: + from lightning.pytorch.loggers import ( + CSVLogger, + MLFlowLogger, + TensorBoardLogger, + ) + + if isinstance(runtime_logger, MLFlowLogger): + runtime_logger.experiment.log_artifact( # type: ignore[call-arg] + local_path=str(path), + artifact_path=artifact_path, + run_id=runtime_logger.run_id, + ) + return + + experiment = getattr(runtime_logger, "experiment", None) + if experiment is not None and hasattr(experiment, "log_artifact"): + experiment.log_artifact(path=path, name=path.name, copy=True) + return + + if isinstance(runtime_logger, (CSVLogger, TensorBoardLogger)): + return + + logger.warning( + "Skipping artifact logging for unsupported logger type {logger_type}", + logger_type=type(runtime_logger).__name__, + ) + + +def log_config_artifact( + logger: Logger, + config: BaseConfig, + filename: str, + artifact_path: Path, +) -> None: + artifact_path.mkdir(parents=True, exist_ok=True) + path = artifact_path / filename + path.write_text(config.to_yaml_string()) + log_artifact_file( + logger, + path, + artifact_path=artifact_path.name, + ) + + +def log_csv_artifact( + logger: Logger, + df: pd.DataFrame, + filename: str, + artifact_path: Path, +) -> None: + artifact_path.mkdir(parents=True, exist_ok=True) + path = artifact_path / filename + df.to_csv(path, index=False) + log_artifact_file( + logger, + path, + artifact_path=artifact_path.name, + ) + + PlotLogger = Callable[[str, "Figure", int], None] diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 72f3e77..ee0731e 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -4,10 +4,24 @@ from batdetect2.train.lightning import ( TrainingModule, load_model_from_checkpoint, ) +from batdetect2.train.logging import ( + ConfigHyperparameterLogging, + DatasetConfigArtifact, + DatasetConfigArtifactLogging, + DataSummaryArtifactLogging, + TargetConfigArtifactLogging, + TrainLoggingContext, +) from batdetect2.train.train import build_trainer, run_train __all__ = [ + "ConfigHyperparameterLogging", + "DataSummaryArtifactLogging", "DEFAULT_CHECKPOINT_DIR", + "DatasetConfigArtifact", + "DatasetConfigArtifactLogging", + "TargetConfigArtifactLogging", + "TrainLoggingContext", "TrainingConfig", "TrainingModule", "build_trainer", diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py new file mode 100644 index 0000000..75829f4 --- /dev/null +++ b/src/batdetect2/train/logging.py @@ -0,0 +1,164 @@ +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass +from pathlib import Path + +import pandas as pd +from lightning.pytorch.loggers import Logger +from soundevent import data + +from batdetect2.audio import AudioConfig +from batdetect2.core.configs import BaseConfig +from batdetect2.data import Dataset, compute_class_summary +from batdetect2.logging import log_config_artifact, log_csv_artifact +from batdetect2.models import ModelConfig +from batdetect2.targets import TargetConfig, TargetProtocol +from batdetect2.train.config import TrainingConfig + +__all__ = [ + "ConfigHyperparameterLogging", + "DataSummaryArtifactLogging", + "DatasetConfigArtifact", + "DatasetConfigArtifactLogging", + "TargetConfigArtifactLogging", + "TrainLoggingContext", +] + + +@dataclass(frozen=True) +class TrainLoggingContext: + model_config: ModelConfig + train_config: TrainingConfig + audio_config: AudioConfig + targets: TargetProtocol + train_dataset: Dataset + val_dataset: Dataset | None + + +@dataclass(frozen=True) +class DatasetConfigArtifact: + filename: str + config: BaseConfig + + +class ConfigHyperparameterLogging: + def run( + self, + logger: Logger, + artifact_path: Path, + context: TrainLoggingContext, + ) -> None: + logger.log_hyperparams( + { + "model": context.model_config.model_dump( + mode="json", + exclude_none=True, + ), + "training": context.train_config.model_dump( + mode="json", + exclude_none=True, + ), + "audio": context.audio_config.model_dump( + mode="json", + exclude_none=True, + ), + } + ) + + +class TargetConfigArtifactLogging: + def run( + self, + logger: Logger, + artifact_path: Path, + context: TrainLoggingContext, + ) -> None: + targets_config = TargetConfig.model_validate( + context.targets.get_config() + ) + log_config_artifact( + logger, + targets_config, + filename="targets.yaml", + artifact_path=artifact_path / "training_artifacts", + ) + + +class DatasetConfigArtifactLogging: + def __init__( + self, + train_dataset_config: DatasetConfigArtifact, + val_dataset_config: DatasetConfigArtifact | None = None, + ): + self.train_dataset_config = train_dataset_config + self.val_dataset_config = val_dataset_config + + def run( + self, + logger: Logger, + artifact_path: Path, + context: TrainLoggingContext, + ) -> None: + training_artifact_path = artifact_path / "training_artifacts" + + log_config_artifact( + logger, + self.train_dataset_config.config, + filename=self.train_dataset_config.filename, + artifact_path=training_artifact_path, + ) + + if self.val_dataset_config is not None: + log_config_artifact( + logger, + self.val_dataset_config.config, + filename=self.val_dataset_config.filename, + artifact_path=training_artifact_path, + ) + + +class DataSummaryArtifactLogging: + def run( + self, + logger: Logger, + artifact_path: Path, + context: TrainLoggingContext, + ) -> None: + training_artifact_path = artifact_path / "training_artifacts" + + log_csv_artifact( + logger, + _compute_class_summary_or_empty( + context.train_dataset, + context.targets, + ), + filename="train_class_summary.csv", + artifact_path=training_artifact_path, + ) + + if context.val_dataset is not None: + log_csv_artifact( + logger, + _compute_class_summary_or_empty( + context.val_dataset, + context.targets, + ), + filename="val_class_summary.csv", + artifact_path=training_artifact_path, + ) + + +def _compute_class_summary_or_empty( + dataset: Sequence[data.ClipAnnotation], + targets: TargetProtocol, +) -> pd.DataFrame: + try: + return compute_class_summary(dataset, targets) + except KeyError as error: + if error.args != ("class_name",): + raise + + return pd.DataFrame( + columns=["num calls", "num recordings", "duration", "call_rate"] + ) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index dcbcbb8..0387e80 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import Optional from lightning import Trainer, seed_everything +from lightning.pytorch.loggers import Logger from loguru import logger from soundevent import data @@ -10,6 +11,7 @@ from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader from batdetect2.evaluate import EvaluatorProtocol, build_evaluator from batdetect2.logging import ( LoggerConfig, + LoggingCallback, TensorBoardLoggerConfig, build_logger, ) @@ -28,6 +30,12 @@ from batdetect2.train.config import TrainingConfig from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module +from batdetect2.train.logging import ( + ConfigHyperparameterLogging, + DataSummaryArtifactLogging, + TargetConfigArtifactLogging, + TrainLoggingContext, +) from batdetect2.train.types import ClipLabeller __all__ = [ @@ -36,6 +44,9 @@ __all__ = [ ] +DEFAULT_LOG_DIR = Path("outputs") / "logs" + + def run_train( train_annotations: Sequence[data.ClipAnnotation], val_annotations: Sequence[data.ClipAnnotation] | None = None, @@ -59,6 +70,7 @@ def run_train( num_epochs: int | None = None, run_name: str | None = None, seed: int | None = None, + logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (), ): if seed is not None: seed_everything(seed) @@ -148,15 +160,44 @@ def run_train( roi_mapper=roi_mapper, ) + train_logger = build_logger( + logger_config or TensorBoardLoggerConfig(), + log_dir=log_dir, + experiment_name=experiment_name, + run_name=run_name, + ) + root_artifact_path = ( + Path(log_dir) if log_dir is not None else DEFAULT_LOG_DIR + ) + root_artifact_path.mkdir(parents=True, exist_ok=True) + + logging_context = TrainLoggingContext( + model_config=model_config, + train_config=train_config, + audio_config=audio_config, + targets=targets, + train_dataset=train_annotations, + val_dataset=val_annotations, + ) + + resolved_logging_callbacks = ( + ConfigHyperparameterLogging(), + TargetConfigArtifactLogging(), + DataSummaryArtifactLogging(), + *logging_callbacks, + ) + + for callback in resolved_logging_callbacks: + callback.run(train_logger, root_artifact_path, logging_context) + trainer = trainer or build_trainer( train_config, - logger_config=logger_config, + train_logger=train_logger, evaluator=evaluator, targets=targets, roi_mapper=roi_mapper, checkpoint_dir=checkpoint_dir, num_epochs=num_epochs, - log_dir=log_dir, experiment_name=experiment_name, run_name=run_name, ) @@ -223,12 +264,11 @@ def _validate_model_compatibility( def build_trainer( config: TrainingConfig, - logger_config: LoggerConfig | None, + train_logger: Logger, evaluator: "EvaluatorProtocol", targets: "TargetProtocol", roi_mapper: "ROIMapperProtocol", checkpoint_dir: Path | None = None, - log_dir: Path | None = None, experiment_name: str | None = None, run_name: str | None = None, num_epochs: int | None = None, @@ -239,20 +279,9 @@ def build_trainer( config=lambda: trainer_conf.to_yaml_string(exclude_none=True), ) - train_logger = build_logger( - logger_config or TensorBoardLoggerConfig(), - log_dir=log_dir, - experiment_name=experiment_name, - run_name=run_name, - ) - if num_epochs is not None: trainer_conf.max_epochs = num_epochs - train_logger.log_hyperparams( - config.model_dump(mode="json", exclude_none=True) - ) - train_config = trainer_conf.model_dump(exclude_none=True) return Trainer( diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index bac0b50..c6a9ccd 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -22,6 +22,10 @@ from batdetect2.train import ( load_model_from_checkpoint, run_train, ) +from batdetect2.train.logging import ( + DatasetConfigArtifact, + DatasetConfigArtifactLogging, +) from batdetect2.train.optimizers import AdamOptimizerConfig from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig from batdetect2.train.train import build_training_module @@ -369,6 +373,59 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> ( assert rebuilt_model.dimension_names == ["width", "height"] +@pytest.mark.slow +def test_run_train_logs_training_artifacts( + tmp_path: Path, + example_annotations: list[data.ClipAnnotation], + example_dataset, +) -> None: + train_config = TrainingConfig.model_validate( + { + "trainer": { + "limit_train_batches": 1, + "limit_val_batches": 1, + "log_every_n_steps": 1, + }, + "train_loader": { + "batch_size": 1, + "augmentations": {"enabled": False}, + }, + } + ) + + run_train( + train_annotations=example_annotations[:1], + val_annotations=example_annotations[:1], + train_config=train_config, + num_epochs=1, + train_workers=0, + val_workers=0, + checkpoint_dir=tmp_path / "checkpoints", + log_dir=tmp_path / "logs", + seed=0, + logging_callbacks=[ + DatasetConfigArtifactLogging( + train_dataset_config=DatasetConfigArtifact( + filename="train_dataset.yaml", + config=example_dataset, + ), + val_dataset_config=DatasetConfigArtifact( + filename="val_dataset.yaml", + config=example_dataset, + ), + ) + ], + ) + + artifact_root = next((tmp_path / "logs").rglob("training_artifacts")) + + assert (artifact_root / "targets.yaml").exists() + assert (artifact_root / "train_dataset.yaml").exists() + assert (artifact_root / "val_dataset.yaml").exists() + assert (artifact_root / "train_class_summary.csv").exists() + assert (artifact_root / "val_class_summary.csv").exists() + + def test_run_train_rejects_incompatible_model_config( example_annotations: list[data.ClipAnnotation], ) -> None: