mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Eval
This commit is contained in:
parent
c73984b213
commit
cd4955d4f3
@ -108,6 +108,9 @@ train:
|
||||
labels:
|
||||
sigma: 3
|
||||
|
||||
trainer:
|
||||
max_epochs: 40
|
||||
|
||||
dataloaders:
|
||||
train:
|
||||
batch_size: 8
|
||||
@ -115,7 +118,7 @@ train:
|
||||
shuffle: True
|
||||
|
||||
val:
|
||||
batch_size: 8
|
||||
batch_size: 1
|
||||
num_workers: 2
|
||||
|
||||
loss:
|
||||
@ -133,7 +136,7 @@ train:
|
||||
weight: 0.1
|
||||
|
||||
logger:
|
||||
logger_type: tensorboard
|
||||
logger_type: csv
|
||||
# save_dir: outputs/log/
|
||||
# name: logs
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.evaluate import evaluate_command
|
||||
from batdetect2.cli.train import train_command
|
||||
|
||||
__all__ = [
|
||||
@ -8,6 +9,7 @@ __all__ = [
|
||||
"detect",
|
||||
"data",
|
||||
"train_command",
|
||||
"evaluate_command",
|
||||
]
|
||||
|
||||
|
||||
|
||||
63
src/batdetect2/cli/evaluate.py
Normal file
63
src/batdetect2/cli/evaluate.py
Normal file
@ -0,0 +1,63 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
|
||||
__all__ = ["evaluate_command"]
|
||||
|
||||
|
||||
@cli.command(name="evaluate")
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--output-dir", type=click.Path())
|
||||
@click.option("--workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
output_dir: Optional[Path] = None,
|
||||
workers: Optional[int] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
test_annotations = load_dataset_from_config(test_dataset)
|
||||
logger.debug(
|
||||
"Loaded {num_annotations} test examples",
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
model, train_config = load_model_from_checkpoint(model_path)
|
||||
|
||||
df, results = evaluate(
|
||||
model,
|
||||
test_annotations,
|
||||
config=train_config,
|
||||
num_workers=workers,
|
||||
)
|
||||
|
||||
print(results)
|
||||
|
||||
if output_dir:
|
||||
df.to_csv(output_dir / "results.csv")
|
||||
@ -20,6 +20,8 @@ __all__ = ["train_command"]
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@ -34,6 +36,8 @@ def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Optional[Path] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
train_workers: int = 0,
|
||||
@ -83,4 +87,6 @@ def train_command(
|
||||
model_path=model_path,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
)
|
||||
|
||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
||||
and serialization capabilities.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
model_config = ConfigDict(extra="ignore")
|
||||
|
||||
def to_yaml_string(
|
||||
self,
|
||||
|
||||
62
src/batdetect2/evaluate/dataframe.py
Normal file
62
src/batdetect2/evaluate/dataframe.py
Normal file
@ -0,0 +1,62 @@
|
||||
from typing import List
|
||||
|
||||
import pandas as pd
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.typing.evaluate import MatchEvaluation
|
||||
|
||||
|
||||
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
|
||||
data = []
|
||||
|
||||
for match in matches:
|
||||
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||
|
||||
sound_event_annotation = match.sound_event_annotation
|
||||
|
||||
if sound_event_annotation is not None:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
assert geometry is not None
|
||||
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||
compute_bounds(geometry)
|
||||
)
|
||||
|
||||
if match.pred_geometry is not None:
|
||||
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||
compute_bounds(match.pred_geometry)
|
||||
)
|
||||
|
||||
data.append(
|
||||
{
|
||||
("recording", "uuid"): match.clip.recording.uuid,
|
||||
("clip", "uuid"): match.clip.uuid,
|
||||
("clip", "start_time"): match.clip.start_time,
|
||||
("clip", "end_time"): match.clip.end_time,
|
||||
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||
if match.sound_event_annotation is not None
|
||||
else None,
|
||||
("gt", "class"): match.gt_class,
|
||||
("gt", "det"): match.gt_det,
|
||||
("gt", "start_time"): gt_start_time,
|
||||
("gt", "end_time"): gt_end_time,
|
||||
("gt", "low_freq"): gt_low_freq,
|
||||
("gt", "high_freq"): gt_high_freq,
|
||||
("pred", "score"): match.pred_score,
|
||||
("pred", "class"): match.pred_class,
|
||||
("pred", "class_score"): match.pred_class_score,
|
||||
("pred", "start_time"): pred_start_time,
|
||||
("pred", "end_time"): pred_end_time,
|
||||
("pred", "low_freq"): pred_low_freq,
|
||||
("pred", "high_freq"): pred_high_freq,
|
||||
("match", "affinity"): match.affinity,
|
||||
**{
|
||||
("pred_class_score", key): value
|
||||
for key, value in match.pred_class_scores.items()
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data)
|
||||
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||
return df
|
||||
100
src/batdetect2/evaluate/evaluate.py
Normal file
100
src/batdetect2/evaluate/evaluate.py
Normal file
@ -0,0 +1,100 @@
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pandas as pd
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||
from batdetect2.evaluate.match import match_all_predictions
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.plotting.clips import build_audio_loader
|
||||
from batdetect2.postprocess import get_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.dataset import ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.train import build_val_loader
|
||||
|
||||
|
||||
def evaluate(
|
||||
model: Model,
|
||||
test_annotations: List[data.ClipAnnotation],
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> Tuple[pd.DataFrame, dict]:
|
||||
config = config or FullTrainingConfig()
|
||||
|
||||
audio_loader = build_audio_loader(config.preprocess.audio)
|
||||
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
|
||||
targets = build_targets(config.targets)
|
||||
|
||||
labeller = build_clip_labeler(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.train.labels,
|
||||
)
|
||||
|
||||
loader = build_val_loader(
|
||||
test_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
dataset: ValidationDataset = loader.dataset # type: ignore
|
||||
|
||||
clip_annotations = []
|
||||
predictions = []
|
||||
|
||||
for batch in loader:
|
||||
outputs = model.detector(batch.spec)
|
||||
|
||||
clip_annotations = [
|
||||
dataset.clip_annotations[int(example_idx)]
|
||||
for example_idx in batch.idx
|
||||
]
|
||||
|
||||
predictions = get_raw_predictions(
|
||||
outputs,
|
||||
clips=[
|
||||
clip_annotation.clip for clip_annotation in clip_annotations
|
||||
],
|
||||
targets=targets,
|
||||
postprocessor=model.postprocessor,
|
||||
)
|
||||
|
||||
clip_annotations.extend(clip_annotations)
|
||||
predictions.extend(predictions)
|
||||
|
||||
matches = match_all_predictions(
|
||||
clip_annotations,
|
||||
predictions,
|
||||
targets=targets,
|
||||
config=config.evaluation.match,
|
||||
)
|
||||
|
||||
df = extract_matches_dataframe(matches)
|
||||
|
||||
metrics = [
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
|
||||
results = {
|
||||
name: value
|
||||
for metric in metrics
|
||||
for name, value in metric(matches).items()
|
||||
}
|
||||
|
||||
return df, results
|
||||
@ -29,7 +29,6 @@ provided here.
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from lightning import LightningModule
|
||||
from pydantic import Field
|
||||
from soundevent.data import PathLike
|
||||
|
||||
@ -105,7 +104,16 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class Model(LightningModule):
|
||||
class ModelConfig(BaseConfig):
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
detector: DetectionModel
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
@ -117,13 +125,14 @@ class Model(LightningModule):
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
targets: TargetProtocol,
|
||||
config: ModelConfig,
|
||||
):
|
||||
super().__init__()
|
||||
self.detector = detector
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.targets = targets
|
||||
self.save_hyperparameters()
|
||||
self.config = config
|
||||
|
||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||
spec = self.preprocessor(wav)
|
||||
@ -131,29 +140,24 @@ class Model(LightningModule):
|
||||
return self.postprocessor(outputs)
|
||||
|
||||
|
||||
class ModelConfig(BaseConfig):
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
|
||||
|
||||
def build_model(config: Optional[ModelConfig] = None):
|
||||
config = config or ModelConfig()
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
preprocessor = build_preprocessor(config=config.preprocess)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor=preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
detector = build_detector(
|
||||
num_classes=len(targets.class_names),
|
||||
config=config.model,
|
||||
)
|
||||
return Model(
|
||||
config=config,
|
||||
detector=detector,
|
||||
postprocessor=postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@ -6,7 +6,6 @@ from soundevent import data
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.augmentations import (
|
||||
DEFAULT_AUGMENTATION_CONFIG,
|
||||
AugmentationsConfig,
|
||||
@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig):
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
|
||||
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import lightning as L
|
||||
import torch
|
||||
from soundevent.data import PathLike
|
||||
from torch.optim.adam import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.train.config import FullTrainingConfig
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import ModelOutput, TrainExample
|
||||
|
||||
__all__ = [
|
||||
@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
loss: torch.nn.Module,
|
||||
config: FullTrainingConfig,
|
||||
learning_rate: float = 0.001,
|
||||
t_max: int = 100,
|
||||
model: Optional[Model] = None,
|
||||
loss: Optional[torch.nn.Module] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
self.config = config
|
||||
self.learning_rate = learning_rate
|
||||
self.t_max = t_max
|
||||
|
||||
if loss is None:
|
||||
loss = build_loss(self.config.train.loss)
|
||||
|
||||
if model is None:
|
||||
model = build_model(self.config)
|
||||
|
||||
self.loss = loss
|
||||
self.model = model
|
||||
self.save_hyperparameters(logger=False)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
||||
return self.model.detector(spec)
|
||||
|
||||
def training_step(self, batch: TrainExample):
|
||||
outputs = self.model.detector(batch.spec)
|
||||
@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule):
|
||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||
return [optimizer], [scheduler]
|
||||
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike,
|
||||
) -> Tuple[Model, FullTrainingConfig]:
|
||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||
return module.model, module.config
|
||||
|
||||
@ -5,10 +5,11 @@ import numpy as np
|
||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||
from loguru import logger
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
|
||||
DEFAULT_LOGS_DIR: str = "logs"
|
||||
DEFAULT_LOGS_DIR: str = "outputs"
|
||||
|
||||
|
||||
class DVCLiveConfig(BaseConfig):
|
||||
@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig):
|
||||
class TensorBoardLoggerConfig(BaseConfig):
|
||||
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||
save_dir: str = DEFAULT_LOGS_DIR
|
||||
name: Optional[str] = "default"
|
||||
name: Optional[str] = "logs"
|
||||
version: Optional[str] = None
|
||||
log_graph: bool = False
|
||||
|
||||
@ -57,7 +58,10 @@ LoggerConfig = Annotated[
|
||||
]
|
||||
|
||||
|
||||
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||
def create_dvclive_logger(
|
||||
config: DVCLiveConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||
except ImportError as error:
|
||||
@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||
) from error
|
||||
|
||||
return DVCLiveLogger(
|
||||
dir=config.dir,
|
||||
dir=log_dir if log_dir is not None else config.dir,
|
||||
run_name=config.run_name,
|
||||
prefix=config.prefix,
|
||||
log_model=config.log_model,
|
||||
@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||
)
|
||||
|
||||
|
||||
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
||||
def create_csv_logger(
|
||||
config: CSVLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import CSVLogger
|
||||
|
||||
return CSVLogger(
|
||||
save_dir=config.save_dir,
|
||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||
name=config.name,
|
||||
version=config.version,
|
||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||
)
|
||||
|
||||
|
||||
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
||||
def create_tensorboard_logger(
|
||||
config: TensorBoardLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
from lightning.pytorch.loggers import TensorBoardLogger
|
||||
|
||||
return TensorBoardLogger(
|
||||
save_dir=config.save_dir,
|
||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||
name=config.name,
|
||||
version=config.version,
|
||||
log_graph=config.log_graph,
|
||||
)
|
||||
|
||||
|
||||
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
||||
def create_mlflow_logger(
|
||||
config: MLFlowLoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
try:
|
||||
from lightning.pytorch.loggers import MLFlowLogger
|
||||
except ImportError as error:
|
||||
@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
||||
return MLFlowLogger(
|
||||
experiment_name=config.experiment_name,
|
||||
run_name=config.run_name,
|
||||
save_dir=config.save_dir,
|
||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||
tracking_uri=config.tracking_uri,
|
||||
tags=config.tags,
|
||||
log_model=config.log_model,
|
||||
@ -126,7 +139,10 @@ LOGGER_FACTORY = {
|
||||
}
|
||||
|
||||
|
||||
def build_logger(config: LoggerConfig) -> Logger:
|
||||
def build_logger(
|
||||
config: LoggerConfig,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Logger:
|
||||
"""
|
||||
Creates a logger instance from a validated Pydantic config object.
|
||||
"""
|
||||
@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger:
|
||||
|
||||
creation_func = LOGGER_FACTORY[logger_type]
|
||||
|
||||
return creation_func(config)
|
||||
return creation_func(config, log_dir=log_dir)
|
||||
|
||||
|
||||
def get_image_plotter(logger: Logger):
|
||||
|
||||
@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train.augmentations import (
|
||||
RandomAudioSource,
|
||||
build_augmentations,
|
||||
@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.logging import build_logger
|
||||
from batdetect2.train.losses import build_loss
|
||||
from batdetect2.typing import (
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
@ -54,19 +53,21 @@ def train(
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[data.PathLike] = None,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
config = config or FullTrainingConfig()
|
||||
|
||||
model = build_model(config=config)
|
||||
targets = build_targets(config.targets)
|
||||
|
||||
trainer = build_trainer(config, targets=model.targets)
|
||||
preprocessor = build_preprocessor(config.preprocess)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||
|
||||
labeller = build_clip_labeler(
|
||||
model.targets,
|
||||
min_freq=model.preprocessor.min_freq,
|
||||
max_freq=model.preprocessor.max_freq,
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
config=config.train.labels,
|
||||
)
|
||||
|
||||
@ -74,7 +75,7 @@ def train(
|
||||
train_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=build_preprocessor(config.preprocess),
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
@ -84,7 +85,7 @@ def train(
|
||||
val_annotations,
|
||||
audio_loader=audio_loader,
|
||||
labeller=labeller,
|
||||
preprocessor=build_preprocessor(config.preprocess),
|
||||
preprocessor=preprocessor,
|
||||
config=config.train,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
@ -97,11 +98,17 @@ def train(
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = build_training_module(
|
||||
model,
|
||||
config,
|
||||
t_max=config.train.t_max * len(train_dataloader),
|
||||
)
|
||||
|
||||
trainer = build_trainer(
|
||||
config,
|
||||
targets=targets,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
)
|
||||
|
||||
logger.info("Starting main training loop...")
|
||||
trainer.fit(
|
||||
module,
|
||||
@ -112,15 +119,12 @@ def train(
|
||||
|
||||
|
||||
def build_training_module(
|
||||
model: Model,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
t_max: int = 200,
|
||||
) -> TrainingModule:
|
||||
config = config or FullTrainingConfig()
|
||||
loss = build_loss(config=config.train.loss)
|
||||
return TrainingModule(
|
||||
model=model,
|
||||
loss=loss,
|
||||
config=config,
|
||||
learning_rate=config.train.learning_rate,
|
||||
t_max=t_max,
|
||||
)
|
||||
@ -130,10 +134,14 @@ def build_trainer_callbacks(
|
||||
targets: TargetProtocol,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: EvaluationConfig,
|
||||
checkpoint_dir: Optional[data.PathLike] = None,
|
||||
) -> List[Callback]:
|
||||
if checkpoint_dir is None:
|
||||
checkpoint_dir = "outputs/checkpoints"
|
||||
|
||||
return [
|
||||
ModelCheckpoint(
|
||||
dirpath="outputs/checkpoints",
|
||||
dirpath=str(checkpoint_dir),
|
||||
save_top_k=1,
|
||||
monitor="total_loss/val",
|
||||
),
|
||||
@ -154,15 +162,22 @@ def build_trainer_callbacks(
|
||||
def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
checkpoint_dir: Optional[data.PathLike] = None,
|
||||
log_dir: Optional[data.PathLike] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = conf.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
train_logger = build_logger(conf.train.logger)
|
||||
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
|
||||
|
||||
train_logger.log_hyperparams(conf.model_dump(mode="json"))
|
||||
train_logger.log_hyperparams(
|
||||
conf.model_dump(
|
||||
mode="json",
|
||||
exclude_none=True,
|
||||
)
|
||||
)
|
||||
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
@ -171,6 +186,7 @@ def build_trainer(
|
||||
targets,
|
||||
config=conf.evaluation,
|
||||
preprocessor=build_preprocessor(conf.preprocess),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user