create train cli command

This commit is contained in:
mbsantiago 2025-04-30 23:27:38 +01:00
parent 2913fa59a4
commit 3b5623ddca
4 changed files with 370 additions and 156 deletions

View File

@ -0,0 +1,196 @@
from pathlib import Path
from typing import Optional
import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import load_label_config, preprocess_annotations
from batdetect2.train.labels import build_clip_labeler
__all__ = ["preprocess"]
@cli.command()
@click.argument(
"dataset_config",
type=click.Path(exists=True),
)
@click.argument(
"output",
type=click.Path(),
)
@click.option(
"--dataset-field",
type=str,
help=(
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
)
@click.option(
"--preprocess-config",
type=click.Path(exists=True),
help=(
"Path to the preprocessing configuration file. This file tells "
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
)
@click.option(
"--preprocess-config-field",
type=str,
help=(
"If the preprocessing settings are inside a nested dictionary "
"within the preprocessing configuration file, specify the key "
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
)
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
)
@click.option(
"--target-config",
type=click.Path(exists=True),
help=(
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
)
@click.option(
"--target-config-field",
type=str,
help=(
"If the target settings are inside a nested dictionary "
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
)
@click.option(
"--force",
is_flag=True,
help=(
"If a preprocessed file already exists, this option tells the "
"program to overwrite it with the new preprocessed data. Use "
"this if you want to re-do the preprocessing even if the files "
"already exist."
),
)
@click.option(
"--num-workers",
type=int,
help=(
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
)
def preprocess(
dataset_config: Path,
output: Path,
target_config: Optional[Path] = None,
base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None,
label_config: Optional[Path] = None,
force: bool = False,
num_workers: Optional[int] = None,
target_config_field: Optional[str] = None,
preprocess_config_field: Optional[str] = None,
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None,
):
logger.info("Starting preprocessing.")
output = Path(output)
logger.info("Will save outputs to {output}", output=output)
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = (
load_preprocessing_config(
preprocess_config,
field=preprocess_config_field,
)
if preprocess_config
else None
)
target = (
load_target_config(
target_config,
field=target_config_field,
)
if target_config
else None
)
label = (
load_label_config(
label_config,
field=label_config_field,
)
if label_config
else None
)
dataset = load_dataset_from_config(
dataset_config,
field=dataset_field,
base_dir=base_dir,
)
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
dataset,
output_dir=output,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
max_workers=num_workers,
)

View File

@ -5,47 +5,48 @@ import click
from loguru import logger
from batdetect2.cli.base import cli
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate.metrics import (
ClassificationAccuracy,
ClassificationMeanAveragePrecision,
DetectionAveragePrecision,
)
from batdetect2.models import build_model
from batdetect2.models.backbones import load_backbone_config
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
from batdetect2.targets import build_targets, load_target_config
from batdetect2.train import load_label_config, preprocess_annotations
from batdetect2.train.labels import build_clip_labeler
from batdetect2.train import train
from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.config import TrainingConfig, load_train_config
from batdetect2.train.dataset import list_preprocessed_files
__all__ = ["train"]
__all__ = [
"train_command",
]
DEFAULT_CONFIG_FILE = Path("config.yaml")
@cli.group()
def train(): ...
@train.command()
@click.argument(
"dataset_config",
@cli.command(name="train")
@click.option(
"--train-examples",
type=click.Path(exists=True),
required=True,
)
@click.option("--val-examples", type=click.Path(exists=True))
@click.option(
"--model-path",
type=click.Path(exists=True),
)
@click.argument(
"output",
type=click.Path(),
@click.option(
"--train-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--dataset-field",
"--train-field",
type=str,
help=(
"Specifies the key to access the dataset information within the "
"dataset configuration file, if the information is nested inside a "
"dictionary. If the dataset information is at the top level of the "
"config file, you don't need to specify this."
),
)
@click.option(
"--base-dir",
type=click.Path(exists=True),
help=(
"The main directory where your audio recordings and annotation "
"files are stored. This helps the program find your data, "
"especially if the paths in your dataset configuration file "
"are relative."
),
default="train",
)
@click.option(
"--preprocess-config",
@ -55,6 +56,7 @@ def train(): ...
"the program how to prepare your audio data before training, such "
"as resampling or applying filters."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--preprocess-config-field",
@ -65,24 +67,7 @@ def train(): ...
"here to access them. If the preprocessing settings are at the "
"top level, you don't need to specify this."
),
)
@click.option(
"--label-config",
type=click.Path(exists=True),
help=(
"Path to the label generation configuration file. This file "
"contains settings for how to create labels from your "
"annotations, which the model uses to learn."
),
)
@click.option(
"--label-config-field",
type=str,
help=(
"If the label generation settings are inside a nested dictionary "
"within the label configuration file, specify the key here. If "
"the settings are at the top level, leave this blank."
),
default="preprocess",
)
@click.option(
"--target-config",
@ -91,6 +76,7 @@ def train(): ...
"Path to the training target configuration file. This file "
"specifies what sounds the model should learn to predict."
),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--target-config-field",
@ -100,101 +86,130 @@ def train(): ...
"within the target configuration file, specify the key here. "
"If the settings are at the top level, you don't need to specify this."
),
default="targets",
)
@click.option(
"--force",
is_flag=True,
help=(
"If a preprocessed file already exists, this option tells the "
"program to overwrite it with the new preprocessed data. Use "
"this if you want to re-do the preprocessing even if the files "
"already exist."
),
"--postprocess-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--num-workers",
type=int,
help=(
"The maximum number of computer cores to use when processing "
"your audio data. Using more cores can speed up the preprocessing, "
"but don't use more than your computer has available. By default, "
"the program will use all available cores."
),
"--postprocess-config-field",
type=str,
default="postprocess",
)
def preprocess(
dataset_config: Path,
output: Path,
target_config: Optional[Path] = None,
base_dir: Optional[Path] = None,
preprocess_config: Optional[Path] = None,
label_config: Optional[Path] = None,
force: bool = False,
num_workers: Optional[int] = None,
target_config_field: Optional[str] = None,
preprocess_config_field: Optional[str] = None,
label_config_field: Optional[str] = None,
dataset_field: Optional[str] = None,
@click.option(
"--model-config",
type=click.Path(exists=True),
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--model-config-field",
type=str,
default="model",
)
def train_command(
train_examples: Path,
val_examples: Optional[Path] = None,
model_path: Optional[Path] = None,
train_config: Path = DEFAULT_CONFIG_FILE,
train_config_field: str = "train",
preprocess_config: Path = DEFAULT_CONFIG_FILE,
preprocess_config_field: str = "preprocess",
target_config: Path = DEFAULT_CONFIG_FILE,
target_config_field: str = "targets",
postprocess_config: Path = DEFAULT_CONFIG_FILE,
postprocess_config_field: str = "postprocess",
model_config: Path = DEFAULT_CONFIG_FILE,
model_config_field: str = "model",
):
logger.info("Starting preprocessing.")
logger.info("Starting training!")
output = Path(output)
logger.info("Will save outputs to {output}", output=output)
base_dir = base_dir or Path.cwd()
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
preprocess = (
load_preprocessing_config(
preprocess_config,
field=preprocess_config_field,
)
if preprocess_config
else None
)
target = (
load_target_config(
target_config,
try:
target_config_loaded = load_target_config(
path=target_config,
field=target_config_field,
)
if target_config
else None
)
label = (
load_label_config(
label_config,
field=label_config_field,
targets = build_targets(config=target_config_loaded)
logger.debug(
"Loaded targets info from config file {path}", path=target_config
)
if label_config
else None
except IOError:
logger.debug(
"Could not load target info from config file, using default"
)
targets = build_targets()
try:
preprocess_config_loaded = load_preprocessing_config(
path=preprocess_config,
field=preprocess_config_field,
)
preprocessor = build_preprocessor(preprocess_config_loaded)
logger.debug(
"Loaded preprocessor from config file {path}", path=target_config
)
except IOError:
logger.debug(
"Could not load preprocessor from config file, using default"
)
preprocessor = build_preprocessor()
try:
model_config_loaded = load_backbone_config(
path=model_config, field=model_config_field
)
model = build_model(
num_classes=len(targets.class_names),
config=model_config_loaded,
)
except IOError:
model = build_model(num_classes=len(targets.class_names))
try:
postprocess_config_loaded = load_postprocess_config(
path=postprocess_config,
field=postprocess_config_field,
)
postprocessor = build_postprocessor(
targets=targets,
config=postprocess_config_loaded,
)
except IOError:
postprocessor = build_postprocessor(targets=targets)
try:
train_config_loaded = load_train_config(
path=train_config, field=train_config_field
)
except IOError:
train_config_loaded = TrainingConfig()
train_files = list_preprocessed_files(train_examples)
val_files = (
None if val_examples is None else list_preprocessed_files(val_examples)
)
dataset = load_dataset_from_config(
dataset_config,
field=dataset_field,
base_dir=base_dir,
)
logger.info(
"Loaded {num_examples} annotated clips from the configured dataset",
num_examples=len(dataset),
)
targets = build_targets(config=target)
preprocessor = build_preprocessor(config=preprocess)
labeller = build_clip_labeler(targets, config=label)
if not output.exists():
logger.debug("Creating directory {directory}", directory=output)
output.mkdir(parents=True)
logger.info("Will start preprocessing")
preprocess_annotations(
dataset,
output_dir=output,
return train(
detector=model,
train_examples=train_files, # type: ignore
val_examples=val_files, # type: ignore
model_path=model_path,
preprocessor=preprocessor,
labeller=labeller,
replace=force,
max_workers=num_workers,
postprocessor=postprocessor,
targets=targets,
config=train_config_loaded,
callbacks=[
ValidationMetrics(
metrics=[
DetectionAveragePrecision(),
ClassificationMeanAveragePrecision(
class_names=targets.class_names,
),
ClassificationAccuracy(class_names=targets.class_names),
]
)
],
)

View File

@ -89,7 +89,7 @@ class LabeledDataset(Dataset):
def list_preprocessed_files(
directory: data.PathLike, extension: str = ".nc"
) -> Sequence[Path]:
) -> List[Path]:
return list(Path(directory).glob(f"*{extension}"))

View File

@ -37,41 +37,44 @@ def train(
val_examples: Optional[List[data.PathLike]] = None,
config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None,
model_path: Optional[data.PathLike] = None,
**trainer_kwargs,
) -> None:
config = config or TrainingConfig()
if model_path is None:
if preprocessor is None:
preprocessor = build_preprocessor()
if preprocessor is None:
preprocessor = build_preprocessor()
if targets is None:
targets = build_targets()
if targets is None:
targets = build_targets()
if postprocessor is None:
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
)
if postprocessor is None:
postprocessor = build_postprocessor(
targets,
min_freq=preprocessor.min_freq,
max_freq=preprocessor.max_freq,
loss = build_loss(config.loss)
module = TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=config.optimizer.learning_rate,
t_max=config.optimizer.t_max,
)
else:
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
train_dataset = build_train_dataset(
train_examples,
preprocessor,
preprocessor=module.preprocessor,
config=config,
)
loss = build_loss(config.loss)
module = TrainingModule(
detector=detector,
loss=loss,
targets=targets,
preprocessor=preprocessor,
postprocessor=postprocessor,
learning_rate=config.optimizer.learning_rate,
t_max=config.optimizer.t_max,
)
trainer = Trainer(
**config.trainer.model_dump(exclude_none=True),
callbacks=callbacks,