Remove num_worker from config

This commit is contained in:
mbsantiago 2026-03-18 00:22:59 +00:00
parent 573d8e38d6
commit d56b9f02ae
9 changed files with 43 additions and 79 deletions

View File

@ -15,7 +15,7 @@ from batdetect2.data import (
load_dataset_from_config, load_dataset_from_config,
) )
from batdetect2.data.datasets import Dataset from batdetect2.data.datasets import Dataset
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate
from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import DEFAULT_LOGS_DIR
@ -81,8 +81,8 @@ class BatDetect2API:
self, self,
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
train_workers: int | None = None, train_workers: int = 0,
val_workers: int | None = None, val_workers: int = 0,
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR, checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Path | None = DEFAULT_LOGS_DIR, log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: str | None = None, experiment_name: str | None = None,
@ -113,19 +113,21 @@ class BatDetect2API:
def evaluate( def evaluate(
self, self,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
num_workers: int | None = None, num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
save_predictions: bool = True, save_predictions: bool = True,
) -> tuple[dict[str, float], list[list[Detection]]]: ) -> tuple[dict[str, float], list[list[Detection]]]:
return evaluate( return run_evaluate(
self.model, self.model,
test_annotations, test_annotations,
targets=self.targets, targets=self.targets,
audio_loader=self.audio_loader, audio_loader=self.audio_loader,
preprocessor=self.preprocessor, preprocessor=self.preprocessor,
config=self.config, audio_config=self.config.audio,
evaluation_config=self.config.evaluation,
output_config=self.config.outputs,
num_workers=num_workers, num_workers=num_workers,
output_dir=output_dir, output_dir=output_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
@ -235,7 +237,7 @@ class BatDetect2API:
def process_files( def process_files(
self, self,
audio_files: Sequence[data.PathLike], audio_files: Sequence[data.PathLike],
num_workers: int | None = None, num_workers: int = 0,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
return process_file_list( return process_file_list(
self.model, self.model,
@ -251,7 +253,7 @@ class BatDetect2API:
self, self,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
batch_size: int | None = None, batch_size: int | None = None,
num_workers: int | None = None, num_workers: int = 0,
) -> list[ClipDetections]: ) -> list[ClipDetections]:
return run_batch_inference( return run_batch_inference(
self.model, self.model,

View File

@ -1,5 +1,5 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task from batdetect2.evaluate.tasks import TaskConfig, build_task
@ -9,7 +9,7 @@ __all__ = [
"TaskConfig", "TaskConfig",
"build_evaluator", "build_evaluator",
"build_task", "build_task",
"evaluate", "run_evaluate",
"load_evaluation_config", "load_evaluation_config",
"DEFAULT_EVAL_DIR", "DEFAULT_EVAL_DIR",
] ]

View File

@ -67,7 +67,6 @@ class TestDataset(Dataset[TestExample]):
class TestLoaderConfig(BaseConfig): class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field( clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig() default_factory=lambda: PaddedClipConfig()
) )
@ -78,7 +77,7 @@ def build_test_loader(
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
config: TestLoaderConfig | None = None, config: TestLoaderConfig | None = None,
num_workers: int | None = None, num_workers: int = 0,
) -> DataLoader[TestExample]: ) -> DataLoader[TestExample]:
logger.info("Building test data loader...") logger.info("Building test data loader...")
config = config or TestLoaderConfig() config = config or TestLoaderConfig()
@ -94,7 +93,6 @@ def build_test_loader(
config=config, config=config,
) )
num_workers = num_workers or config.num_workers
return DataLoader( return DataLoader(
test_dataset, test_dataset,
batch_size=1, batch_size=1,

View File

@ -1,46 +1,47 @@
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple from typing import Sequence
from lightning import Trainer from lightning import Trainer
from soundevent import data from soundevent import data
from batdetect2.audio import build_audio_loader from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import EvaluationConfig
from batdetect2.evaluate.dataset import build_test_loader from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger from batdetect2.logging import build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import build_output_transform from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import Detection from batdetect2.postprocess.types import Detection
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def evaluate( def run_evaluate(
model: Model, model: Model,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
targets: Optional["TargetProtocol"] = None, targets: TargetProtocol | None = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: AudioLoader | None = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: PreprocessorProtocol | None = None,
config: Optional["BatDetect2Config"] = None, audio_config: AudioConfig | None = None,
formatter: Optional["OutputFormatterProtocol"] = None, evaluation_config: EvaluationConfig | None = None,
num_workers: int | None = None, output_config: OutputsConfig | None = None,
formatter: OutputFormatterProtocol | None = None,
num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: str | None = None, experiment_name: str | None = None,
run_name: str | None = None, run_name: str | None = None,
) -> Tuple[Dict[str, float], List[List[Detection]]]: ) -> tuple[dict[str, float], list[list[Detection]]]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config() audio_config = audio_config or AudioConfig()
evaluation_config = evaluation_config or EvaluationConfig()
output_config = output_config or OutputsConfig()
audio_loader = audio_loader or build_audio_loader(config=config.audio) audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets targets = targets or model.targets
@ -52,15 +53,15 @@ def evaluate(
num_workers=num_workers, num_workers=num_workers,
) )
evaluator = build_evaluator(config=config.evaluation, targets=targets) evaluator = build_evaluator(config=evaluation_config, targets=targets)
logger = build_logger( logger = build_logger(
config.evaluation.logger, evaluation_config.logger,
log_dir=Path(output_dir), log_dir=Path(output_dir),
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
) )
output_transform = build_output_transform(config=config.outputs.transform) output_transform = build_output_transform(config=output_config.transform)
module = EvaluationModule( module = EvaluationModule(
model, model,
evaluator, evaluator,

View File

@ -28,7 +28,7 @@ def run_batch_inference(
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None, config: Optional["BatDetect2Config"] = None,
output_transform: Optional[OutputTransformProtocol] = None, output_transform: Optional[OutputTransformProtocol] = None,
num_workers: int | None = None, num_workers: int = 1,
batch_size: int | None = None, batch_size: int | None = None,
) -> List[ClipDetections]: ) -> List[ClipDetections]:
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
@ -75,7 +75,7 @@ def process_file_list(
targets: Optional["TargetProtocol"] = None, targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None, audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: int | None = None, num_workers: int = 0,
) -> List[ClipDetections]: ) -> List[ClipDetections]:
clip_config = config.inference.clipping clip_config = config.inference.clipping
clips = get_clips_from_files( clips = get_clips_from_files(

View File

@ -61,7 +61,6 @@ class InferenceDataset(Dataset[DatasetItem]):
class InferenceLoaderConfig(BaseConfig): class InferenceLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8 batch_size: int = 8
@ -70,7 +69,7 @@ def build_inference_loader(
audio_loader: AudioLoader | None = None, audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
config: InferenceLoaderConfig | None = None, config: InferenceLoaderConfig | None = None,
num_workers: int | None = None, num_workers: int = 0,
batch_size: int | None = None, batch_size: int | None = None,
) -> DataLoader[DatasetItem]: ) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...") logger.info("Building inference data loader...")
@ -84,12 +83,11 @@ def build_inference_loader(
batch_size = batch_size or config.batch_size batch_size = batch_size or config.batch_size
num_workers = num_workers or config.num_workers
return DataLoader( return DataLoader(
inference_dataset, inference_dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
num_workers=config.num_workers, num_workers=num_workers,
collate_fn=_collate_fn, collate_fn=_collate_fn,
) )

View File

@ -143,8 +143,6 @@ class ValidationDataset(Dataset):
class TrainLoaderConfig(BaseConfig): class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8 batch_size: int = 8
shuffle: bool = False shuffle: bool = False
@ -164,7 +162,7 @@ def build_train_loader(
labeller: ClipLabeller | None = None, labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
config: TrainLoaderConfig | None = None, config: TrainLoaderConfig | None = None,
num_workers: int | None = None, num_workers: int = 0,
) -> DataLoader: ) -> DataLoader:
config = config or TrainLoaderConfig() config = config or TrainLoaderConfig()
@ -182,7 +180,6 @@ def build_train_loader(
config=config, config=config,
) )
num_workers = num_workers or config.num_workers
return DataLoader( return DataLoader(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
@ -193,8 +190,6 @@ def build_train_loader(
class ValLoaderConfig(BaseConfig): class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field( clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig() default_factory=lambda: PaddedClipConfig()
) )
@ -206,7 +201,7 @@ def build_val_loader(
labeller: ClipLabeller | None = None, labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None, preprocessor: PreprocessorProtocol | None = None,
config: ValLoaderConfig | None = None, config: ValLoaderConfig | None = None,
num_workers: int | None = None, num_workers: int = 0,
): ):
logger.info("Building validation data loader...") logger.info("Building validation data loader...")
config = config or ValLoaderConfig() config = config or ValLoaderConfig()
@ -223,7 +218,6 @@ def build_val_loader(
config=config, config=config,
) )
num_workers = num_workers or config.num_workers
return DataLoader( return DataLoader(
val_dataset, val_dataset,
batch_size=1, batch_size=1,

View File

@ -12,7 +12,7 @@ import torch
from loguru import logger from loguru import logger
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.targets.types import TargetProtocol from batdetect2.targets.types import TargetProtocol
@ -22,7 +22,6 @@ __all__ = [
"LabelConfig", "LabelConfig",
"build_clip_labeler", "build_clip_labeler",
"generate_heatmaps", "generate_heatmaps",
"load_label_config",
] ]
@ -150,31 +149,3 @@ def generate_heatmaps(
classes=class_heatmap, classes=class_heatmap,
size=size_heatmap, size=size_heatmap,
) )
def load_label_config(
path: data.PathLike, field: str | None = None
) -> LabelConfig:
"""Load the heatmap label generation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML or JSON).
field : str, optional
If the label configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
LabelConfig
The loaded and validated label configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the LabelConfig schema.
"""
return load_config(path, schema=LabelConfig, field=field)

View File

@ -41,8 +41,8 @@ def run_train(
model_config: Optional[ModelConfig] = None, model_config: Optional[ModelConfig] = None,
train_config: Optional[TrainingConfig] = None, train_config: Optional[TrainingConfig] = None,
trainer: Trainer | None = None, trainer: Trainer | None = None,
train_workers: int | None = None, train_workers: int = 0,
val_workers: int | None = None, val_workers: int = 0,
checkpoint_dir: Path | None = None, checkpoint_dir: Path | None = None,
log_dir: Path | None = None, log_dir: Path | None = None,
experiment_name: str | None = None, experiment_name: str | None = None,