diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index ca14288..e395786 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -15,7 +15,7 @@ from batdetect2.data import ( load_dataset_from_config, ) from batdetect2.data.datasets import Dataset -from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate +from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR @@ -81,8 +81,8 @@ class BatDetect2API: self, train_annotations: Sequence[data.ClipAnnotation], val_annotations: Sequence[data.ClipAnnotation] | None = None, - train_workers: int | None = None, - val_workers: int | None = None, + train_workers: int = 0, + val_workers: int = 0, checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR, log_dir: Path | None = DEFAULT_LOGS_DIR, experiment_name: str | None = None, @@ -113,19 +113,21 @@ class BatDetect2API: def evaluate( self, test_annotations: Sequence[data.ClipAnnotation], - num_workers: int | None = None, + num_workers: int = 0, output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: str | None = None, run_name: str | None = None, save_predictions: bool = True, ) -> tuple[dict[str, float], list[list[Detection]]]: - return evaluate( + return run_evaluate( self.model, test_annotations, targets=self.targets, audio_loader=self.audio_loader, preprocessor=self.preprocessor, - config=self.config, + audio_config=self.config.audio, + evaluation_config=self.config.evaluation, + output_config=self.config.outputs, num_workers=num_workers, output_dir=output_dir, experiment_name=experiment_name, @@ -235,7 +237,7 @@ class BatDetect2API: def process_files( self, audio_files: Sequence[data.PathLike], - num_workers: int | None = None, + num_workers: int = 0, ) -> list[ClipDetections]: return process_file_list( self.model, @@ -251,7 +253,7 @@ class BatDetect2API: self, clips: Sequence[data.Clip], batch_size: int | None = None, - num_workers: int | None = None, + num_workers: int = 0, ) -> list[ClipDetections]: return run_batch_inference( self.model, diff --git a/src/batdetect2/evaluate/__init__.py b/src/batdetect2/evaluate/__init__.py index aefec05..c35d851 100644 --- a/src/batdetect2/evaluate/__init__.py +++ b/src/batdetect2/evaluate/__init__.py @@ -1,5 +1,5 @@ from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config -from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate +from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.tasks import TaskConfig, build_task @@ -9,7 +9,7 @@ __all__ = [ "TaskConfig", "build_evaluator", "build_task", - "evaluate", + "run_evaluate", "load_evaluation_config", "DEFAULT_EVAL_DIR", ] diff --git a/src/batdetect2/evaluate/dataset.py b/src/batdetect2/evaluate/dataset.py index ebc7431..5fe43eb 100644 --- a/src/batdetect2/evaluate/dataset.py +++ b/src/batdetect2/evaluate/dataset.py @@ -67,7 +67,6 @@ class TestDataset(Dataset[TestExample]): class TestLoaderConfig(BaseConfig): - num_workers: int = 0 clipping_strategy: ClipConfig = Field( default_factory=lambda: PaddedClipConfig() ) @@ -78,7 +77,7 @@ def build_test_loader( audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, config: TestLoaderConfig | None = None, - num_workers: int | None = None, + num_workers: int = 0, ) -> DataLoader[TestExample]: logger.info("Building test data loader...") config = config or TestLoaderConfig() @@ -94,7 +93,6 @@ def build_test_loader( config=config, ) - num_workers = num_workers or config.num_workers return DataLoader( test_dataset, batch_size=1, diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 99d6219..92cea1e 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -1,46 +1,47 @@ from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple +from typing import Sequence from lightning import Trainer from soundevent import data -from batdetect2.audio import build_audio_loader +from batdetect2.audio import AudioConfig, build_audio_loader from batdetect2.audio.types import AudioLoader +from batdetect2.evaluate import EvaluationConfig from batdetect2.evaluate.dataset import build_test_loader from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.logging import build_logger from batdetect2.models import Model -from batdetect2.outputs import build_output_transform +from batdetect2.outputs import OutputsConfig, build_output_transform from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.postprocess.types import Detection from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets.types import TargetProtocol -if TYPE_CHECKING: - from batdetect2.config import BatDetect2Config - DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" -def evaluate( +def run_evaluate( model: Model, test_annotations: Sequence[data.ClipAnnotation], - targets: Optional["TargetProtocol"] = None, - audio_loader: Optional["AudioLoader"] = None, - preprocessor: Optional["PreprocessorProtocol"] = None, - config: Optional["BatDetect2Config"] = None, - formatter: Optional["OutputFormatterProtocol"] = None, - num_workers: int | None = None, + targets: TargetProtocol | None = None, + audio_loader: AudioLoader | None = None, + preprocessor: PreprocessorProtocol | None = None, + audio_config: AudioConfig | None = None, + evaluation_config: EvaluationConfig | None = None, + output_config: OutputsConfig | None = None, + formatter: OutputFormatterProtocol | None = None, + num_workers: int = 0, output_dir: data.PathLike = DEFAULT_EVAL_DIR, experiment_name: str | None = None, run_name: str | None = None, -) -> Tuple[Dict[str, float], List[List[Detection]]]: - from batdetect2.config import BatDetect2Config +) -> tuple[dict[str, float], list[list[Detection]]]: - config = config or BatDetect2Config() + audio_config = audio_config or AudioConfig() + evaluation_config = evaluation_config or EvaluationConfig() + output_config = output_config or OutputsConfig() - audio_loader = audio_loader or build_audio_loader(config=config.audio) + audio_loader = audio_loader or build_audio_loader(config=audio_config) preprocessor = preprocessor or model.preprocessor targets = targets or model.targets @@ -52,15 +53,15 @@ def evaluate( num_workers=num_workers, ) - evaluator = build_evaluator(config=config.evaluation, targets=targets) + evaluator = build_evaluator(config=evaluation_config, targets=targets) logger = build_logger( - config.evaluation.logger, + evaluation_config.logger, log_dir=Path(output_dir), experiment_name=experiment_name, run_name=run_name, ) - output_transform = build_output_transform(config=config.outputs.transform) + output_transform = build_output_transform(config=output_config.transform) module = EvaluationModule( model, evaluator, diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 987bb74..da56712 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -28,7 +28,7 @@ def run_batch_inference( preprocessor: Optional["PreprocessorProtocol"] = None, config: Optional["BatDetect2Config"] = None, output_transform: Optional[OutputTransformProtocol] = None, - num_workers: int | None = None, + num_workers: int = 1, batch_size: int | None = None, ) -> List[ClipDetections]: from batdetect2.config import BatDetect2Config @@ -75,7 +75,7 @@ def process_file_list( targets: Optional["TargetProtocol"] = None, audio_loader: Optional["AudioLoader"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, - num_workers: int | None = None, + num_workers: int = 0, ) -> List[ClipDetections]: clip_config = config.inference.clipping clips = get_clips_from_files( diff --git a/src/batdetect2/inference/dataset.py b/src/batdetect2/inference/dataset.py index 0b1aa50..aaeae7f 100644 --- a/src/batdetect2/inference/dataset.py +++ b/src/batdetect2/inference/dataset.py @@ -61,7 +61,6 @@ class InferenceDataset(Dataset[DatasetItem]): class InferenceLoaderConfig(BaseConfig): - num_workers: int = 0 batch_size: int = 8 @@ -70,7 +69,7 @@ def build_inference_loader( audio_loader: AudioLoader | None = None, preprocessor: PreprocessorProtocol | None = None, config: InferenceLoaderConfig | None = None, - num_workers: int | None = None, + num_workers: int = 0, batch_size: int | None = None, ) -> DataLoader[DatasetItem]: logger.info("Building inference data loader...") @@ -84,12 +83,11 @@ def build_inference_loader( batch_size = batch_size or config.batch_size - num_workers = num_workers or config.num_workers return DataLoader( inference_dataset, batch_size=batch_size, shuffle=False, - num_workers=config.num_workers, + num_workers=num_workers, collate_fn=_collate_fn, ) diff --git a/src/batdetect2/train/dataset.py b/src/batdetect2/train/dataset.py index 3e22b8e..e9e1883 100644 --- a/src/batdetect2/train/dataset.py +++ b/src/batdetect2/train/dataset.py @@ -143,8 +143,6 @@ class ValidationDataset(Dataset): class TrainLoaderConfig(BaseConfig): - num_workers: int = 0 - batch_size: int = 8 shuffle: bool = False @@ -164,7 +162,7 @@ def build_train_loader( labeller: ClipLabeller | None = None, preprocessor: PreprocessorProtocol | None = None, config: TrainLoaderConfig | None = None, - num_workers: int | None = None, + num_workers: int = 0, ) -> DataLoader: config = config or TrainLoaderConfig() @@ -182,7 +180,6 @@ def build_train_loader( config=config, ) - num_workers = num_workers or config.num_workers return DataLoader( train_dataset, batch_size=config.batch_size, @@ -193,8 +190,6 @@ def build_train_loader( class ValLoaderConfig(BaseConfig): - num_workers: int = 0 - clipping_strategy: ClipConfig = Field( default_factory=lambda: PaddedClipConfig() ) @@ -206,7 +201,7 @@ def build_val_loader( labeller: ClipLabeller | None = None, preprocessor: PreprocessorProtocol | None = None, config: ValLoaderConfig | None = None, - num_workers: int | None = None, + num_workers: int = 0, ): logger.info("Building validation data loader...") config = config or ValLoaderConfig() @@ -223,7 +218,6 @@ def build_val_loader( config=config, ) - num_workers = num_workers or config.num_workers return DataLoader( val_dataset, batch_size=1, diff --git a/src/batdetect2/train/labels.py b/src/batdetect2/train/labels.py index 96e8d44..ce055d3 100644 --- a/src/batdetect2/train/labels.py +++ b/src/batdetect2/train/labels.py @@ -12,7 +12,7 @@ import torch from loguru import logger from soundevent import data -from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.targets import build_targets, iterate_encoded_sound_events from batdetect2.targets.types import TargetProtocol @@ -22,7 +22,6 @@ __all__ = [ "LabelConfig", "build_clip_labeler", "generate_heatmaps", - "load_label_config", ] @@ -150,31 +149,3 @@ def generate_heatmaps( classes=class_heatmap, size=size_heatmap, ) - - -def load_label_config( - path: data.PathLike, field: str | None = None -) -> LabelConfig: - """Load the heatmap label generation configuration from a file. - - Parameters - ---------- - path : data.PathLike - Path to the configuration file (e.g., YAML or JSON). - field : str, optional - If the label configuration is nested under a specific key in the - file, specify the key here. Defaults to None. - - Returns - ------- - LabelConfig - The loaded and validated label configuration object. - - Raises - ------ - FileNotFoundError - If the config file path does not exist. - pydantic.ValidationError - If the config file structure does not match the LabelConfig schema. - """ - return load_config(path, schema=LabelConfig, field=field) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 3f61654..7b11c1b 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -41,8 +41,8 @@ def run_train( model_config: Optional[ModelConfig] = None, train_config: Optional[TrainingConfig] = None, trainer: Trainer | None = None, - train_workers: int | None = None, - val_workers: int | None = None, + train_workers: int = 0, + val_workers: int = 0, checkpoint_dir: Path | None = None, log_dir: Path | None = None, experiment_name: str | None = None,