mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Remove num_worker from config
This commit is contained in:
parent
573d8e38d6
commit
d56b9f02ae
@ -15,7 +15,7 @@ from batdetect2.data import (
|
|||||||
load_dataset_from_config,
|
load_dataset_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.data.datasets import Dataset
|
from batdetect2.data.datasets import Dataset
|
||||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, run_evaluate
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.inference import process_file_list, run_batch_inference
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||||
@ -81,8 +81,8 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
train_workers: int | None = None,
|
train_workers: int = 0,
|
||||||
val_workers: int | None = None,
|
val_workers: int = 0,
|
||||||
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
checkpoint_dir: Path | None = DEFAULT_CHECKPOINT_DIR,
|
||||||
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
log_dir: Path | None = DEFAULT_LOGS_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
@ -113,19 +113,21 @@ class BatDetect2API:
|
|||||||
def evaluate(
|
def evaluate(
|
||||||
self,
|
self,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
save_predictions: bool = True,
|
save_predictions: bool = True,
|
||||||
) -> tuple[dict[str, float], list[list[Detection]]]:
|
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||||
return evaluate(
|
return run_evaluate(
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
config=self.config,
|
audio_config=self.config.audio,
|
||||||
|
evaluation_config=self.config.evaluation,
|
||||||
|
output_config=self.config.outputs,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
@ -235,7 +237,7 @@ class BatDetect2API:
|
|||||||
def process_files(
|
def process_files(
|
||||||
self,
|
self,
|
||||||
audio_files: Sequence[data.PathLike],
|
audio_files: Sequence[data.PathLike],
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
@ -251,7 +253,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
from batdetect2.evaluate.config import EvaluationConfig, load_evaluation_config
|
||||||
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, evaluate
|
from batdetect2.evaluate.evaluate import DEFAULT_EVAL_DIR, run_evaluate
|
||||||
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
from batdetect2.evaluate.evaluator import Evaluator, build_evaluator
|
||||||
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
from batdetect2.evaluate.tasks import TaskConfig, build_task
|
||||||
|
|
||||||
@ -9,7 +9,7 @@ __all__ = [
|
|||||||
"TaskConfig",
|
"TaskConfig",
|
||||||
"build_evaluator",
|
"build_evaluator",
|
||||||
"build_task",
|
"build_task",
|
||||||
"evaluate",
|
"run_evaluate",
|
||||||
"load_evaluation_config",
|
"load_evaluation_config",
|
||||||
"DEFAULT_EVAL_DIR",
|
"DEFAULT_EVAL_DIR",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -67,7 +67,6 @@ class TestDataset(Dataset[TestExample]):
|
|||||||
|
|
||||||
|
|
||||||
class TestLoaderConfig(BaseConfig):
|
class TestLoaderConfig(BaseConfig):
|
||||||
num_workers: int = 0
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
clipping_strategy: ClipConfig = Field(
|
||||||
default_factory=lambda: PaddedClipConfig()
|
default_factory=lambda: PaddedClipConfig()
|
||||||
)
|
)
|
||||||
@ -78,7 +77,7 @@ def build_test_loader(
|
|||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: TestLoaderConfig | None = None,
|
config: TestLoaderConfig | None = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
) -> DataLoader[TestExample]:
|
) -> DataLoader[TestExample]:
|
||||||
logger.info("Building test data loader...")
|
logger.info("Building test data loader...")
|
||||||
config = config or TestLoaderConfig()
|
config = config or TestLoaderConfig()
|
||||||
@ -94,7 +93,6 @@ def build_test_loader(
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
|
|||||||
@ -1,46 +1,47 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple
|
from typing import Sequence
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.audio.types import AudioLoader
|
||||||
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.evaluate.dataset import build_test_loader
|
from batdetect2.evaluate.dataset import build_test_loader
|
||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
from batdetect2.logging import build_logger
|
from batdetect2.logging import build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import build_output_transform
|
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess.types import Detection
|
from batdetect2.postprocess.types import Detection
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
|
|
||||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||||
|
|
||||||
|
|
||||||
def evaluate(
|
def run_evaluate(
|
||||||
model: Model,
|
model: Model,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: Optional["TargetProtocol"] = None,
|
targets: TargetProtocol | None = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
audio_config: AudioConfig | None = None,
|
||||||
formatter: Optional["OutputFormatterProtocol"] = None,
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
num_workers: int | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
|
formatter: OutputFormatterProtocol | None = None,
|
||||||
|
num_workers: int = 0,
|
||||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
) -> Tuple[Dict[str, float], List[List[Detection]]]:
|
) -> tuple[dict[str, float], list[list[Detection]]]:
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
|
|
||||||
config = config or BatDetect2Config()
|
audio_config = audio_config or AudioConfig()
|
||||||
|
evaluation_config = evaluation_config or EvaluationConfig()
|
||||||
|
output_config = output_config or OutputsConfig()
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or model.preprocessor
|
||||||
targets = targets or model.targets
|
targets = targets or model.targets
|
||||||
@ -52,15 +53,15 @@ def evaluate(
|
|||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
evaluator = build_evaluator(config=evaluation_config, targets=targets)
|
||||||
|
|
||||||
logger = build_logger(
|
logger = build_logger(
|
||||||
config.evaluation.logger,
|
evaluation_config.logger,
|
||||||
log_dir=Path(output_dir),
|
log_dir=Path(output_dir),
|
||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
output_transform = build_output_transform(config=config.outputs.transform)
|
output_transform = build_output_transform(config=output_config.transform)
|
||||||
module = EvaluationModule(
|
module = EvaluationModule(
|
||||||
model,
|
model,
|
||||||
evaluator,
|
evaluator,
|
||||||
|
|||||||
@ -28,7 +28,7 @@ def run_batch_inference(
|
|||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
output_transform: Optional[OutputTransformProtocol] = None,
|
output_transform: Optional[OutputTransformProtocol] = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 1,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
) -> List[ClipDetections]:
|
) -> List[ClipDetections]:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
@ -75,7 +75,7 @@ def process_file_list(
|
|||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
audio_loader: Optional["AudioLoader"] = None,
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
) -> List[ClipDetections]:
|
) -> List[ClipDetections]:
|
||||||
clip_config = config.inference.clipping
|
clip_config = config.inference.clipping
|
||||||
clips = get_clips_from_files(
|
clips = get_clips_from_files(
|
||||||
|
|||||||
@ -61,7 +61,6 @@ class InferenceDataset(Dataset[DatasetItem]):
|
|||||||
|
|
||||||
|
|
||||||
class InferenceLoaderConfig(BaseConfig):
|
class InferenceLoaderConfig(BaseConfig):
|
||||||
num_workers: int = 0
|
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +69,7 @@ def build_inference_loader(
|
|||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: InferenceLoaderConfig | None = None,
|
config: InferenceLoaderConfig | None = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
) -> DataLoader[DatasetItem]:
|
) -> DataLoader[DatasetItem]:
|
||||||
logger.info("Building inference data loader...")
|
logger.info("Building inference data loader...")
|
||||||
@ -84,12 +83,11 @@ def build_inference_loader(
|
|||||||
|
|
||||||
batch_size = batch_size or config.batch_size
|
batch_size = batch_size or config.batch_size
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
inference_dataset,
|
inference_dataset,
|
||||||
batch_size=batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=config.num_workers,
|
num_workers=num_workers,
|
||||||
collate_fn=_collate_fn,
|
collate_fn=_collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -143,8 +143,6 @@ class ValidationDataset(Dataset):
|
|||||||
|
|
||||||
|
|
||||||
class TrainLoaderConfig(BaseConfig):
|
class TrainLoaderConfig(BaseConfig):
|
||||||
num_workers: int = 0
|
|
||||||
|
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
|
||||||
shuffle: bool = False
|
shuffle: bool = False
|
||||||
@ -164,7 +162,7 @@ def build_train_loader(
|
|||||||
labeller: ClipLabeller | None = None,
|
labeller: ClipLabeller | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: TrainLoaderConfig | None = None,
|
config: TrainLoaderConfig | None = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
) -> DataLoader:
|
) -> DataLoader:
|
||||||
config = config or TrainLoaderConfig()
|
config = config or TrainLoaderConfig()
|
||||||
|
|
||||||
@ -182,7 +180,6 @@ def build_train_loader(
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
@ -193,8 +190,6 @@ def build_train_loader(
|
|||||||
|
|
||||||
|
|
||||||
class ValLoaderConfig(BaseConfig):
|
class ValLoaderConfig(BaseConfig):
|
||||||
num_workers: int = 0
|
|
||||||
|
|
||||||
clipping_strategy: ClipConfig = Field(
|
clipping_strategy: ClipConfig = Field(
|
||||||
default_factory=lambda: PaddedClipConfig()
|
default_factory=lambda: PaddedClipConfig()
|
||||||
)
|
)
|
||||||
@ -206,7 +201,7 @@ def build_val_loader(
|
|||||||
labeller: ClipLabeller | None = None,
|
labeller: ClipLabeller | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
config: ValLoaderConfig | None = None,
|
config: ValLoaderConfig | None = None,
|
||||||
num_workers: int | None = None,
|
num_workers: int = 0,
|
||||||
):
|
):
|
||||||
logger.info("Building validation data loader...")
|
logger.info("Building validation data loader...")
|
||||||
config = config or ValLoaderConfig()
|
config = config or ValLoaderConfig()
|
||||||
@ -223,7 +218,6 @@ def build_val_loader(
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
|
|||||||
@ -12,7 +12,7 @@ import torch
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig, load_config
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
from batdetect2.preprocess import MAX_FREQ, MIN_FREQ
|
||||||
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
from batdetect2.targets import build_targets, iterate_encoded_sound_events
|
||||||
from batdetect2.targets.types import TargetProtocol
|
from batdetect2.targets.types import TargetProtocol
|
||||||
@ -22,7 +22,6 @@ __all__ = [
|
|||||||
"LabelConfig",
|
"LabelConfig",
|
||||||
"build_clip_labeler",
|
"build_clip_labeler",
|
||||||
"generate_heatmaps",
|
"generate_heatmaps",
|
||||||
"load_label_config",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -150,31 +149,3 @@ def generate_heatmaps(
|
|||||||
classes=class_heatmap,
|
classes=class_heatmap,
|
||||||
size=size_heatmap,
|
size=size_heatmap,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_label_config(
|
|
||||||
path: data.PathLike, field: str | None = None
|
|
||||||
) -> LabelConfig:
|
|
||||||
"""Load the heatmap label generation configuration from a file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
path : data.PathLike
|
|
||||||
Path to the configuration file (e.g., YAML or JSON).
|
|
||||||
field : str, optional
|
|
||||||
If the label configuration is nested under a specific key in the
|
|
||||||
file, specify the key here. Defaults to None.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
LabelConfig
|
|
||||||
The loaded and validated label configuration object.
|
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
FileNotFoundError
|
|
||||||
If the config file path does not exist.
|
|
||||||
pydantic.ValidationError
|
|
||||||
If the config file structure does not match the LabelConfig schema.
|
|
||||||
"""
|
|
||||||
return load_config(path, schema=LabelConfig, field=field)
|
|
||||||
|
|||||||
@ -41,8 +41,8 @@ def run_train(
|
|||||||
model_config: Optional[ModelConfig] = None,
|
model_config: Optional[ModelConfig] = None,
|
||||||
train_config: Optional[TrainingConfig] = None,
|
train_config: Optional[TrainingConfig] = None,
|
||||||
trainer: Trainer | None = None,
|
trainer: Trainer | None = None,
|
||||||
train_workers: int | None = None,
|
train_workers: int = 0,
|
||||||
val_workers: int | None = None,
|
val_workers: int = 0,
|
||||||
checkpoint_dir: Path | None = None,
|
checkpoint_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user