mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
api_v2
This commit is contained in:
parent
981e37c346
commit
2f48c58de1
@ -8,13 +8,18 @@ from soundevent.audio.files import get_audio_files
|
|||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.evaluate import build_evaluator, evaluate
|
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train import load_model_from_checkpoint, train
|
from batdetect2.train import (
|
||||||
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
|
load_model_from_checkpoint,
|
||||||
|
train,
|
||||||
|
)
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
BatDetect2Prediction,
|
BatDetect2Prediction,
|
||||||
@ -53,8 +58,8 @@ class BatDetect2API:
|
|||||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
checkpoint_dir: Optional[Path] = None,
|
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
@ -80,7 +85,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
output_dir: data.PathLike = ".",
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
from batdetect2.evaluate.evaluate import evaluate
|
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||||
|
|
||||||
@ -11,4 +11,5 @@ __all__ = [
|
|||||||
"build_task",
|
"build_task",
|
||||||
"evaluate",
|
"evaluate",
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
|
"DEFAULT_EVAL_DIR",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
|||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
|
||||||
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
|
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def evaluate(
|
||||||
@ -32,7 +32,7 @@ def evaluate(
|
|||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -43,6 +43,9 @@ def average_precision(
|
|||||||
y_score,
|
y_score,
|
||||||
num_positives: Optional[int] = None,
|
num_positives: Optional[int] = None,
|
||||||
) -> float:
|
) -> float:
|
||||||
|
if num_positives == 0:
|
||||||
|
return np.nan
|
||||||
|
|
||||||
precision, recall, _ = compute_precision_recall(
|
precision, recall, _ = compute_precision_recall(
|
||||||
y_true,
|
y_true,
|
||||||
y_score,
|
y_score,
|
||||||
|
|||||||
@ -51,7 +51,7 @@ def run_batch_inference(
|
|||||||
)
|
)
|
||||||
|
|
||||||
module = InferenceModule(model)
|
module = InferenceModule(model)
|
||||||
trainer = Trainer(enable_checkpointing=False)
|
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||||
outputs = trainer.predict(module, loader)
|
outputs = trainer.predict(module, loader)
|
||||||
return [
|
return [
|
||||||
clip_prediction
|
clip_prediction
|
||||||
|
|||||||
@ -221,20 +221,16 @@ def build_logger(
|
|||||||
experiment_name: Optional[str] = None,
|
experiment_name: Optional[str] = None,
|
||||||
run_name: Optional[str] = None,
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
"""
|
|
||||||
Creates a logger instance from a validated Pydantic config object.
|
|
||||||
"""
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building logger with config: \n{}",
|
"Building logger with config: \n{}",
|
||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
logger_type = config.name
|
|
||||||
|
|
||||||
|
logger_type = config.name
|
||||||
if logger_type not in LOGGER_FACTORY:
|
if logger_type not in LOGGER_FACTORY:
|
||||||
raise ValueError(f"Unknown logger type: {logger_type}")
|
raise ValueError(f"Unknown logger type: {logger_type}")
|
||||||
|
|
||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(
|
return creation_func(
|
||||||
config,
|
config,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
|
|||||||
@ -1,81 +1,19 @@
|
|||||||
from batdetect2.train.augmentations import (
|
|
||||||
AddEchoConfig,
|
|
||||||
AugmentationsConfig,
|
|
||||||
MaskFrequencyConfig,
|
|
||||||
MaskTimeConfig,
|
|
||||||
RandomAudioSource,
|
|
||||||
ScaleVolumeConfig,
|
|
||||||
WarpConfig,
|
|
||||||
add_echo,
|
|
||||||
build_augmentations,
|
|
||||||
mask_frequency,
|
|
||||||
mask_time,
|
|
||||||
mix_audio,
|
|
||||||
scale_volume,
|
|
||||||
warp_spectrogram,
|
|
||||||
)
|
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
PLTrainerConfig,
|
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
load_train_config,
|
load_train_config,
|
||||||
)
|
)
|
||||||
from batdetect2.train.dataset import (
|
|
||||||
TrainingDataset,
|
|
||||||
ValidationDataset,
|
|
||||||
build_train_dataset,
|
|
||||||
build_train_loader,
|
|
||||||
build_val_dataset,
|
|
||||||
build_val_loader,
|
|
||||||
)
|
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
|
||||||
from batdetect2.train.lightning import (
|
from batdetect2.train.lightning import (
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
load_model_from_checkpoint,
|
load_model_from_checkpoint,
|
||||||
)
|
)
|
||||||
from batdetect2.train.losses import (
|
from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train
|
||||||
ClassificationLossConfig,
|
|
||||||
DetectionLossConfig,
|
|
||||||
LossConfig,
|
|
||||||
LossFunction,
|
|
||||||
SizeLossConfig,
|
|
||||||
build_loss,
|
|
||||||
)
|
|
||||||
from batdetect2.train.train import build_trainer, train
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AugmentationsConfig",
|
"DEFAULT_CHECKPOINT_DIR",
|
||||||
"ClassificationLossConfig",
|
|
||||||
"DetectionLossConfig",
|
|
||||||
"AddEchoConfig",
|
|
||||||
"MaskFrequencyConfig",
|
|
||||||
"LossConfig",
|
|
||||||
"LossFunction",
|
|
||||||
"PLTrainerConfig",
|
|
||||||
"RandomAudioSource",
|
|
||||||
"SizeLossConfig",
|
|
||||||
"MaskTimeConfig",
|
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"TrainingDataset",
|
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
"ValidationDataset",
|
|
||||||
"ScaleVolumeConfig",
|
|
||||||
"WarpConfig",
|
|
||||||
"add_echo",
|
|
||||||
"build_augmentations",
|
|
||||||
"build_clip_labeler",
|
|
||||||
"build_loss",
|
|
||||||
"build_train_dataset",
|
|
||||||
"build_train_loader",
|
|
||||||
"build_trainer",
|
"build_trainer",
|
||||||
"build_val_dataset",
|
|
||||||
"build_val_loader",
|
|
||||||
"load_label_config",
|
|
||||||
"load_train_config",
|
|
||||||
"mask_frequency",
|
|
||||||
"mask_time",
|
|
||||||
"mix_audio",
|
|
||||||
"scale_volume",
|
|
||||||
"train",
|
|
||||||
"warp_spectrogram",
|
|
||||||
"load_model_from_checkpoint",
|
"load_model_from_checkpoint",
|
||||||
|
"load_train_config",
|
||||||
|
"train",
|
||||||
]
|
]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user