mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Create TrainPreprocessConfig
This commit is contained in:
parent
4b6acd5e6e
commit
1384c549f7
7
Makefile
7
Makefile
@ -97,12 +97,7 @@ example-preprocess:
|
|||||||
batdetect2 preprocess \
|
batdetect2 preprocess \
|
||||||
--base-dir . \
|
--base-dir . \
|
||||||
--dataset-field datasets.train \
|
--dataset-field datasets.train \
|
||||||
--preprocess-config config.yaml \
|
--config config.yaml \
|
||||||
--preprocess-config-field preprocessing \
|
|
||||||
--label-config config.yaml \
|
|
||||||
--label-config-field labels \
|
|
||||||
--target-config config.yaml \
|
|
||||||
--target-config-field targets \
|
|
||||||
config.yaml example_data/preprocessed
|
config.yaml example_data/preprocessed
|
||||||
|
|
||||||
example-train:
|
example-train:
|
||||||
|
@ -6,10 +6,11 @@ from loguru import logger
|
|||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
from batdetect2.train.preprocess import (
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
TrainPreprocessConfig,
|
||||||
from batdetect2.train import load_label_config, preprocess_annotations
|
load_train_preprocessing_config,
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
preprocess_dataset,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = ["preprocess"]
|
__all__ = ["preprocess"]
|
||||||
|
|
||||||
@ -44,16 +45,16 @@ __all__ = ["preprocess"]
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config",
|
"--config",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
help=(
|
help=(
|
||||||
"Path to the preprocessing configuration file. This file tells "
|
"Path to the configuration file. This file tells "
|
||||||
"the program how to prepare your audio data before training, such "
|
"the program how to prepare your audio data before training, such "
|
||||||
"as resampling or applying filters."
|
"as resampling or applying filters."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config-field",
|
"--config-field",
|
||||||
type=str,
|
type=str,
|
||||||
help=(
|
help=(
|
||||||
"If the preprocessing settings are inside a nested dictionary "
|
"If the preprocessing settings are inside a nested dictionary "
|
||||||
@ -62,41 +63,6 @@ __all__ = ["preprocess"]
|
|||||||
"top level, you don't need to specify this."
|
"top level, you don't need to specify this."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
@click.option(
|
|
||||||
"--label-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the label generation configuration file. This file "
|
|
||||||
"contains settings for how to create labels from your "
|
|
||||||
"annotations, which the model uses to learn."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--label-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the label generation settings are inside a nested dictionary "
|
|
||||||
"within the label configuration file, specify the key here. If "
|
|
||||||
"the settings are at the top level, leave this blank."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the training target configuration file. This file "
|
|
||||||
"specifies what sounds the model should learn to predict."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the target settings are inside a nested dictionary "
|
|
||||||
"within the target configuration file, specify the key here. "
|
|
||||||
"If the settings are at the top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
@click.option(
|
||||||
"--force",
|
"--force",
|
||||||
is_flag=True,
|
is_flag=True,
|
||||||
@ -120,15 +86,11 @@ __all__ = ["preprocess"]
|
|||||||
def preprocess(
|
def preprocess(
|
||||||
dataset_config: Path,
|
dataset_config: Path,
|
||||||
output: Path,
|
output: Path,
|
||||||
target_config: Optional[Path] = None,
|
|
||||||
base_dir: Optional[Path] = None,
|
base_dir: Optional[Path] = None,
|
||||||
preprocess_config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
label_config: Optional[Path] = None,
|
config_field: Optional[str] = None,
|
||||||
force: bool = False,
|
force: bool = False,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
target_config_field: Optional[str] = None,
|
|
||||||
preprocess_config_field: Optional[str] = None,
|
|
||||||
label_config_field: Optional[str] = None,
|
|
||||||
dataset_field: Optional[str] = None,
|
dataset_field: Optional[str] = None,
|
||||||
):
|
):
|
||||||
logger.info("Starting preprocessing.")
|
logger.info("Starting preprocessing.")
|
||||||
@ -139,31 +101,10 @@ def preprocess(
|
|||||||
base_dir = base_dir or Path.cwd()
|
base_dir = base_dir or Path.cwd()
|
||||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||||
|
|
||||||
preprocess = (
|
conf = (
|
||||||
load_preprocessing_config(
|
load_train_preprocessing_config(config, field=config_field)
|
||||||
preprocess_config,
|
if config is not None
|
||||||
field=preprocess_config_field,
|
else TrainPreprocessConfig()
|
||||||
)
|
|
||||||
if preprocess_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
target = (
|
|
||||||
load_target_config(
|
|
||||||
target_config,
|
|
||||||
field=target_config_field,
|
|
||||||
)
|
|
||||||
if target_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
label = (
|
|
||||||
load_label_config(
|
|
||||||
label_config,
|
|
||||||
field=label_config_field,
|
|
||||||
)
|
|
||||||
if label_config
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = load_dataset_from_config(
|
dataset = load_dataset_from_config(
|
||||||
@ -177,20 +118,10 @@ def preprocess(
|
|||||||
num_examples=len(dataset),
|
num_examples=len(dataset),
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(config=target)
|
preprocess_dataset(
|
||||||
preprocessor = build_preprocessor(config=preprocess)
|
|
||||||
labeller = build_clip_labeler(targets, config=label)
|
|
||||||
|
|
||||||
if not output.exists():
|
|
||||||
logger.debug("Creating directory {directory}", directory=output)
|
|
||||||
output.mkdir(parents=True)
|
|
||||||
|
|
||||||
logger.info("Will start preprocessing")
|
|
||||||
preprocess_annotations(
|
|
||||||
dataset,
|
dataset,
|
||||||
output_dir=output,
|
conf,
|
||||||
preprocessor=preprocessor,
|
output=output,
|
||||||
labeller=labeller,
|
force=force,
|
||||||
replace=force,
|
|
||||||
max_workers=num_workers,
|
max_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
@ -38,7 +38,7 @@ class BaseConfig(BaseModel):
|
|||||||
Pydantic model configuration dictionary. Set to forbid extra fields.
|
Pydantic model configuration dictionary. Set to forbid extra fields.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T", bound=BaseModel)
|
T = TypeVar("T", bound=BaseModel)
|
||||||
|
@ -47,19 +47,13 @@ class TrainerConfig(BaseConfig):
|
|||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(BaseConfig):
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
|
|
||||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||||
|
|
||||||
augmentations: AugmentationsConfig = Field(
|
augmentations: AugmentationsConfig = Field(
|
||||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||||
)
|
)
|
||||||
|
|
||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
|
|
||||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
||||||
|
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,22 +27,71 @@ from typing import Callable, Optional, Sequence
|
|||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
from batdetect2.data.datasets import Dataset
|
||||||
|
from batdetect2.preprocess import PreprocessingConfig, build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
from batdetect2.targets import TargetConfig, build_targets
|
||||||
|
from batdetect2.train.labels import LabelConfig, build_clip_labeler
|
||||||
from batdetect2.train.types import ClipLabeller
|
from batdetect2.train.types import ClipLabeller
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"preprocess_annotations",
|
"preprocess_annotations",
|
||||||
"preprocess_single_annotation",
|
"preprocess_single_annotation",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
|
"preprocess_dataset",
|
||||||
|
"TrainPreprocessConfig",
|
||||||
|
"load_train_preprocessing_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||||
"""Type alias for a function that generates an output filename."""
|
"""Type alias for a function that generates an output filename."""
|
||||||
|
|
||||||
|
|
||||||
|
class TrainPreprocessConfig(BaseConfig):
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_train_preprocessing_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> TrainPreprocessConfig:
|
||||||
|
return load_config(path=path, schema=TrainPreprocessConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_dataset(
|
||||||
|
dataset: Dataset,
|
||||||
|
config: TrainPreprocessConfig,
|
||||||
|
output: Path,
|
||||||
|
force: bool = False,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
targets = build_targets(config=config.targets)
|
||||||
|
preprocessor = build_preprocessor(config=config.preprocess)
|
||||||
|
labeller = build_clip_labeler(targets, config=config.labels)
|
||||||
|
|
||||||
|
if not output.exists():
|
||||||
|
logger.debug("Creating directory {directory}", directory=output)
|
||||||
|
output.mkdir(parents=True)
|
||||||
|
|
||||||
|
preprocess_annotations(
|
||||||
|
dataset,
|
||||||
|
output_dir=output,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
labeller=labeller,
|
||||||
|
replace=force,
|
||||||
|
max_workers=max_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
|
Loading…
Reference in New Issue
Block a user