This commit is contained in:
mbsantiago 2025-09-08 22:04:30 +01:00
parent c73984b213
commit cd4955d4f3
12 changed files with 341 additions and 53 deletions

View File

@ -108,6 +108,9 @@ train:
labels: labels:
sigma: 3 sigma: 3
trainer:
max_epochs: 40
dataloaders: dataloaders:
train: train:
batch_size: 8 batch_size: 8
@ -115,7 +118,7 @@ train:
shuffle: True shuffle: True
val: val:
batch_size: 8 batch_size: 1
num_workers: 2 num_workers: 2
loss: loss:
@ -133,7 +136,7 @@ train:
weight: 0.1 weight: 0.1
logger: logger:
logger_type: tensorboard logger_type: csv
# save_dir: outputs/log/ # save_dir: outputs/log/
# name: logs # name: logs

View File

@ -1,6 +1,7 @@
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect from batdetect2.cli.compat import detect
from batdetect2.cli.data import data from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.train import train_command from batdetect2.cli.train import train_command
__all__ = [ __all__ = [
@ -8,6 +9,7 @@ __all__ = [
"detect", "detect",
"data", "data",
"train_command", "train_command",
"evaluate_command",
] ]

View File

@ -0,0 +1,63 @@
import sys
from pathlib import Path
from typing import Optional
import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
__all__ = ["evaluate_command"]
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--output-dir", type=click.Path())
@click.option("--workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command(
model_path: Path,
test_dataset: Path,
output_dir: Optional[Path] = None,
workers: Optional[int] = None,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(test_dataset)
logger.debug(
"Loaded {num_annotations} test examples",
num_annotations=len(test_annotations),
)
model, train_config = load_model_from_checkpoint(model_path)
df, results = evaluate(
model,
test_annotations,
config=train_config,
num_workers=workers,
)
print(results)
if output_dir:
df.to_csv(output_dir / "results.csv")

View File

@ -20,6 +20,8 @@ __all__ = ["train_command"]
@click.argument("train_dataset", type=click.Path(exists=True)) @click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True)) @click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@ -34,6 +36,8 @@ def train_command(
train_dataset: Path, train_dataset: Path,
val_dataset: Optional[Path] = None, val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None, model_path: Optional[Path] = None,
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None, config: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
train_workers: int = 0, train_workers: int = 0,
@ -83,4 +87,6 @@ def train_command(
model_path=model_path, model_path=model_path,
train_workers=train_workers, train_workers=train_workers,
val_workers=val_workers, val_workers=val_workers,
log_dir=log_dir,
checkpoint_dir=ckpt_dir,
) )

View File

@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
and serialization capabilities. and serialization capabilities.
""" """
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="ignore")
def to_yaml_string( def to_yaml_string(
self, self,

View File

@ -0,0 +1,62 @@
from typing import List
import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import MatchEvaluation
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
data = []
for match in matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
sound_event_annotation = match.sound_event_annotation
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
)
if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
compute_bounds(match.pred_geometry)
)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.pred_class,
("pred", "class_score"): match.pred_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df

View File

@ -0,0 +1,100 @@
from typing import List, Optional, Tuple
import pandas as pd
from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.match import match_all_predictions
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import Model
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.postprocess import get_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.train import build_val_loader
def evaluate(
model: Model,
test_annotations: List[data.ClipAnnotation],
config: Optional[FullTrainingConfig] = None,
num_workers: Optional[int] = None,
) -> Tuple[pd.DataFrame, dict]:
config = config or FullTrainingConfig()
audio_loader = build_audio_loader(config.preprocess.audio)
preprocessor = build_preprocessor(config.preprocess)
targets = build_targets(config.targets)
labeller = build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.train.labels,
)
loader = build_val_loader(
test_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train,
num_workers=num_workers,
)
dataset: ValidationDataset = loader.dataset # type: ignore
clip_annotations = []
predictions = []
for batch in loader:
outputs = model.detector(batch.spec)
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
predictions = get_raw_predictions(
outputs,
clips=[
clip_annotation.clip for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,
)
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matches = match_all_predictions(
clip_annotations,
predictions,
targets=targets,
config=config.evaluation.match,
)
df = extract_matches_dataframe(matches)
metrics = [
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(class_names=targets.class_names),
ClassificationAccuracy(class_names=targets.class_names),
]
results = {
name: value
for metric in metrics
for name, value in metric(matches).items()
}
return df, results

View File

@ -29,7 +29,6 @@ provided here.
from typing import List, Optional from typing import List, Optional
import torch import torch
from lightning import LightningModule
from pydantic import Field from pydantic import Field
from soundevent.data import PathLike from soundevent.data import PathLike
@ -105,7 +104,16 @@ __all__ = [
] ]
class Model(LightningModule): class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
class Model(torch.nn.Module):
detector: DetectionModel detector: DetectionModel
preprocessor: PreprocessorProtocol preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol postprocessor: PostprocessorProtocol
@ -117,13 +125,14 @@ class Model(LightningModule):
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol, postprocessor: PostprocessorProtocol,
targets: TargetProtocol, targets: TargetProtocol,
config: ModelConfig,
): ):
super().__init__() super().__init__()
self.detector = detector self.detector = detector
self.preprocessor = preprocessor self.preprocessor = preprocessor
self.postprocessor = postprocessor self.postprocessor = postprocessor
self.targets = targets self.targets = targets
self.save_hyperparameters() self.config = config
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]: def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
spec = self.preprocessor(wav) spec = self.preprocessor(wav)
@ -131,29 +140,24 @@ class Model(LightningModule):
return self.postprocessor(outputs) return self.postprocessor(outputs)
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
def build_model(config: Optional[ModelConfig] = None): def build_model(config: Optional[ModelConfig] = None):
config = config or ModelConfig() config = config or ModelConfig()
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess) preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor( postprocessor = build_postprocessor(
preprocessor=preprocessor, preprocessor=preprocessor,
config=config.postprocess, config=config.postprocess,
) )
detector = build_detector( detector = build_detector(
num_classes=len(targets.class_names), num_classes=len(targets.class_names),
config=config.model, config=config.model,
) )
return Model( return Model(
config=config,
detector=detector, detector=detector,
postprocessor=postprocessor, postprocessor=postprocessor,
preprocessor=preprocessor, preprocessor=preprocessor,

View File

@ -6,7 +6,6 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG, DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig, AugmentationsConfig,
@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig):
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig) labels: LabelConfig = Field(default_factory=LabelConfig)

View File

@ -1,9 +1,14 @@
from typing import Optional, Tuple
import lightning as L import lightning as L
import torch import torch
from soundevent.data import PathLike
from torch.optim.adam import Adam from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import Model from batdetect2.models import Model, build_model
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample from batdetect2.typing import ModelOutput, TrainExample
__all__ = [ __all__ = [
@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule):
def __init__( def __init__(
self, self,
model: Model, config: FullTrainingConfig,
loss: torch.nn.Module,
learning_rate: float = 0.001, learning_rate: float = 0.001,
t_max: int = 100, t_max: int = 100,
model: Optional[Model] = None,
loss: Optional[torch.nn.Module] = None,
): ):
super().__init__() super().__init__()
self.save_hyperparameters(logger=False)
self.config = config
self.learning_rate = learning_rate self.learning_rate = learning_rate
self.t_max = t_max self.t_max = t_max
if loss is None:
loss = build_loss(self.config.train.loss)
if model is None:
model = build_model(self.config)
self.loss = loss self.loss = loss
self.model = model self.model = model
self.save_hyperparameters(logger=False)
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.model.detector(spec)
def training_step(self, batch: TrainExample): def training_step(self, batch: TrainExample):
outputs = self.model.detector(batch.spec) outputs = self.model.detector(batch.spec)
@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule):
optimizer = Adam(self.parameters(), lr=self.learning_rate) optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max) scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
return [optimizer], [scheduler] return [optimizer], [scheduler]
def load_model_from_checkpoint(
path: PathLike,
) -> Tuple[Model, FullTrainingConfig]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config

View File

@ -5,10 +5,11 @@ import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
from loguru import logger from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "logs" DEFAULT_LOGS_DIR: str = "outputs"
class DVCLiveConfig(BaseConfig): class DVCLiveConfig(BaseConfig):
@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig):
class TensorBoardLoggerConfig(BaseConfig): class TensorBoardLoggerConfig(BaseConfig):
logger_type: Literal["tensorboard"] = "tensorboard" logger_type: Literal["tensorboard"] = "tensorboard"
save_dir: str = DEFAULT_LOGS_DIR save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "default" name: Optional[str] = "logs"
version: Optional[str] = None version: Optional[str] = None
log_graph: bool = False log_graph: bool = False
@ -57,7 +58,10 @@ LoggerConfig = Annotated[
] ]
def create_dvclive_logger(config: DVCLiveConfig) -> Logger: def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
try: try:
from dvclive.lightning import DVCLiveLogger # type: ignore from dvclive.lightning import DVCLiveLogger # type: ignore
except ImportError as error: except ImportError as error:
@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
) from error ) from error
return DVCLiveLogger( return DVCLiveLogger(
dir=config.dir, dir=log_dir if log_dir is not None else config.dir,
run_name=config.run_name, run_name=config.run_name,
prefix=config.prefix, prefix=config.prefix,
log_model=config.log_model, log_model=config.log_model,
@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
) )
def create_csv_logger(config: CSVLoggerConfig) -> Logger: def create_csv_logger(
config: CSVLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
from lightning.pytorch.loggers import CSVLogger from lightning.pytorch.loggers import CSVLogger
return CSVLogger( return CSVLogger(
save_dir=config.save_dir, save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name, name=config.name,
version=config.version, version=config.version,
flush_logs_every_n_steps=config.flush_logs_every_n_steps, flush_logs_every_n_steps=config.flush_logs_every_n_steps,
) )
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger: def create_tensorboard_logger(
config: TensorBoardLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger from lightning.pytorch.loggers import TensorBoardLogger
return TensorBoardLogger( return TensorBoardLogger(
save_dir=config.save_dir, save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name, name=config.name,
version=config.version, version=config.version,
log_graph=config.log_graph, log_graph=config.log_graph,
) )
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger: def create_mlflow_logger(
config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
try: try:
from lightning.pytorch.loggers import MLFlowLogger from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error: except ImportError as error:
@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
return MLFlowLogger( return MLFlowLogger(
experiment_name=config.experiment_name, experiment_name=config.experiment_name,
run_name=config.run_name, run_name=config.run_name,
save_dir=config.save_dir, save_dir=str(log_dir) if log_dir is not None else config.save_dir,
tracking_uri=config.tracking_uri, tracking_uri=config.tracking_uri,
tags=config.tags, tags=config.tags,
log_model=config.log_model, log_model=config.log_model,
@ -126,7 +139,10 @@ LOGGER_FACTORY = {
} }
def build_logger(config: LoggerConfig) -> Logger: def build_logger(
config: LoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
""" """
Creates a logger instance from a validated Pydantic config object. Creates a logger instance from a validated Pydantic config object.
""" """
@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type] creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config) return creation_func(config, log_dir=log_dir)
def get_image_plotter(logger: Logger): def get_image_plotter(logger: Logger):

View File

@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision, ClassificationMeanAveragePrecision,
DetectionAveragePrecision, DetectionAveragePrecision,
) )
from batdetect2.models import Model, build_model
from batdetect2.plotting.clips import AudioLoader, build_audio_loader from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
RandomAudioSource, RandomAudioSource,
build_augmentations, build_augmentations,
@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
from batdetect2.typing import ( from batdetect2.typing import (
PreprocessorProtocol, PreprocessorProtocol,
TargetProtocol, TargetProtocol,
@ -54,19 +53,21 @@ def train(
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None, train_workers: Optional[int] = None,
val_workers: Optional[int] = None, val_workers: Optional[int] = None,
checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None,
): ):
config = config or FullTrainingConfig() config = config or FullTrainingConfig()
model = build_model(config=config) targets = build_targets(config.targets)
trainer = build_trainer(config, targets=model.targets) preprocessor = build_preprocessor(config.preprocess)
audio_loader = build_audio_loader(config=config.preprocess.audio) audio_loader = build_audio_loader(config=config.preprocess.audio)
labeller = build_clip_labeler( labeller = build_clip_labeler(
model.targets, targets,
min_freq=model.preprocessor.min_freq, min_freq=preprocessor.min_freq,
max_freq=model.preprocessor.max_freq, max_freq=preprocessor.max_freq,
config=config.train.labels, config=config.train.labels,
) )
@ -74,7 +75,7 @@ def train(
train_annotations, train_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=build_preprocessor(config.preprocess), preprocessor=preprocessor,
config=config.train, config=config.train,
num_workers=train_workers, num_workers=train_workers,
) )
@ -84,7 +85,7 @@ def train(
val_annotations, val_annotations,
audio_loader=audio_loader, audio_loader=audio_loader,
labeller=labeller, labeller=labeller,
preprocessor=build_preprocessor(config.preprocess), preprocessor=preprocessor,
config=config.train, config=config.train,
num_workers=val_workers, num_workers=val_workers,
) )
@ -97,11 +98,17 @@ def train(
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else: else:
module = build_training_module( module = build_training_module(
model,
config, config,
t_max=config.train.t_max * len(train_dataloader), t_max=config.train.t_max * len(train_dataloader),
) )
trainer = build_trainer(
config,
targets=targets,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
)
logger.info("Starting main training loop...") logger.info("Starting main training loop...")
trainer.fit( trainer.fit(
module, module,
@ -112,15 +119,12 @@ def train(
def build_training_module( def build_training_module(
model: Model,
config: Optional[FullTrainingConfig] = None, config: Optional[FullTrainingConfig] = None,
t_max: int = 200, t_max: int = 200,
) -> TrainingModule: ) -> TrainingModule:
config = config or FullTrainingConfig() config = config or FullTrainingConfig()
loss = build_loss(config=config.train.loss)
return TrainingModule( return TrainingModule(
model=model, config=config,
loss=loss,
learning_rate=config.train.learning_rate, learning_rate=config.train.learning_rate,
t_max=t_max, t_max=t_max,
) )
@ -130,10 +134,14 @@ def build_trainer_callbacks(
targets: TargetProtocol, targets: TargetProtocol,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: EvaluationConfig, config: EvaluationConfig,
checkpoint_dir: Optional[data.PathLike] = None,
) -> List[Callback]: ) -> List[Callback]:
if checkpoint_dir is None:
checkpoint_dir = "outputs/checkpoints"
return [ return [
ModelCheckpoint( ModelCheckpoint(
dirpath="outputs/checkpoints", dirpath=str(checkpoint_dir),
save_top_k=1, save_top_k=1,
monitor="total_loss/val", monitor="total_loss/val",
), ),
@ -154,15 +162,22 @@ def build_trainer_callbacks(
def build_trainer( def build_trainer(
conf: FullTrainingConfig, conf: FullTrainingConfig,
targets: TargetProtocol, targets: TargetProtocol,
checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = conf.train.trainer trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
train_logger = build_logger(conf.train.logger) train_logger = build_logger(conf.train.logger, log_dir=log_dir)
train_logger.log_hyperparams(conf.model_dump(mode="json")) train_logger.log_hyperparams(
conf.model_dump(
mode="json",
exclude_none=True,
)
)
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
@ -171,6 +186,7 @@ def build_trainer(
targets, targets,
config=conf.evaluation, config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess), preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir,
), ),
) )