mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Remove num_worker from config
This commit is contained in:
parent
573d8e38d6
commit
d56b9f02ae
@ -15,7 +15,7 @@ from batdetect2.data import (
|
||||
load_dataset_from_config,
|
||||
)
|
||||
from batdetect2.data.datasets import Dataset
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
@ -81,8 +81,8 @@ class BatDetect2API:
|
||||
self,
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
train_workers: int | None = None,
|
||||
val_workers: int | None = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||
experiment_name: str | None = None,
|
||||
@ -113,19 +113,21 @@ class BatDetect2API:
|
||||
def evaluate(
|
||||
self,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
save_predictions: bool = True,
|
||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||
return evaluate(
|
||||
return run_evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
audio_config=self.config.audio,
|
||||
evaluation_config=self.config.evaluation,
|
||||
output_config=self.config.outputs,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
@ -235,7 +237,7 @@ class BatDetect2API:
|
||||
def process_files(
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> list[ClipDetections]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
@ -251,7 +253,7 @@ class BatDetect2API:
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
batch_size: int | None = None,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> list[ClipDetections]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
|
||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||
|
||||
@ -9,7 +9,7 @@ __all__ = [
|
||||
"TaskConfig",
|
||||
"build_evaluator",
|
||||
"build_task",
|
||||
"evaluate",
|
||||
"run_evaluate",
|
||||
"load_evaluation_config",
|
||||
"DEFAULT_EVAL_DIR",
|
||||
]
|
||||
|
||||
@ -67,7 +67,6 @@ class TestDataset(Dataset[TestExample]):
|
||||
|
||||
|
||||
class TestLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
@ -78,7 +77,7 @@ def build_test_loader(
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TestLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> DataLoader[TestExample]:
|
||||
logger.info("Building test data loader...")
|
||||
config = config or TestLoaderConfig()
|
||||
@ -94,7 +93,6 @@ def build_test_loader(
|
||||
config=config,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
|
||||
@ -1,46 +1,47 @@
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import build_output_transform
|
||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import Detection
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def evaluate(
|
||||
def run_evaluate(
|
||||
model: Model,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
formatter: Optional["OutputFormatterProtocol"] = None,
|
||||
num_workers: int | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
formatter: OutputFormatterProtocol | None = None,
|
||||
num_workers: int = 0,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
audio_config = audio_config or AudioConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
output_config = output_config or OutputsConfig()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = preprocessor or model.preprocessor
|
||||
targets = targets or model.targets
|
||||
@ -52,15 +53,15 @@ def evaluate(
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||
evaluator = build_evaluator(config=evaluation_config, targets=targets)
|
||||
|
||||
logger = build_logger(
|
||||
config.evaluation.logger,
|
||||
evaluation_config.logger,
|
||||
log_dir=Path(output_dir),
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
)
|
||||
output_transform = build_output_transform(config=config.outputs.transform)
|
||||
output_transform = build_output_transform(config=output_config.transform)
|
||||
module = EvaluationModule(
|
||||
model,
|
||||
evaluator,
|
||||
|
||||
@ -28,7 +28,7 @@ def run_batch_inference(
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
output_transform: Optional[OutputTransformProtocol] = None,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 1,
|
||||
batch_size: int | None = None,
|
||||
) -> List[ClipDetections]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
@ -75,7 +75,7 @@ def process_file_list(
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> List[ClipDetections]:
|
||||
clip_config = config.inference.clipping
|
||||
clips = get_clips_from_files(
|
||||
|
||||
@ -61,7 +61,6 @@ class InferenceDataset(Dataset[DatasetItem]):
|
||||
|
||||
|
||||
class InferenceLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
batch_size: int = 8
|
||||
|
||||
|
||||
@ -70,7 +69,7 @@ def build_inference_loader(
|
||||
audio_loader: AudioLoader | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: InferenceLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
batch_size: int | None = None,
|
||||
) -> DataLoader[DatasetItem]:
|
||||
logger.info("Building inference data loader...")
|
||||
@ -84,12 +83,11 @@ def build_inference_loader(
|
||||
|
||||
batch_size = batch_size or config.batch_size
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.num_workers,
|
||||
num_workers=num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
@ -143,8 +143,6 @@ class ValidationDataset(Dataset):
|
||||
|
||||
|
||||
class TrainLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
batch_size: int = 8
|
||||
|
||||
shuffle: bool = False
|
||||
@ -164,7 +162,7 @@ def build_train_loader(
|
||||
labeller: ClipLabeller | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: TrainLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
) -> DataLoader:
|
||||
config = config or TrainLoaderConfig()
|
||||
|
||||
@ -182,7 +180,6 @@ def build_train_loader(
|
||||
config=config,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
@ -193,8 +190,6 @@ def build_train_loader(
|
||||
|
||||
|
||||
class ValLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
|
||||
clipping_strategy: ClipConfig = Field(
|
||||
default_factory=lambda: PaddedClipConfig()
|
||||
)
|
||||
@ -206,7 +201,7 @@ def build_val_loader(
|
||||
labeller: ClipLabeller | None = None,
|
||||
preprocessor: PreprocessorProtocol | None = None,
|
||||
config: ValLoaderConfig | None = None,
|
||||
num_workers: int | None = None,
|
||||
num_workers: int = 0,
|
||||
):
|
||||
logger.info("Building validation data loader...")
|
||||
config = config or ValLoaderConfig()
|
||||
@ -223,7 +218,6 @@ def build_val_loader(
|
||||
config=config,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
val_dataset,
|
||||
batch_size=1,
|
||||
|
||||
@ -12,7 +12,7 @@ import torch
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
@ -22,7 +22,6 @@ __all__ = [
|
||||
"LabelConfig",
|
||||
"build_clip_labeler",
|
||||
"generate_heatmaps",
|
||||
"load_label_config",
|
||||
]
|
||||
|
||||
|
||||
@ -150,31 +149,3 @@ def generate_heatmaps(
|
||||
classes=class_heatmap,
|
||||
size=size_heatmap,
|
||||
)
|
||||
|
||||
|
||||
def load_label_config(
|
||||
path: data.PathLike, field: str | None = None
|
||||
) -> LabelConfig:
|
||||
"""Load the heatmap label generation configuration from a file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML or JSON).
|
||||
field : str, optional
|
||||
If the label configuration is nested under a specific key in the
|
||||
file, specify the key here. Defaults to None.
|
||||
|
||||
Returns
|
||||
-------
|
||||
LabelConfig
|
||||
The loaded and validated label configuration object.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file path does not exist.
|
||||
pydantic.ValidationError
|
||||
If the config file structure does not match the LabelConfig schema.
|
||||
"""
|
||||
return load_config(path, schema=LabelConfig, field=field)
|
||||
|
||||
@ -41,8 +41,8 @@ def run_train(
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
train_config: Optional[TrainingConfig] = None,
|
||||
trainer: Trainer | None = None,
|
||||
train_workers: int | None = None,
|
||||
val_workers: int | None = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
checkpoint_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
experiment_name: str | None = None,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user