From 957c0735d21ab4bf196ff35ae876c84b0c15a858 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 16 Sep 2025 19:39:30 +0100 Subject: [PATCH] Starting new API --- src/batdetect2/{api.py => api/__init__.py} | 0 src/batdetect2/api/base.py | 148 +++++++++++++++++++++ src/batdetect2/cli/train.py | 21 +-- src/batdetect2/inference/__init__.py | 0 src/batdetect2/train/train.py | 20 +-- 5 files changed, 166 insertions(+), 23 deletions(-) rename src/batdetect2/{api.py => api/__init__.py} (100%) create mode 100644 src/batdetect2/api/base.py create mode 100644 src/batdetect2/inference/__init__.py diff --git a/src/batdetect2/api.py b/src/batdetect2/api/__init__.py similarity index 100% rename from src/batdetect2/api.py rename to src/batdetect2/api/__init__.py diff --git a/src/batdetect2/api/base.py b/src/batdetect2/api/base.py new file mode 100644 index 0000000..df0cae0 --- /dev/null +++ b/src/batdetect2/api/base.py @@ -0,0 +1,148 @@ +from pathlib import Path +from typing import Optional, Sequence + +from soundevent import data + +from batdetect2.audio import build_audio_loader +from batdetect2.config import BatDetect2Config +from batdetect2.evaluate import Evaluator, build_evaluator +from batdetect2.models import Model, build_model +from batdetect2.postprocess import build_postprocessor +from batdetect2.preprocess import build_preprocessor +from batdetect2.targets.targets import build_targets +from batdetect2.train import train +from batdetect2.train.lightning import load_model_from_checkpoint +from batdetect2.typing import ( + AudioLoader, + PostprocessorProtocol, + PreprocessorProtocol, + TargetProtocol, +) + + +class BatDetect2API: + def __init__( + self, + config: BatDetect2Config, + targets: TargetProtocol, + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + postprocessor: PostprocessorProtocol, + evaluator: Evaluator, + model: Model, + ): + self.config = config + self.targets = targets + self.audio_loader = audio_loader + self.preprocessor = preprocessor + self.postprocessor = postprocessor + self.evaluator = evaluator + self.model = model + + self.model.eval() + + def train( + self, + train_annotations: Sequence[data.ClipAnnotation], + val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, + train_workers: Optional[int] = None, + val_workers: Optional[int] = None, + checkpoint_dir: Optional[Path] = None, + log_dir: Optional[Path] = None, + experiment_name: Optional[str] = None, + run_name: Optional[str] = None, + seed: Optional[int] = None, + ): + train( + train_annotations=train_annotations, + val_annotations=val_annotations, + config=self.config, + train_workers=train_workers, + val_workers=val_workers, + checkpoint_dir=checkpoint_dir, + log_dir=log_dir, + experiment_name=experiment_name, + run_name=run_name, + seed=seed, + ) + return self + + @classmethod + def from_config(cls, config: BatDetect2Config): + targets = build_targets(config=config.targets) + + audio_loader = build_audio_loader(config=config.audio) + + preprocessor = build_preprocessor( + input_samplerate=audio_loader.samplerate, + config=config.preprocess, + ) + + postprocessor = build_postprocessor( + preprocessor, + config=config.postprocess, + ) + + evaluator = build_evaluator( + config=config.evaluation, + targets=targets, + ) + + # NOTE: Better to have a separate instance of + # preprocessor and postprocessor as these may be moved + # to another device. + model = build_model( + config=config.model, + targets=targets, + preprocessor=build_preprocessor( + input_samplerate=audio_loader.samplerate, + config=config.preprocess, + ), + postprocessor=build_postprocessor( + preprocessor, + config=config.postprocess, + ), + ) + + return cls( + config=config, + targets=targets, + audio_loader=audio_loader, + preprocessor=preprocessor, + postprocessor=postprocessor, + evaluator=evaluator, + model=model, + ) + + @classmethod + def from_checkpoint(cls, path: data.PathLike): + model, config = load_model_from_checkpoint(path) + + targets = build_targets(config=config.targets) + + audio_loader = build_audio_loader(config=config.audio) + + preprocessor = build_preprocessor( + input_samplerate=audio_loader.samplerate, + config=config.preprocess, + ) + + postprocessor = build_postprocessor( + preprocessor, + config=config.postprocess, + ) + + evaluator = build_evaluator( + config=config.evaluation, + targets=targets, + ) + + return cls( + config=config, + targets=targets, + audio_loader=audio_loader, + preprocessor=preprocessor, + postprocessor=postprocessor, + evaluator=evaluator, + model=model, + ) diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index 77e17df..64cf28b 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -46,13 +46,13 @@ def train_command( run_name: Optional[str] = None, verbose: int = 0, ): + from batdetect2.api.base import BatDetect2API from batdetect2.config import ( BatDetect2Config, load_full_config, ) from batdetect2.data import load_dataset_from_config from batdetect2.targets import load_target_config - from batdetect2.train import train logger.remove() if verbose == 0: @@ -62,11 +62,9 @@ def train_command( else: log_level = "DEBUG" logger.add(sys.stderr, level=log_level) - logger.info("Initiating training process...") - logger.info("Loading training configuration...") - + logger.info("Loading configuration...") conf = ( load_full_config(config, field=config_field) if config is not None @@ -98,16 +96,19 @@ def train_command( logger.info("Configuration and data loaded. Starting training...") - train( + if model_path is None: + api = BatDetect2API.from_config(conf) + else: + api = BatDetect2API.from_checkpoint(model_path) + + return api.train( train_annotations=train_annotations, val_annotations=val_annotations, - config=conf, - model_path=model_path, train_workers=train_workers, val_workers=val_workers, - experiment_name=experiment_name, - log_dir=log_dir, checkpoint_dir=ckpt_dir, - seed=seed, + log_dir=log_dir, + experiment_name=experiment_name, run_name=run_name, + seed=seed, ) diff --git a/src/batdetect2/inference/__init__.py b/src/batdetect2/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index dcdfff0..4247b2b 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -15,7 +15,7 @@ from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.config import TrainingConfig from batdetect2.train.dataset import build_train_loader, build_val_loader from batdetect2.train.labels import build_clip_labeler -from batdetect2.train.lightning import TrainingModule, build_training_module +from batdetect2.train.lightning import build_training_module from batdetect2.train.logging import build_logger if TYPE_CHECKING: @@ -38,14 +38,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints" def train( train_annotations: Sequence[data.ClipAnnotation], val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, - evaluator: Optional[Evaluator] = None, - trainer: Optional[Trainer] = None, targets: Optional["TargetProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, audio_loader: Optional["AudioLoader"] = None, labeller: Optional["ClipLabeller"] = None, config: Optional["BatDetect2Config"] = None, - model_path: Optional[data.PathLike] = None, + trainer: Optional[Trainer] = None, train_workers: Optional[int] = None, val_workers: Optional[int] = None, checkpoint_dir: Optional[Path] = None, @@ -99,19 +97,15 @@ def train( else None ) - if model_path is not None: - logger.debug("Loading model from: {path}", path=model_path) - module = TrainingModule.load_from_checkpoint(Path(model_path)) - else: - module = build_training_module( - config, - t_max=config.train.optimizer.t_max * len(train_dataloader), - ) + module = build_training_module( + config, + t_max=config.train.optimizer.t_max * len(train_dataloader), + ) trainer = trainer or build_trainer( config.train, targets=targets, - evaluator=evaluator, + evaluator=build_evaluator(config.evaluation, targets=targets), checkpoint_dir=checkpoint_dir, log_dir=log_dir, experiment_name=experiment_name,