Compare commits

...

8 Commits

Author SHA1 Message Date
mbsantiago
839a632aa2 Remove non-existing option in tensorboard logger 2025-07-29 00:17:46 +01:00
mbsantiago
829c07fb12 Remove unnecessary note 2025-07-29 00:03:20 +01:00
mbsantiago
2dd2c0156b Update gitignore 2025-06-28 15:52:56 -06:00
mbsantiago
6732394f50 Add model checkpoint callback 2025-06-28 15:46:50 -06:00
mbsantiago
8368aad178 Add better logging 2025-06-28 14:45:18 -06:00
mbsantiago
19e873dd0b Improve augmentations config and logging 2025-06-28 11:36:18 -06:00
mbsantiago
ed67d8ceec Add mlruns to gitignore 2025-06-28 11:23:54 -06:00
mbsantiago
bafb9a3622 Add mlflow logger 2025-06-28 11:08:19 -06:00
21 changed files with 230 additions and 4042 deletions

12
.gitignore vendored
View File

@ -99,6 +99,11 @@ plots/*
# Model experiments
experiments/*
DvcLiveLogger/checkpoints
logs/
mlruns/
outputs/
notebooks/lightning_logs
# Jupiter notebooks
.virtual_documents
@ -111,8 +116,7 @@ experiments/*
!tests/data/*.wav
!notebooks/*.ipynb
!tests/data/**/*.wav
notebooks/lightning_logs
example_data/preprocessed
.aider*
DvcLiveLogger/checkpoints
logs
# Intermediate artifacts
example_data/preprocessed

View File

@ -117,7 +117,13 @@ train:
size:
weight: 0.1
logger:
logger_type: dvclive
logger_type: mlflow
experiment_name: batdetect2
tracking_uri: http://localhost:5000
log_model: true
save_dir: outputs/log/
artifact_location: outputs/artifacts/
checkpoint_path_prefix: outputs/checkpoints/
augmentations:
steps:
- augmentation_type: mix_audio

View File

@ -21,7 +21,6 @@ dependencies = [
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
"pytorch-lightning>=2.2.2",
"cf-xarray>=0.9.0",
"onnx>=1.16.0",
"lightning[extra]>=2.2.2",
@ -91,6 +90,9 @@ dev = [
dvclive = [
"dvclive>=3.48.2",
]
mlflow = [
"mlflow>=3.1.1",
]
[tool.ruff]
line-length = 79

View File

@ -1,3 +1,4 @@
import sys
from pathlib import Path
from typing import Optional
@ -12,9 +13,7 @@ from batdetect2.train import (
)
from batdetect2.train.dataset import list_preprocessed_files
__all__ = [
"train_command",
]
__all__ = ["train_command"]
@cli.command(name="train")
@ -25,6 +24,12 @@ __all__ = [
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int, default=0)
@click.option("--val-workers", type=int, default=0)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command(
train_dir: Path,
val_dir: Optional[Path] = None,
@ -33,20 +38,46 @@ def train_command(
config_field: Optional[str] = None,
train_workers: int = 0,
val_workers: int = 0,
verbose: int = 0,
):
logger.info("Starting training!")
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...")
logger.info("Loading training configuration...")
conf = (
load_full_training_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
)
logger.info("Scanning for training and validation data...")
train_examples = list_preprocessed_files(train_dir)
val_examples = (
list_preprocessed_files(val_dir) if val_dir is not None else None
logger.debug(
"Found {num_files} training examples in {path}",
num_files=len(train_examples),
path=train_dir,
)
val_examples = None
if val_dir is not None:
val_examples = list_preprocessed_files(val_dir)
logger.debug(
"Found {num_files} validation examples in {path}",
num_files=len(val_examples),
path=val_dir,
)
else:
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train(
train_examples=train_examples,
val_examples=val_examples,

View File

@ -29,6 +29,36 @@ class BaseConfig(BaseModel):
model_config = ConfigDict(extra="ignore")
def to_yaml_string(
self,
exclude_none: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> str:
"""Converts the Pydantic model instance to a YAML string.
Parameters
----------
exclude_none : bool, default=False
Whether to exclude fields whose value is `None`.
exclude_unset : bool, default=False
Whether to exclude fields that were not explicitly set.
exclude_defaults : bool, default=False
Whether to exclude fields whose value is the default value.
Returns
-------
str
A YAML string representation of the model.
"""
return yaml.dump(
self.model_dump(
exclude_none=exclude_none,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
)
T = TypeVar("T", bound=BaseModel)

View File

@ -28,6 +28,8 @@ provided here.
from typing import Optional
from loguru import logger
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
@ -131,5 +133,10 @@ def build_model(
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
backbone = build_backbone(config or BackboneConfig())
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config)
return build_detector(num_classes, backbone)

View File

@ -26,9 +26,9 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig,
FreqCoordConvUpConfig,
LayerGroupConfig,
StandardConvUpConfig,
build_layer_from_config,
)

View File

@ -28,9 +28,9 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig,
FreqCoordConvDownConfig,
LayerGroupConfig,
StandardConvDownConfig,
build_layer_from_config,
)

View File

@ -31,6 +31,7 @@ It also re-exports key components from submodules for convenience.
from typing import List, Optional
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
@ -203,9 +204,14 @@ def build_postprocessor(
PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs.
"""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
targets=targets,
config=config or PostprocessConfig(),
config=config,
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -32,6 +32,7 @@ from typing import Optional, Union
import numpy as np
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
@ -429,6 +430,10 @@ def build_preprocessor(
according to the configuration.
"""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
default_samplerate = (
config.audio.resample.samplerate

View File

@ -537,21 +537,13 @@ def apply_pcen(
-------
xr.DataArray
PCEN-scaled spectrogram.
Notes
-----
- The input spectrogram magnitude `spec` is multiplied by `2**31` before
being passed to `audio.pcen`. This suggests the underlying implementation
might expect values in a range typical of 16-bit or 32-bit signed integers,
even though the input here might be float. This scaling factor should be
verified against the specific `soundevent.audio.pcen` implementation
details.
"""
samplerate = 1 / spec.time.attrs["step"]
hop_size = spec.attrs["hop_size"]
hop_length = int(hop_size * samplerate)
t_frames = time_constant * samplerate / (float(hop_length) * 10)
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
return audio.pcen(
spec * (2**31),

View File

@ -573,6 +573,10 @@ def build_targets(
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
filter_fn = (
build_sound_event_filter(

View File

@ -9,8 +9,8 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
get_tag_from_info,
default_term_registry,
get_tag_from_info,
)
__all__ = [

View File

@ -28,6 +28,7 @@ from typing import Annotated, Callable, List, Literal, Optional, Union
import numpy as np
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import arrays, data
@ -851,6 +852,11 @@ def build_augmentations(
"""
config = config or DEFAULT_AUGMENTATION_CONFIG
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
augmentations = []
for step_config in config.steps:

View File

@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union
import numpy as np
import xarray as xr
from loguru import logger
from soundevent import arrays
from batdetect2.configs import BaseConfig
@ -74,6 +75,10 @@ def build_clipper(
random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig()
logger.opt(lazy=True).debug(
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return Clipper(
duration=config.duration,
max_empty=config.max_empty,

View File

@ -50,7 +50,7 @@ class TrainingConfig(PLTrainerConfig):
learning_rate: float = 1e-3
t_max: int = 100
loss: LossConfig = Field(default_factory=LossConfig)
augmentations: AugmentationsConfig = Field(
augmentations: Optional[AugmentationsConfig] = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)

View File

@ -93,10 +93,15 @@ def build_clip_labeler(
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
config = config or LabelConfig()
logger.opt(lazy=True).debug(
"Building clip labeler with config: \n{}",
lambda: config.to_yaml_string(),
)
return partial(
generate_clip_label,
targets=targets,
config=config or LabelConfig(),
config=config,
)

View File

@ -1,6 +1,7 @@
from typing import Annotated, Literal, Optional, Union
from typing import Annotated, Any, Literal, Optional, Union
from lightning.pytorch.loggers import Logger
from loguru import logger
from pydantic import Field
from batdetect2.configs import BaseConfig
@ -31,11 +32,25 @@ class TensorBoardLoggerConfig(BaseConfig):
name: Optional[str] = "default"
version: Optional[str] = None
log_graph: bool = False
flush_logs_every_n_steps: Optional[int] = None
class MLFlowLoggerConfig(BaseConfig):
logger_type: Literal["mlflow"] = "mlflow"
experiment_name: str = "default"
run_name: Optional[str] = None
save_dir: Optional[str] = "./mlruns"
tracking_uri: Optional[str] = None
tags: Optional[dict[str, Any]] = None
log_model: bool = False
LoggerConfig = Annotated[
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
Union[
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
Field(discriminator="logger_type"),
]
@ -82,10 +97,31 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
)
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error:
raise ValueError(
"MLFlow is not installed and cannot be used for logging. "
"Make sure you have it installed by running `pip install mlflow` "
"or `uv add mlflow`"
) from error
return MLFlowLogger(
experiment_name=config.experiment_name,
run_name=config.run_name,
save_dir=config.save_dir,
tracking_uri=config.tracking_uri,
tags=config.tags,
log_model=config.log_model,
)
LOGGER_FACTORY = {
"dvclive": create_dvclive_logger,
"csv": create_csv_logger,
"tensorboard": create_tensorboard_logger,
"mlflow": create_mlflow_logger,
}
@ -93,6 +129,10 @@ def build_logger(config: LoggerConfig) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
lambda: config.to_yaml_string(),
)
logger_type = config.logger_type
if logger_type not in LOGGER_FACTORY:

View File

@ -23,6 +23,7 @@ from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from pydantic import Field
from torch import nn
@ -451,6 +452,10 @@ def build_loss(
An initialized `LossFunction` module ready for training.
"""
config = config or LossConfig()
logger.opt(lazy=True).debug(
"Building loss with config: \n{}",
lambda: config.to_yaml_string(),
)
class_weights_tensor = (
torch.tensor(class_weights) if class_weights else None

View File

@ -1,8 +1,10 @@
from collections.abc import Sequence
from typing import List, Optional
import yaml
from lightning import Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader
@ -52,6 +54,7 @@ def train(
conf = config or FullTrainingConfig()
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = TrainingModule(conf)
@ -75,15 +78,22 @@ def train(
else None
)
logger.info("Starting main training loop...")
trainer.fit(
module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [
ModelCheckpoint(
dirpath="outputs/checkpoints",
save_top_k=1,
monitor="total_loss/val",
),
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
@ -92,7 +102,7 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
),
]
@ -103,6 +113,10 @@ def build_trainer(
trainer_conf = PLTrainerConfig.model_validate(
conf.train.model_dump(mode="python")
)
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval,
@ -117,12 +131,23 @@ def build_train_loader(
config: TrainingConfig,
num_workers: Optional[int] = None,
) -> DataLoader:
logger.info("Building training data loader...")
train_dataset = build_train_dataset(
train_examples,
preprocessor=preprocessor,
config=config,
)
logger.opt(lazy=True).debug(
"Training data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": True,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader(
train_dataset,
batch_size=config.batch_size,
@ -137,10 +162,22 @@ def build_val_loader(
config: TrainingConfig,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
val_dataset = build_val_dataset(
val_examples,
config=config,
)
logger.opt(lazy=True).debug(
"Validation data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": False,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader(
val_dataset,
batch_size=config.batch_size,
@ -155,6 +192,7 @@ def build_train_dataset(
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
) -> LabeledDataset:
logger.info("Building training dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=True)
@ -164,11 +202,15 @@ def build_train_dataset(
clipper=clipper,
)
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
if config.augmentations and config.augmentations.steps:
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
else:
logger.debug("No augmentations configured for training dataset.")
augmentations = None
return LabeledDataset(
examples,
@ -182,6 +224,7 @@ def build_val_dataset(
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
return LabeledDataset(examples, clipper=clipper)

4003
uv.lock generated

File diff suppressed because it is too large Load Diff