diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 330c0c1..c735576 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -13,9 +13,7 @@ from batdetect2.train import ( ) from batdetect2.train.dataset import list_preprocessed_files -__all__ = [ - "train_command", -] +__all__ = ["train_command"] @cli.command(name="train") @@ -51,19 +49,35 @@ def train_command( log_level = "DEBUG" logger.add(sys.stderr, level=log_level) - logger.info("Starting training!") + logger.info("Initiating training process...") + logger.info("Loading training configuration...") conf = ( load_full_training_config(config, field=config_field) if config is not None else FullTrainingConfig() ) + logger.info("Scanning for training and validation data...") train_examples = list_preprocessed_files(train_dir) - val_examples = ( - list_preprocessed_files(val_dir) if val_dir is not None else None + logger.debug( + "Found {num_files} training examples in {path}", + num_files=len(train_examples), + path=train_dir, ) + val_examples = None + if val_dir is not None: + val_examples = list_preprocessed_files(val_dir) + logger.debug( + "Found {num_files} validation examples in {path}", + num_files=len(val_examples), + path=val_dir, + ) + else: + logger.debug("No validation directory provided.") + + logger.info("Configuration and data loaded. Starting training...") train( train_examples=train_examples, val_examples=val_examples, diff --git a/src/batdetect2/configs.py b/src/batdetect2/configs.py index 857170e..7399d6e 100644 --- a/src/batdetect2/configs.py +++ b/src/batdetect2/configs.py @@ -29,6 +29,36 @@ class BaseConfig(BaseModel): model_config = ConfigDict(extra="ignore") + def to_yaml_string( + self, + exclude_none: bool = False, + exclude_unset: bool = False, + exclude_defaults: bool = False, + ) -> str: + """Converts the Pydantic model instance to a YAML string. + + Parameters + ---------- + exclude_none : bool, default=False + Whether to exclude fields whose value is `None`. + exclude_unset : bool, default=False + Whether to exclude fields that were not explicitly set. + exclude_defaults : bool, default=False + Whether to exclude fields whose value is the default value. + + Returns + ------- + str + A YAML string representation of the model. + """ + return yaml.dump( + self.model_dump( + exclude_none=exclude_none, + exclude_unset=exclude_unset, + exclude_defaults=exclude_defaults, + ) + ) + T = TypeVar("T", bound=BaseModel) diff --git a/src/batdetect2/models/__init__.py b/src/batdetect2/models/__init__.py index daa5ad6..7d6fb9f 100644 --- a/src/batdetect2/models/__init__.py +++ b/src/batdetect2/models/__init__.py @@ -28,6 +28,8 @@ provided here. from typing import Optional +from loguru import logger + from batdetect2.models.backbones import ( Backbone, BackboneConfig, @@ -131,5 +133,10 @@ def build_model( construction of the backbone or detector components (e.g., incompatible configurations, invalid parameters). """ - backbone = build_backbone(config or BackboneConfig()) + config = config or BackboneConfig() + logger.opt(lazy=True).debug( + "Building model with config: \n{}", + lambda: config.to_yaml_string(), + ) + backbone = build_backbone(config) return build_detector(num_classes, backbone) diff --git a/src/batdetect2/models/decoder.py b/src/batdetect2/models/decoder.py index c3deed7..270fc9d 100644 --- a/src/batdetect2/models/decoder.py +++ b/src/batdetect2/models/decoder.py @@ -26,9 +26,9 @@ from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.blocks import ( - LayerGroupConfig, ConvConfig, FreqCoordConvUpConfig, + LayerGroupConfig, StandardConvUpConfig, build_layer_from_config, ) diff --git a/src/batdetect2/models/encoder.py b/src/batdetect2/models/encoder.py index 91087cb..9bb13e5 100644 --- a/src/batdetect2/models/encoder.py +++ b/src/batdetect2/models/encoder.py @@ -28,9 +28,9 @@ from torch import nn from batdetect2.configs import BaseConfig from batdetect2.models.blocks import ( - LayerGroupConfig, ConvConfig, FreqCoordConvDownConfig, + LayerGroupConfig, StandardConvDownConfig, build_layer_from_config, ) diff --git a/src/batdetect2/postprocess/__init__.py b/src/batdetect2/postprocess/__init__.py index f6f84ac..cf93b0c 100644 --- a/src/batdetect2/postprocess/__init__.py +++ b/src/batdetect2/postprocess/__init__.py @@ -31,6 +31,7 @@ It also re-exports key components from submodules for convenience. from typing import List, Optional import xarray as xr +from loguru import logger from pydantic import Field from soundevent import data @@ -203,9 +204,14 @@ def build_postprocessor( PostprocessorProtocol An initialized `Postprocessor` instance ready to process model outputs. """ + config = config or PostprocessConfig() + logger.opt(lazy=True).debug( + "Building postprocessor with config: \n{}", + lambda: config.to_yaml_string(), + ) return Postprocessor( targets=targets, - config=config or PostprocessConfig(), + config=config, min_freq=min_freq, max_freq=max_freq, ) diff --git a/src/batdetect2/preprocess/__init__.py b/src/batdetect2/preprocess/__init__.py index f27b591..fe872c0 100644 --- a/src/batdetect2/preprocess/__init__.py +++ b/src/batdetect2/preprocess/__init__.py @@ -32,6 +32,7 @@ from typing import Optional, Union import numpy as np import xarray as xr +from loguru import logger from pydantic import Field from soundevent import data @@ -429,6 +430,10 @@ def build_preprocessor( according to the configuration. """ config = config or PreprocessingConfig() + logger.opt(lazy=True).debug( + "Building preprocessor with config: \n{}", + lambda: config.to_yaml_string(), + ) default_samplerate = ( config.audio.resample.samplerate diff --git a/src/batdetect2/targets/__init__.py b/src/batdetect2/targets/__init__.py index 6f9cfec..1c50690 100644 --- a/src/batdetect2/targets/__init__.py +++ b/src/batdetect2/targets/__init__.py @@ -573,6 +573,10 @@ def build_targets( If dynamic import of a derivation function fails (when configured). """ config = config or DEFAULT_TARGET_CONFIG + logger.opt(lazy=True).debug( + "Building targets with config: \n{}", + lambda: config.to_yaml_string(), + ) filter_fn = ( build_sound_event_filter( diff --git a/src/batdetect2/targets/filtering.py b/src/batdetect2/targets/filtering.py index 4f30dc8..28e9c43 100644 --- a/src/batdetect2/targets/filtering.py +++ b/src/batdetect2/targets/filtering.py @@ -9,8 +9,8 @@ from batdetect2.configs import BaseConfig, load_config from batdetect2.targets.terms import ( TagInfo, TermRegistry, - get_tag_from_info, default_term_registry, + get_tag_from_info, ) __all__ = [ diff --git a/src/batdetect2/train/augmentations.py b/src/batdetect2/train/augmentations.py index 6930c97..6c08f31 100644 --- a/src/batdetect2/train/augmentations.py +++ b/src/batdetect2/train/augmentations.py @@ -28,6 +28,7 @@ from typing import Annotated, Callable, List, Literal, Optional, Union import numpy as np import xarray as xr +from loguru import logger from pydantic import Field from soundevent import arrays, data @@ -851,6 +852,11 @@ def build_augmentations( """ config = config or DEFAULT_AUGMENTATION_CONFIG + logger.opt(lazy=True).debug( + "Building augmentations with config: \n{}", + lambda: config.to_yaml_string(), + ) + augmentations = [] for step_config in config.steps: diff --git a/src/batdetect2/train/clips.py b/src/batdetect2/train/clips.py index 7debd13..87ee673 100644 --- a/src/batdetect2/train/clips.py +++ b/src/batdetect2/train/clips.py @@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union import numpy as np import xarray as xr +from loguru import logger from soundevent import arrays from batdetect2.configs import BaseConfig @@ -74,6 +75,10 @@ def build_clipper( random: Optional[bool] = None, ) -> ClipperProtocol: config = config or ClipingConfig() + logger.opt(lazy=True).debug( + "Building clipper with config: \n{}", + lambda: config.to_yaml_string(), + ) return Clipper( duration=config.duration, max_empty=config.max_empty, diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 48a1383..642e71e 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -93,10 +93,15 @@ def build_clip_labeler( A function that accepts a `data.ClipAnnotation` and `xr.DataArray` (spectrogram) and returns the generated `Heatmaps`. """ + config = config or LabelConfig() + logger.opt(lazy=True).debug( + "Building clip labeler with config: \n{}", + lambda: config.to_yaml_string(), + ) return partial( generate_clip_label, targets=targets, - config=config or LabelConfig(), + config=config, ) diff --git a/src/batdetect2/train/logging.py b/src/batdetect2/train/logging.py index 26fa64d..a7658a0 100644 --- a/src/batdetect2/train/logging.py +++ b/src/batdetect2/train/logging.py @@ -1,6 +1,7 @@ from typing import Annotated, Any, Literal, Optional, Union from lightning.pytorch.loggers import Logger +from loguru import logger from pydantic import Field from batdetect2.configs import BaseConfig @@ -129,6 +130,10 @@ def build_logger(config: LoggerConfig) -> Logger: """ Creates a logger instance from a validated Pydantic config object. """ + logger.opt(lazy=True).debug( + "Building logger with config: \n{}", + lambda: config.to_yaml_string(), + ) logger_type = config.logger_type if logger_type not in LOGGER_FACTORY: @@ -137,4 +142,3 @@ def build_logger(config: LoggerConfig) -> Logger: creation_func = LOGGER_FACTORY[logger_type] return creation_func(config) - diff --git a/src/batdetect2/train/losses.py b/src/batdetect2/train/losses.py index b4083bb..a93b4ea 100644 --- a/src/batdetect2/train/losses.py +++ b/src/batdetect2/train/losses.py @@ -23,6 +23,7 @@ from typing import Optional import numpy as np import torch import torch.nn.functional as F +from loguru import logger from pydantic import Field from torch import nn @@ -451,6 +452,10 @@ def build_loss( An initialized `LossFunction` module ready for training. """ config = config or LossConfig() + logger.opt(lazy=True).debug( + "Building loss with config: \n{}", + lambda: config.to_yaml_string(), + ) class_weights_tensor = ( torch.tensor(class_weights) if class_weights else None diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index dce0221..01cb4d9 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -1,8 +1,10 @@ from collections.abc import Sequence from typing import List, Optional +import yaml from lightning import Trainer from lightning.pytorch.callbacks import Callback +from loguru import logger from soundevent import data from torch.utils.data import DataLoader @@ -52,6 +54,7 @@ def train( conf = config or FullTrainingConfig() if model_path is not None: + logger.debug("Loading model from: {path}", path=model_path) module = TrainingModule.load_from_checkpoint(model_path) # type: ignore else: module = TrainingModule(conf) @@ -75,11 +78,13 @@ def train( else None ) + logger.info("Starting main training loop...") trainer.fit( module, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader, ) + logger.info("Training complete.") def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]: @@ -103,6 +108,10 @@ def build_trainer( trainer_conf = PLTrainerConfig.model_validate( conf.train.model_dump(mode="python") ) + logger.opt(lazy=True).debug( + "Building trainer with config: \n{config}", + config=lambda: trainer_conf.to_yaml_string(exclude_none=True), + ) return Trainer( **trainer_conf.model_dump(exclude_none=True), val_check_interval=conf.train.val_check_interval, @@ -117,12 +126,23 @@ def build_train_loader( config: TrainingConfig, num_workers: Optional[int] = None, ) -> DataLoader: + logger.info("Building training data loader...") train_dataset = build_train_dataset( train_examples, preprocessor=preprocessor, config=config, ) - + logger.opt(lazy=True).debug( + "Training data loader config: \n{}", + lambda: yaml.dump( + { + "batch_size": config.batch_size, + "shuffle": True, + "num_workers": num_workers or 0, + "collate_fn": str(collate_fn), + } + ), + ) return DataLoader( train_dataset, batch_size=config.batch_size, @@ -137,10 +157,22 @@ def build_val_loader( config: TrainingConfig, num_workers: Optional[int] = None, ): + logger.info("Building validation data loader...") val_dataset = build_val_dataset( val_examples, config=config, ) + logger.opt(lazy=True).debug( + "Validation data loader config: \n{}", + lambda: yaml.dump( + { + "batch_size": config.batch_size, + "shuffle": False, + "num_workers": num_workers or 0, + "collate_fn": str(collate_fn), + } + ), + ) return DataLoader( val_dataset, batch_size=config.batch_size, @@ -155,6 +187,7 @@ def build_train_dataset( preprocessor: PreprocessorProtocol, config: Optional[TrainingConfig] = None, ) -> LabeledDataset: + logger.info("Building training dataset...") config = config or TrainingConfig() clipper = build_clipper(config.cliping, random=True) @@ -164,18 +197,15 @@ def build_train_dataset( clipper=clipper, ) - logger.debug( - "Augmentations config: {}.", config.augmentations - ) - augmentations = ( - build_augmentations( + if config.augmentations and config.augmentations.steps: + augmentations = build_augmentations( preprocessor, config=config.augmentations, example_source=random_example_source, ) - if config.augmentations - else None - ) + else: + logger.debug("No augmentations configured for training dataset.") + augmentations = None return LabeledDataset( examples, @@ -189,6 +219,7 @@ def build_val_dataset( config: Optional[TrainingConfig] = None, train: bool = True, ) -> LabeledDataset: + logger.info("Building validation dataset...") config = config or TrainingConfig() clipper = build_clipper(config.cliping, random=train) return LabeledDataset(examples, clipper=clipper)