mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
feat: log training provenance artifacts
This commit is contained in:
parent
7b2699786f
commit
5a974711b0
@ -15,7 +15,11 @@ if TYPE_CHECKING:
|
||||
from batdetect2.data import Dataset
|
||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
||||
from batdetect2.logging import (
|
||||
AppLoggingConfig,
|
||||
LoggerConfig,
|
||||
LoggingCallback,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
@ -35,6 +39,7 @@ if TYPE_CHECKING:
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.logging import TrainLoggingContext
|
||||
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
@ -106,6 +111,7 @@ class BatDetect2API:
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
):
|
||||
from batdetect2.train import run_train
|
||||
|
||||
@ -130,6 +136,7 @@ class BatDetect2API:
|
||||
train_config=train_config or self.train_config,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
logger_config=logger_config or self.logging_config.train,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
self.model.eval()
|
||||
return self
|
||||
@ -153,6 +160,7 @@ class BatDetect2API:
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune from a checkpoint using a new target definition."""
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
@ -231,6 +239,7 @@ class BatDetect2API:
|
||||
audio_config=api.audio_config,
|
||||
train_config=api.train_config,
|
||||
logger_config=logger_config or api.logging_config.train,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
api.model.eval()
|
||||
return api
|
||||
|
||||
@ -126,10 +126,14 @@ def finetune_command(
|
||||
"""Fine-tune a BatDetect2 checkpoint on a new target definition."""
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.data import load_dataset, load_dataset_config
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.logging import (
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
)
|
||||
|
||||
logger.info("Initiating fine-tuning process...")
|
||||
|
||||
@ -148,16 +152,34 @@ def finetune_command(
|
||||
else None
|
||||
)
|
||||
|
||||
train_annotations = load_dataset_from_config(
|
||||
train_dataset,
|
||||
base_dir=base_dir,
|
||||
train_dataset_conf = load_dataset_config(train_dataset)
|
||||
train_annotations = load_dataset(train_dataset_conf, base_dir=base_dir)
|
||||
|
||||
val_dataset_conf = (
|
||||
load_dataset_config(val_dataset) if val_dataset else None
|
||||
)
|
||||
val_annotations = None
|
||||
if val_dataset is not None:
|
||||
val_annotations = load_dataset_from_config(
|
||||
val_dataset,
|
||||
base_dir=base_dir,
|
||||
val_annotations = (
|
||||
load_dataset(val_dataset_conf, base_dir=base_dir)
|
||||
if val_dataset_conf
|
||||
else None
|
||||
)
|
||||
|
||||
logging_callbacks = [
|
||||
DatasetConfigArtifactLogging(
|
||||
train_dataset_config=DatasetConfigArtifact(
|
||||
filename="train_dataset.yaml",
|
||||
config=train_dataset_conf,
|
||||
),
|
||||
val_dataset_config=(
|
||||
DatasetConfigArtifact(
|
||||
filename="val_dataset.yaml",
|
||||
config=val_dataset_conf,
|
||||
)
|
||||
if val_dataset_conf
|
||||
else None
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
@ -185,4 +207,5 @@ def finetune_command(
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
logger_config=logging_conf.train if logging_conf is not None else None,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
|
||||
@ -145,7 +145,7 @@ def train_command(
|
||||
"""
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.data import load_dataset_config, load_dataset_from_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
@ -153,6 +153,10 @@ def train_command(
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train import TrainingConfig
|
||||
from batdetect2.train.logging import (
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
)
|
||||
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
@ -222,6 +226,23 @@ def train_command(
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
|
||||
logging_callbacks = [
|
||||
DatasetConfigArtifactLogging(
|
||||
train_dataset_config=DatasetConfigArtifact(
|
||||
filename="train_dataset.yaml",
|
||||
config=load_dataset_config(train_dataset),
|
||||
),
|
||||
val_dataset_config=(
|
||||
DatasetConfigArtifact(
|
||||
filename="val_dataset.yaml",
|
||||
config=load_dataset_config(val_dataset),
|
||||
)
|
||||
if val_dataset is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
if model_path is not None and model_conf is not None:
|
||||
raise click.UsageError(
|
||||
"--model-config cannot be used with --model. "
|
||||
@ -267,4 +288,5 @@ def train_command(
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
logging_callbacks=logging_callbacks,
|
||||
)
|
||||
|
||||
@ -24,9 +24,7 @@ from batdetect2.core.configs import BaseConfig
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import (
|
||||
Logger,
|
||||
)
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
@ -40,11 +38,15 @@ __all__ = [
|
||||
"DVCLiveConfig",
|
||||
"LoggerConfig",
|
||||
"MLFlowLoggerConfig",
|
||||
"LoggingCallback",
|
||||
"TensorBoardLoggerConfig",
|
||||
"build_logger",
|
||||
"enable_logging",
|
||||
"get_image_logger",
|
||||
"get_table_logger",
|
||||
"log_artifact_file",
|
||||
"log_config_artifact",
|
||||
"log_csv_artifact",
|
||||
]
|
||||
|
||||
|
||||
@ -120,6 +122,18 @@ class LoggerBuilder(Protocol, Generic[T]):
|
||||
) -> Logger: ...
|
||||
|
||||
|
||||
LoggingContext = TypeVar("LoggingContext", contravariant=True)
|
||||
|
||||
|
||||
class LoggingCallback(Protocol, Generic[LoggingContext]):
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: LoggingContext,
|
||||
) -> None: ...
|
||||
|
||||
|
||||
def create_dvclive_logger(
|
||||
config: DVCLiveConfig,
|
||||
log_dir: Path | None = None,
|
||||
@ -273,6 +287,71 @@ def build_logger(
|
||||
)
|
||||
|
||||
|
||||
def log_artifact_file(
|
||||
runtime_logger: Logger,
|
||||
path: Path,
|
||||
artifact_path: str = "artifacts",
|
||||
) -> None:
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
|
||||
if isinstance(runtime_logger, MLFlowLogger):
|
||||
runtime_logger.experiment.log_artifact( # type: ignore[call-arg]
|
||||
local_path=str(path),
|
||||
artifact_path=artifact_path,
|
||||
run_id=runtime_logger.run_id,
|
||||
)
|
||||
return
|
||||
|
||||
experiment = getattr(runtime_logger, "experiment", None)
|
||||
if experiment is not None and hasattr(experiment, "log_artifact"):
|
||||
experiment.log_artifact(path=path, name=path.name, copy=True)
|
||||
return
|
||||
|
||||
if isinstance(runtime_logger, (CSVLogger, TensorBoardLogger)):
|
||||
return
|
||||
|
||||
logger.warning(
|
||||
"Skipping artifact logging for unsupported logger type {logger_type}",
|
||||
logger_type=type(runtime_logger).__name__,
|
||||
)
|
||||
|
||||
|
||||
def log_config_artifact(
|
||||
logger: Logger,
|
||||
config: BaseConfig,
|
||||
filename: str,
|
||||
artifact_path: Path,
|
||||
) -> None:
|
||||
artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
path = artifact_path / filename
|
||||
path.write_text(config.to_yaml_string())
|
||||
log_artifact_file(
|
||||
logger,
|
||||
path,
|
||||
artifact_path=artifact_path.name,
|
||||
)
|
||||
|
||||
|
||||
def log_csv_artifact(
|
||||
logger: Logger,
|
||||
df: pd.DataFrame,
|
||||
filename: str,
|
||||
artifact_path: Path,
|
||||
) -> None:
|
||||
artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
path = artifact_path / filename
|
||||
df.to_csv(path, index=False)
|
||||
log_artifact_file(
|
||||
logger,
|
||||
path,
|
||||
artifact_path=artifact_path.name,
|
||||
)
|
||||
|
||||
|
||||
PlotLogger = Callable[[str, "Figure", int], None]
|
||||
|
||||
|
||||
|
||||
@ -4,10 +4,24 @@ from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
)
|
||||
from batdetect2.train.logging import (
|
||||
ConfigHyperparameterLogging,
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
DataSummaryArtifactLogging,
|
||||
TargetConfigArtifactLogging,
|
||||
TrainLoggingContext,
|
||||
)
|
||||
from batdetect2.train.train import build_trainer, run_train
|
||||
|
||||
__all__ = [
|
||||
"ConfigHyperparameterLogging",
|
||||
"DataSummaryArtifactLogging",
|
||||
"DEFAULT_CHECKPOINT_DIR",
|
||||
"DatasetConfigArtifact",
|
||||
"DatasetConfigArtifactLogging",
|
||||
"TargetConfigArtifactLogging",
|
||||
"TrainLoggingContext",
|
||||
"TrainingConfig",
|
||||
"TrainingModule",
|
||||
"build_trainer",
|
||||
|
||||
164
src/batdetect2/train/logging.py
Normal file
164
src/batdetect2/train/logging.py
Normal file
@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.data import Dataset, compute_class_summary
|
||||
from batdetect2.logging import log_config_artifact, log_csv_artifact
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
|
||||
__all__ = [
|
||||
"ConfigHyperparameterLogging",
|
||||
"DataSummaryArtifactLogging",
|
||||
"DatasetConfigArtifact",
|
||||
"DatasetConfigArtifactLogging",
|
||||
"TargetConfigArtifactLogging",
|
||||
"TrainLoggingContext",
|
||||
]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TrainLoggingContext:
|
||||
model_config: ModelConfig
|
||||
train_config: TrainingConfig
|
||||
audio_config: AudioConfig
|
||||
targets: TargetProtocol
|
||||
train_dataset: Dataset
|
||||
val_dataset: Dataset | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DatasetConfigArtifact:
|
||||
filename: str
|
||||
config: BaseConfig
|
||||
|
||||
|
||||
class ConfigHyperparameterLogging:
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
logger.log_hyperparams(
|
||||
{
|
||||
"model": context.model_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
"training": context.train_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
"audio": context.audio_config.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TargetConfigArtifactLogging:
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
targets_config = TargetConfig.model_validate(
|
||||
context.targets.get_config()
|
||||
)
|
||||
log_config_artifact(
|
||||
logger,
|
||||
targets_config,
|
||||
filename="targets.yaml",
|
||||
artifact_path=artifact_path / "training_artifacts",
|
||||
)
|
||||
|
||||
|
||||
class DatasetConfigArtifactLogging:
|
||||
def __init__(
|
||||
self,
|
||||
train_dataset_config: DatasetConfigArtifact,
|
||||
val_dataset_config: DatasetConfigArtifact | None = None,
|
||||
):
|
||||
self.train_dataset_config = train_dataset_config
|
||||
self.val_dataset_config = val_dataset_config
|
||||
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
training_artifact_path = artifact_path / "training_artifacts"
|
||||
|
||||
log_config_artifact(
|
||||
logger,
|
||||
self.train_dataset_config.config,
|
||||
filename=self.train_dataset_config.filename,
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
if self.val_dataset_config is not None:
|
||||
log_config_artifact(
|
||||
logger,
|
||||
self.val_dataset_config.config,
|
||||
filename=self.val_dataset_config.filename,
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
|
||||
class DataSummaryArtifactLogging:
|
||||
def run(
|
||||
self,
|
||||
logger: Logger,
|
||||
artifact_path: Path,
|
||||
context: TrainLoggingContext,
|
||||
) -> None:
|
||||
training_artifact_path = artifact_path / "training_artifacts"
|
||||
|
||||
log_csv_artifact(
|
||||
logger,
|
||||
_compute_class_summary_or_empty(
|
||||
context.train_dataset,
|
||||
context.targets,
|
||||
),
|
||||
filename="train_class_summary.csv",
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
if context.val_dataset is not None:
|
||||
log_csv_artifact(
|
||||
logger,
|
||||
_compute_class_summary_or_empty(
|
||||
context.val_dataset,
|
||||
context.targets,
|
||||
),
|
||||
filename="val_class_summary.csv",
|
||||
artifact_path=training_artifact_path,
|
||||
)
|
||||
|
||||
|
||||
def _compute_class_summary_or_empty(
|
||||
dataset: Sequence[data.ClipAnnotation],
|
||||
targets: TargetProtocol,
|
||||
) -> pd.DataFrame:
|
||||
try:
|
||||
return compute_class_summary(dataset, targets)
|
||||
except KeyError as error:
|
||||
if error.args != ("class_name",):
|
||||
raise
|
||||
|
||||
return pd.DataFrame(
|
||||
columns=["num calls", "num recordings", "duration", "call_rate"]
|
||||
)
|
||||
@ -3,6 +3,7 @@ from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from lightning import Trainer, seed_everything
|
||||
from lightning.pytorch.loggers import Logger
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
@ -10,6 +11,7 @@ from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluatorProtocol, build_evaluator
|
||||
from batdetect2.logging import (
|
||||
LoggerConfig,
|
||||
LoggingCallback,
|
||||
TensorBoardLoggerConfig,
|
||||
build_logger,
|
||||
)
|
||||
@ -28,6 +30,12 @@ from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
from batdetect2.train.logging import (
|
||||
ConfigHyperparameterLogging,
|
||||
DataSummaryArtifactLogging,
|
||||
TargetConfigArtifactLogging,
|
||||
TrainLoggingContext,
|
||||
)
|
||||
from batdetect2.train.types import ClipLabeller
|
||||
|
||||
__all__ = [
|
||||
@ -36,6 +44,9 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
||||
|
||||
|
||||
def run_train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
@ -59,6 +70,7 @@ def run_train(
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
logging_callbacks: Sequence[LoggingCallback[TrainLoggingContext]] = (),
|
||||
):
|
||||
if seed is not None:
|
||||
seed_everything(seed)
|
||||
@ -148,15 +160,44 @@ def run_train(
|
||||
roi_mapper=roi_mapper,
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
logger_config or TensorBoardLoggerConfig(),
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
root_artifact_path = (
|
||||
Path(log_dir) if log_dir is not None else DEFAULT_LOG_DIR
|
||||
)
|
||||
root_artifact_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logging_context = TrainLoggingContext(
|
||||
model_config=model_config,
|
||||
train_config=train_config,
|
||||
audio_config=audio_config,
|
||||
targets=targets,
|
||||
train_dataset=train_annotations,
|
||||
val_dataset=val_annotations,
|
||||
)
|
||||
|
||||
resolved_logging_callbacks = (
|
||||
ConfigHyperparameterLogging(),
|
||||
TargetConfigArtifactLogging(),
|
||||
DataSummaryArtifactLogging(),
|
||||
*logging_callbacks,
|
||||
)
|
||||
|
||||
for callback in resolved_logging_callbacks:
|
||||
callback.run(train_logger, root_artifact_path, logging_context)
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
train_config,
|
||||
logger_config=logger_config,
|
||||
train_logger=train_logger,
|
||||
evaluator=evaluator,
|
||||
targets=targets,
|
||||
roi_mapper=roi_mapper,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
num_epochs=num_epochs,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
@ -223,12 +264,11 @@ def _validate_model_compatibility(
|
||||
|
||||
def build_trainer(
|
||||
config: TrainingConfig,
|
||||
logger_config: LoggerConfig | None,
|
||||
train_logger: Logger,
|
||||
evaluator: "EvaluatorProtocol",
|
||||
targets: "TargetProtocol",
|
||||
roi_mapper: "ROIMapperProtocol",
|
||||
checkpoint_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
num_epochs: int | None = None,
|
||||
@ -239,20 +279,9 @@ def build_trainer(
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
logger_config or TensorBoardLoggerConfig(),
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
if num_epochs is not None:
|
||||
trainer_conf.max_epochs = num_epochs
|
||||
|
||||
train_logger.log_hyperparams(
|
||||
config.model_dump(mode="json", exclude_none=True)
|
||||
)
|
||||
|
||||
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||
|
||||
return Trainer(
|
||||
|
||||
@ -22,6 +22,10 @@ from batdetect2.train import (
|
||||
load_model_from_checkpoint,
|
||||
run_train,
|
||||
)
|
||||
from batdetect2.train.logging import (
|
||||
DatasetConfigArtifact,
|
||||
DatasetConfigArtifactLogging,
|
||||
)
|
||||
from batdetect2.train.optimizers import AdamOptimizerConfig
|
||||
from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
||||
from batdetect2.train.train import build_training_module
|
||||
@ -369,6 +373,59 @@ def test_build_model_with_new_targets_reuses_backbone_and_rebuilds_heads() -> (
|
||||
assert rebuilt_model.dimension_names == ["width", "height"]
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_run_train_logs_training_artifacts(
|
||||
tmp_path: Path,
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
example_dataset,
|
||||
) -> None:
|
||||
train_config = TrainingConfig.model_validate(
|
||||
{
|
||||
"trainer": {
|
||||
"limit_train_batches": 1,
|
||||
"limit_val_batches": 1,
|
||||
"log_every_n_steps": 1,
|
||||
},
|
||||
"train_loader": {
|
||||
"batch_size": 1,
|
||||
"augmentations": {"enabled": False},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
run_train(
|
||||
train_annotations=example_annotations[:1],
|
||||
val_annotations=example_annotations[:1],
|
||||
train_config=train_config,
|
||||
num_epochs=1,
|
||||
train_workers=0,
|
||||
val_workers=0,
|
||||
checkpoint_dir=tmp_path / "checkpoints",
|
||||
log_dir=tmp_path / "logs",
|
||||
seed=0,
|
||||
logging_callbacks=[
|
||||
DatasetConfigArtifactLogging(
|
||||
train_dataset_config=DatasetConfigArtifact(
|
||||
filename="train_dataset.yaml",
|
||||
config=example_dataset,
|
||||
),
|
||||
val_dataset_config=DatasetConfigArtifact(
|
||||
filename="val_dataset.yaml",
|
||||
config=example_dataset,
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
artifact_root = next((tmp_path / "logs").rglob("training_artifacts"))
|
||||
|
||||
assert (artifact_root / "targets.yaml").exists()
|
||||
assert (artifact_root / "train_dataset.yaml").exists()
|
||||
assert (artifact_root / "val_dataset.yaml").exists()
|
||||
assert (artifact_root / "train_class_summary.csv").exists()
|
||||
assert (artifact_root / "val_class_summary.csv").exists()
|
||||
|
||||
|
||||
def test_run_train_rejects_incompatible_model_config(
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
) -> None:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user