Create TrainPreprocessConfig

This commit is contained in:
mbsantiago 2025-06-26 12:30:16 -06:00
parent 4b6acd5e6e
commit 1384c549f7
5 changed files with 69 additions and 100 deletions

View File

@ -97,12 +97,7 @@ example-preprocess:
batdetect2 preprocess \ batdetect2 preprocess \
--base-dir . \ --base-dir . \
--dataset-field datasets.train \ --dataset-field datasets.train \
--preprocess-config config.yaml \ --config config.yaml \
--preprocess-config-field preprocessing \
--label-config config.yaml \
--label-config-field labels \
--target-config config.yaml \
--target-config-field targets \
config.yaml example_data/preprocessed config.yaml example_data/preprocessed
example-train: example-train:

View File

@ -6,10 +6,11 @@ from loguru import logger
from batdetect2.cli.base import cli from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.train.preprocess import (
from batdetect2.targets import build_targets, load_target_config TrainPreprocessConfig,
from batdetect2.train import load_label_config, preprocess_annotations load_train_preprocessing_config,
from batdetect2.train.labels import build_clip_labeler preprocess_dataset,
)
__all__ = ["preprocess"] __all__ = ["preprocess"]
@ -44,16 +45,16 @@ __all__ = ["preprocess"]
), ),
) )
@click.option( @click.option(
"--preprocess-config", "--config",
type=click.Path(exists=True), type=click.Path(exists=True),
help=( help=(
"Path to the preprocessing configuration file. This file tells " "Path to the configuration file. This file tells "
"the program how to prepare your audio data before training, such " "the program how to prepare your audio data before training, such "
"as resampling or applying filters." "as resampling or applying filters."
), ),
) )
@click.option( @click.option(
"--preprocess-config-field", "--config-field",
type=str, type=str,
help=( help=(
"If the preprocessing settings are inside a nested dictionary " "If the preprocessing settings are inside a nested dictionary "
@ -62,41 +63,6 @@ __all__ = ["preprocess"]
"top level, you don't need to specify this." "top level, you don't need to specify this."
), ),
) )
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
)
@click.option( @click.option(
"--force", "--force",
is_flag=True, is_flag=True,
@ -120,15 +86,11 @@ __all__ = ["preprocess"]
def preprocess( def preprocess(
dataset_config: Path, dataset_config: Path,
output: Path, output: Path,
target_config: Optional[Path] = None,
base_dir: Optional[Path] = None, base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None, config: Optional[Path] = None,
label_config: Optional[Path] = None, config_field: Optional[str] = None,
force: bool = False, force: bool = False,
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
target_config_field: Optional[str] = None,
preprocess_config_field: Optional[str] = None,
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None, dataset_field: Optional[str] = None,
): ):
logger.info("Starting preprocessing.") logger.info("Starting preprocessing.")
@ -139,31 +101,10 @@ def preprocess(
base_dir = base_dir or Path.cwd() base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir) logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = ( conf = (
load_preprocessing_config( load_train_preprocessing_config(config, field=config_field)
preprocess_config, if config is not None
field=preprocess_config_field, else TrainPreprocessConfig()
)
if preprocess_config
else None
)
target = (
load_target_config(
target_config,
field=target_config_field,
)
if target_config
else None
)
label = (
load_label_config(
label_config,
field=label_config_field,
)
if label_config
else None
) )
dataset = load_dataset_from_config( dataset = load_dataset_from_config(
@ -177,20 +118,10 @@ def preprocess(
num_examples=len(dataset), num_examples=len(dataset),
) )
targets = build_targets(config=target) preprocess_dataset(
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
dataset, dataset,
output_dir=output, conf,
preprocessor=preprocessor, output=output,
labeller=labeller, force=force,
replace=force,
max_workers=num_workers, max_workers=num_workers,
) )

View File

@ -38,7 +38,7 @@ class BaseConfig(BaseModel):
Pydantic model configuration dictionary. Set to forbid extra fields. Pydantic model configuration dictionary. Set to forbid extra fields.
""" """
model_config = ConfigDict(extra="forbid") model_config = ConfigDict(extra="allow")
T = TypeVar("T", bound=BaseModel) T = TypeVar("T", bound=BaseModel)

View File

@ -47,19 +47,13 @@ class TrainerConfig(BaseConfig):
class TrainingConfig(BaseConfig): class TrainingConfig(BaseConfig):
batch_size: int = 8 batch_size: int = 8
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field( augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
) )
cliping: ClipingConfig = Field(default_factory=ClipingConfig) cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: TrainerConfig = Field(default_factory=TrainerConfig) trainer: TrainerConfig = Field(default_factory=TrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig) logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)

View File

@ -27,22 +27,71 @@ from typing import Callable, Optional, Sequence
import xarray as xr import xarray as xr
from loguru import logger from loguru import logger
from pydantic import Field
from soundevent import data from soundevent import data
from tqdm.auto import tqdm from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.train.types import ClipLabeller from batdetect2.train.types import ClipLabeller
__all__ = [ __all__ = [
"preprocess_annotations", "preprocess_annotations",
"preprocess_single_annotation", "preprocess_single_annotation",
"generate_train_example", "generate_train_example",
"preprocess_dataset",
"TrainPreprocessConfig",
"load_train_preprocessing_config",
] ]
FilenameFn = Callable[[data.ClipAnnotation], str] FilenameFn = Callable[[data.ClipAnnotation], str]
"""Type alias for a function that generates an output filename.""" """Type alias for a function that generates an output filename."""
class TrainPreprocessConfig(BaseConfig):
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def load_train_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TrainPreprocessConfig:
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
def preprocess_dataset(
dataset: Dataset,
config: TrainPreprocessConfig,
output: Path,
force: bool = False,
max_workers: Optional[int] = None,
) -> None:
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
labeller = build_clip_labeler(targets, config=config.labels)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
preprocess_annotations(
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
max_workers=max_workers,
)
def generate_train_example( def generate_train_example(
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,