mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Eval
This commit is contained in:
parent
c73984b213
commit
cd4955d4f3
@ -108,6 +108,9 @@ train:
|
|||||||
labels:
|
labels:
|
||||||
sigma: 3
|
sigma: 3
|
||||||
|
|
||||||
|
trainer:
|
||||||
|
max_epochs: 40
|
||||||
|
|
||||||
dataloaders:
|
dataloaders:
|
||||||
train:
|
train:
|
||||||
batch_size: 8
|
batch_size: 8
|
||||||
@ -115,7 +118,7 @@ train:
|
|||||||
shuffle: True
|
shuffle: True
|
||||||
|
|
||||||
val:
|
val:
|
||||||
batch_size: 8
|
batch_size: 1
|
||||||
num_workers: 2
|
num_workers: 2
|
||||||
|
|
||||||
loss:
|
loss:
|
||||||
@ -133,7 +136,7 @@ train:
|
|||||||
weight: 0.1
|
weight: 0.1
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
logger_type: tensorboard
|
logger_type: csv
|
||||||
# save_dir: outputs/log/
|
# save_dir: outputs/log/
|
||||||
# name: logs
|
# name: logs
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
|
from batdetect2.cli.evaluate import evaluate_command
|
||||||
from batdetect2.cli.train import train_command
|
from batdetect2.cli.train import train_command
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -8,6 +9,7 @@ __all__ = [
|
|||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train_command",
|
"train_command",
|
||||||
|
"evaluate_command",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
63
src/batdetect2/cli/evaluate.py
Normal file
63
src/batdetect2/cli/evaluate.py
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.evaluate.evaluate import evaluate
|
||||||
|
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||||
|
|
||||||
|
__all__ = ["evaluate_command"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command(name="evaluate")
|
||||||
|
@click.argument("model-path", type=click.Path(exists=True))
|
||||||
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
|
@click.option("--output-dir", type=click.Path())
|
||||||
|
@click.option("--workers", type=int)
|
||||||
|
@click.option(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
count=True,
|
||||||
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
|
)
|
||||||
|
def evaluate_command(
|
||||||
|
model_path: Path,
|
||||||
|
test_dataset: Path,
|
||||||
|
output_dir: Optional[Path] = None,
|
||||||
|
workers: Optional[int] = None,
|
||||||
|
verbose: int = 0,
|
||||||
|
):
|
||||||
|
logger.remove()
|
||||||
|
if verbose == 0:
|
||||||
|
log_level = "WARNING"
|
||||||
|
elif verbose == 1:
|
||||||
|
log_level = "INFO"
|
||||||
|
else:
|
||||||
|
log_level = "DEBUG"
|
||||||
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
|
test_annotations = load_dataset_from_config(test_dataset)
|
||||||
|
logger.debug(
|
||||||
|
"Loaded {num_annotations} test examples",
|
||||||
|
num_annotations=len(test_annotations),
|
||||||
|
)
|
||||||
|
|
||||||
|
model, train_config = load_model_from_checkpoint(model_path)
|
||||||
|
|
||||||
|
df, results = evaluate(
|
||||||
|
model,
|
||||||
|
test_annotations,
|
||||||
|
config=train_config,
|
||||||
|
num_workers=workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
df.to_csv(output_dir / "results.csv")
|
||||||
@ -20,6 +20,8 @@ __all__ = ["train_command"]
|
|||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
|
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||||
|
@click.option("--log-dir", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@ -34,6 +36,8 @@ def train_command(
|
|||||||
train_dataset: Path,
|
train_dataset: Path,
|
||||||
val_dataset: Optional[Path] = None,
|
val_dataset: Optional[Path] = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Optional[Path] = None,
|
||||||
|
ckpt_dir: Optional[Path] = None,
|
||||||
|
log_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -83,4 +87,6 @@ def train_command(
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
|
log_dir=log_dir,
|
||||||
|
checkpoint_dir=ckpt_dir,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -27,7 +27,7 @@ class BaseConfig(BaseModel):
|
|||||||
and serialization capabilities.
|
and serialization capabilities.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="ignore")
|
||||||
|
|
||||||
def to_yaml_string(
|
def to_yaml_string(
|
||||||
self,
|
self,
|
||||||
|
|||||||
62
src/batdetect2/evaluate/dataframe.py
Normal file
62
src/batdetect2/evaluate/dataframe.py
Normal file
@ -0,0 +1,62 @@
|
|||||||
|
from typing import List
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.typing.evaluate import MatchEvaluation
|
||||||
|
|
||||||
|
|
||||||
|
def extract_matches_dataframe(matches: List[MatchEvaluation]) -> pd.DataFrame:
|
||||||
|
data = []
|
||||||
|
|
||||||
|
for match in matches:
|
||||||
|
gt_start_time = gt_low_freq = gt_end_time = gt_high_freq = None
|
||||||
|
pred_start_time = pred_low_freq = pred_end_time = pred_high_freq = None
|
||||||
|
|
||||||
|
sound_event_annotation = match.sound_event_annotation
|
||||||
|
|
||||||
|
if sound_event_annotation is not None:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
assert geometry is not None
|
||||||
|
gt_start_time, gt_low_freq, gt_end_time, gt_high_freq = (
|
||||||
|
compute_bounds(geometry)
|
||||||
|
)
|
||||||
|
|
||||||
|
if match.pred_geometry is not None:
|
||||||
|
pred_start_time, pred_low_freq, pred_end_time, pred_high_freq = (
|
||||||
|
compute_bounds(match.pred_geometry)
|
||||||
|
)
|
||||||
|
|
||||||
|
data.append(
|
||||||
|
{
|
||||||
|
("recording", "uuid"): match.clip.recording.uuid,
|
||||||
|
("clip", "uuid"): match.clip.uuid,
|
||||||
|
("clip", "start_time"): match.clip.start_time,
|
||||||
|
("clip", "end_time"): match.clip.end_time,
|
||||||
|
("gt", "uuid"): match.sound_event_annotation.uuid
|
||||||
|
if match.sound_event_annotation is not None
|
||||||
|
else None,
|
||||||
|
("gt", "class"): match.gt_class,
|
||||||
|
("gt", "det"): match.gt_det,
|
||||||
|
("gt", "start_time"): gt_start_time,
|
||||||
|
("gt", "end_time"): gt_end_time,
|
||||||
|
("gt", "low_freq"): gt_low_freq,
|
||||||
|
("gt", "high_freq"): gt_high_freq,
|
||||||
|
("pred", "score"): match.pred_score,
|
||||||
|
("pred", "class"): match.pred_class,
|
||||||
|
("pred", "class_score"): match.pred_class_score,
|
||||||
|
("pred", "start_time"): pred_start_time,
|
||||||
|
("pred", "end_time"): pred_end_time,
|
||||||
|
("pred", "low_freq"): pred_low_freq,
|
||||||
|
("pred", "high_freq"): pred_high_freq,
|
||||||
|
("match", "affinity"): match.affinity,
|
||||||
|
**{
|
||||||
|
("pred_class_score", key): value
|
||||||
|
for key, value in match.pred_class_scores.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
df.columns = pd.MultiIndex.from_tuples(df.columns) # type: ignore
|
||||||
|
return df
|
||||||
100
src/batdetect2/evaluate/evaluate.py
Normal file
100
src/batdetect2/evaluate/evaluate.py
Normal file
@ -0,0 +1,100 @@
|
|||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.evaluate.dataframe import extract_matches_dataframe
|
||||||
|
from batdetect2.evaluate.match import match_all_predictions
|
||||||
|
from batdetect2.evaluate.metrics import (
|
||||||
|
ClassificationAccuracy,
|
||||||
|
ClassificationMeanAveragePrecision,
|
||||||
|
DetectionAveragePrecision,
|
||||||
|
)
|
||||||
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.plotting.clips import build_audio_loader
|
||||||
|
from batdetect2.postprocess import get_raw_predictions
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
|
from batdetect2.train.config import FullTrainingConfig
|
||||||
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
|
from batdetect2.train.train import build_val_loader
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate(
|
||||||
|
model: Model,
|
||||||
|
test_annotations: List[data.ClipAnnotation],
|
||||||
|
config: Optional[FullTrainingConfig] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> Tuple[pd.DataFrame, dict]:
|
||||||
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader(config.preprocess.audio)
|
||||||
|
|
||||||
|
preprocessor = build_preprocessor(config.preprocess)
|
||||||
|
|
||||||
|
targets = build_targets(config.targets)
|
||||||
|
|
||||||
|
labeller = build_clip_labeler(
|
||||||
|
targets,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
config=config.train.labels,
|
||||||
|
)
|
||||||
|
|
||||||
|
loader = build_val_loader(
|
||||||
|
test_annotations,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
labeller=labeller,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config.train,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset: ValidationDataset = loader.dataset # type: ignore
|
||||||
|
|
||||||
|
clip_annotations = []
|
||||||
|
predictions = []
|
||||||
|
|
||||||
|
for batch in loader:
|
||||||
|
outputs = model.detector(batch.spec)
|
||||||
|
|
||||||
|
clip_annotations = [
|
||||||
|
dataset.clip_annotations[int(example_idx)]
|
||||||
|
for example_idx in batch.idx
|
||||||
|
]
|
||||||
|
|
||||||
|
predictions = get_raw_predictions(
|
||||||
|
outputs,
|
||||||
|
clips=[
|
||||||
|
clip_annotation.clip for clip_annotation in clip_annotations
|
||||||
|
],
|
||||||
|
targets=targets,
|
||||||
|
postprocessor=model.postprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_annotations.extend(clip_annotations)
|
||||||
|
predictions.extend(predictions)
|
||||||
|
|
||||||
|
matches = match_all_predictions(
|
||||||
|
clip_annotations,
|
||||||
|
predictions,
|
||||||
|
targets=targets,
|
||||||
|
config=config.evaluation.match,
|
||||||
|
)
|
||||||
|
|
||||||
|
df = extract_matches_dataframe(matches)
|
||||||
|
|
||||||
|
metrics = [
|
||||||
|
DetectionAveragePrecision(),
|
||||||
|
ClassificationMeanAveragePrecision(class_names=targets.class_names),
|
||||||
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
|
]
|
||||||
|
|
||||||
|
results = {
|
||||||
|
name: value
|
||||||
|
for metric in metrics
|
||||||
|
for name, value in metric(matches).items()
|
||||||
|
}
|
||||||
|
|
||||||
|
return df, results
|
||||||
@ -29,7 +29,6 @@ provided here.
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lightning import LightningModule
|
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
@ -105,7 +104,16 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class Model(LightningModule):
|
class ModelConfig(BaseConfig):
|
||||||
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
@ -117,13 +125,14 @@ class Model(LightningModule):
|
|||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
config: ModelConfig,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.targets = targets
|
||||||
self.save_hyperparameters()
|
self.config = config
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> List[DetectionsTensor]:
|
||||||
spec = self.preprocessor(wav)
|
spec = self.preprocessor(wav)
|
||||||
@ -131,29 +140,24 @@ class Model(LightningModule):
|
|||||||
return self.postprocessor(outputs)
|
return self.postprocessor(outputs)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(BaseConfig):
|
|
||||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
|
|
||||||
|
|
||||||
def build_model(config: Optional[ModelConfig] = None):
|
def build_model(config: Optional[ModelConfig] = None):
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
|
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.targets)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(config=config.preprocess)
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
config=config.model,
|
config=config.model,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
|
config=config,
|
||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
|||||||
@ -6,7 +6,6 @@ from soundevent import data
|
|||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.targets import TargetConfig
|
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
DEFAULT_AUGMENTATION_CONFIG,
|
DEFAULT_AUGMENTATION_CONFIG,
|
||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
@ -75,7 +74,6 @@ class TrainingConfig(BaseConfig):
|
|||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,9 +1,14 @@
|
|||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import lightning as L
|
import lightning as L
|
||||||
import torch
|
import torch
|
||||||
|
from soundevent.data import PathLike
|
||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model, build_model
|
||||||
|
from batdetect2.train.config import FullTrainingConfig
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.typing import ModelOutput, TrainExample
|
from batdetect2.typing import ModelOutput, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -16,22 +21,28 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
config: FullTrainingConfig,
|
||||||
loss: torch.nn.Module,
|
|
||||||
learning_rate: float = 0.001,
|
learning_rate: float = 0.001,
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
|
model: Optional[Model] = None,
|
||||||
|
loss: Optional[torch.nn.Module] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.save_hyperparameters(logger=False)
|
||||||
|
|
||||||
|
self.config = config
|
||||||
self.learning_rate = learning_rate
|
self.learning_rate = learning_rate
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
|
if loss is None:
|
||||||
|
loss = build_loss(self.config.train.loss)
|
||||||
|
|
||||||
|
if model is None:
|
||||||
|
model = build_model(self.config)
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.model = model
|
self.model = model
|
||||||
self.save_hyperparameters(logger=False)
|
|
||||||
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput:
|
|
||||||
return self.model.detector(spec)
|
|
||||||
|
|
||||||
def training_step(self, batch: TrainExample):
|
def training_step(self, batch: TrainExample):
|
||||||
outputs = self.model.detector(batch.spec)
|
outputs = self.model.detector(batch.spec)
|
||||||
@ -59,3 +70,10 @@ class TrainingModule(L.LightningModule):
|
|||||||
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
optimizer = Adam(self.parameters(), lr=self.learning_rate)
|
||||||
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
scheduler = CosineAnnealingLR(optimizer, T_max=self.t_max)
|
||||||
return [optimizer], [scheduler]
|
return [optimizer], [scheduler]
|
||||||
|
|
||||||
|
|
||||||
|
def load_model_from_checkpoint(
|
||||||
|
path: PathLike,
|
||||||
|
) -> Tuple[Model, FullTrainingConfig]:
|
||||||
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
|
return module.model, module.config
|
||||||
|
|||||||
@ -5,10 +5,11 @@ import numpy as np
|
|||||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: str = "logs"
|
DEFAULT_LOGS_DIR: str = "outputs"
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseConfig):
|
class DVCLiveConfig(BaseConfig):
|
||||||
@ -31,7 +32,7 @@ class CSVLoggerConfig(BaseConfig):
|
|||||||
class TensorBoardLoggerConfig(BaseConfig):
|
class TensorBoardLoggerConfig(BaseConfig):
|
||||||
logger_type: Literal["tensorboard"] = "tensorboard"
|
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||||
save_dir: str = DEFAULT_LOGS_DIR
|
save_dir: str = DEFAULT_LOGS_DIR
|
||||||
name: Optional[str] = "default"
|
name: Optional[str] = "logs"
|
||||||
version: Optional[str] = None
|
version: Optional[str] = None
|
||||||
log_graph: bool = False
|
log_graph: bool = False
|
||||||
|
|
||||||
@ -57,7 +58,10 @@ LoggerConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
def create_dvclive_logger(
|
||||||
|
config: DVCLiveConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
@ -68,7 +72,7 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
|||||||
) from error
|
) from error
|
||||||
|
|
||||||
return DVCLiveLogger(
|
return DVCLiveLogger(
|
||||||
dir=config.dir,
|
dir=log_dir if log_dir is not None else config.dir,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
@ -76,29 +80,38 @@ def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
def create_csv_logger(
|
||||||
|
config: CSVLoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
return CSVLogger(
|
return CSVLogger(
|
||||||
save_dir=config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
def create_tensorboard_logger(
|
||||||
|
config: TensorBoardLoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
return TensorBoardLogger(
|
return TensorBoardLogger(
|
||||||
save_dir=config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
name=config.name,
|
name=config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
log_graph=config.log_graph,
|
log_graph=config.log_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
def create_mlflow_logger(
|
||||||
|
config: MLFlowLoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
except ImportError as error:
|
except ImportError as error:
|
||||||
@ -111,7 +124,7 @@ def create_mlflow_logger(config: MLFlowLoggerConfig) -> Logger:
|
|||||||
return MLFlowLogger(
|
return MLFlowLogger(
|
||||||
experiment_name=config.experiment_name,
|
experiment_name=config.experiment_name,
|
||||||
run_name=config.run_name,
|
run_name=config.run_name,
|
||||||
save_dir=config.save_dir,
|
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
||||||
tracking_uri=config.tracking_uri,
|
tracking_uri=config.tracking_uri,
|
||||||
tags=config.tags,
|
tags=config.tags,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
@ -126,7 +139,10 @@ LOGGER_FACTORY = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_logger(config: LoggerConfig) -> Logger:
|
def build_logger(
|
||||||
|
config: LoggerConfig,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
) -> Logger:
|
||||||
"""
|
"""
|
||||||
Creates a logger instance from a validated Pydantic config object.
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
"""
|
"""
|
||||||
@ -141,7 +157,7 @@ def build_logger(config: LoggerConfig) -> Logger:
|
|||||||
|
|
||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config)
|
return creation_func(config, log_dir=log_dir)
|
||||||
|
|
||||||
|
|
||||||
def get_image_plotter(logger: Logger):
|
def get_image_plotter(logger: Logger):
|
||||||
|
|||||||
@ -14,9 +14,9 @@ from batdetect2.evaluate.metrics import (
|
|||||||
ClassificationMeanAveragePrecision,
|
ClassificationMeanAveragePrecision,
|
||||||
DetectionAveragePrecision,
|
DetectionAveragePrecision,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, build_model
|
|
||||||
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
from batdetect2.plotting.clips import AudioLoader, build_audio_loader
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
RandomAudioSource,
|
RandomAudioSource,
|
||||||
build_augmentations,
|
build_augmentations,
|
||||||
@ -28,7 +28,6 @@ from batdetect2.train.dataset import TrainingDataset, ValidationDataset
|
|||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.logging import build_logger
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.train.losses import build_loss
|
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
@ -54,19 +53,21 @@ def train(
|
|||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
):
|
):
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
model = build_model(config=config)
|
targets = build_targets(config.targets)
|
||||||
|
|
||||||
trainer = build_trainer(config, targets=model.targets)
|
preprocessor = build_preprocessor(config.preprocess)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
audio_loader = build_audio_loader(config=config.preprocess.audio)
|
||||||
|
|
||||||
labeller = build_clip_labeler(
|
labeller = build_clip_labeler(
|
||||||
model.targets,
|
targets,
|
||||||
min_freq=model.preprocessor.min_freq,
|
min_freq=preprocessor.min_freq,
|
||||||
max_freq=model.preprocessor.max_freq,
|
max_freq=preprocessor.max_freq,
|
||||||
config=config.train.labels,
|
config=config.train.labels,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -74,7 +75,7 @@ def train(
|
|||||||
train_annotations,
|
train_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=build_preprocessor(config.preprocess),
|
preprocessor=preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
@ -84,7 +85,7 @@ def train(
|
|||||||
val_annotations,
|
val_annotations,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
labeller=labeller,
|
labeller=labeller,
|
||||||
preprocessor=build_preprocessor(config.preprocess),
|
preprocessor=preprocessor,
|
||||||
config=config.train,
|
config=config.train,
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
@ -97,11 +98,17 @@ def train(
|
|||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
else:
|
else:
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model,
|
|
||||||
config,
|
config,
|
||||||
t_max=config.train.t_max * len(train_dataloader),
|
t_max=config.train.t_max * len(train_dataloader),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trainer = build_trainer(
|
||||||
|
config,
|
||||||
|
targets=targets,
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
log_dir=log_dir,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
@ -112,15 +119,12 @@ def train(
|
|||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model: Model,
|
|
||||||
config: Optional[FullTrainingConfig] = None,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
t_max: int = 200,
|
t_max: int = 200,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
loss = build_loss(config=config.train.loss)
|
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model=model,
|
config=config,
|
||||||
loss=loss,
|
|
||||||
learning_rate=config.train.learning_rate,
|
learning_rate=config.train.learning_rate,
|
||||||
t_max=t_max,
|
t_max=t_max,
|
||||||
)
|
)
|
||||||
@ -130,10 +134,14 @@ def build_trainer_callbacks(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: EvaluationConfig,
|
config: EvaluationConfig,
|
||||||
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
) -> List[Callback]:
|
) -> List[Callback]:
|
||||||
|
if checkpoint_dir is None:
|
||||||
|
checkpoint_dir = "outputs/checkpoints"
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
ModelCheckpoint(
|
||||||
dirpath="outputs/checkpoints",
|
dirpath=str(checkpoint_dir),
|
||||||
save_top_k=1,
|
save_top_k=1,
|
||||||
monitor="total_loss/val",
|
monitor="total_loss/val",
|
||||||
),
|
),
|
||||||
@ -154,15 +162,22 @@ def build_trainer_callbacks(
|
|||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: FullTrainingConfig,
|
conf: FullTrainingConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
|
checkpoint_dir: Optional[data.PathLike] = None,
|
||||||
|
log_dir: Optional[data.PathLike] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.train.trainer
|
trainer_conf = conf.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
train_logger = build_logger(conf.train.logger)
|
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
|
||||||
|
|
||||||
train_logger.log_hyperparams(conf.model_dump(mode="json"))
|
train_logger.log_hyperparams(
|
||||||
|
conf.model_dump(
|
||||||
|
mode="json",
|
||||||
|
exclude_none=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
@ -171,6 +186,7 @@ def build_trainer(
|
|||||||
targets,
|
targets,
|
||||||
config=conf.evaluation,
|
config=conf.evaluation,
|
||||||
preprocessor=build_preprocessor(conf.preprocess),
|
preprocessor=build_preprocessor(conf.preprocess),
|
||||||
|
checkpoint_dir=checkpoint_dir,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user