Add better logging

This commit is contained in:
mbsantiago 2025-06-28 14:45:18 -06:00
parent 19e873dd0b
commit 8368aad178
15 changed files with 144 additions and 22 deletions

View File

@ -13,9 +13,7 @@ from batdetect2.train import (
) )
from batdetect2.train.dataset import list_preprocessed_files from batdetect2.train.dataset import list_preprocessed_files
__all__ = [ __all__ = ["train_command"]
"train_command",
]
@cli.command(name="train") @cli.command(name="train")
@ -51,19 +49,35 @@ def train_command(
log_level = "DEBUG" log_level = "DEBUG"
logger.add(sys.stderr, level=log_level) logger.add(sys.stderr, level=log_level)
logger.info("Starting training!") logger.info("Initiating training process...")
logger.info("Loading training configuration...")
conf = ( conf = (
load_full_training_config(config, field=config_field) load_full_training_config(config, field=config_field)
if config is not None if config is not None
else FullTrainingConfig() else FullTrainingConfig()
) )
logger.info("Scanning for training and validation data...")
train_examples = list_preprocessed_files(train_dir) train_examples = list_preprocessed_files(train_dir)
val_examples = ( logger.debug(
list_preprocessed_files(val_dir) if val_dir is not None else None "Found {num_files} training examples in {path}",
num_files=len(train_examples),
path=train_dir,
) )
val_examples = None
if val_dir is not None:
val_examples = list_preprocessed_files(val_dir)
logger.debug(
"Found {num_files} validation examples in {path}",
num_files=len(val_examples),
path=val_dir,
)
else:
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train( train(
train_examples=train_examples, train_examples=train_examples,
val_examples=val_examples, val_examples=val_examples,

View File

@ -29,6 +29,36 @@ class BaseConfig(BaseModel):
model_config = ConfigDict(extra="ignore") model_config = ConfigDict(extra="ignore")
def to_yaml_string(
self,
exclude_none: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> str:
"""Converts the Pydantic model instance to a YAML string.
Parameters
----------
exclude_none : bool, default=False
Whether to exclude fields whose value is `None`.
exclude_unset : bool, default=False
Whether to exclude fields that were not explicitly set.
exclude_defaults : bool, default=False
Whether to exclude fields whose value is the default value.
Returns
-------
str
A YAML string representation of the model.
"""
return yaml.dump(
self.model_dump(
exclude_none=exclude_none,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
)
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)

View File

@ -28,6 +28,8 @@ provided here.
from typing import Optional from typing import Optional
from loguru import logger
from batdetect2.models.backbones import ( from batdetect2.models.backbones import (
Backbone, Backbone,
BackboneConfig, BackboneConfig,
@ -131,5 +133,10 @@ def build_model(
construction of the backbone or detector components (e.g., incompatible construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters). configurations, invalid parameters).
""" """
backbone = build_backbone(config or BackboneConfig()) config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config)
return build_detector(num_classes, backbone) return build_detector(num_classes, backbone)

View File

@ -26,9 +26,9 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvUpConfig, FreqCoordConvUpConfig,
LayerGroupConfig,
StandardConvUpConfig, StandardConvUpConfig,
build_layer_from_config, build_layer_from_config,
) )

View File

@ -28,9 +28,9 @@ from torch import nn
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import ( from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig, ConvConfig,
FreqCoordConvDownConfig, FreqCoordConvDownConfig,
LayerGroupConfig,
StandardConvDownConfig, StandardConvDownConfig,
build_layer_from_config, build_layer_from_config,
) )

View File

@ -31,6 +31,7 @@ It also re-exports key components from submodules for convenience.
from typing import List, Optional from typing import List, Optional
import xarray as xr import xarray as xr
from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -203,9 +204,14 @@ def build_postprocessor(
PostprocessorProtocol PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs. An initialized `Postprocessor` instance ready to process model outputs.
""" """
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor( return Postprocessor(
targets=targets, targets=targets,
config=config or PostprocessConfig(), config=config,
min_freq=min_freq, min_freq=min_freq,
max_freq=max_freq, max_freq=max_freq,
) )

View File

@ -32,6 +32,7 @@ from typing import Optional, Union
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -429,6 +430,10 @@ def build_preprocessor(
according to the configuration. according to the configuration.
""" """
config = config or PreprocessingConfig() config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
default_samplerate = ( default_samplerate = (
config.audio.resample.samplerate config.audio.resample.samplerate

View File

@ -573,6 +573,10 @@ def build_targets(
If dynamic import of a derivation function fails (when configured). If dynamic import of a derivation function fails (when configured).
""" """
config = config or DEFAULT_TARGET_CONFIG config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
filter_fn = ( filter_fn = (
build_sound_event_filter( build_sound_event_filter(

View File

@ -9,8 +9,8 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import ( from batdetect2.targets.terms import (
TagInfo, TagInfo,
TermRegistry, TermRegistry,
get_tag_from_info,
default_term_registry, default_term_registry,
get_tag_from_info,
) )
__all__ = [ __all__ = [

View File

@ -28,6 +28,7 @@ from typing import Annotated, Callable, List, Literal, Optional, Union
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from loguru import logger
from pydantic import Field from pydantic import Field
from soundevent import arrays, data from soundevent import arrays, data
@ -851,6 +852,11 @@ def build_augmentations(
""" """
config = config or DEFAULT_AUGMENTATION_CONFIG config = config or DEFAULT_AUGMENTATION_CONFIG
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
augmentations = [] augmentations = []
for step_config in config.steps: for step_config in config.steps:

View File

@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union
import numpy as np import numpy as np
import xarray as xr import xarray as xr
from loguru import logger
from soundevent import arrays from soundevent import arrays
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
@ -74,6 +75,10 @@ def build_clipper(
random: Optional[bool] = None, random: Optional[bool] = None,
) -> ClipperProtocol: ) -> ClipperProtocol:
config = config or ClipingConfig() config = config or ClipingConfig()
logger.opt(lazy=True).debug(
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return Clipper( return Clipper(
duration=config.duration, duration=config.duration,
max_empty=config.max_empty, max_empty=config.max_empty,

View File

@ -93,10 +93,15 @@ def build_clip_labeler(
A function that accepts a `data.ClipAnnotation` and `xr.DataArray` A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`. (spectrogram) and returns the generated `Heatmaps`.
""" """
config = config or LabelConfig()
logger.opt(lazy=True).debug(
"Building clip labeler with config: \n{}",
lambda: config.to_yaml_string(),
)
return partial( return partial(
generate_clip_label, generate_clip_label,
targets=targets, targets=targets,
config=config or LabelConfig(), config=config,
) )

View File

@ -1,6 +1,7 @@
from typing import Annotated, Any, Literal, Optional, Union from typing import Annotated, Any, Literal, Optional, Union
from lightning.pytorch.loggers import Logger from lightning.pytorch.loggers import Logger
from loguru import logger
from pydantic import Field from pydantic import Field
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
@ -129,6 +130,10 @@ def build_logger(config: LoggerConfig) -> Logger:
""" """
Creates a logger instance from a validated Pydantic config object. Creates a logger instance from a validated Pydantic config object.
""" """
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
lambda: config.to_yaml_string(),
)
logger_type = config.logger_type logger_type = config.logger_type
if logger_type not in LOGGER_FACTORY: if logger_type not in LOGGER_FACTORY:
@ -137,4 +142,3 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type] creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config) return creation_func(config)

View File

@ -23,6 +23,7 @@ from typing import Optional
import numpy as np import numpy as np
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
@ -451,6 +452,10 @@ def build_loss(
An initialized `LossFunction` module ready for training. An initialized `LossFunction` module ready for training.
""" """
config = config or LossConfig() config = config or LossConfig()
logger.opt(lazy=True).debug(
"Building loss with config: \n{}",
lambda: config.to_yaml_string(),
)
class_weights_tensor = ( class_weights_tensor = (
torch.tensor(class_weights) if class_weights else None torch.tensor(class_weights) if class_weights else None

View File

@ -1,8 +1,10 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import List, Optional from typing import List, Optional
import yaml
from lightning import Trainer from lightning import Trainer
from lightning.pytorch.callbacks import Callback from lightning.pytorch.callbacks import Callback
from loguru import logger
from soundevent import data from soundevent import data
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -52,6 +54,7 @@ def train(
conf = config or FullTrainingConfig() conf = config or FullTrainingConfig()
if model_path is not None: if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else: else:
module = TrainingModule(conf) module = TrainingModule(conf)
@ -75,11 +78,13 @@ def train(
else None else None
) )
logger.info("Starting main training loop...")
trainer.fit( trainer.fit(
module, module,
train_dataloaders=train_dataloader, train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader, val_dataloaders=val_dataloader,
) )
logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
@ -103,6 +108,10 @@ def build_trainer(
trainer_conf = PLTrainerConfig.model_validate( trainer_conf = PLTrainerConfig.model_validate(
conf.train.model_dump(mode="python") conf.train.model_dump(mode="python")
) )
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval, val_check_interval=conf.train.val_check_interval,
@ -117,12 +126,23 @@ def build_train_loader(
config: TrainingConfig, config: TrainingConfig,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
) -> DataLoader: ) -> DataLoader:
logger.info("Building training data loader...")
train_dataset = build_train_dataset( train_dataset = build_train_dataset(
train_examples, train_examples,
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, config=config,
) )
logger.opt(lazy=True).debug(
"Training data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": True,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader( return DataLoader(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
@ -137,10 +157,22 @@ def build_val_loader(
config: TrainingConfig, config: TrainingConfig,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
): ):
logger.info("Building validation data loader...")
val_dataset = build_val_dataset( val_dataset = build_val_dataset(
val_examples, val_examples,
config=config, config=config,
) )
logger.opt(lazy=True).debug(
"Validation data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": False,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader( return DataLoader(
val_dataset, val_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
@ -155,6 +187,7 @@ def build_train_dataset(
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
) -> LabeledDataset: ) -> LabeledDataset:
logger.info("Building training dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=True) clipper = build_clipper(config.cliping, random=True)
@ -164,18 +197,15 @@ def build_train_dataset(
clipper=clipper, clipper=clipper,
) )
logger.debug( if config.augmentations and config.augmentations.steps:
"Augmentations config: {}.", config.augmentations augmentations = build_augmentations(
)
augmentations = (
build_augmentations(
preprocessor, preprocessor,
config=config.augmentations, config=config.augmentations,
example_source=random_example_source, example_source=random_example_source,
) )
if config.augmentations else:
else None logger.debug("No augmentations configured for training dataset.")
) augmentations = None
return LabeledDataset( return LabeledDataset(
examples, examples,
@ -189,6 +219,7 @@ def build_val_dataset(
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
train: bool = True, train: bool = True,
) -> LabeledDataset: ) -> LabeledDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig() config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train) clipper = build_clipper(config.cliping, random=train)
return LabeledDataset(examples, clipper=clipper) return LabeledDataset(examples, clipper=clipper)