diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 3b31bb7..7d7f35b 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -8,13 +8,18 @@ from soundevent.audio.files import get_audio_files from batdetect2.audio import build_audio_loader from batdetect2.config import BatDetect2Config -from batdetect2.evaluate import build_evaluator, evaluate +from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.inference import process_file_list, run_batch_inference +from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.models import Model, build_model from batdetect2.postprocess import build_postprocessor, to_raw_predictions from batdetect2.preprocess import build_preprocessor from batdetect2.targets import build_targets -from batdetect2.train import load_model_from_checkpoint, train +from batdetect2.train import ( + DEFAULT_CHECKPOINT_DIR, + load_model_from_checkpoint, + train, +) from batdetect2.typing import ( AudioLoader, BatDetect2Prediction, @@ -53,8 +58,8 @@ class BatDetect2API: val_annotations: Optional[Sequence[data.ClipAnnotation]] = None, train_workers: Optional[int] = None, val_workers: Optional[int] = None, - checkpoint_dir: Optional[Path] = None, - log_dir: Optional[Path] = None, + checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR, + log_dir: Optional[Path] = DEFAULT_LOGS_DIR, experiment_name: Optional[str] = None, run_name: Optional[str] = None, seed: Optional[int] = None, @@ -80,7 +85,7 @@ class BatDetect2API: self, test_annotations: Sequence[data.ClipAnnotation], num_workers: Optional[int] = None, - output_dir: data.PathLike = ".", + output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: Optional[str] = None, run_name: Optional[str] = None, ): diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index 07fa19e..aefec05 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,5 +1,5 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config -from batdetect2.evaluate.evaluate import evaluate +from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.tasks import TaskConfig, build_task @@ -11,4 +11,5 @@ __all__ = [ "build_task", "evaluate", "load_evaluation_config", + "DEFAULT_EVAL_DIR", ] diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index cbc0cd1..7312639 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -21,7 +21,7 @@ if TYPE_CHECKING: TargetProtocol, ) -DEFAULT_OUTPUT_DIR: Path = Path("outputs") / "evaluations" +DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" def evaluate( @@ -32,7 +32,7 @@ def evaluate( preprocessor: Optional["PreprocessorProtocol"] = None, config: Optional["BatDetect2Config"] = None, num_workers: Optional[int] = None, - output_dir: data.PathLike = DEFAULT_OUTPUT_DIR, + output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: Optional[str] = None, run_name: Optional[str] = None, ): diff --git a/src/batdetect2/evaluate/metrics/common.py b/src/batdetect2/evaluate/metrics/common.py index 7c2925a..0aa632d 100644 --- a/src/batdetect2/evaluate/metrics/common.py +++ b/src/batdetect2/evaluate/metrics/common.py @@ -43,6 +43,9 @@ def average_precision( y_score, num_positives: Optional[int] = None, ) -> float: + if num_positives == 0: + return np.nan + precision, recall, _ = compute_precision_recall( y_true, y_score, diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index b0d878a..06e67b4 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -51,7 +51,7 @@ def run_batch_inference( ) module = InferenceModule(model) - trainer = Trainer(enable_checkpointing=False) + trainer = Trainer(enable_checkpointing=False, logger=False) outputs = trainer.predict(module, loader) return [ clip_prediction diff --git a/src/batdetect2/logging.py b/src/batdetect2/logging.py index eb96d44..67bf11d 100644 --- a/src/batdetect2/logging.py +++ b/src/batdetect2/logging.py @@ -221,20 +221,16 @@ def build_logger( experiment_name: Optional[str] = None, run_name: Optional[str] = None, ) -> Logger: - """ - Creates a logger instance from a validated Pydantic config object. - """ logger.opt(lazy=True).debug( "Building logger with config: \n{}", lambda: config.to_yaml_string(), ) - logger_type = config.name + logger_type = config.name if logger_type not in LOGGER_FACTORY: raise ValueError(f"Unknown logger type: {logger_type}") creation_func = LOGGER_FACTORY[logger_type] - return creation_func( config, log_dir=log_dir, diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 5581e56..09d0a88 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -1,81 +1,19 @@ -from batdetect2.train.augmentations import ( - AddEchoConfig, - AugmentationsConfig, - MaskFrequencyConfig, - MaskTimeConfig, - RandomAudioSource, - ScaleVolumeConfig, - WarpConfig, - add_echo, - build_augmentations, - mask_frequency, - mask_time, - mix_audio, - scale_volume, - warp_spectrogram, -) from batdetect2.train.config import ( - PLTrainerConfig, TrainingConfig, load_train_config, ) -from batdetect2.train.dataset import ( - TrainingDataset, - ValidationDataset, - build_train_dataset, - build_train_loader, - build_val_dataset, - build_val_loader, -) -from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.lightning import ( TrainingModule, load_model_from_checkpoint, ) -from batdetect2.train.losses import ( - ClassificationLossConfig, - DetectionLossConfig, - LossConfig, - LossFunction, - SizeLossConfig, - build_loss, -) -from batdetect2.train.train import build_trainer, train +from batdetect2.train.train import DEFAULT_CHECKPOINT_DIR, build_trainer, train __all__ = [ - "AugmentationsConfig", - "ClassificationLossConfig", - "DetectionLossConfig", - "AddEchoConfig", - "MaskFrequencyConfig", - "LossConfig", - "LossFunction", - "PLTrainerConfig", - "RandomAudioSource", - "SizeLossConfig", - "MaskTimeConfig", + "DEFAULT_CHECKPOINT_DIR", "TrainingConfig", - "TrainingDataset", "TrainingModule", - "ValidationDataset", - "ScaleVolumeConfig", - "WarpConfig", - "add_echo", - "build_augmentations", - "build_clip_labeler", - "build_loss", - "build_train_dataset", - "build_train_loader", "build_trainer", - "build_val_dataset", - "build_val_loader", - "load_label_config", - "load_train_config", - "mask_frequency", - "mask_time", - "mix_audio", - "scale_volume", - "train", - "warp_spectrogram", "load_model_from_checkpoint", + "load_train_config", + "train", ]