mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: log training provenance artifacts
This commit is contained in:
parent
7b2699786f
commit
5a974711b0
@ -15,7 +15,11 @@ if TYPE_CHECKING:
|
|||||||
from batdetect2.data import Dataset
|
from batdetect2.data import Dataset
|
||||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
from batdetect2.logging import (
|
||||||
|
AppLoggingConfig,
|
||||||
|
LoggerConfig,
|
||||||
|
LoggingCallback,
|
||||||
|
)
|
||||||
from batdetect2.models import Model, ModelConfig
|
from batdetect2.models import Model, ModelConfig
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputFormatConfig,
|
OutputFormatConfig,
|
||||||
@ -35,6 +39,7 @@ if TYPE_CHECKING:
|
|||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.train import TrainingConfig
|
from batdetect2.train import TrainingConfig
|
||||||
|
from batdetect2.train.logging import TrainLoggingContext
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
@ -106,6 +111,7 @@ class BatDetect2API:
|
|||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
|
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||||
):
|
):
|
||||||
from batdetect2.train import run_train
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
@ -130,6 +136,7 @@ class BatDetect2API:
|
|||||||
train_config=train_config or self.train_config,
|
train_config=train_config or self.train_config,
|
||||||
audio_config=audio_config or self.audio_config,
|
audio_config=audio_config or self.audio_config,
|
||||||
logger_config=logger_config or self.logging_config.train,
|
logger_config=logger_config or self.logging_config.train,
|
||||||
|
logging_callbacks=logging_callbacks,
|
||||||
)
|
)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
return self
|
return self
|
||||||
@ -153,6 +160,7 @@ class BatDetect2API:
|
|||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
|
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
"""Fine-tune from a checkpoint using a new target definition."""
|
"""Fine-tune from a checkpoint using a new target definition."""
|
||||||
from batdetect2.evaluate import build_evaluator
|
from batdetect2.evaluate import build_evaluator
|
||||||
@ -231,6 +239,7 @@ class BatDetect2API:
|
|||||||
audio_config=api.audio_config,
|
audio_config=api.audio_config,
|
||||||
train_config=api.train_config,
|
train_config=api.train_config,
|
||||||
logger_config=logger_config or api.logging_config.train,
|
logger_config=logger_config or api.logging_config.train,
|
||||||
|
logging_callbacks=logging_callbacks,
|
||||||
)
|
)
|
||||||
api.model.eval()
|
api.model.eval()
|
||||||
return api
|
return api
|
||||||
|
|||||||
@ -126,10 +126,14 @@ def finetune_command(
|
|||||||
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset, load_dataset_config
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
from batdetect2.targets import TargetConfig
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train import TrainingConfig
|
from batdetect2.train import TrainingConfig
|
||||||
|
from batdetect2.train.logging import (
|
||||||
|
DatasetConfigArtifact,
|
||||||
|
DatasetConfigArtifactLogging,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Initiating fine-tuning process...")
|
logger.info("Initiating fine-tuning process...")
|
||||||
|
|
||||||
@ -148,16 +152,34 @@ def finetune_command(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
train_annotations = load_dataset_from_config(
|
train_dataset_conf = load_dataset_config(train_dataset)
|
||||||
train_dataset,
|
train_annotations = load_dataset(train_dataset_conf, base_dir=base_dir)
|
||||||
base_dir=base_dir,
|
|
||||||
|
val_dataset_conf = (
|
||||||
|
load_dataset_config(val_dataset) if val_dataset else None
|
||||||
)
|
)
|
||||||
val_annotations = None
|
val_annotations = (
|
||||||
if val_dataset is not None:
|
load_dataset(val_dataset_conf, base_dir=base_dir)
|
||||||
val_annotations = load_dataset_from_config(
|
if val_dataset_conf
|
||||||
val_dataset,
|
else None
|
||||||
base_dir=base_dir,
|
)
|
||||||
|
|
||||||
|
logging_callbacks = [
|
||||||
|
DatasetConfigArtifactLogging(
|
||||||
|
train_dataset_config=DatasetConfigArtifact(
|
||||||
|
filename="train_dataset.yaml",
|
||||||
|
config=train_dataset_conf,
|
||||||
|
),
|
||||||
|
val_dataset_config=(
|
||||||
|
DatasetConfigArtifact(
|
||||||
|
filename="val_dataset.yaml",
|
||||||
|
config=val_dataset_conf,
|
||||||
|
)
|
||||||
|
if val_dataset_conf
|
||||||
|
else None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
]
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
@ -185,4 +207,5 @@ def finetune_command(
|
|||||||
train_config=train_conf,
|
train_config=train_conf,
|
||||||
audio_config=audio_conf,
|
audio_config=audio_conf,
|
||||||
logger_config=logging_conf.train if logging_conf is not None else None,
|
logger_config=logging_conf.train if logging_conf is not None else None,
|
||||||
|
logging_callbacks=logging_callbacks,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -145,7 +145,7 @@ def train_command(
|
|||||||
"""
|
"""
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_config, load_dataset_from_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
@ -153,6 +153,10 @@ def train_command(
|
|||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
from batdetect2.targets import TargetConfig
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train import TrainingConfig
|
from batdetect2.train import TrainingConfig
|
||||||
|
from batdetect2.train.logging import (
|
||||||
|
DatasetConfigArtifact,
|
||||||
|
DatasetConfigArtifactLogging,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
@ -222,6 +226,23 @@ def train_command(
|
|||||||
|
|
||||||
logger.info("Configuration and data loaded. Starting training...")
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
|
|
||||||
|
logging_callbacks = [
|
||||||
|
DatasetConfigArtifactLogging(
|
||||||
|
train_dataset_config=DatasetConfigArtifact(
|
||||||
|
filename="train_dataset.yaml",
|
||||||
|
config=load_dataset_config(train_dataset),
|
||||||
|
),
|
||||||
|
val_dataset_config=(
|
||||||
|
DatasetConfigArtifact(
|
||||||
|
filename="val_dataset.yaml",
|
||||||
|
config=load_dataset_config(val_dataset),
|
||||||
|
)
|
||||||
|
if val_dataset is not None
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
if model_path is not None and model_conf is not None:
|
if model_path is not None and model_conf is not None:
|
||||||
raise click.UsageError(
|
raise click.UsageError(
|
||||||
"--model-config cannot be used with --model. "
|
"--model-config cannot be used with --model. "
|
||||||
@ -267,4 +288,5 @@ def train_command(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
seed=seed,
|
seed=seed,
|
||||||
|
logging_callbacks=logging_callbacks,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -24,9 +24,7 @@ from batdetect2.core.configs import BaseConfig
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from lightning.pytorch.loggers import (
|
from lightning.pytorch.loggers import Logger
|
||||||
Logger,
|
|
||||||
)
|
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -40,11 +38,15 @@ __all__ = [
|
|||||||
"DVCLiveConfig",
|
"DVCLiveConfig",
|
||||||
"LoggerConfig",
|
"LoggerConfig",
|
||||||
"MLFlowLoggerConfig",
|
"MLFlowLoggerConfig",
|
||||||
|
"LoggingCallback",
|
||||||
"TensorBoardLoggerConfig",
|
"TensorBoardLoggerConfig",
|
||||||
"build_logger",
|
"build_logger",
|
||||||
"enable_logging",
|
"enable_logging",
|
||||||
"get_image_logger",
|
"get_image_logger",
|
||||||
"get_table_logger",
|
"get_table_logger",
|
||||||
|
"log_artifact_file",
|
||||||
|
"log_config_artifact",
|
||||||
|
"log_csv_artifact",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -120,6 +122,18 @@ class LoggerBuilder(Protocol, Generic[T]):
|
|||||||
) -> Logger: ...
|
) -> Logger: ...
|
||||||
|
|
||||||
|
|
||||||
|
LoggingContext = TypeVar("LoggingContext", contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingCallback(Protocol, Generic[LoggingContext]):
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
logger: Logger,
|
||||||
|
artifact_path: Path,
|
||||||
|
context: LoggingContext,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(
|
def create_dvclive_logger(
|
||||||
config: DVCLiveConfig,
|
config: DVCLiveConfig,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
@ -273,6 +287,71 @@ def build_logger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_artifact_file(
|
||||||
|
runtime_logger: Logger,
|
||||||
|
path: Path,
|
||||||
|
artifact_path: str = "artifacts",
|
||||||
|
) -> None:
|
||||||
|
from lightning.pytorch.loggers import (
|
||||||
|
CSVLogger,
|
||||||
|
MLFlowLogger,
|
||||||
|
TensorBoardLogger,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(runtime_logger, MLFlowLogger):
|
||||||
|
runtime_logger.experiment.log_artifact( # type: ignore[call-arg]
|
||||||
|
local_path=str(path),
|
||||||
|
artifact_path=artifact_path,
|
||||||
|
run_id=runtime_logger.run_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
experiment = getattr(runtime_logger, "experiment", None)
|
||||||
|
if experiment is not None and hasattr(experiment, "log_artifact"):
|
||||||
|
experiment.log_artifact(path=path, name=path.name, copy=True)
|
||||||
|
return
|
||||||
|
|
||||||
|
if isinstance(runtime_logger, (CSVLogger, TensorBoardLogger)):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"Skipping artifact logging for unsupported logger type {logger_type}",
|
||||||
|
logger_type=type(runtime_logger).__name__,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_config_artifact(
|
||||||
|
logger: Logger,
|
||||||
|
config: BaseConfig,
|
||||||
|
filename: str,
|
||||||
|
artifact_path: Path,
|
||||||
|
) -> None:
|
||||||
|
artifact_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = artifact_path / filename
|
||||||
|
path.write_text(config.to_yaml_string())
|
||||||
|
log_artifact_file(
|
||||||
|
logger,
|
||||||
|
path,
|
||||||
|
artifact_path=artifact_path.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_csv_artifact(
|
||||||
|
logger: Logger,
|
||||||
|
df: pd.DataFrame,
|
||||||
|
filename: str,
|
||||||
|
artifact_path: Path,
|
||||||
|
) -> None:
|
||||||
|
artifact_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
path = artifact_path / filename
|
||||||
|
df.to_csv(path, index=False)
|
||||||
|
log_artifact_file(
|
||||||
|
logger,
|
||||||
|
path,
|
||||||
|
artifact_path=artifact_path.name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
PlotLogger = Callable[[str, "Figure", int], None]
|
PlotLogger = Callable[[str, "Figure", int], None]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -4,10 +4,24 @@ from batdetect2.train.lightning import (
|
|||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.logging import (
|
||||||
|
ConfigHyperparameterLogging,
|
||||||
|
DatasetConfigArtifact,
|
||||||
|
DatasetConfigArtifactLogging,
|
||||||
|
DataSummaryArtifactLogging,
|
||||||
|
TargetConfigArtifactLogging,
|
||||||
|
TrainLoggingContext,
|
||||||
|
)
|
||||||
from batdetect2.train.train import build_trainer, run_train
|
from batdetect2.train.train import build_trainer, run_train
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
"ConfigHyperparameterLogging",
|
||||||
|
"DataSummaryArtifactLogging",
|
||||||
"DEFAULT_CHECKPOINT_DIR",
|
"DEFAULT_CHECKPOINT_DIR",
|
||||||
|
"DatasetConfigArtifact",
|
||||||
|
"DatasetConfigArtifactLogging",
|
||||||
|
"TargetConfigArtifactLogging",
|
||||||
|
"TrainLoggingContext",
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
|
|||||||
164
src/batdetect2/train/logging.py
Normal file
164
src/batdetect2/train/logging.py
Normal file
@ -0,0 +1,164 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.audio import AudioConfig
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.data import Dataset, compute_class_summary
|
||||||
|
from batdetect2.logging import log_config_artifact, log_csv_artifact
|
||||||
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.targets import TargetConfig, TargetProtocol
|
||||||
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ConfigHyperparameterLogging",
|
||||||
|
"DataSummaryArtifactLogging",
|
||||||
|
"DatasetConfigArtifact",
|
||||||
|
"DatasetConfigArtifactLogging",
|
||||||
|
"TargetConfigArtifactLogging",
|
||||||
|
"TrainLoggingContext",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainLoggingContext:
|
||||||
|
model_config: ModelConfig
|
||||||
|
train_config: TrainingConfig
|
||||||
|
audio_config: AudioConfig
|
||||||
|
targets: TargetProtocol
|
||||||
|
train_dataset: Dataset
|
||||||
|
val_dataset: Dataset | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class DatasetConfigArtifact:
|
||||||
|
filename: str
|
||||||
|
config: BaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigHyperparameterLogging:
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
logger: Logger,
|
||||||
|
artifact_path: Path,
|
||||||
|
context: TrainLoggingContext,
|
||||||
|
) -> None:
|
||||||
|
logger.log_hyperparams(
|
||||||
|
{
|
||||||
|
"model": context.model_config.model_dump(
|
||||||
|
mode="json",
|
||||||
|
exclude_none=True,
|
||||||
|
),
|
||||||
|
"training": context.train_config.model_dump(
|
||||||
|
mode="json",
|
||||||
|
exclude_none=True,
|
||||||
|
),
|
||||||
|
"audio": context.audio_config.model_dump(
|
||||||
|
mode="json",
|
||||||
|
exclude_none=True,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TargetConfigArtifactLogging:
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
logger: Logger,
|
||||||
|
artifact_path: Path,
|
||||||
|
context: TrainLoggingContext,
|
||||||
|
) -> None:
|
||||||
|
targets_config = TargetConfig.model_validate(
|
||||||
|
context.targets.get_config()
|
||||||
|
)
|
||||||
|
log_config_artifact(
|
||||||
|
logger,
|
||||||
|
targets_config,
|
||||||
|
filename="targets.yaml",
|
||||||
|
artifact_path=artifact_path / "training_artifacts",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetConfigArtifactLogging:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
train_dataset_config: DatasetConfigArtifact,
|
||||||
|
val_dataset_config: DatasetConfigArtifact | None = None,
|
||||||
|
):
|
||||||
|
self.train_dataset_config = train_dataset_config
|
||||||
|
self.val_dataset_config = val_dataset_config
|
||||||
|
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
logger: Logger,
|
||||||
|
artifact_path: Path,
|
||||||
|
context: TrainLoggingContext,
|
||||||
|
) -> None:
|
||||||
|
training_artifact_path = artifact_path / "training_artifacts"
|
||||||
|
|
||||||
|
log_config_artifact(
|
||||||
|
logger,
|
||||||
|
self.train_dataset_config.config,
|
||||||
|
filename=self.train_dataset_config.filename,
|
||||||
|
artifact_path=training_artifact_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.val_dataset_config is not None:
|
||||||
|
log_config_artifact(
|
||||||
|
logger,
|
||||||
|
self.val_dataset_config.config,
|
||||||
|
filename=self.val_dataset_config.filename,
|
||||||
|
artifact_path=training_artifact_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DataSummaryArtifactLogging:
|
||||||
|
def run(
|
||||||
|
self,
|
||||||
|
logger: Logger,
|
||||||
|
artifact_path: Path,
|
||||||
|
context: TrainLoggingContext,
|
||||||
|
) -> None:
|
||||||
|
training_artifact_path = artifact_path / "training_artifacts"
|
||||||
|
|
||||||
|
log_csv_artifact(
|
||||||
|
logger,
|
||||||
|
_compute_class_summary_or_empty(
|
||||||
|
context.train_dataset,
|
||||||
|
context.targets,
|
||||||
|
),
|
||||||
|
filename="train_class_summary.csv",
|
||||||
|
artifact_path=training_artifact_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
if context.val_dataset is not None:
|
||||||
|
log_csv_artifact(
|
||||||
|
logger,
|
||||||
|
_compute_class_summary_or_empty(
|
||||||
|
context.val_dataset,
|
||||||
|
context.targets,
|
||||||
|
),
|
||||||
|
filename="val_class_summary.csv",
|
||||||
|
artifact_path=training_artifact_path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_class_summary_or_empty(
|
||||||
|
dataset: Sequence[data.ClipAnnotation],
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> pd.DataFrame:
|
||||||
|
try:
|
||||||
|
return compute_class_summary(dataset, targets)
|
||||||
|
except KeyError as error:
|
||||||
|
if error.args != ("class_name",):
|
||||||
|
raise
|
||||||
|
|
||||||
|
return pd.DataFrame(
|
||||||
|
columns=["num calls", "num recordings", "duration", "call_rate"]
|
||||||
|
)
|
||||||
@ -3,6 +3,7 @@ from pathlib import Path
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from lightning import Trainer, seed_everything
|
from lightning import Trainer, seed_everything
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -10,6 +11,7 @@ from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
|||||||
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
||||||
from batdetect2.logging import (
|
from batdetect2.logging import (
|
||||||
LoggerConfig,
|
LoggerConfig,
|
||||||
|
LoggingCallback,
|
||||||
TensorBoardLoggerConfig,
|
TensorBoardLoggerConfig,
|
||||||
build_logger,
|
build_logger,
|
||||||
)
|
)
|
||||||
@ -28,6 +30,12 @@ from batdetect2.train.config import TrainingConfig
|
|||||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
from batdetect2.train.logging import (
|
||||||
|
ConfigHyperparameterLogging,
|
||||||
|
DataSummaryArtifactLogging,
|
||||||
|
TargetConfigArtifactLogging,
|
||||||
|
TrainLoggingContext,
|
||||||
|
)
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -36,6 +44,9 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
||||||
|
|
||||||
|
|
||||||
def run_train(
|
def run_train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
@ -59,6 +70,7 @@ def run_train(
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
|
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||||
):
|
):
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
@ -148,15 +160,44 @@ def run_train(
|
|||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
train_logger = build_logger(
|
||||||
|
logger_config or TensorBoardLoggerConfig(),
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
|
)
|
||||||
|
root_artifact_path = (
|
||||||
|
Path(log_dir) if log_dir is not None else DEFAULT_LOG_DIR
|
||||||
|
)
|
||||||
|
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
logging_context = TrainLoggingContext(
|
||||||
|
model_config=model_config,
|
||||||
|
train_config=train_config,
|
||||||
|
audio_config=audio_config,
|
||||||
|
targets=targets,
|
||||||
|
train_dataset=train_annotations,
|
||||||
|
val_dataset=val_annotations,
|
||||||
|
)
|
||||||
|
|
||||||
|
resolved_logging_callbacks = (
|
||||||
|
ConfigHyperparameterLogging(),
|
||||||
|
TargetConfigArtifactLogging(),
|
||||||
|
DataSummaryArtifactLogging(),
|
||||||
|
*logging_callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
for callback in resolved_logging_callbacks:
|
||||||
|
callback.run(train_logger, root_artifact_path, logging_context)
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
train_config,
|
train_config,
|
||||||
logger_config=logger_config,
|
train_logger=train_logger,
|
||||||
evaluator=evaluator,
|
evaluator=evaluator,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
@ -223,12 +264,11 @@ def _validate_model_compatibility(
|
|||||||
|
|
||||||
def build_trainer(
|
def build_trainer(
|
||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
logger_config: LoggerConfig | None,
|
train_logger: Logger,
|
||||||
evaluator: "EvaluatorProtocol",
|
evaluator: "EvaluatorProtocol",
|
||||||
targets: "TargetProtocol",
|
targets: "TargetProtocol",
|
||||||
roi_mapper: "ROIMapperProtocol",
|
roi_mapper: "ROIMapperProtocol",
|
||||||
checkpoint_dir: Path | None = None,
|
checkpoint_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
@ -239,20 +279,9 @@ def build_trainer(
|
|||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
|
|
||||||
train_logger = build_logger(
|
|
||||||
logger_config or TensorBoardLoggerConfig(),
|
|
||||||
log_dir=log_dir,
|
|
||||||
experiment_name=experiment_name,
|
|
||||||
run_name=run_name,
|
|
||||||
)
|
|
||||||
|
|
||||||
if num_epochs is not None:
|
if num_epochs is not None:
|
||||||
trainer_conf.max_epochs = num_epochs
|
trainer_conf.max_epochs = num_epochs
|
||||||
|
|
||||||
train_logger.log_hyperparams(
|
|
||||||
config.model_dump(mode="json", exclude_none=True)
|
|
||||||
)
|
|
||||||
|
|
||||||
train_config = trainer_conf.model_dump(exclude_none=True)
|
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
|
|||||||
@ -22,6 +22,10 @@ from batdetect2.train import (
|
|||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
run_train,
|
run_train,
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.logging import (
|
||||||
|
DatasetConfigArtifact,
|
||||||
|
DatasetConfigArtifactLogging,
|
||||||
|
)
|
||||||
from batdetect2.train.optimizers import AdamOptimizerConfig
|
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||||
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
@ -369,6 +373,59 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
|||||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.slow
|
||||||
|
def test_run_train_logs_training_artifacts(
|
||||||
|
tmp_path: Path,
|
||||||
|
example_annotations: list[data.ClipAnnotation],
|
||||||
|
example_dataset,
|
||||||
|
) -> None:
|
||||||
|
train_config = TrainingConfig.model_validate(
|
||||||
|
{
|
||||||
|
"trainer": {
|
||||||
|
"limit_train_batches": 1,
|
||||||
|
"limit_val_batches": 1,
|
||||||
|
"log_every_n_steps": 1,
|
||||||
|
},
|
||||||
|
"train_loader": {
|
||||||
|
"batch_size": 1,
|
||||||
|
"augmentations": {"enabled": False},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
run_train(
|
||||||
|
train_annotations=example_annotations[:1],
|
||||||
|
val_annotations=example_annotations[:1],
|
||||||
|
train_config=train_config,
|
||||||
|
num_epochs=1,
|
||||||
|
train_workers=0,
|
||||||
|
val_workers=0,
|
||||||
|
checkpoint_dir=tmp_path / "checkpoints",
|
||||||
|
log_dir=tmp_path / "logs",
|
||||||
|
seed=0,
|
||||||
|
logging_callbacks=[
|
||||||
|
DatasetConfigArtifactLogging(
|
||||||
|
train_dataset_config=DatasetConfigArtifact(
|
||||||
|
filename="train_dataset.yaml",
|
||||||
|
config=example_dataset,
|
||||||
|
),
|
||||||
|
val_dataset_config=DatasetConfigArtifact(
|
||||||
|
filename="val_dataset.yaml",
|
||||||
|
config=example_dataset,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
artifact_root = next((tmp_path / "logs").rglob("training_artifacts"))
|
||||||
|
|
||||||
|
assert (artifact_root / "targets.yaml").exists()
|
||||||
|
assert (artifact_root / "train_dataset.yaml").exists()
|
||||||
|
assert (artifact_root / "val_dataset.yaml").exists()
|
||||||
|
assert (artifact_root / "train_class_summary.csv").exists()
|
||||||
|
assert (artifact_root / "val_class_summary.csv").exists()
|
||||||
|
|
||||||
|
|
||||||
def test_run_train_rejects_incompatible_model_config(
|
def test_run_train_rejects_incompatible_model_config(
|
||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user