Starting new API

This commit is contained in:
mbsantiago 2025-09-16 19:39:30 +01:00
parent bbb96b33a2
commit 957c0735d2
5 changed files with 166 additions and 23 deletions

148
src/batdetect2/api/base.py Normal file
View File

@ -0,0 +1,148 @@
from pathlib import Path
from typing import Optional, Sequence
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import Evaluator, build_evaluator
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.train import train
from batdetect2.train.lightning import load_model_from_checkpoint
from batdetect2.typing import (
AudioLoader,
PostprocessorProtocol,
PreprocessorProtocol,
TargetProtocol,
)
class BatDetect2API:
def __init__(
self,
config: BatDetect2Config,
targets: TargetProtocol,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
postprocessor: PostprocessorProtocol,
evaluator: Evaluator,
model: Model,
):
self.config = config
self.targets = targets
self.audio_loader = audio_loader
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.evaluator = evaluator
self.model = model
self.model.eval()
def train(
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
train(
train_annotations=train_annotations,
val_annotations=val_annotations,
config=self.config,
train_workers=train_workers,
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
)
return self
@classmethod
def from_config(cls, config: BatDetect2Config):
targets = build_targets(config=config.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
)
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
)
# NOTE: Better to have a separate instance of
# preprocessor and postprocessor as these may be moved
# to another device.
model = build_model(
config=config.model,
targets=targets,
preprocessor=build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
),
postprocessor=build_postprocessor(
preprocessor,
config=config.postprocess,
),
)
return cls(
config=config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
)
@classmethod
def from_checkpoint(cls, path: data.PathLike):
model, config = load_model_from_checkpoint(path)
targets = build_targets(config=config.targets)
audio_loader = build_audio_loader(config=config.audio)
preprocessor = build_preprocessor(
input_samplerate=audio_loader.samplerate,
config=config.preprocess,
)
postprocessor = build_postprocessor(
preprocessor,
config=config.postprocess,
)
evaluator = build_evaluator(
config=config.evaluation,
targets=targets,
)
return cls(
config=config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
postprocessor=postprocessor,
evaluator=evaluator,
model=model,
)

View File

@ -46,13 +46,13 @@ def train_command(
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api.base import BatDetect2API
from batdetect2.config import (
BatDetect2Config,
load_full_config,
)
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
from batdetect2.train import train
logger.remove()
if verbose == 0:
@ -62,11 +62,9 @@ def train_command(
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...")
logger.info("Loading training configuration...")
logger.info("Loading configuration...")
conf = (
load_full_config(config, field=config_field)
if config is not None
@ -98,16 +96,19 @@ def train_command(
logger.info("Configuration and data loaded. Starting training...")
train(
if model_path is None:
api = BatDetect2API.from_config(conf)
else:
api = BatDetect2API.from_checkpoint(model_path)
return api.train(
train_annotations=train_annotations,
val_annotations=val_annotations,
config=conf,
model_path=model_path,
train_workers=train_workers,
val_workers=val_workers,
experiment_name=experiment_name,
log_dir=log_dir,
checkpoint_dir=ckpt_dir,
seed=seed,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
)

View File

View File

@ -15,7 +15,7 @@ from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import build_train_loader, build_val_loader
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.lightning import TrainingModule, build_training_module
from batdetect2.train.lightning import build_training_module
from batdetect2.train.logging import build_logger
if TYPE_CHECKING:
@ -38,14 +38,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
def train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
evaluator: Optional[Evaluator] = None,
trainer: Optional[Trainer] = None,
targets: Optional["TargetProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
labeller: Optional["ClipLabeller"] = None,
config: Optional["BatDetect2Config"] = None,
model_path: Optional[data.PathLike] = None,
trainer: Optional[Trainer] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = None,
@ -99,19 +97,15 @@ def train(
else None
)
if model_path is not None:
logger.debug("Loading model from: {path}", path=model_path)
module = TrainingModule.load_from_checkpoint(Path(model_path))
else:
module = build_training_module(
config,
t_max=config.train.optimizer.t_max * len(train_dataloader),
)
module = build_training_module(
config,
t_max=config.train.optimizer.t_max * len(train_dataloader),
)
trainer = trainer or build_trainer(
config.train,
targets=targets,
evaluator=evaluator,
evaluator=build_evaluator(config.evaluation, targets=targets),
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
experiment_name=experiment_name,