feat: log training provenance artifacts

This commit is contained in:
mbsantiago 2026-05-05 14:09:53 +01:00
parent 7b2699786f
commit 5a974711b0
8 changed files with 426 additions and 29 deletions

View File

@ -15,7 +15,11 @@ if TYPE_CHECKING:
from batdetect2.data import Dataset
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig, LoggerConfig
from batdetect2.logging import (
AppLoggingConfig,
LoggerConfig,
LoggingCallback,
)
from batdetect2.models import Model, ModelConfig
from batdetect2.outputs import (
OutputFormatConfig,
@ -35,6 +39,7 @@ if TYPE_CHECKING:
TargetProtocol,
)
from batdetect2.train import TrainingConfig
from batdetect2.train.logging import TrainLoggingContext
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
@ -106,6 +111,7 @@ class BatDetect2API:
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
):
from batdetect2.train import run_train
@ -130,6 +136,7 @@ class BatDetect2API:
train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config,
logger_config=logger_config or self.logging_config.train,
logging_callbacks=logging_callbacks,
)
self.model.eval()
return self
@ -153,6 +160,7 @@ class BatDetect2API:
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
) -> "BatDetect2API":
"""Fine-tune from a checkpoint using a new target definition."""
from batdetect2.evaluate import build_evaluator
@ -231,6 +239,7 @@ class BatDetect2API:
audio_config=api.audio_config,
train_config=api.train_config,
logger_config=logger_config or api.logging_config.train,
logging_callbacks=logging_callbacks,
)
api.model.eval()
return api

View File

@ -126,10 +126,14 @@ def finetune_command(
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config
from batdetect2.data import load_dataset, load_dataset_config
from batdetect2.logging import AppLoggingConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
from batdetect2.train.logging import (
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
)
logger.info("Initiating fine-tuning process...")
@ -148,16 +152,34 @@ def finetune_command(
else None
)
train_annotations = load_dataset_from_config(
train_dataset,
base_dir=base_dir,
train_dataset_conf = load_dataset_config(train_dataset)
train_annotations = load_dataset(train_dataset_conf, base_dir=base_dir)
val_dataset_conf = (
load_dataset_config(val_dataset) if val_dataset else None
)
val_annotations = None
if val_dataset is not None:
val_annotations = load_dataset_from_config(
val_dataset,
base_dir=base_dir,
val_annotations = (
load_dataset(val_dataset_conf, base_dir=base_dir)
if val_dataset_conf
else None
)
logging_callbacks = [
DatasetConfigArtifactLogging(
train_dataset_config=DatasetConfigArtifact(
filename="train_dataset.yaml",
config=train_dataset_conf,
),
val_dataset_config=(
DatasetConfigArtifact(
filename="val_dataset.yaml",
config=val_dataset_conf,
)
if val_dataset_conf
else None
),
)
]
api = BatDetect2API.from_checkpoint(
model_path,
@ -185,4 +207,5 @@ def finetune_command(
train_config=train_conf,
audio_config=audio_conf,
logger_config=logging_conf.train if logging_conf is not None else None,
logging_callbacks=logging_callbacks,
)

View File

@ -145,7 +145,7 @@ def train_command(
"""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.data import load_dataset_from_config
from batdetect2.data import load_dataset_config, load_dataset_from_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
@ -153,6 +153,10 @@ def train_command(
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import TargetConfig
from batdetect2.train import TrainingConfig
from batdetect2.train.logging import (
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
)
logger.info("Initiating training process...")
@ -222,6 +226,23 @@ def train_command(
logger.info("Configuration and data loaded. Starting training...")
logging_callbacks = [
DatasetConfigArtifactLogging(
train_dataset_config=DatasetConfigArtifact(
filename="train_dataset.yaml",
config=load_dataset_config(train_dataset),
),
val_dataset_config=(
DatasetConfigArtifact(
filename="val_dataset.yaml",
config=load_dataset_config(val_dataset),
)
if val_dataset is not None
else None
),
)
]
if model_path is not None and model_conf is not None:
raise click.UsageError(
"--model-config cannot be used with --model. "
@ -267,4 +288,5 @@ def train_command(
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
logging_callbacks=logging_callbacks,
)

View File

@ -24,9 +24,7 @@ from batdetect2.core.configs import BaseConfig
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
Logger,
)
from lightning.pytorch.loggers import Logger
from matplotlib.figure import Figure
from soundevent import data
@ -40,11 +38,15 @@ __all__ = [
"DVCLiveConfig",
"LoggerConfig",
"MLFlowLoggerConfig",
"LoggingCallback",
"TensorBoardLoggerConfig",
"build_logger",
"enable_logging",
"get_image_logger",
"get_table_logger",
"log_artifact_file",
"log_config_artifact",
"log_csv_artifact",
]
@ -120,6 +122,18 @@ class LoggerBuilder(Protocol, Generic[T]):
) -> Logger: ...
LoggingContext = TypeVar("LoggingContext", contravariant=True)
class LoggingCallback(Protocol, Generic[LoggingContext]):
def run(
self,
logger: Logger,
artifact_path: Path,
context: LoggingContext,
) -> None: ...
def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Path | None = None,
@ -273,6 +287,71 @@ def build_logger(
)
def log_artifact_file(
runtime_logger: Logger,
path: Path,
artifact_path: str = "artifacts",
) -> None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(runtime_logger, MLFlowLogger):
runtime_logger.experiment.log_artifact( # type: ignore[call-arg]
local_path=str(path),
artifact_path=artifact_path,
run_id=runtime_logger.run_id,
)
return
experiment = getattr(runtime_logger, "experiment", None)
if experiment is not None and hasattr(experiment, "log_artifact"):
experiment.log_artifact(path=path, name=path.name, copy=True)
return
if isinstance(runtime_logger, (CSVLogger, TensorBoardLogger)):
return
logger.warning(
"Skipping artifact logging for unsupported logger type {logger_type}",
logger_type=type(runtime_logger).__name__,
)
def log_config_artifact(
logger: Logger,
config: BaseConfig,
filename: str,
artifact_path: Path,
) -> None:
artifact_path.mkdir(parents=True, exist_ok=True)
path = artifact_path / filename
path.write_text(config.to_yaml_string())
log_artifact_file(
logger,
path,
artifact_path=artifact_path.name,
)
def log_csv_artifact(
logger: Logger,
df: pd.DataFrame,
filename: str,
artifact_path: Path,
) -> None:
artifact_path.mkdir(parents=True, exist_ok=True)
path = artifact_path / filename
df.to_csv(path, index=False)
log_artifact_file(
logger,
path,
artifact_path=artifact_path.name,
)
PlotLogger = Callable[[str, "Figure", int], None]

View File

@ -4,10 +4,24 @@ from batdetect2.train.lightning import (
TrainingModule,
load_model_from_checkpoint,
)
from batdetect2.train.logging import (
ConfigHyperparameterLogging,
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
DataSummaryArtifactLogging,
TargetConfigArtifactLogging,
TrainLoggingContext,
)
from batdetect2.train.train import build_trainer, run_train
__all__ = [
"ConfigHyperparameterLogging",
"DataSummaryArtifactLogging",
"DEFAULT_CHECKPOINT_DIR",
"DatasetConfigArtifact",
"DatasetConfigArtifactLogging",
"TargetConfigArtifactLogging",
"TrainLoggingContext",
"TrainingConfig",
"TrainingModule",
"build_trainer",

View File

@ -0,0 +1,164 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from pathlib import Path
import pandas as pd
from lightning.pytorch.loggers import Logger
from soundevent import data
from batdetect2.audio import AudioConfig
from batdetect2.core.configs import BaseConfig
from batdetect2.data import Dataset, compute_class_summary
from batdetect2.logging import log_config_artifact, log_csv_artifact
from batdetect2.models import ModelConfig
from batdetect2.targets import TargetConfig, TargetProtocol
from batdetect2.train.config import TrainingConfig
__all__ = [
"ConfigHyperparameterLogging",
"DataSummaryArtifactLogging",
"DatasetConfigArtifact",
"DatasetConfigArtifactLogging",
"TargetConfigArtifactLogging",
"TrainLoggingContext",
]
@dataclass(frozen=True)
class TrainLoggingContext:
model_config: ModelConfig
train_config: TrainingConfig
audio_config: AudioConfig
targets: TargetProtocol
train_dataset: Dataset
val_dataset: Dataset | None
@dataclass(frozen=True)
class DatasetConfigArtifact:
filename: str
config: BaseConfig
class ConfigHyperparameterLogging:
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
logger.log_hyperparams(
{
"model": context.model_config.model_dump(
mode="json",
exclude_none=True,
),
"training": context.train_config.model_dump(
mode="json",
exclude_none=True,
),
"audio": context.audio_config.model_dump(
mode="json",
exclude_none=True,
),
}
)
class TargetConfigArtifactLogging:
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
targets_config = TargetConfig.model_validate(
context.targets.get_config()
)
log_config_artifact(
logger,
targets_config,
filename="targets.yaml",
artifact_path=artifact_path / "training_artifacts",
)
class DatasetConfigArtifactLogging:
def __init__(
self,
train_dataset_config: DatasetConfigArtifact,
val_dataset_config: DatasetConfigArtifact | None = None,
):
self.train_dataset_config = train_dataset_config
self.val_dataset_config = val_dataset_config
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
training_artifact_path = artifact_path / "training_artifacts"
log_config_artifact(
logger,
self.train_dataset_config.config,
filename=self.train_dataset_config.filename,
artifact_path=training_artifact_path,
)
if self.val_dataset_config is not None:
log_config_artifact(
logger,
self.val_dataset_config.config,
filename=self.val_dataset_config.filename,
artifact_path=training_artifact_path,
)
class DataSummaryArtifactLogging:
def run(
self,
logger: Logger,
artifact_path: Path,
context: TrainLoggingContext,
) -> None:
training_artifact_path = artifact_path / "training_artifacts"
log_csv_artifact(
logger,
_compute_class_summary_or_empty(
context.train_dataset,
context.targets,
),
filename="train_class_summary.csv",
artifact_path=training_artifact_path,
)
if context.val_dataset is not None:
log_csv_artifact(
logger,
_compute_class_summary_or_empty(
context.val_dataset,
context.targets,
),
filename="val_class_summary.csv",
artifact_path=training_artifact_path,
)
def _compute_class_summary_or_empty(
dataset: Sequence[data.ClipAnnotation],
targets: TargetProtocol,
) -> pd.DataFrame:
try:
return compute_class_summary(dataset, targets)
except KeyError as error:
if error.args != ("class_name",):
raise
return pd.DataFrame(
columns=["num calls", "num recordings", "duration", "call_rate"]
)

View File

@ -3,6 +3,7 @@ from pathlib import Path
from typing import Optional
from lightning import Trainer, seed_everything
from lightning.pytorch.loggers import Logger
from loguru import logger
from soundevent import data
@ -10,6 +11,7 @@ from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
from batdetect2.logging import (
LoggerConfig,
LoggingCallback,
TensorBoardLoggerConfig,
build_logger,
)
@ -28,6 +30,12 @@ from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import build_training_module
from batdetect2.train.logging import (
ConfigHyperparameterLogging,
DataSummaryArtifactLogging,
TargetConfigArtifactLogging,
TrainLoggingContext,
)
from batdetect2.train.types import ClipLabeller
__all__ = [
@ -36,6 +44,9 @@ __all__ = [
]
DEFAULT_LOG_DIR = Path("outputs") / "logs"
def run_train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
@ -59,6 +70,7 @@ def run_train(
num_epochs: int | None = None,
run_name: str | None = None,
seed: int | None = None,
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
):
if seed is not None:
seed_everything(seed)
@ -148,15 +160,44 @@ def run_train(
roi_mapper=roi_mapper,
)
train_logger = build_logger(
logger_config or TensorBoardLoggerConfig(),
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
root_artifact_path = (
Path(log_dir) if log_dir is not None else DEFAULT_LOG_DIR
)
root_artifact_path.mkdir(parents=True, exist_ok=True)
logging_context = TrainLoggingContext(
model_config=model_config,
train_config=train_config,
audio_config=audio_config,
targets=targets,
train_dataset=train_annotations,
val_dataset=val_annotations,
)
resolved_logging_callbacks = (
ConfigHyperparameterLogging(),
TargetConfigArtifactLogging(),
DataSummaryArtifactLogging(),
*logging_callbacks,
)
for callback in resolved_logging_callbacks:
callback.run(train_logger, root_artifact_path, logging_context)
trainer = trainer or build_trainer(
train_config,
logger_config=logger_config,
train_logger=train_logger,
evaluator=evaluator,
targets=targets,
roi_mapper=roi_mapper,
checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
@ -223,12 +264,11 @@ def _validate_model_compatibility(
def build_trainer(
config: TrainingConfig,
logger_config: LoggerConfig | None,
train_logger: Logger,
evaluator: "EvaluatorProtocol",
targets: "TargetProtocol",
roi_mapper: "ROIMapperProtocol",
checkpoint_dir: Path | None = None,
log_dir: Path | None = None,
experiment_name: str | None = None,
run_name: str | None = None,
num_epochs: int | None = None,
@ -239,20 +279,9 @@ def build_trainer(
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(
logger_config or TensorBoardLoggerConfig(),
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
)
if num_epochs is not None:
trainer_conf.max_epochs = num_epochs
train_logger.log_hyperparams(
config.model_dump(mode="json", exclude_none=True)
)
train_config = trainer_conf.model_dump(exclude_none=True)
return Trainer(

View File

@ -22,6 +22,10 @@ from batdetect2.train import (
load_model_from_checkpoint,
run_train,
)
from batdetect2.train.logging import (
DatasetConfigArtifact,
DatasetConfigArtifactLogging,
)
from batdetect2.train.optimizers import AdamOptimizerConfig
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
from batdetect2.train.train import build_training_module
@ -369,6 +373,59 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
assert rebuilt_model.dimension_names == ["width", "height"]
@pytest.mark.slow
def test_run_train_logs_training_artifacts(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],
example_dataset,
) -> None:
train_config = TrainingConfig.model_validate(
{
"trainer": {
"limit_train_batches": 1,
"limit_val_batches": 1,
"log_every_n_steps": 1,
},
"train_loader": {
"batch_size": 1,
"augmentations": {"enabled": False},
},
}
)
run_train(
train_annotations=example_annotations[:1],
val_annotations=example_annotations[:1],
train_config=train_config,
num_epochs=1,
train_workers=0,
val_workers=0,
checkpoint_dir=tmp_path / "checkpoints",
log_dir=tmp_path / "logs",
seed=0,
logging_callbacks=[
DatasetConfigArtifactLogging(
train_dataset_config=DatasetConfigArtifact(
filename="train_dataset.yaml",
config=example_dataset,
),
val_dataset_config=DatasetConfigArtifact(
filename="val_dataset.yaml",
config=example_dataset,
),
)
],
)
artifact_root = next((tmp_path / "logs").rglob("training_artifacts"))
assert (artifact_root / "targets.yaml").exists()
assert (artifact_root / "train_dataset.yaml").exists()
assert (artifact_root / "val_dataset.yaml").exists()
assert (artifact_root / "train_class_summary.csv").exists()
assert (artifact_root / "val_class_summary.csv").exists()
def test_run_train_rejects_incompatible_model_config(
example_annotations: list[data.ClipAnnotation],
) -> None: