From 1384c549f7b48d528b6146a58a0f361c2cac2af1 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 12:30:16 -0600 Subject: [PATCH] Create TrainPreprocessConfig --- Makefile | 7 +- src/batdetect2/cli/preprocess.py | 105 +++++------------------------ src/batdetect2/configs.py | 2 +- src/batdetect2/train/config.py | 6 -- src/batdetect2/train/preprocess.py | 49 ++++++++++++++ 5 files changed, 69 insertions(+), 100 deletions(-) diff --git a/Makefile b/Makefile index 50b715e..539d992 100644 --- a/Makefile +++ b/Makefile @@ -97,12 +97,7 @@ example-preprocess: batdetect2 preprocess \ --base-dir . \ --dataset-field datasets.train \ - --preprocess-config config.yaml \ - --preprocess-config-field preprocessing \ - --label-config config.yaml \ - --label-config-field labels \ - --target-config config.yaml \ - --target-config-field targets \ + --config config.yaml \ config.yaml example_data/preprocessed example-train: diff --git a/src/batdetect2/cli/preprocess.py b/src/batdetect2/cli/preprocess.py index 553bcca..15a7681 100644 --- a/src/batdetect2/cli/preprocess.py +++ b/src/batdetect2/cli/preprocess.py @@ -6,10 +6,11 @@ from loguru import logger from batdetect2.cli.base import cli from batdetect2.data import load_dataset_from_config -from batdetect2.preprocess import build_preprocessor, load_preprocessing_config -from batdetect2.targets import build_targets, load_target_config -from batdetect2.train import load_label_config, preprocess_annotations -from batdetect2.train.labels import build_clip_labeler +from batdetect2.train.preprocess import ( + TrainPreprocessConfig, + load_train_preprocessing_config, + preprocess_dataset, +) __all__ = ["preprocess"] @@ -44,16 +45,16 @@ __all__ = ["preprocess"] ), ) @click.option( - "--preprocess-config", + "--config", type=click.Path(exists=True), help=( - "Path to the preprocessing configuration file. This file tells " + "Path to the configuration file. This file tells " "the program how to prepare your audio data before training, such " "as resampling or applying filters." ), ) @click.option( - "--preprocess-config-field", + "--config-field", type=str, help=( "If the preprocessing settings are inside a nested dictionary " @@ -62,41 +63,6 @@ __all__ = ["preprocess"] "top level, you don't need to specify this." ), ) -@click.option( - "--label-config", - type=click.Path(exists=True), - help=( - "Path to the label generation configuration file. This file " - "contains settings for how to create labels from your " - "annotations, which the model uses to learn." - ), -) -@click.option( - "--label-config-field", - type=str, - help=( - "If the label generation settings are inside a nested dictionary " - "within the label configuration file, specify the key here. If " - "the settings are at the top level, leave this blank." - ), -) -@click.option( - "--target-config", - type=click.Path(exists=True), - help=( - "Path to the training target configuration file. This file " - "specifies what sounds the model should learn to predict." - ), -) -@click.option( - "--target-config-field", - type=str, - help=( - "If the target settings are inside a nested dictionary " - "within the target configuration file, specify the key here. " - "If the settings are at the top level, you don't need to specify this." - ), -) @click.option( "--force", is_flag=True, @@ -120,15 +86,11 @@ __all__ = ["preprocess"] def preprocess( dataset_config: Path, output: Path, - target_config: Optional[Path] = None, base_dir: Optional[Path] = None, - preprocess_config: Optional[Path] = None, - label_config: Optional[Path] = None, + config: Optional[Path] = None, + config_field: Optional[str] = None, force: bool = False, num_workers: Optional[int] = None, - target_config_field: Optional[str] = None, - preprocess_config_field: Optional[str] = None, - label_config_field: Optional[str] = None, dataset_field: Optional[str] = None, ): logger.info("Starting preprocessing.") @@ -139,31 +101,10 @@ def preprocess( base_dir = base_dir or Path.cwd() logger.debug("Current working directory: {base_dir}", base_dir=base_dir) - preprocess = ( - load_preprocessing_config( - preprocess_config, - field=preprocess_config_field, - ) - if preprocess_config - else None - ) - - target = ( - load_target_config( - target_config, - field=target_config_field, - ) - if target_config - else None - ) - - label = ( - load_label_config( - label_config, - field=label_config_field, - ) - if label_config - else None + conf = ( + load_train_preprocessing_config(config, field=config_field) + if config is not None + else TrainPreprocessConfig() ) dataset = load_dataset_from_config( @@ -177,20 +118,10 @@ def preprocess( num_examples=len(dataset), ) - targets = build_targets(config=target) - preprocessor = build_preprocessor(config=preprocess) - labeller = build_clip_labeler(targets, config=label) - - if not output.exists(): - logger.debug("Creating directory {directory}", directory=output) - output.mkdir(parents=True) - - logger.info("Will start preprocessing") - preprocess_annotations( + preprocess_dataset( dataset, - output_dir=output, - preprocessor=preprocessor, - labeller=labeller, - replace=force, + conf, + output=output, + force=force, max_workers=num_workers, ) diff --git a/src/batdetect2/configs.py b/src/batdetect2/configs.py index f252dd2..43764f5 100644 --- a/src/batdetect2/configs.py +++ b/src/batdetect2/configs.py @@ -38,7 +38,7 @@ class BaseConfig(BaseModel): Pydantic model configuration dictionary. Set to forbid extra fields. """ - model_config = ConfigDict(extra="forbid") + model_config = ConfigDict(extra="allow") T = TypeVar("T", bound=BaseModel) diff --git a/src/batdetect2/train/config.py b/src/batdetect2/train/config.py index d854d2b..2ca4338 100644 --- a/src/batdetect2/train/config.py +++ b/src/batdetect2/train/config.py @@ -47,19 +47,13 @@ class TrainerConfig(BaseConfig): class TrainingConfig(BaseConfig): batch_size: int = 8 - loss: LossConfig = Field(default_factory=LossConfig) - optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) - augmentations: AugmentationsConfig = Field( default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG ) - cliping: ClipingConfig = Field(default_factory=ClipingConfig) - trainer: TrainerConfig = Field(default_factory=TrainerConfig) - logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) diff --git a/src/batdetect2/train/preprocess.py b/src/batdetect2/train/preprocess.py index 2abed50..11c9972 100644 --- a/src/batdetect2/train/preprocess.py +++ b/src/batdetect2/train/preprocess.py @@ -27,22 +27,71 @@ from typing import Callable, Optional, Sequence import xarray as xr from loguru import logger +from pydantic import Field from soundevent import data from tqdm.auto import tqdm +from batdetect2.configs import BaseConfig, load_config +from batdetect2.data.datasets import Dataset +from batdetect2.preprocess import PreprocessingConfig, build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol +from batdetect2.targets import TargetConfig, build_targets +from batdetect2.train.labels import LabelConfig, build_clip_labeler from batdetect2.train.types import ClipLabeller __all__ = [ "preprocess_annotations", "preprocess_single_annotation", "generate_train_example", + "preprocess_dataset", + "TrainPreprocessConfig", + "load_train_preprocessing_config", ] FilenameFn = Callable[[data.ClipAnnotation], str] """Type alias for a function that generates an output filename.""" +class TrainPreprocessConfig(BaseConfig): + preprocess: PreprocessingConfig = Field( + default_factory=PreprocessingConfig + ) + targets: TargetConfig = Field(default_factory=TargetConfig) + labels: LabelConfig = Field(default_factory=LabelConfig) + + +def load_train_preprocessing_config( + path: data.PathLike, + field: Optional[str] = None, +) -> TrainPreprocessConfig: + return load_config(path=path, schema=TrainPreprocessConfig, field=field) + + +def preprocess_dataset( + dataset: Dataset, + config: TrainPreprocessConfig, + output: Path, + force: bool = False, + max_workers: Optional[int] = None, +) -> None: + targets = build_targets(config=config.targets) + preprocessor = build_preprocessor(config=config.preprocess) + labeller = build_clip_labeler(targets, config=config.labels) + + if not output.exists(): + logger.debug("Creating directory {directory}", directory=output) + output.mkdir(parents=True) + + preprocess_annotations( + dataset, + output_dir=output, + preprocessor=preprocessor, + labeller=labeller, + replace=force, + max_workers=max_workers, + ) + + def generate_train_example( clip_annotation: data.ClipAnnotation, preprocessor: PreprocessorProtocol,