Add better logging

This commit is contained in:
mbsantiago 2025-06-28 14:45:18 -06:00
parent 19e873dd0b
commit 8368aad178
15 changed files with 144 additions and 22 deletions

View File

@ -13,9 +13,7 @@ from batdetect2.train import (
)
from batdetect2.train.dataset import list_preprocessed_files
__all__ = [
"train_command",
]
__all__ = ["train_command"]
@cli.command(name="train")
@ -51,19 +49,35 @@ def train_command(
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Starting training!")
logger.info("Initiating training process...")
logger.info("Loading training configuration...")
conf = (
load_full_training_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
)
logger.info("Scanning for training and validation data...")
train_examples = list_preprocessed_files(train_dir)
val_examples = (
list_preprocessed_files(val_dir) if val_dir is not None else None
logger.debug(
"Found {num_files} training examples in {path}",
num_files=len(train_examples),
path=train_dir,
)
val_examples = None
if val_dir is not None:
val_examples = list_preprocessed_files(val_dir)
logger.debug(
"Found {num_files} validation examples in {path}",
num_files=len(val_examples),
path=val_dir,
)
else:
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train(
train_examples=train_examples,
val_examples=val_examples,

View File

@ -29,6 +29,36 @@ class BaseConfig(BaseModel):
model_config = ConfigDict(extra="ignore")
def to_yaml_string(
self,
exclude_none: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> str:
"""Converts the Pydantic model instance to a YAML string.
Parameters
----------
exclude_none : bool, default=False
Whether to exclude fields whose value is `None`.
exclude_unset : bool, default=False
Whether to exclude fields that were not explicitly set.
exclude_defaults : bool, default=False
Whether to exclude fields whose value is the default value.
Returns
-------
str
A YAML string representation of the model.
"""
return yaml.dump(
self.model_dump(
exclude_none=exclude_none,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
)
T = TypeVar("T", bound=BaseModel)

View File

@ -28,6 +28,8 @@ provided here.
from typing import Optional
from loguru import logger
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
@ -131,5 +133,10 @@ def build_model(
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
backbone = build_backbone(config or BackboneConfig())
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config)
return build_detector(num_classes, backbone)

View File

@ -26,9 +26,9 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig,
FreqCoordConvUpConfig,
LayerGroupConfig,
StandardConvUpConfig,
build_layer_from_config,
)

View File

@ -28,9 +28,9 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig,
FreqCoordConvDownConfig,
LayerGroupConfig,
StandardConvDownConfig,
build_layer_from_config,
)

View File

@ -31,6 +31,7 @@ It also re-exports key components from submodules for convenience.
from typing import List, Optional
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
@ -203,9 +204,14 @@ def build_postprocessor(
PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs.
"""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
targets=targets,
config=config or PostprocessConfig(),
config=config,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -32,6 +32,7 @@ from typing import Optional, Union
import numpy as np
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
@ -429,6 +430,10 @@ def build_preprocessor(
according to the configuration.
"""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
default_samplerate = (
config.audio.resample.samplerate

View File

@ -573,6 +573,10 @@ def build_targets(
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
filter_fn = (
build_sound_event_filter(

View File

@ -9,8 +9,8 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
default_term_registry,
get_tag_from_info,
)
__all__ = [

View File

@ -28,6 +28,7 @@ from typing import Annotated, Callable, List, Literal, Optional, Union
import numpy as np
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import arrays, data
@ -851,6 +852,11 @@ def build_augmentations(
"""
config = config or DEFAULT_AUGMENTATION_CONFIG
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
augmentations = []
for step_config in config.steps:

View File

@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union
import numpy as np
import xarray as xr
from loguru import logger
from soundevent import arrays
from batdetect2.configs import BaseConfig
@ -74,6 +75,10 @@ def build_clipper(
random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig()
logger.opt(lazy=True).debug(
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return Clipper(
duration=config.duration,
max_empty=config.max_empty,

View File

@ -93,10 +93,15 @@ def build_clip_labeler(
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
config = config or LabelConfig()
logger.opt(lazy=True).debug(
"Building clip labeler with config: \n{}",
lambda: config.to_yaml_string(),
)
return partial(
generate_clip_label,
targets=targets,
config=config or LabelConfig(),
config=config,
)

View File

@ -1,6 +1,7 @@
from typing import Annotated, Any, Literal, Optional, Union
from lightning.pytorch.loggers import Logger
from loguru import logger
from pydantic import Field
from batdetect2.configs import BaseConfig
@ -129,6 +130,10 @@ def build_logger(config: LoggerConfig) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
lambda: config.to_yaml_string(),
)
logger_type = config.logger_type
if logger_type not in LOGGER_FACTORY:
@ -137,4 +142,3 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config)

View File

@ -23,6 +23,7 @@ from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from pydantic import Field
from torch import nn
@ -451,6 +452,10 @@ def build_loss(
An initialized `LossFunction` module ready for training.
"""
config = config or LossConfig()
logger.opt(lazy=True).debug(
"Building loss with config: \n{}",
lambda: config.to_yaml_string(),
)
class_weights_tensor = (
torch.tensor(class_weights) if class_weights else None

View File

@ -1,8 +1,10 @@
from collections.abc import Sequence
from typing import List, Optional
import yaml
from lightning import Trainer
from lightning.pytorch.callbacks import Callback
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
@ -52,6 +54,7 @@ def train(
conf = config or FullTrainingConfig()
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = TrainingModule(conf)
@ -75,11 +78,13 @@ def train(
else None
)
logger.info("Starting main training loop...")
trainer.fit(
module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
@ -103,6 +108,10 @@ def build_trainer(
trainer_conf = PLTrainerConfig.model_validate(
conf.train.model_dump(mode="python")
)
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval,
@ -117,12 +126,23 @@ def build_train_loader(
config: TrainingConfig,
num_workers: Optional[int] = None,
) -> DataLoader:
logger.info("Building training data loader...")
train_dataset = build_train_dataset(
train_examples,
preprocessor=preprocessor,
config=config,
)
logger.opt(lazy=True).debug(
"Training data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": True,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader(
train_dataset,
batch_size=config.batch_size,
@ -137,10 +157,22 @@ def build_val_loader(
config: TrainingConfig,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
val_dataset = build_val_dataset(
val_examples,
config=config,
)
logger.opt(lazy=True).debug(
"Validation data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": False,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader(
val_dataset,
batch_size=config.batch_size,
@ -155,6 +187,7 @@ def build_train_dataset(
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
) -> LabeledDataset:
logger.info("Building training dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=True)
@ -164,18 +197,15 @@ def build_train_dataset(
clipper=clipper,
)
logger.debug(
"Augmentations config: {}.", config.augmentations
)
augmentations = (
build_augmentations(
if config.augmentations and config.augmentations.steps:
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
if config.augmentations
else None
)
else:
logger.debug("No augmentations configured for training dataset.")
augmentations = None
return LabeledDataset(
examples,
@ -189,6 +219,7 @@ def build_val_dataset(
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
return LabeledDataset(examples, clipper=clipper)