Remove num_worker from config

This commit is contained in:
mbsantiago 2026-03-18 00:22:59 +00:00
parent 573d8e38d6
commit d56b9f02ae
9 changed files with 43 additions and 79 deletions

View File

@ -15,7 +15,7 @@ from batdetect2.data import (
load_dataset_from_config,
)
from batdetect2.data.datasets import Dataset
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
@ -81,8 +81,8 @@ class BatDetect2API:
self,
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
train_workers: int | None = None,
val_workers: int | None = None,
train_workers: int = 0,
val_workers: int = 0,
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
log_dir: Path | None = DEFAULT_LOGS_DIR,
experiment_name: str | None = None,
@ -113,19 +113,21 @@ class BatDetect2API:
def evaluate(
self,
test_annotations: Sequence[data.ClipAnnotation],
num_workers: int | None = None,
num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: str | None = None,
run_name: str | None = None,
save_predictions: bool = True,
) -> tuple[dict[str, float], list[list[Detection]]]:
return evaluate(
return run_evaluate(
self.model,
test_annotations,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
audio_config=self.config.audio,
evaluation_config=self.config.evaluation,
output_config=self.config.outputs,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
@ -235,7 +237,7 @@ class BatDetect2API:
def process_files(
self,
audio_files: Sequence[data.PathLike],
num_workers: int | None = None,
num_workers: int = 0,
) -> list[ClipDetections]:
return process_file_list(
self.model,
@ -251,7 +253,7 @@ class BatDetect2API:
self,
clips: Sequence[data.Clip],
batch_size: int | None = None,
num_workers: int | None = None,
num_workers: int = 0,
) -> list[ClipDetections]:
return run_batch_inference(
self.model,

View File

@ -1,5 +1,5 @@
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
from batdetect2.evaluate.tasks import TaskConfig, build_task
@ -9,7 +9,7 @@ __all__ = [
"TaskConfig",
"build_evaluator",
"build_task",
"evaluate",
"run_evaluate",
"load_evaluation_config",
"DEFAULT_EVAL_DIR",
]

View File

@ -67,7 +67,6 @@ class TestDataset(Dataset[TestExample]):
class TestLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
@ -78,7 +77,7 @@ def build_test_loader(
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TestLoaderConfig | None = None,
num_workers: int | None = None,
num_workers: int = 0,
) -> DataLoader[TestExample]:
logger.info("Building test data loader...")
config = config or TestLoaderConfig()
@ -94,7 +93,6 @@ def build_test_loader(
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
test_dataset,
batch_size=1,

View File

@ -1,46 +1,47 @@
from pathlib import Path
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
from typing import Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio import build_audio_loader
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import EvaluationConfig
from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.models import Model
from batdetect2.outputs import build_output_transform
from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import Detection
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets.types import TargetProtocol
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def evaluate(
def run_evaluate(
model: Model,
test_annotations: Sequence[data.ClipAnnotation],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
formatter: Optional["OutputFormatterProtocol"] = None,
num_workers: int | None = None,
targets: TargetProtocol | None = None,
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
output_config: OutputsConfig | None = None,
formatter: OutputFormatterProtocol | None = None,
num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
experiment_name: str | None = None,
run_name: str | None = None,
) -> Tuple[Dict[str, float], List[List[Detection]]]:
from batdetect2.config import BatDetect2Config
) -> tuple[dict[str, float], list[list[Detection]]]:
config = config or BatDetect2Config()
audio_config = audio_config or AudioConfig()
evaluation_config = evaluation_config or EvaluationConfig()
output_config = output_config or OutputsConfig()
audio_loader = audio_loader or build_audio_loader(config=config.audio)
audio_loader = audio_loader or build_audio_loader(config=audio_config)
preprocessor = preprocessor or model.preprocessor
targets = targets or model.targets
@ -52,15 +53,15 @@ def evaluate(
num_workers=num_workers,
)
evaluator = build_evaluator(config=config.evaluation, targets=targets)
evaluator = build_evaluator(config=evaluation_config, targets=targets)
logger = build_logger(
config.evaluation.logger,
evaluation_config.logger,
log_dir=Path(output_dir),
experiment_name=experiment_name,
run_name=run_name,
)
output_transform = build_output_transform(config=config.outputs.transform)
output_transform = build_output_transform(config=output_config.transform)
module = EvaluationModule(
model,
evaluator,

View File

@ -28,7 +28,7 @@ def run_batch_inference(
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
output_transform: Optional[OutputTransformProtocol] = None,
num_workers: int | None = None,
num_workers: int = 1,
batch_size: int | None = None,
) -> List[ClipDetections]:
from batdetect2.config import BatDetect2Config
@ -75,7 +75,7 @@ def process_file_list(
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: int | None = None,
num_workers: int = 0,
) -> List[ClipDetections]:
clip_config = config.inference.clipping
clips = get_clips_from_files(

View File

@ -61,7 +61,6 @@ class InferenceDataset(Dataset[DatasetItem]):
class InferenceLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
@ -70,7 +69,7 @@ def build_inference_loader(
audio_loader: AudioLoader | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: InferenceLoaderConfig | None = None,
num_workers: int | None = None,
num_workers: int = 0,
batch_size: int | None = None,
) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...")
@ -84,12 +83,11 @@ def build_inference_loader(
batch_size = batch_size or config.batch_size
num_workers = num_workers or config.num_workers
return DataLoader(
inference_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=config.num_workers,
num_workers=num_workers,
collate_fn=_collate_fn,
)

View File

@ -143,8 +143,6 @@ class ValidationDataset(Dataset):
class TrainLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
shuffle: bool = False
@ -164,7 +162,7 @@ def build_train_loader(
labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: TrainLoaderConfig | None = None,
num_workers: int | None = None,
num_workers: int = 0,
) -> DataLoader:
config = config or TrainLoaderConfig()
@ -182,7 +180,6 @@ def build_train_loader(
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
train_dataset,
batch_size=config.batch_size,
@ -193,8 +190,6 @@ def build_train_loader(
class ValLoaderConfig(BaseConfig):
num_workers: int = 0
clipping_strategy: ClipConfig = Field(
default_factory=lambda: PaddedClipConfig()
)
@ -206,7 +201,7 @@ def build_val_loader(
labeller: ClipLabeller | None = None,
preprocessor: PreprocessorProtocol | None = None,
config: ValLoaderConfig | None = None,
num_workers: int | None = None,
num_workers: int = 0,
):
logger.info("Building validation data loader...")
config = config or ValLoaderConfig()
@ -223,7 +218,6 @@ def build_val_loader(
config=config,
)
num_workers = num_workers or config.num_workers
return DataLoader(
val_dataset,
batch_size=1,

View File

@ -12,7 +12,7 @@ import torch
from loguru import logger
from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
from batdetect2.targets import build_targets, iterate_encoded_sound_events
from batdetect2.targets.types import TargetProtocol
@ -22,7 +22,6 @@ __all__ = [
"LabelConfig",
"build_clip_labeler",
"generate_heatmaps",
"load_label_config",
]
@ -150,31 +149,3 @@ def generate_heatmaps(
classes=class_heatmap,
size=size_heatmap,
)
def load_label_config(
path: data.PathLike, field: str | None = None
) -> LabelConfig:
"""Load the heatmap label generation configuration from a file.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML or JSON).
field : str, optional
If the label configuration is nested under a specific key in the
file, specify the key here. Defaults to None.
Returns
-------
LabelConfig
The loaded and validated label configuration object.
Raises
------
FileNotFoundError
If the config file path does not exist.
pydantic.ValidationError
If the config file structure does not match the LabelConfig schema.
"""
return load_config(path, schema=LabelConfig, field=field)

View File

@ -41,8 +41,8 @@ def run_train(
model_config: Optional[ModelConfig] = None,
train_config: Optional[TrainingConfig] = None,
trainer: Trainer | None = None,
train_workers: int | None = None,
val_workers: int | None = None,
train_workers: int = 0,
val_workers: int = 0,
checkpoint_dir: Path | None = None,
log_dir: Path | None = None,
experiment_name: str | None = None,