mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
api_v2
This commit is contained in:
parent
981e37c346
commit
2f48c58de1
@ -8,13 +8,18 @@ from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.evaluate import build_evaluator, evaluate
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train import load_model_from_checkpoint, train
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
load_model_from_checkpoint,
|
||||
train,
|
||||
)
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
BatDetect2Prediction,
|
||||
@ -53,8 +58,8 @@ class BatDetect2API:
|
||||
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
@ -80,7 +85,7 @@ class BatDetect2API:
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
num_workers: Optional[int] = None,
|
||||
output_dir: data.PathLike = ".",
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import evaluate
|
||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||
|
||||
@ -11,4 +11,5 @@ __all__ = [
|
||||
"build_task",
|
||||
"evaluate",
|
||||
"load_evaluation_config",
|
||||
"DEFAULT_EVAL_DIR",
|
||||
]
|
||||
|
||||
@ -21,7 +21,7 @@ if TYPE_CHECKING:
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def evaluate(
|
||||
@ -32,7 +32,7 @@ def evaluate(
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
):
|
||||
|
||||
@ -43,6 +43,9 @@ def average_precision(
|
||||
y_score,
|
||||
num_positives: Optional[int] = None,
|
||||
) -> float:
|
||||
if num_positives == 0:
|
||||
return np.nan
|
||||
|
||||
precision, recall, _ = compute_precision_recall(
|
||||
y_true,
|
||||
y_score,
|
||||
|
||||
@ -51,7 +51,7 @@ def run_batch_inference(
|
||||
)
|
||||
|
||||
module = InferenceModule(model)
|
||||
trainer = Trainer(enable_checkpointing=False)
|
||||
trainer = Trainer(enable_checkpointing=False, logger=False)
|
||||
outputs = trainer.predict(module, loader)
|
||||
return [
|
||||
clip_prediction
|
||||
|
||||
@ -221,20 +221,16 @@ def build_logger(
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
) -> Logger:
|
||||
"""
|
||||
Creates a logger instance from a validated Pydantic config object.
|
||||
"""
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building logger with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
logger_type = config.name
|
||||
|
||||
logger_type = config.name
|
||||
if logger_type not in LOGGER_FACTORY:
|
||||
raise ValueError(f"Unknown logger type: {logger_type}")
|
||||
|
||||
creation_func = LOGGER_FACTORY[logger_type]
|
||||
|
||||
return creation_func(
|
||||
config,
|
||||
log_dir=log_dir,
|
||||
|
||||
@ -1,81 +1,19 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AddEchoConfig,
|
||||
AugmentationsConfig,
|
||||
MaskFrequencyConfig,
|
||||
MaskTimeConfig,
|
||||
RandomAudioSource,
|
||||
ScaleVolumeConfig,
|
||||
WarpConfig,
|
||||
add_echo,
|
||||
build_augmentations,
|
||||
mask_frequency,
|
||||
mask_time,
|
||||
mix_audio,
|
||||
scale_volume,
|
||||
warp_spectrogram,
|
||||
)
|
||||
from batdetect2.train.config import (
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
load_train_config,
|
||||
)
|
||||
from batdetect2.train.dataset import (
|
||||
TrainingDataset,
|
||||
ValidationDataset,
|
||||
build_train_dataset,
|
||||
build_train_loader,
|
||||
build_val_dataset,
|
||||
build_val_loader,
|
||||
)
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
)
|
||||
from batdetect2.train.losses import (
|
||||
ClassificationLossConfig,
|
||||
DetectionLossConfig,
|
||||
LossConfig,
|
||||
LossFunction,
|
||||
SizeLossConfig,
|
||||
build_loss,
|
||||
)
|
||||
from batdetect2.train.train import build_trainer, train
|
||||
from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train
|
||||
|
||||
__all__ = [
|
||||
"AugmentationsConfig",
|
||||
"ClassificationLossConfig",
|
||||
"DetectionLossConfig",
|
||||
"AddEchoConfig",
|
||||
"MaskFrequencyConfig",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"PLTrainerConfig",
|
||||
"RandomAudioSource",
|
||||
"SizeLossConfig",
|
||||
"MaskTimeConfig",
|
||||
"DEFAULT_CHECKPOINT_DIR",
|
||||
"TrainingConfig",
|
||||
"TrainingDataset",
|
||||
"TrainingModule",
|
||||
"ValidationDataset",
|
||||
"ScaleVolumeConfig",
|
||||
"WarpConfig",
|
||||
"add_echo",
|
||||
"build_augmentations",
|
||||
"build_clip_labeler",
|
||||
"build_loss",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"mask_frequency",
|
||||
"mask_time",
|
||||
"mix_audio",
|
||||
"scale_volume",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
"load_model_from_checkpoint",
|
||||
"load_train_config",
|
||||
"train",
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user