This commit is contained in:
mbsantiago 2025-09-08 22:04:30 +01:00
parent c73984b213
commit cd4955d4f3
12 changed files with 341 additions and 53 deletions

View File

@ -108,6 +108,9 @@ train:
labels:
sigma: 3
trainer:
max_epochs: 40
dataloaders:
train:
batch_size: 8
@ -115,7 +118,7 @@ train:
shuffle: True
val:
batch_size: 8
batch_size: 1
num_workers: 2
loss:
@ -133,7 +136,7 @@ train:
weight: 0.1
logger:
logger_type: tensorboard
logger_type: csv
# save_dir: outputs/log/
# name: logs

View File

@ -1,6 +1,7 @@
from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
from batdetect2.cli.data import data
from batdetect2.cli.evaluate import evaluate_command
from batdetect2.cli.train import train_command
__all__ = [
@ -8,6 +9,7 @@ __all__ = [
"detect",
"data",
"train_command",
"evaluate_command",
]

View File

@ -0,0 +1,63 @@
import sys
from pathlib import Path
from typing import Optional
import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.train.lightning import load_model_from_checkpoint
__all__ = ["evaluate_command"]
@cli.command(name="evaluate")
@click.argument("model-path", type=click.Path(exists=True))
@click.argument("test_dataset", type=click.Path(exists=True))
@click.option("--output-dir", type=click.Path())
@click.option("--workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command(
model_path: Path,
test_dataset: Path,
output_dir: Optional[Path] = None,
workers: Optional[int] = None,
verbose: int = 0,
):
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(test_dataset)
logger.debug(
"Loaded {num_annotations} test examples",
num_annotations=len(test_annotations),
)
model, train_config = load_model_from_checkpoint(model_path)
df, results = evaluate(
model,
test_annotations,
config=train_config,
num_workers=workers,
)
print(results)
if output_dir:
df.to_csv(output_dir / "results.csv")

View File

@ -20,6 +20,8 @@ __all__ = ["train_command"]
@click.argument("train_dataset", type=click.Path(exists=True))
@click.option("--val-dataset", type=click.Path(exists=True))
@click.option("--model-path", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@ -34,6 +36,8 @@ def train_command(
train_dataset: Path,
val_dataset: Optional[Path] = None,
model_path: Optional[Path] = None,
ckpt_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
train_workers: int = 0,
@ -83,4 +87,6 @@ def train_command(
model_path=model_path,
train_workers=train_workers,
val_workers=val_workers,
log_dir=log_dir,
checkpoint_dir=ckpt_dir,
)

View File

@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
and serialization capabilities.
"""
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="ignore")
def to_yaml_string(
self,

View File

@ -0,0 +1,62 @@
from typing import List
import pandas as pd
from soundevent.geometry import compute_bounds
from batdetect2.typing.evaluate import MatchEvaluation
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
data = []
for match in matches:
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
sound_event_annotation = match.sound_event_annotation
if sound_event_annotation is not None:
geometry = sound_event_annotation.sound_event.geometry
assert geometry is not None
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
compute_bounds(geometry)
)
if match.pred_geometry is not None:
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
compute_bounds(match.pred_geometry)
)
data.append(
{
("recording", "uuid"): match.clip.recording.uuid,
("clip", "uuid"): match.clip.uuid,
("clip", "start_time"): match.clip.start_time,
("clip", "end_time"): match.clip.end_time,
("gt", "uuid"): match.sound_event_annotation.uuid
if match.sound_event_annotation is not None
else None,
("gt", "class"): match.gt_class,
("gt", "det"): match.gt_det,
("gt", "start_time"): gt_start_time,
("gt", "end_time"): gt_end_time,
("gt", "low_freq"): gt_low_freq,
("gt", "high_freq"): gt_high_freq,
("pred", "score"): match.pred_score,
("pred", "class"): match.pred_class,
("pred", "class_score"): match.pred_class_score,
("pred", "start_time"): pred_start_time,
("pred", "end_time"): pred_end_time,
("pred", "low_freq"): pred_low_freq,
("pred", "high_freq"): pred_high_freq,
("match", "affinity"): match.affinity,
**{
("pred_class_score", key): value
for key, value in match.pred_class_scores.items()
},
}
)
df = pd.DataFrame(data)
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
return df

View File

@ -0,0 +1,100 @@
from typing import List, Optional, Tuple
import pandas as pd
from soundevent import data
from batdetect2.evaluate.dataframe import extract_matches_dataframe
from batdetect2.evaluate.match import match_all_predictions
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import Model
from batdetect2.plotting.clips import build_audio_loader
from batdetect2.postprocess import get_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.dataset import ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.train import build_val_loader
def evaluate(
model: Model,
test_annotations: List[data.ClipAnnotation],
config: Optional[FullTrainingConfig] = None,
num_workers: Optional[int] = None,
) -> Tuple[pd.DataFrame, dict]:
config = config or FullTrainingConfig()
audio_loader = build_audio_loader(config.preprocess.audio)
preprocessor = build_preprocessor(config.preprocess)
targets = build_targets(config.targets)
labeller = build_clip_labeler(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.train.labels,
)
loader = build_val_loader(
test_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=preprocessor,
config=config.train,
num_workers=num_workers,
)
dataset: ValidationDataset = loader.dataset # type: ignore
clip_annotations = []
predictions = []
for batch in loader:
outputs = model.detector(batch.spec)
clip_annotations = [
dataset.clip_annotations[int(example_idx)]
for example_idx in batch.idx
]
predictions = get_raw_predictions(
outputs,
clips=[
clip_annotation.clip for clip_annotation in clip_annotations
],
targets=targets,
postprocessor=model.postprocessor,
)
clip_annotations.extend(clip_annotations)
predictions.extend(predictions)
matches = match_all_predictions(
clip_annotations,
predictions,
targets=targets,
config=config.evaluation.match,
)
df = extract_matches_dataframe(matches)
metrics = [
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(class_names=targets.class_names),
ClassificationAccuracy(class_names=targets.class_names),
]
results = {
name: value
for metric in metrics
for name, value in metric(matches).items()
}
return df, results

View File

@ -29,7 +29,6 @@ provided here.
from typing import List, Optional
import torch
from lightning import LightningModule
from pydantic import Field
from soundevent.data import PathLike
@ -105,7 +104,16 @@ __all__ = [
]
class Model(LightningModule):
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
class Model(torch.nn.Module):
detector: DetectionModel
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
@ -117,13 +125,14 @@ class Model(LightningModule):
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
targets: TargetProtocol,
config: ModelConfig,
):
super().__init__()
self.detector = detector
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.targets = targets
self.save_hyperparameters()
self.config = config
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
spec = self.preprocessor(wav)
@ -131,29 +140,24 @@ class Model(LightningModule):
return self.postprocessor(outputs)
class ModelConfig(BaseConfig):
model: BackboneConfig = Field(default_factory=BackboneConfig)
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
def build_model(config: Optional[ModelConfig] = None):
config = config or ModelConfig()
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
postprocessor = build_postprocessor(
preprocessor=preprocessor,
config=config.postprocess,
)
detector = build_detector(
num_classes=len(targets.class_names),
config=config.model,
)
return Model(
config=config,
detector=detector,
postprocessor=postprocessor,
preprocessor=preprocessor,

View File

@ -6,7 +6,6 @@ from soundevent import data
from batdetect2.configs import BaseConfig, load_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.models import ModelConfig
from batdetect2.targets import TargetConfig
from batdetect2.train.augmentations import (
DEFAULT_AUGMENTATION_CONFIG,
AugmentationsConfig,
@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig):
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)

View File

@ -1,9 +1,14 @@
from typing import Optional, Tuple
import lightning as L
import torch
from soundevent.data import PathLike
from torch.optim.adam import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from batdetect2.models import Model
from batdetect2.models import Model, build_model
from batdetect2.train.config import FullTrainingConfig
from batdetect2.train.losses import build_loss
from batdetect2.typing import ModelOutput, TrainExample
__all__ = [
@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule):
def __init__(
self,
model: Model,
loss: torch.nn.Module,
config: FullTrainingConfig,
learning_rate: float = 0.001,
t_max: int = 100,
model: Optional[Model] = None,
loss: Optional[torch.nn.Module] = None,
):
super().__init__()
self.save_hyperparameters(logger=False)
self.config = config
self.learning_rate = learning_rate
self.t_max = t_max
if loss is None:
loss = build_loss(self.config.train.loss)
if model is None:
model = build_model(self.config)
self.loss = loss
self.model = model
self.save_hyperparameters(logger=False)
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.model.detector(spec)
def training_step(self, batch: TrainExample):
outputs = self.model.detector(batch.spec)
@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule):
optimizer = Adam(self.parameters(), lr=self.learning_rate)
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
return [optimizer], [scheduler]
def load_model_from_checkpoint(
path: PathLike,
) -> Tuple[Model, FullTrainingConfig]:
module = TrainingModule.load_from_checkpoint(path) # type: ignore
return module.model, module.config

View File

@ -5,10 +5,11 @@ import numpy as np
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
from loguru import logger
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig
DEFAULT_LOGS_DIR: str = "logs"
DEFAULT_LOGS_DIR: str = "outputs"
class DVCLiveConfig(BaseConfig):
@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig):
class TensorBoardLoggerConfig(BaseConfig):
logger_type: Literal["tensorboard"] = "tensorboard"
save_dir: str = DEFAULT_LOGS_DIR
name: Optional[str] = "default"
name: Optional[str] = "logs"
version: Optional[str] = None
log_graph: bool = False
@ -57,7 +58,10 @@ LoggerConfig = Annotated[
]
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
def create_dvclive_logger(
config: DVCLiveConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
try:
from dvclive.lightning import DVCLiveLogger # type: ignore
except ImportError as error:
@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
) from error
return DVCLiveLogger(
dir=config.dir,
dir=log_dir if log_dir is not None else config.dir,
run_name=config.run_name,
prefix=config.prefix,
log_model=config.log_model,
@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
)
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
def create_csv_logger(
config: CSVLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
from lightning.pytorch.loggers import CSVLogger
return CSVLogger(
save_dir=config.save_dir,
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name,
version=config.version,
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
)
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
def create_tensorboard_logger(
config: TensorBoardLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
from lightning.pytorch.loggers import TensorBoardLogger
return TensorBoardLogger(
save_dir=config.save_dir,
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
name=config.name,
version=config.version,
log_graph=config.log_graph,
)
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
def create_mlflow_logger(
config: MLFlowLoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
try:
from lightning.pytorch.loggers import MLFlowLogger
except ImportError as error:
@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
return MLFlowLogger(
experiment_name=config.experiment_name,
run_name=config.run_name,
save_dir=config.save_dir,
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
tracking_uri=config.tracking_uri,
tags=config.tags,
log_model=config.log_model,
@ -126,7 +139,10 @@ LOGGER_FACTORY = {
}
def build_logger(config: LoggerConfig) -> Logger:
def build_logger(
config: LoggerConfig,
log_dir: Optional[data.PathLike] = None,
) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger:
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(config)
return creation_func(config, log_dir=log_dir)
def get_image_plotter(logger: Logger):

View File

@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import Model, build_model
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train.augmentations import (
RandomAudioSource,
build_augmentations,
@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.logging import build_logger
from batdetect2.train.losses import build_loss
from batdetect2.typing import (
PreprocessorProtocol,
TargetProtocol,
@ -54,19 +53,21 @@ def train(
model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None,
):
config = config or FullTrainingConfig()
model = build_model(config=config)
targets = build_targets(config.targets)
trainer = build_trainer(config, targets=model.targets)
preprocessor = build_preprocessor(config.preprocess)
audio_loader = build_audio_loader(config=config.preprocess.audio)
labeller = build_clip_labeler(
model.targets,
min_freq=model.preprocessor.min_freq,
max_freq=model.preprocessor.max_freq,
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
config=config.train.labels,
)
@ -74,7 +75,7 @@ def train(
train_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=build_preprocessor(config.preprocess),
preprocessor=preprocessor,
config=config.train,
num_workers=train_workers,
)
@ -84,7 +85,7 @@ def train(
val_annotations,
audio_loader=audio_loader,
labeller=labeller,
preprocessor=build_preprocessor(config.preprocess),
preprocessor=preprocessor,
config=config.train,
num_workers=val_workers,
)
@ -97,11 +98,17 @@ def train(
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
else:
module = build_training_module(
model,
config,
t_max=config.train.t_max * len(train_dataloader),
)
trainer = build_trainer(
config,
targets=targets,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
)
logger.info("Starting main training loop...")
trainer.fit(
module,
@ -112,15 +119,12 @@ def train(
def build_training_module(
model: Model,
config: Optional[FullTrainingConfig] = None,
t_max: int = 200,
) -> TrainingModule:
config = config or FullTrainingConfig()
loss = build_loss(config=config.train.loss)
return TrainingModule(
model=model,
loss=loss,
config=config,
learning_rate=config.train.learning_rate,
t_max=t_max,
)
@ -130,10 +134,14 @@ def build_trainer_callbacks(
targets: TargetProtocol,
preprocessor: PreprocessorProtocol,
config: EvaluationConfig,
checkpoint_dir: Optional[data.PathLike] = None,
) -> List[Callback]:
if checkpoint_dir is None:
checkpoint_dir = "outputs/checkpoints"
return [
ModelCheckpoint(
dirpath="outputs/checkpoints",
dirpath=str(checkpoint_dir),
save_top_k=1,
monitor="total_loss/val",
),
@ -154,15 +162,22 @@ def build_trainer_callbacks(
def build_trainer(
conf: FullTrainingConfig,
targets: TargetProtocol,
checkpoint_dir: Optional[data.PathLike] = None,
log_dir: Optional[data.PathLike] = None,
) -> Trainer:
trainer_conf = conf.train.trainer
logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
)
train_logger = build_logger(conf.train.logger)
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
train_logger.log_hyperparams(conf.model_dump(mode="json"))
train_logger.log_hyperparams(
conf.model_dump(
mode="json",
exclude_none=True,
)
)
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
@ -171,6 +186,7 @@ def build_trainer(
targets,
config=conf.evaluation,
preprocessor=build_preprocessor(conf.preprocess),
checkpoint_dir=checkpoint_dir,
),
)