mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Add better logging
This commit is contained in:
parent
19e873dd0b
commit
8368aad178
@ -13,9 +13,7 @@ from batdetect2.train import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.dataset import list_preprocessed_files
|
from batdetect2.train.dataset import list_preprocessed_files
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["train_command"]
|
||||||
"train_command",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="train")
|
@cli.command(name="train")
|
||||||
@ -51,19 +49,35 @@ def train_command(
|
|||||||
log_level = "DEBUG"
|
log_level = "DEBUG"
|
||||||
logger.add(sys.stderr, level=log_level)
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
logger.info("Starting training!")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
|
logger.info("Loading training configuration...")
|
||||||
conf = (
|
conf = (
|
||||||
load_full_training_config(config, field=config_field)
|
load_full_training_config(config, field=config_field)
|
||||||
if config is not None
|
if config is not None
|
||||||
else FullTrainingConfig()
|
else FullTrainingConfig()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Scanning for training and validation data...")
|
||||||
train_examples = list_preprocessed_files(train_dir)
|
train_examples = list_preprocessed_files(train_dir)
|
||||||
val_examples = (
|
logger.debug(
|
||||||
list_preprocessed_files(val_dir) if val_dir is not None else None
|
"Found {num_files} training examples in {path}",
|
||||||
|
num_files=len(train_examples),
|
||||||
|
path=train_dir,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
val_examples = None
|
||||||
|
if val_dir is not None:
|
||||||
|
val_examples = list_preprocessed_files(val_dir)
|
||||||
|
logger.debug(
|
||||||
|
"Found {num_files} validation examples in {path}",
|
||||||
|
num_files=len(val_examples),
|
||||||
|
path=val_dir,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug("No validation directory provided.")
|
||||||
|
|
||||||
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
train(
|
train(
|
||||||
train_examples=train_examples,
|
train_examples=train_examples,
|
||||||
val_examples=val_examples,
|
val_examples=val_examples,
|
||||||
|
|||||||
@ -29,6 +29,36 @@ class BaseConfig(BaseModel):
|
|||||||
|
|
||||||
model_config = ConfigDict(extra="ignore")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
|
def to_yaml_string(
|
||||||
|
self,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
) -> str:
|
||||||
|
"""Converts the Pydantic model instance to a YAML string.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
exclude_none : bool, default=False
|
||||||
|
Whether to exclude fields whose value is `None`.
|
||||||
|
exclude_unset : bool, default=False
|
||||||
|
Whether to exclude fields that were not explicitly set.
|
||||||
|
exclude_defaults : bool, default=False
|
||||||
|
Whether to exclude fields whose value is the default value.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
str
|
||||||
|
A YAML string representation of the model.
|
||||||
|
"""
|
||||||
|
return yaml.dump(
|
||||||
|
self.model_dump(
|
||||||
|
exclude_none=exclude_none,
|
||||||
|
exclude_unset=exclude_unset,
|
||||||
|
exclude_defaults=exclude_defaults,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
|
||||||
|
|||||||
@ -28,6 +28,8 @@ provided here.
|
|||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
Backbone,
|
Backbone,
|
||||||
BackboneConfig,
|
BackboneConfig,
|
||||||
@ -131,5 +133,10 @@ def build_model(
|
|||||||
construction of the backbone or detector components (e.g., incompatible
|
construction of the backbone or detector components (e.g., incompatible
|
||||||
configurations, invalid parameters).
|
configurations, invalid parameters).
|
||||||
"""
|
"""
|
||||||
backbone = build_backbone(config or BackboneConfig())
|
config = config or BackboneConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building model with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
backbone = build_backbone(config)
|
||||||
return build_detector(num_classes, backbone)
|
return build_detector(num_classes, backbone)
|
||||||
|
|||||||
@ -26,9 +26,9 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
LayerGroupConfig,
|
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvUpConfig,
|
FreqCoordConvUpConfig,
|
||||||
|
LayerGroupConfig,
|
||||||
StandardConvUpConfig,
|
StandardConvUpConfig,
|
||||||
build_layer_from_config,
|
build_layer_from_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -28,9 +28,9 @@ from torch import nn
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.models.blocks import (
|
from batdetect2.models.blocks import (
|
||||||
LayerGroupConfig,
|
|
||||||
ConvConfig,
|
ConvConfig,
|
||||||
FreqCoordConvDownConfig,
|
FreqCoordConvDownConfig,
|
||||||
|
LayerGroupConfig,
|
||||||
StandardConvDownConfig,
|
StandardConvDownConfig,
|
||||||
build_layer_from_config,
|
build_layer_from_config,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -31,6 +31,7 @@ It also re-exports key components from submodules for convenience.
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -203,9 +204,14 @@ def build_postprocessor(
|
|||||||
PostprocessorProtocol
|
PostprocessorProtocol
|
||||||
An initialized `Postprocessor` instance ready to process model outputs.
|
An initialized `Postprocessor` instance ready to process model outputs.
|
||||||
"""
|
"""
|
||||||
|
config = config or PostprocessConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building postprocessor with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
return Postprocessor(
|
return Postprocessor(
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config or PostprocessConfig(),
|
config=config,
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
max_freq=max_freq,
|
max_freq=max_freq,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -32,6 +32,7 @@ from typing import Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -429,6 +430,10 @@ def build_preprocessor(
|
|||||||
according to the configuration.
|
according to the configuration.
|
||||||
"""
|
"""
|
||||||
config = config or PreprocessingConfig()
|
config = config or PreprocessingConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building preprocessor with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
|
||||||
default_samplerate = (
|
default_samplerate = (
|
||||||
config.audio.resample.samplerate
|
config.audio.resample.samplerate
|
||||||
|
|||||||
@ -573,6 +573,10 @@ def build_targets(
|
|||||||
If dynamic import of a derivation function fails (when configured).
|
If dynamic import of a derivation function fails (when configured).
|
||||||
"""
|
"""
|
||||||
config = config or DEFAULT_TARGET_CONFIG
|
config = config or DEFAULT_TARGET_CONFIG
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building targets with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
|
||||||
filter_fn = (
|
filter_fn = (
|
||||||
build_sound_event_filter(
|
build_sound_event_filter(
|
||||||
|
|||||||
@ -9,8 +9,8 @@ from batdetect2.configs import BaseConfig, load_config
|
|||||||
from batdetect2.targets.terms import (
|
from batdetect2.targets.terms import (
|
||||||
TagInfo,
|
TagInfo,
|
||||||
TermRegistry,
|
TermRegistry,
|
||||||
get_tag_from_info,
|
|
||||||
default_term_registry,
|
default_term_registry,
|
||||||
|
get_tag_from_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|||||||
@ -28,6 +28,7 @@ from typing import Annotated, Callable, List, Literal, Optional, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent import arrays, data
|
from soundevent import arrays, data
|
||||||
|
|
||||||
@ -851,6 +852,11 @@ def build_augmentations(
|
|||||||
"""
|
"""
|
||||||
config = config or DEFAULT_AUGMENTATION_CONFIG
|
config = config or DEFAULT_AUGMENTATION_CONFIG
|
||||||
|
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building augmentations with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
|
||||||
augmentations = []
|
augmentations = []
|
||||||
|
|
||||||
for step_config in config.steps:
|
for step_config in config.steps:
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from typing import Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
|
from loguru import logger
|
||||||
from soundevent import arrays
|
from soundevent import arrays
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
@ -74,6 +75,10 @@ def build_clipper(
|
|||||||
random: Optional[bool] = None,
|
random: Optional[bool] = None,
|
||||||
) -> ClipperProtocol:
|
) -> ClipperProtocol:
|
||||||
config = config or ClipingConfig()
|
config = config or ClipingConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building clipper with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
return Clipper(
|
return Clipper(
|
||||||
duration=config.duration,
|
duration=config.duration,
|
||||||
max_empty=config.max_empty,
|
max_empty=config.max_empty,
|
||||||
|
|||||||
@ -93,10 +93,15 @@ def build_clip_labeler(
|
|||||||
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
|
A function that accepts a `data.ClipAnnotation` and `xr.DataArray`
|
||||||
(spectrogram) and returns the generated `Heatmaps`.
|
(spectrogram) and returns the generated `Heatmaps`.
|
||||||
"""
|
"""
|
||||||
|
config = config or LabelConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building clip labeler with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
return partial(
|
return partial(
|
||||||
generate_clip_label,
|
generate_clip_label,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
config=config or LabelConfig(),
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from typing import Annotated, Any, Literal, Optional, Union
|
from typing import Annotated, Any, Literal, Optional, Union
|
||||||
|
|
||||||
from lightning.pytorch.loggers import Logger
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
@ -129,6 +130,10 @@ def build_logger(config: LoggerConfig) -> Logger:
|
|||||||
"""
|
"""
|
||||||
Creates a logger instance from a validated Pydantic config object.
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
"""
|
"""
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building logger with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
logger_type = config.logger_type
|
logger_type = config.logger_type
|
||||||
|
|
||||||
if logger_type not in LOGGER_FACTORY:
|
if logger_type not in LOGGER_FACTORY:
|
||||||
@ -137,4 +142,3 @@ def build_logger(config: LoggerConfig) -> Logger:
|
|||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config)
|
return creation_func(config)
|
||||||
|
|
||||||
|
|||||||
@ -23,6 +23,7 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -451,6 +452,10 @@ def build_loss(
|
|||||||
An initialized `LossFunction` module ready for training.
|
An initialized `LossFunction` module ready for training.
|
||||||
"""
|
"""
|
||||||
config = config or LossConfig()
|
config = config or LossConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building loss with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
|
|
||||||
class_weights_tensor = (
|
class_weights_tensor = (
|
||||||
torch.tensor(class_weights) if class_weights else None
|
torch.tensor(class_weights) if class_weights else None
|
||||||
|
|||||||
@ -1,8 +1,10 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
@ -52,6 +54,7 @@ def train(
|
|||||||
conf = config or FullTrainingConfig()
|
conf = config or FullTrainingConfig()
|
||||||
|
|
||||||
if model_path is not None:
|
if model_path is not None:
|
||||||
|
logger.debug("Loading model from: {path}", path=model_path)
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
else:
|
else:
|
||||||
module = TrainingModule(conf)
|
module = TrainingModule(conf)
|
||||||
@ -75,11 +78,13 @@ def train(
|
|||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.info("Starting main training loop...")
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
train_dataloaders=train_dataloader,
|
train_dataloaders=train_dataloader,
|
||||||
val_dataloaders=val_dataloader,
|
val_dataloaders=val_dataloader,
|
||||||
)
|
)
|
||||||
|
logger.info("Training complete.")
|
||||||
|
|
||||||
|
|
||||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||||
@ -103,6 +108,10 @@ def build_trainer(
|
|||||||
trainer_conf = PLTrainerConfig.model_validate(
|
trainer_conf = PLTrainerConfig.model_validate(
|
||||||
conf.train.model_dump(mode="python")
|
conf.train.model_dump(mode="python")
|
||||||
)
|
)
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building trainer with config: \n{config}",
|
||||||
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
|
)
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
val_check_interval=conf.train.val_check_interval,
|
val_check_interval=conf.train.val_check_interval,
|
||||||
@ -117,12 +126,23 @@ def build_train_loader(
|
|||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
|
logger.info("Building training data loader...")
|
||||||
train_dataset = build_train_dataset(
|
train_dataset = build_train_dataset(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Training data loader config: \n{}",
|
||||||
|
lambda: yaml.dump(
|
||||||
|
{
|
||||||
|
"batch_size": config.batch_size,
|
||||||
|
"shuffle": True,
|
||||||
|
"num_workers": num_workers or 0,
|
||||||
|
"collate_fn": str(collate_fn),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
@ -137,10 +157,22 @@ def build_val_loader(
|
|||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
logger.info("Building validation data loader...")
|
||||||
val_dataset = build_val_dataset(
|
val_dataset = build_val_dataset(
|
||||||
val_examples,
|
val_examples,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Validation data loader config: \n{}",
|
||||||
|
lambda: yaml.dump(
|
||||||
|
{
|
||||||
|
"batch_size": config.batch_size,
|
||||||
|
"shuffle": False,
|
||||||
|
"num_workers": num_workers or 0,
|
||||||
|
"collate_fn": str(collate_fn),
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
@ -155,6 +187,7 @@ def build_train_dataset(
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
) -> LabeledDataset:
|
) -> LabeledDataset:
|
||||||
|
logger.info("Building training dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
clipper = build_clipper(config.cliping, random=True)
|
clipper = build_clipper(config.cliping, random=True)
|
||||||
@ -164,18 +197,15 @@ def build_train_dataset(
|
|||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(
|
if config.augmentations and config.augmentations.steps:
|
||||||
"Augmentations config: {}.", config.augmentations
|
augmentations = build_augmentations(
|
||||||
)
|
|
||||||
augmentations = (
|
|
||||||
build_augmentations(
|
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.augmentations,
|
config=config.augmentations,
|
||||||
example_source=random_example_source,
|
example_source=random_example_source,
|
||||||
)
|
)
|
||||||
if config.augmentations
|
else:
|
||||||
else None
|
logger.debug("No augmentations configured for training dataset.")
|
||||||
)
|
augmentations = None
|
||||||
|
|
||||||
return LabeledDataset(
|
return LabeledDataset(
|
||||||
examples,
|
examples,
|
||||||
@ -189,6 +219,7 @@ def build_val_dataset(
|
|||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
train: bool = True,
|
train: bool = True,
|
||||||
) -> LabeledDataset:
|
) -> LabeledDataset:
|
||||||
|
logger.info("Building validation dataset...")
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
clipper = build_clipper(config.cliping, random=train)
|
clipper = build_clipper(config.cliping, random=train)
|
||||||
return LabeledDataset(examples, clipper=clipper)
|
return LabeledDataset(examples, clipper=clipper)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user