This commit is contained in:
mbsantiago 2025-09-30 13:56:25 +01:00
parent 981e37c346
commit 2f48c58de1
7 changed files with 23 additions and 80 deletions

View File

@ -8,13 +8,18 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor, to_raw_predictions from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.train import load_model_from_checkpoint, train from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
load_model_from_checkpoint,
train,
)
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
BatDetect2Prediction, BatDetect2Prediction,
@ -53,8 +58,8 @@ class BatDetect2API:
val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, val_annotations: Optional[Sequence[data.ClipAnnotation]] = None,
train_workers: Optional[int] = None, train_workers: Optional[int] = None,
val_workers: Optional[int] = None, val_workers: Optional[int] = None,
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -80,7 +85,7 @@ class BatDetect2API:
self, self,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
output_dir: data.PathLike = ".", output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
): ):

View File

@ -1,5 +1,5 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import evaluate from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task from batdetect2.evaluate.tasks import TaskConfig, build_task
@ -11,4 +11,5 @@ __all__ = [
"build_task", "build_task",
"evaluate", "evaluate",
"load_evaluation_config", "load_evaluation_config",
"DEFAULT_EVAL_DIR",
] ]

View File

@ -21,7 +21,7 @@ if TYPE_CHECKING:
TargetProtocol, TargetProtocol,
) )
DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations" DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def evaluate( def evaluate(
@ -32,7 +32,7 @@ def evaluate(
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None, config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
output_dir: data.PathLike = DEFAULT_OUTPUT_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
): ):

View File

@ -43,6 +43,9 @@ def average_precision(
y_score, y_score,
num_positives: Optional[int] = None, num_positives: Optional[int] = None,
) -> float: ) -> float:
if num_positives == 0:
return np.nan
precision, recall, _ = compute_precision_recall( precision, recall, _ = compute_precision_recall(
y_true, y_true,
y_score, y_score,

View File

@ -51,7 +51,7 @@ def run_batch_inference(
) )
module = InferenceModule(model) module = InferenceModule(model)
trainer = Trainer(enable_checkpointing=False) trainer = Trainer(enable_checkpointing=False, logger=False)
outputs = trainer.predict(module, loader) outputs = trainer.predict(module, loader)
return [ return [
clip_prediction clip_prediction

View File

@ -221,20 +221,16 @@ def build_logger(
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
) -> Logger: ) -> Logger:
"""
Creates a logger instance from a validated Pydantic config object.
"""
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building logger with config: \n{}", "Building logger with config: \n{}",
lambda: config.to_yaml_string(), lambda: config.to_yaml_string(),
) )
logger_type = config.name
logger_type = config.name
if logger_type not in LOGGER_FACTORY: if logger_type not in LOGGER_FACTORY:
raise ValueError(f"Unknown logger type: {logger_type}") raise ValueError(f"Unknown logger type: {logger_type}")
creation_func = LOGGER_FACTORY[logger_type] creation_func = LOGGER_FACTORY[logger_type]
return creation_func( return creation_func(
config, config,
log_dir=log_dir, log_dir=log_dir,

View File

@ -1,81 +1,19 @@
from batdetect2.train.augmentations import (
AddEchoConfig,
AugmentationsConfig,
MaskFrequencyConfig,
MaskTimeConfig,
RandomAudioSource,
ScaleVolumeConfig,
WarpConfig,
add_echo,
build_augmentations,
mask_frequency,
mask_time,
mix_audio,
scale_volume,
warp_spectrogram,
)
from batdetect2.train.config import ( from batdetect2.train.config import (
PLTrainerConfig,
TrainingConfig, TrainingConfig,
load_train_config, load_train_config,
) )
from batdetect2.train.dataset import (
TrainingDataset,
ValidationDataset,
build_train_dataset,
build_train_loader,
build_val_dataset,
build_val_loader,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import ( from batdetect2.train.lightning import (
TrainingModule, TrainingModule,
load_model_from_checkpoint, load_model_from_checkpoint,
) )
from batdetect2.train.losses import ( from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train
ClassificationLossConfig,
DetectionLossConfig,
LossConfig,
LossFunction,
SizeLossConfig,
build_loss,
)
from batdetect2.train.train import build_trainer, train
__all__ = [ __all__ = [
"AugmentationsConfig", "DEFAULT_CHECKPOINT_DIR",
"ClassificationLossConfig",
"DetectionLossConfig",
"AddEchoConfig",
"MaskFrequencyConfig",
"LossConfig",
"LossFunction",
"PLTrainerConfig",
"RandomAudioSource",
"SizeLossConfig",
"MaskTimeConfig",
"TrainingConfig", "TrainingConfig",
"TrainingDataset",
"TrainingModule", "TrainingModule",
"ValidationDataset",
"ScaleVolumeConfig",
"WarpConfig",
"add_echo",
"build_augmentations",
"build_clip_labeler",
"build_loss",
"build_train_dataset",
"build_train_loader",
"build_trainer", "build_trainer",
"build_val_dataset",
"build_val_loader",
"load_label_config",
"load_train_config",
"mask_frequency",
"mask_time",
"mix_audio",
"scale_volume",
"train",
"warp_spectrogram",
"load_model_from_checkpoint", "load_model_from_checkpoint",
"load_train_config",
"train",
] ]