mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Starting new API
This commit is contained in:
parent
bbb96b33a2
commit
957c0735d2
148
src/batdetect2/api/base.py
Normal file
148
src/batdetect2/api/base.py
Normal file
@ -0,0 +1,148 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.evaluate import Evaluator, build_evaluator
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.train import train
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
def __init__(
|
||||
self,
|
||||
config: BatDetect2Config,
|
||||
targets: TargetProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
postprocessor: PostprocessorProtocol,
|
||||
evaluator: Evaluator,
|
||||
model: Model,
|
||||
):
|
||||
self.config = config
|
||||
self.targets = targets
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
self.postprocessor = postprocessor
|
||||
self.evaluator = evaluator
|
||||
self.model = model
|
||||
|
||||
self.model.eval()
|
||||
|
||||
def train(
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
config=self.config,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BatDetect2Config):
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
# NOTE: Better to have a separate instance of
|
||||
# preprocessor and postprocessor as these may be moved
|
||||
# to another device.
|
||||
model = build_model(
|
||||
config=config.model,
|
||||
targets=targets,
|
||||
preprocessor=build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
),
|
||||
postprocessor=build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
),
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, path: data.PathLike):
|
||||
model, config = load_model_from_checkpoint(path)
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.postprocess,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
return cls(
|
||||
config=config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
evaluator=evaluator,
|
||||
model=model,
|
||||
)
|
||||
@ -46,13 +46,13 @@ def train_command(
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api.base import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
BatDetect2Config,
|
||||
load_full_config,
|
||||
)
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.train import train
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
@ -62,11 +62,9 @@ def train_command(
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading training configuration...")
|
||||
|
||||
logger.info("Loading configuration...")
|
||||
conf = (
|
||||
load_full_config(config, field=config_field)
|
||||
if config is not None
|
||||
@ -98,16 +96,19 @@ def train_command(
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
|
||||
train(
|
||||
if model_path is None:
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(model_path)
|
||||
|
||||
return api.train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
config=conf,
|
||||
model_path=model_path,
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
experiment_name=experiment_name,
|
||||
log_dir=log_dir,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
seed=seed,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
)
|
||||
|
||||
0
src/batdetect2/inference/__init__.py
Normal file
0
src/batdetect2/inference/__init__.py
Normal file
@ -15,7 +15,7 @@ from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import build_train_loader, build_val_loader
|
||||
from batdetect2.train.labels import build_clip_labeler
|
||||
from batdetect2.train.lightning import TrainingModule, build_training_module
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
from batdetect2.train.logging import build_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -38,14 +38,12 @@ DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
def train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
evaluator: Optional[Evaluator] = None,
|
||||
trainer: Optional[Trainer] = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
labeller: Optional["ClipLabeller"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
trainer: Optional[Trainer] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
@ -99,10 +97,6 @@ def train(
|
||||
else None
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
logger.debug("Loading model from: {path}", path=model_path)
|
||||
module = TrainingModule.load_from_checkpoint(Path(model_path))
|
||||
else:
|
||||
module = build_training_module(
|
||||
config,
|
||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||
@ -111,7 +105,7 @@ def train(
|
||||
trainer = trainer or build_trainer(
|
||||
config.train,
|
||||
targets=targets,
|
||||
evaluator=evaluator,
|
||||
evaluator=build_evaluator(config.evaluation, targets=targets),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user