Compare commits

..

No commits in common. "839a632aa261c237faabd73de5cc4184e61dfe71" and "e8db1d40508b74b389586b71550d41ac0a807c3f" have entirely different histories.

21 changed files with 4042 additions and 230 deletions

12
.gitignore vendored
View File

@ -99,11 +99,6 @@ plots/*
# Model experiments
experiments/*
DvcLiveLogger/checkpoints
logs/
mlruns/
outputs/
notebooks/lightning_logs
# Jupiter notebooks
.virtual_documents
@ -116,7 +111,8 @@ notebooks/lightning_logs
!tests/data/*.wav
!notebooks/*.ipynb
!tests/data/**/*.wav
.aider*
# Intermediate artifacts
notebooks/lightning_logs
example_data/preprocessed
.aider*
DvcLiveLogger/checkpoints
logs

View File

@ -117,13 +117,7 @@ train:
size:
weight: 0.1
logger:
logger_type: mlflow
experiment_name: batdetect2
tracking_uri: http://localhost:5000
log_model: true
save_dir: outputs/log/
artifact_location: outputs/artifacts/
checkpoint_path_prefix: outputs/checkpoints/
logger_type: dvclive
augmentations:
steps:
- augmentation_type: mix_audio

View File

@ -21,6 +21,7 @@ dependencies = [
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
"pytorch-lightning>=2.2.2",
"cf-xarray>=0.9.0",
"onnx>=1.16.0",
"lightning[extra]>=2.2.2",
@ -90,9 +91,6 @@ dev = [
dvclive = [
"dvclive>=3.48.2",
]
mlflow = [
"mlflow>=3.1.1",
]
[tool.ruff]
line-length = 79

View File

@ -1,4 +1,3 @@
import sys
from pathlib import Path
from typing import Optional
@ -13,7 +12,9 @@ from batdetect2.train import (
)
from batdetect2.train.dataset import list_preprocessed_files
__all__ = ["train_command"]
__all__ = [
"train_command",
]
@cli.command(name="train")
@ -24,12 +25,6 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int, default=0)
@click.option("--val-workers", type=int, default=0)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command(
train_dir: Path,
val_dir: Optional[Path] = None,
@ -38,46 +33,20 @@ def train_command(
config_field: Optional[str] = None,
train_workers: int = 0,
val_workers: int = 0,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Starting training!")
logger.info("Initiating training process...")
logger.info("Loading training configuration...")
conf = (
load_full_training_config(config, field=config_field)
if config is not None
else FullTrainingConfig()
)
logger.info("Scanning for training and validation data...")
train_examples = list_preprocessed_files(train_dir)
logger.debug(
"Found {num_files} training examples in {path}",
num_files=len(train_examples),
path=train_dir,
val_examples = (
list_preprocessed_files(val_dir) if val_dir is not None else None
)
val_examples = None
if val_dir is not None:
val_examples = list_preprocessed_files(val_dir)
logger.debug(
"Found {num_files} validation examples in {path}",
num_files=len(val_examples),
path=val_dir,
)
else:
logger.debug("No validation directory provided.")
logger.info("Configuration and data loaded. Starting training...")
train(
train_examples=train_examples,
val_examples=val_examples,

View File

@ -29,36 +29,6 @@ class BaseConfig(BaseModel):
model_config = ConfigDict(extra="ignore")
def to_yaml_string(
self,
exclude_none: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
) -> str:
"""Converts the Pydantic model instance to a YAML string.
Parameters
----------
exclude_none : bool, default=False
Whether to exclude fields whose value is `None`.
exclude_unset : bool, default=False
Whether to exclude fields that were not explicitly set.
exclude_defaults : bool, default=False
Whether to exclude fields whose value is the default value.
Returns
-------
str
A YAML string representation of the model.
"""
return yaml.dump(
self.model_dump(
exclude_none=exclude_none,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
)
)
T = TypeVar("T", bound=BaseModel)

View File

@ -28,8 +28,6 @@ provided here.
from typing import Optional
from loguru import logger
from batdetect2.models.backbones import (
Backbone,
BackboneConfig,
@ -133,10 +131,5 @@ def build_model(
construction of the backbone or detector components (e.g., incompatible
configurations, invalid parameters).
"""
config = config or BackboneConfig()
logger.opt(lazy=True).debug(
"Building model with config: \n{}",
lambda: config.to_yaml_string(),
)
backbone = build_backbone(config)
backbone = build_backbone(config or BackboneConfig())
return build_detector(num_classes, backbone)

View File

@ -26,9 +26,9 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig,
FreqCoordConvUpConfig,
LayerGroupConfig,
StandardConvUpConfig,
build_layer_from_config,
)

View File

@ -28,9 +28,9 @@ from torch import nn
from batdetect2.configs import BaseConfig
from batdetect2.models.blocks import (
LayerGroupConfig,
ConvConfig,
FreqCoordConvDownConfig,
LayerGroupConfig,
StandardConvDownConfig,
build_layer_from_config,
)

View File

@ -31,7 +31,6 @@ It also re-exports key components from submodules for convenience.
from typing import List, Optional
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
@ -204,14 +203,9 @@ def build_postprocessor(
PostprocessorProtocol
An initialized `Postprocessor` instance ready to process model outputs.
"""
config = config or PostprocessConfig()
logger.opt(lazy=True).debug(
"Building postprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
return Postprocessor(
targets=targets,
config=config,
config=config or PostprocessConfig(),
min_freq=min_freq,
max_freq=max_freq,
)

View File

@ -32,7 +32,6 @@ from typing import Optional, Union
import numpy as np
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
@ -430,10 +429,6 @@ def build_preprocessor(
according to the configuration.
"""
config = config or PreprocessingConfig()
logger.opt(lazy=True).debug(
"Building preprocessor with config: \n{}",
lambda: config.to_yaml_string(),
)
default_samplerate = (
config.audio.resample.samplerate

View File

@ -537,13 +537,21 @@ def apply_pcen(
-------
xr.DataArray
PCEN-scaled spectrogram.
Notes
-----
- The input spectrogram magnitude `spec` is multiplied by `2**31` before
being passed to `audio.pcen`. This suggests the underlying implementation
might expect values in a range typical of 16-bit or 32-bit signed integers,
even though the input here might be float. This scaling factor should be
verified against the specific `soundevent.audio.pcen` implementation
details.
"""
samplerate = 1 / spec.time.attrs["step"]
hop_size = spec.attrs["hop_size"]
hop_length = int(hop_size * samplerate)
t_frames = time_constant * samplerate / (float(hop_length) * 10)
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
return audio.pcen(
spec * (2**31),

View File

@ -573,10 +573,6 @@ def build_targets(
If dynamic import of a derivation function fails (when configured).
"""
config = config or DEFAULT_TARGET_CONFIG
logger.opt(lazy=True).debug(
"Building targets with config: \n{}",
lambda: config.to_yaml_string(),
)
filter_fn = (
build_sound_event_filter(

View File

@ -9,8 +9,8 @@ from batdetect2.configs import BaseConfig, load_config
from batdetect2.targets.terms import (
TagInfo,
TermRegistry,
default_term_registry,
get_tag_from_info,
default_term_registry,
)
__all__ = [

View File

@ -28,7 +28,6 @@ from typing import Annotated, Callable, List, Literal, Optional, Union
import numpy as np
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import arrays, data
@ -852,11 +851,6 @@ def build_augmentations(
"""
config = config or DEFAULT_AUGMENTATION_CONFIG
logger.opt(lazy=True).debug(
"Building augmentations with config: \n{}",
lambda: config.to_yaml_string(),
)
augmentations = []
for step_config in config.steps:

View File

@ -2,7 +2,6 @@ from typing import Optional, Tuple, Union
import numpy as np
import xarray as xr
from loguru import logger
from soundevent import arrays
from batdetect2.configs import BaseConfig
@ -75,10 +74,6 @@ def build_clipper(
random: Optional[bool] = None,
) -> ClipperProtocol:
config = config or ClipingConfig()
logger.opt(lazy=True).debug(
"Building clipper with config: \n{}",
lambda: config.to_yaml_string(),
)
return Clipper(
duration=config.duration,
max_empty=config.max_empty,

View File

@ -50,7 +50,7 @@ class TrainingConfig(PLTrainerConfig):
learning_rate: float = 1e-3
t_max: int = 100
loss: LossConfig = Field(default_factory=LossConfig)
augmentations: Optional[AugmentationsConfig] = Field(
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)

View File

@ -93,15 +93,10 @@ def build_clip_labeler(
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
(spectrogram) and returns the generated `Heatmaps`.
"""
config = config or LabelConfig()
logger.opt(lazy=True).debug(
"Building clip labeler with config: \n{}",
lambda: config.to_yaml_string(),
)
return partial(
generate_clip_label,
targets=targets,
config=config,
config=config or LabelConfig(),
)

View File

@ -1,7 +1,6 @@
from typing import Annotated, Any, Literal, Optional, Union
from typing import Annotated, Literal, Optional, Union
from lightning.pytorch.loggers import Logger
from loguru import logger
from pydantic import Field
from batdetect2.configs import BaseConfig
@ -32,25 +31,11 @@ class TensorBoardLoggerConfig(BaseConfig):
name: Optional[str] = "default"
version: Optional[str] = None
log_graph: bool = False
class MLFlowLoggerConfig(BaseConfig):
logger_type: Literal["mlflow"] = "mlflow"
experiment_name: str = "default"
run_name: Optional[str] = None
save_dir: Optional[str] = "./mlruns"
tracking_uri: Optional[str] = None
tags: Optional[dict[str, Any]] = None
log_model: bool = False
flush_logs_every_n_steps: Optional[int] = None
LoggerConfig = Annotated[
Union[
DVCLiveConfig,
CSVLoggerConfig,
TensorBoardLoggerConfig,
MLFlowLoggerConfig,
],
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
Field(discriminator="logger_type"),
]
@ -97,31 +82,10 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
)
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error:
raise ValueError(
"MLFlow is not installed and cannot be used for logging. "
"Make sure you have it installed by running `pip install mlflow` "
"or `uv add mlflow`"
) from error
return MLFlowLogger(
experiment_name=config.experiment_name,
run_name=config.run_name,
save_dir=config.save_dir,
tracking_uri=config.tracking_uri,
tags=config.tags,
log_model=config.log_model,
)
LOGGER_FACTORY = {
"dvclive": create_dvclive_logger,
"csv": create_csv_logger,
"tensorboard": create_tensorboard_logger,
"mlflow": create_mlflow_logger,
}
@ -129,10 +93,6 @@ def build_logger(config: LoggerConfig) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
lambda: config.to_yaml_string(),
)
logger_type = config.logger_type
if logger_type not in LOGGER_FACTORY:

View File

@ -23,7 +23,6 @@ from typing import Optional
import numpy as np
import torch
import torch.nn.functional as F
from loguru import logger
from pydantic import Field
from torch import nn
@ -452,10 +451,6 @@ def build_loss(
An initialized `LossFunction` module ready for training.
"""
config = config or LossConfig()
logger.opt(lazy=True).debug(
"Building loss with config: \n{}",
lambda: config.to_yaml_string(),
)
class_weights_tensor = (
torch.tensor(class_weights) if class_weights else None

View File

@ -1,10 +1,8 @@
from collections.abc import Sequence
from typing import List, Optional
import yaml
from lightning import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger
from lightning.pytorch.callbacks import Callback
from soundevent import data
from torch.utils.data import DataLoader
@ -54,7 +52,6 @@ def train(
conf = config or FullTrainingConfig()
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = TrainingModule(conf)
@ -78,22 +75,15 @@ def train(
else None
)
logger.info("Starting main training loop...")
trainer.fit(
module,
train_dataloaders=train_dataloader,
val_dataloaders=val_dataloader,
)
logger.info("Training complete.")
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
return [
ModelCheckpoint(
dirpath="outputs/checkpoints",
save_top_k=1,
monitor="total_loss/val",
),
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
@ -102,7 +92,7 @@ def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
),
ClassificationAccuracy(class_names=targets.class_names),
]
),
)
]
@ -113,10 +103,6 @@ def build_trainer(
trainer_conf = PLTrainerConfig.model_validate(
conf.train.model_dump(mode="python")
)
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval,
@ -131,23 +117,12 @@ def build_train_loader(
config: TrainingConfig,
num_workers: Optional[int] = None,
) -> DataLoader:
logger.info("Building training data loader...")
train_dataset = build_train_dataset(
train_examples,
preprocessor=preprocessor,
config=config,
)
logger.opt(lazy=True).debug(
"Training data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": True,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader(
train_dataset,
batch_size=config.batch_size,
@ -162,22 +137,10 @@ def build_val_loader(
config: TrainingConfig,
num_workers: Optional[int] = None,
):
logger.info("Building validation data loader...")
val_dataset = build_val_dataset(
val_examples,
config=config,
)
logger.opt(lazy=True).debug(
"Validation data loader config: \n{}",
lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": False,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
)
return DataLoader(
val_dataset,
batch_size=config.batch_size,
@ -192,7 +155,6 @@ def build_train_dataset(
preprocessor: PreprocessorProtocol,
config: Optional[TrainingConfig] = None,
) -> LabeledDataset:
logger.info("Building training dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=True)
@ -202,15 +164,11 @@ def build_train_dataset(
clipper=clipper,
)
if config.augmentations and config.augmentations.steps:
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
else:
logger.debug("No augmentations configured for training dataset.")
augmentations = None
augmentations = build_augmentations(
preprocessor,
config=config.augmentations,
example_source=random_example_source,
)
return LabeledDataset(
examples,
@ -224,7 +182,6 @@ def build_val_dataset(
config: Optional[TrainingConfig] = None,
train: bool = True,
) -> LabeledDataset:
logger.info("Building validation dataset...")
config = config or TrainingConfig()
clipper = build_clipper(config.cliping, random=train)
return LabeledDataset(examples, clipper=clipper)

4003
uv.lock generated Normal file

File diff suppressed because it is too large Load Diff