Create TrainPreprocessConfig

This commit is contained in:
mbsantiago 2025-06-26 12:30:16 -06:00
parent 4b6acd5e6e
commit 1384c549f7
5 changed files with 69 additions and 100 deletions

View File

@ -97,12 +97,7 @@ example-preprocess:
batdetect2 preprocess \
--base-dir . \
--dataset-field datasets.train \
--preprocess-config config.yaml \
--preprocess-config-field preprocessing \
--label-config config.yaml \
--label-config-field labels \
--target-config config.yaml \
--target-config-field targets \
--config config.yaml \
config.yaml example_data/preprocessed
example-train:

View File

@ -6,10 +6,11 @@ from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import load_label_config, preprocess_annotations
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train.preprocess import (
TrainPreprocessConfig,
load_train_preprocessing_config,
preprocess_dataset,
)
__all__ = ["preprocess"]
@ -44,16 +45,16 @@ __all__ = ["preprocess"]
),
)
@click.option(
"--preprocess-config",
"--config",
type=click.Path(exists=True),
help=(
"Path to the preprocessing configuration file. This file tells "
"Path to the configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
)
@click.option(
"--preprocess-config-field",
"--config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
@ -62,41 +63,6 @@ __all__ = ["preprocess"]
"top level, you don't need to specify this."
),
)
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
)
@click.option(
"--force",
is_flag=True,
@ -120,15 +86,11 @@ __all__ = ["preprocess"]
def preprocess(
dataset_config: Path,
output: Path,
target_config: Optional[Path] = None,
base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None,
label_config: Optional[Path] = None,
config: Optional[Path] = None,
config_field: Optional[str] = None,
force: bool = False,
num_workers: Optional[int] = None,
target_config_field: Optional[str] = None,
preprocess_config_field: Optional[str] = None,
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None,
):
logger.info("Starting preprocessing.")
@ -139,31 +101,10 @@ def preprocess(
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = (
load_preprocessing_config(
preprocess_config,
field=preprocess_config_field,
)
if preprocess_config
else None
)
target = (
load_target_config(
target_config,
field=target_config_field,
)
if target_config
else None
)
label = (
load_label_config(
label_config,
field=label_config_field,
)
if label_config
else None
conf = (
load_train_preprocessing_config(config, field=config_field)
if config is not None
else TrainPreprocessConfig()
)
dataset = load_dataset_from_config(
@ -177,20 +118,10 @@ def preprocess(
num_examples=len(dataset),
)
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
preprocess_dataset(
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
conf,
output=output,
force=force,
max_workers=num_workers,
)

View File

@ -38,7 +38,7 @@ class BaseConfig(BaseModel):
Pydantic model configuration dictionary. Set to forbid extra fields.
"""
model_config = ConfigDict(extra="forbid")
model_config = ConfigDict(extra="allow")
T = TypeVar("T", bound=BaseModel)

View File

@ -47,19 +47,13 @@ class TrainerConfig(BaseConfig):
class TrainingConfig(BaseConfig):
batch_size: int = 8
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
augmentations: AugmentationsConfig = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
)
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)

View File

@ -27,22 +27,71 @@ from typing import Callable, Optional, Sequence
import xarray as xr
from loguru import logger
from pydantic import Field
from soundevent import data
from tqdm.auto import tqdm
from batdetect2.configs import BaseConfig, load_config
from batdetect2.data.datasets import Dataset
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
from batdetect2.targets import TargetConfig, build_targets
from batdetect2.train.labels import LabelConfig, build_clip_labeler
from batdetect2.train.types import ClipLabeller
__all__ = [
"preprocess_annotations",
"preprocess_single_annotation",
"generate_train_example",
"preprocess_dataset",
"TrainPreprocessConfig",
"load_train_preprocessing_config",
]
FilenameFn = Callable[[data.ClipAnnotation], str]
"""Type alias for a function that generates an output filename."""
class TrainPreprocessConfig(BaseConfig):
preprocess: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
targets: TargetConfig = Field(default_factory=TargetConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
def load_train_preprocessing_config(
path: data.PathLike,
field: Optional[str] = None,
) -> TrainPreprocessConfig:
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
def preprocess_dataset(
dataset: Dataset,
config: TrainPreprocessConfig,
output: Path,
force: bool = False,
max_workers: Optional[int] = None,
) -> None:
targets = build_targets(config=config.targets)
preprocessor = build_preprocessor(config=config.preprocess)
labeller = build_clip_labeler(targets, config=config.labels)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
preprocess_annotations(
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
max_workers=max_workers,
)
def generate_train_example(
clip_annotation: data.ClipAnnotation,
preprocessor: PreprocessorProtocol,