This commit is contained in:
mbsantiago 2025-09-30 13:56:25 +01:00
parent 981e37c346
commit 2f48c58de1
7 changed files with 23 additions and 80 deletions

View File

@ -8,13 +8,18 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets
from batdetect2.train import load_model_from_checkpoint, train
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
load_model_from_checkpoint,
train,
)
from batdetect2.typing import (
AudioLoader,
BatDetect2Prediction,
@ -53,8 +58,8 @@ class BatDetect2API:
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
train_workers: Optional[int] = None,
val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
@ -80,7 +85,7 @@ class BatDetect2API:
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None,
output_dir: data.PathLike = ".",
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):

View File

@ -1,5 +1,5 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task
@ -11,4 +11,5 @@ __all__ = [
"build_task",
"evaluate",
"load_evaluation_config",
"DEFAULT_EVAL_DIR",
]

View File

@ -21,7 +21,7 @@ if TYPE_CHECKING:
TargetProtocol,
)
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations"
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def evaluate(
@ -32,7 +32,7 @@ def evaluate(
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
):

View File

@ -43,6 +43,9 @@ def average_precision(
y_score,
num_positives: Optional[int] = None,
) -> float:
if num_positives == 0:
return np.nan
precision, recall, _ = compute_precision_recall(
y_true,
y_score,

View File

@ -51,7 +51,7 @@ def run_batch_inference(
)
module = InferenceModule(model)
trainer = Trainer(enable_checkpointing=False)
trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader)
return [
clip_prediction

View File

@ -221,20 +221,16 @@ def build_logger(
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
logger.opt(lazy=True).debug(
"Building logger with config: \n{}",
lambda: config.to_yaml_string(),
)
logger_type = config.name
logger_type = config.name
if logger_type not in LOGGER_FACTORY:
raise ValueError(f"Unknown logger type: {logger_type}")
creation_func = LOGGER_FACTORY[logger_type]
return creation_func(
config,
log_dir=log_dir,

View File

@ -1,81 +1,19 @@
from batdetect2.train.augmentations import (
AddEchoConfig,
AugmentationsConfig,
MaskFrequencyConfig,
MaskTimeConfig,
RandomAudioSource,
ScaleVolumeConfig,
WarpConfig,
add_echo,
build_augmentations,
mask_frequency,
mask_time,
mix_audio,
scale_volume,
warp_spectrogram,
)
from batdetect2.train.config import (
PLTrainerConfig,
TrainingConfig,
load_train_config,
)
from batdetect2.train.dataset import (
TrainingDataset,
ValidationDataset,
build_train_dataset,
build_train_loader,
build_val_dataset,
build_val_loader,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import (
TrainingModule,
load_model_from_checkpoint,
)
from batdetect2.train.losses import (
ClassificationLossConfig,
DetectionLossConfig,
LossConfig,
LossFunction,
SizeLossConfig,
build_loss,
)
from batdetect2.train.train import build_trainer, train
from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train
__all__ = [
"AugmentationsConfig",
"ClassificationLossConfig",
"DetectionLossConfig",
"AddEchoConfig",
"MaskFrequencyConfig",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
"RandomAudioSource",
"SizeLossConfig",
"MaskTimeConfig",
"DEFAULT_CHECKPOINT_DIR",
"TrainingConfig",
"TrainingDataset",
"TrainingModule",
"ValidationDataset",
"ScaleVolumeConfig",
"WarpConfig",
"add_echo",
"build_augmentations",
"build_clip_labeler",
"build_loss",
"build_train_dataset",
"build_train_loader",
"build_trainer",
"build_val_dataset",
"build_val_loader",
"load_label_config",
"load_train_config",
"mask_frequency",
"mask_time",
"mix_audio",
"scale_volume",
"train",
"warp_spectrogram",
"load_model_from_checkpoint",
"load_train_config",
"train",
]