From 3b5623ddcadc7b11687884d3325298b7fb9ed005 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 30 Apr 2025 23:27:38 +0100 Subject: [PATCH] create train cli command --- batdetect2/cli/preprocess.py | 196 ++++++++++++++++++++++++ batdetect2/cli/train.py | 281 ++++++++++++++++++----------------- batdetect2/train/dataset.py | 2 +- batdetect2/train/train.py | 47 +++--- 4 files changed, 370 insertions(+), 156 deletions(-) create mode 100644 batdetect2/cli/preprocess.py diff --git a/batdetect2/cli/preprocess.py b/batdetect2/cli/preprocess.py new file mode 100644 index 0000000..553bcca --- /dev/null +++ b/batdetect2/cli/preprocess.py @@ -0,0 +1,196 @@ +from pathlib import Path +from typing import Optional + +import click +from loguru import logger + +from batdetect2.cli.base import cli +from batdetect2.data import load_dataset_from_config +from batdetect2.preprocess import build_preprocessor, load_preprocessing_config +from batdetect2.targets import build_targets, load_target_config +from batdetect2.train import load_label_config, preprocess_annotations +from batdetect2.train.labels import build_clip_labeler + +__all__ = ["preprocess"] + + +@cli.command() +@click.argument( + "dataset_config", + type=click.Path(exists=True), +) +@click.argument( + "output", + type=click.Path(), +) +@click.option( + "--dataset-field", + type=str, + help=( + "Specifies the key to access the dataset information within the " + "dataset configuration file, if the information is nested inside a " + "dictionary. If the dataset information is at the top level of the " + "config file, you don't need to specify this." + ), +) +@click.option( + "--base-dir", + type=click.Path(exists=True), + help=( + "The main directory where your audio recordings and annotation " + "files are stored. This helps the program find your data, " + "especially if the paths in your dataset configuration file " + "are relative." + ), +) +@click.option( + "--preprocess-config", + type=click.Path(exists=True), + help=( + "Path to the preprocessing configuration file. This file tells " + "the program how to prepare your audio data before training, such " + "as resampling or applying filters." + ), +) +@click.option( + "--preprocess-config-field", + type=str, + help=( + "If the preprocessing settings are inside a nested dictionary " + "within the preprocessing configuration file, specify the key " + "here to access them. If the preprocessing settings are at the " + "top level, you don't need to specify this." + ), +) +@click.option( + "--label-config", + type=click.Path(exists=True), + help=( + "Path to the label generation configuration file. This file " + "contains settings for how to create labels from your " + "annotations, which the model uses to learn." + ), +) +@click.option( + "--label-config-field", + type=str, + help=( + "If the label generation settings are inside a nested dictionary " + "within the label configuration file, specify the key here. If " + "the settings are at the top level, leave this blank." + ), +) +@click.option( + "--target-config", + type=click.Path(exists=True), + help=( + "Path to the training target configuration file. This file " + "specifies what sounds the model should learn to predict." + ), +) +@click.option( + "--target-config-field", + type=str, + help=( + "If the target settings are inside a nested dictionary " + "within the target configuration file, specify the key here. " + "If the settings are at the top level, you don't need to specify this." + ), +) +@click.option( + "--force", + is_flag=True, + help=( + "If a preprocessed file already exists, this option tells the " + "program to overwrite it with the new preprocessed data. Use " + "this if you want to re-do the preprocessing even if the files " + "already exist." + ), +) +@click.option( + "--num-workers", + type=int, + help=( + "The maximum number of computer cores to use when processing " + "your audio data. Using more cores can speed up the preprocessing, " + "but don't use more than your computer has available. By default, " + "the program will use all available cores." + ), +) +def preprocess( + dataset_config: Path, + output: Path, + target_config: Optional[Path] = None, + base_dir: Optional[Path] = None, + preprocess_config: Optional[Path] = None, + label_config: Optional[Path] = None, + force: bool = False, + num_workers: Optional[int] = None, + target_config_field: Optional[str] = None, + preprocess_config_field: Optional[str] = None, + label_config_field: Optional[str] = None, + dataset_field: Optional[str] = None, +): + logger.info("Starting preprocessing.") + + output = Path(output) + logger.info("Will save outputs to {output}", output=output) + + base_dir = base_dir or Path.cwd() + logger.debug("Current working directory: {base_dir}", base_dir=base_dir) + + preprocess = ( + load_preprocessing_config( + preprocess_config, + field=preprocess_config_field, + ) + if preprocess_config + else None + ) + + target = ( + load_target_config( + target_config, + field=target_config_field, + ) + if target_config + else None + ) + + label = ( + load_label_config( + label_config, + field=label_config_field, + ) + if label_config + else None + ) + + dataset = load_dataset_from_config( + dataset_config, + field=dataset_field, + base_dir=base_dir, + ) + + logger.info( + "Loaded {num_examples} annotated clips from the configured dataset", + num_examples=len(dataset), + ) + + targets = build_targets(config=target) + preprocessor = build_preprocessor(config=preprocess) + labeller = build_clip_labeler(targets, config=label) + + if not output.exists(): + logger.debug("Creating directory {directory}", directory=output) + output.mkdir(parents=True) + + logger.info("Will start preprocessing") + preprocess_annotations( + dataset, + output_dir=output, + preprocessor=preprocessor, + labeller=labeller, + replace=force, + max_workers=num_workers, + ) diff --git a/batdetect2/cli/train.py b/batdetect2/cli/train.py index 2c7ebd3..cf6954a 100644 --- a/batdetect2/cli/train.py +++ b/batdetect2/cli/train.py @@ -5,47 +5,48 @@ import click from loguru import logger from batdetect2.cli.base import cli -from batdetect2.data import load_dataset_from_config +from batdetect2.evaluate.metrics import ( + ClassificationAccuracy, + ClassificationMeanAveragePrecision, + DetectionAveragePrecision, +) +from batdetect2.models import build_model +from batdetect2.models.backbones import load_backbone_config +from batdetect2.postprocess import build_postprocessor, load_postprocess_config from batdetect2.preprocess import build_preprocessor, load_preprocessing_config from batdetect2.targets import build_targets, load_target_config -from batdetect2.train import load_label_config, preprocess_annotations -from batdetect2.train.labels import build_clip_labeler +from batdetect2.train import train +from batdetect2.train.callbacks import ValidationMetrics +from batdetect2.train.config import TrainingConfig, load_train_config +from batdetect2.train.dataset import list_preprocessed_files -__all__ = ["train"] +__all__ = [ + "train_command", +] + +DEFAULT_CONFIG_FILE = Path("config.yaml") -@cli.group() -def train(): ... - - -@train.command() -@click.argument( - "dataset_config", +@cli.command(name="train") +@click.option( + "--train-examples", + type=click.Path(exists=True), + required=True, +) +@click.option("--val-examples", type=click.Path(exists=True)) +@click.option( + "--model-path", type=click.Path(exists=True), ) -@click.argument( - "output", - type=click.Path(), +@click.option( + "--train-config", + type=click.Path(exists=True), + default=DEFAULT_CONFIG_FILE, ) @click.option( - "--dataset-field", + "--train-field", type=str, - help=( - "Specifies the key to access the dataset information within the " - "dataset configuration file, if the information is nested inside a " - "dictionary. If the dataset information is at the top level of the " - "config file, you don't need to specify this." - ), -) -@click.option( - "--base-dir", - type=click.Path(exists=True), - help=( - "The main directory where your audio recordings and annotation " - "files are stored. This helps the program find your data, " - "especially if the paths in your dataset configuration file " - "are relative." - ), + default="train", ) @click.option( "--preprocess-config", @@ -55,6 +56,7 @@ def train(): ... "the program how to prepare your audio data before training, such " "as resampling or applying filters." ), + default=DEFAULT_CONFIG_FILE, ) @click.option( "--preprocess-config-field", @@ -65,24 +67,7 @@ def train(): ... "here to access them. If the preprocessing settings are at the " "top level, you don't need to specify this." ), -) -@click.option( - "--label-config", - type=click.Path(exists=True), - help=( - "Path to the label generation configuration file. This file " - "contains settings for how to create labels from your " - "annotations, which the model uses to learn." - ), -) -@click.option( - "--label-config-field", - type=str, - help=( - "If the label generation settings are inside a nested dictionary " - "within the label configuration file, specify the key here. If " - "the settings are at the top level, leave this blank." - ), + default="preprocess", ) @click.option( "--target-config", @@ -91,6 +76,7 @@ def train(): ... "Path to the training target configuration file. This file " "specifies what sounds the model should learn to predict." ), + default=DEFAULT_CONFIG_FILE, ) @click.option( "--target-config-field", @@ -100,101 +86,130 @@ def train(): ... "within the target configuration file, specify the key here. " "If the settings are at the top level, you don't need to specify this." ), + default="targets", ) @click.option( - "--force", - is_flag=True, - help=( - "If a preprocessed file already exists, this option tells the " - "program to overwrite it with the new preprocessed data. Use " - "this if you want to re-do the preprocessing even if the files " - "already exist." - ), + "--postprocess-config", + type=click.Path(exists=True), + default=DEFAULT_CONFIG_FILE, ) @click.option( - "--num-workers", - type=int, - help=( - "The maximum number of computer cores to use when processing " - "your audio data. Using more cores can speed up the preprocessing, " - "but don't use more than your computer has available. By default, " - "the program will use all available cores." - ), + "--postprocess-config-field", + type=str, + default="postprocess", ) -def preprocess( - dataset_config: Path, - output: Path, - target_config: Optional[Path] = None, - base_dir: Optional[Path] = None, - preprocess_config: Optional[Path] = None, - label_config: Optional[Path] = None, - force: bool = False, - num_workers: Optional[int] = None, - target_config_field: Optional[str] = None, - preprocess_config_field: Optional[str] = None, - label_config_field: Optional[str] = None, - dataset_field: Optional[str] = None, +@click.option( + "--model-config", + type=click.Path(exists=True), + default=DEFAULT_CONFIG_FILE, +) +@click.option( + "--model-config-field", + type=str, + default="model", +) +def train_command( + train_examples: Path, + val_examples: Optional[Path] = None, + model_path: Optional[Path] = None, + train_config: Path = DEFAULT_CONFIG_FILE, + train_config_field: str = "train", + preprocess_config: Path = DEFAULT_CONFIG_FILE, + preprocess_config_field: str = "preprocess", + target_config: Path = DEFAULT_CONFIG_FILE, + target_config_field: str = "targets", + postprocess_config: Path = DEFAULT_CONFIG_FILE, + postprocess_config_field: str = "postprocess", + model_config: Path = DEFAULT_CONFIG_FILE, + model_config_field: str = "model", ): - logger.info("Starting preprocessing.") + logger.info("Starting training!") - output = Path(output) - logger.info("Will save outputs to {output}", output=output) - - base_dir = base_dir or Path.cwd() - logger.debug("Current working directory: {base_dir}", base_dir=base_dir) - - preprocess = ( - load_preprocessing_config( - preprocess_config, - field=preprocess_config_field, - ) - if preprocess_config - else None - ) - - target = ( - load_target_config( - target_config, + try: + target_config_loaded = load_target_config( + path=target_config, field=target_config_field, ) - if target_config - else None - ) - - label = ( - load_label_config( - label_config, - field=label_config_field, + targets = build_targets(config=target_config_loaded) + logger.debug( + "Loaded targets info from config file {path}", path=target_config ) - if label_config - else None + except IOError: + logger.debug( + "Could not load target info from config file, using default" + ) + targets = build_targets() + + try: + preprocess_config_loaded = load_preprocessing_config( + path=preprocess_config, + field=preprocess_config_field, + ) + preprocessor = build_preprocessor(preprocess_config_loaded) + logger.debug( + "Loaded preprocessor from config file {path}", path=target_config + ) + + except IOError: + logger.debug( + "Could not load preprocessor from config file, using default" + ) + preprocessor = build_preprocessor() + + try: + model_config_loaded = load_backbone_config( + path=model_config, field=model_config_field + ) + model = build_model( + num_classes=len(targets.class_names), + config=model_config_loaded, + ) + except IOError: + model = build_model(num_classes=len(targets.class_names)) + + try: + postprocess_config_loaded = load_postprocess_config( + path=postprocess_config, + field=postprocess_config_field, + ) + postprocessor = build_postprocessor( + targets=targets, + config=postprocess_config_loaded, + ) + except IOError: + postprocessor = build_postprocessor(targets=targets) + + try: + train_config_loaded = load_train_config( + path=train_config, field=train_config_field + ) + except IOError: + train_config_loaded = TrainingConfig() + + train_files = list_preprocessed_files(train_examples) + + val_files = ( + None if val_examples is None else list_preprocessed_files(val_examples) ) - dataset = load_dataset_from_config( - dataset_config, - field=dataset_field, - base_dir=base_dir, - ) - - logger.info( - "Loaded {num_examples} annotated clips from the configured dataset", - num_examples=len(dataset), - ) - - targets = build_targets(config=target) - preprocessor = build_preprocessor(config=preprocess) - labeller = build_clip_labeler(targets, config=label) - - if not output.exists(): - logger.debug("Creating directory {directory}", directory=output) - output.mkdir(parents=True) - - logger.info("Will start preprocessing") - preprocess_annotations( - dataset, - output_dir=output, + return train( + detector=model, + train_examples=train_files, # type: ignore + val_examples=val_files, # type: ignore + model_path=model_path, preprocessor=preprocessor, - labeller=labeller, - replace=force, - max_workers=num_workers, + postprocessor=postprocessor, + targets=targets, + config=train_config_loaded, + callbacks=[ + ValidationMetrics( + metrics=[ + DetectionAveragePrecision(), + ClassificationMeanAveragePrecision( + class_names=targets.class_names, + ), + ClassificationAccuracy(class_names=targets.class_names), + ] + ) + ], ) diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index f64c8f9..e586429 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -89,7 +89,7 @@ class LabeledDataset(Dataset): def list_preprocessed_files( directory: data.PathLike, extension: str = ".nc" -) -> Sequence[Path]: +) -> List[Path]: return list(Path(directory).glob(f"*{extension}")) diff --git a/batdetect2/train/train.py b/batdetect2/train/train.py index 757ec99..c0c29a9 100644 --- a/batdetect2/train/train.py +++ b/batdetect2/train/train.py @@ -37,41 +37,44 @@ def train( val_examples: Optional[List[data.PathLike]] = None, config: Optional[TrainingConfig] = None, callbacks: Optional[List[Callback]] = None, + model_path: Optional[data.PathLike] = None, **trainer_kwargs, ) -> None: config = config or TrainingConfig() + if model_path is None: + if preprocessor is None: + preprocessor = build_preprocessor() - if preprocessor is None: - preprocessor = build_preprocessor() + if targets is None: + targets = build_targets() - if targets is None: - targets = build_targets() + if postprocessor is None: + postprocessor = build_postprocessor( + targets, + min_freq=preprocessor.min_freq, + max_freq=preprocessor.max_freq, + ) - if postprocessor is None: - postprocessor = build_postprocessor( - targets, - min_freq=preprocessor.min_freq, - max_freq=preprocessor.max_freq, + loss = build_loss(config.loss) + + module = TrainingModule( + detector=detector, + loss=loss, + targets=targets, + preprocessor=preprocessor, + postprocessor=postprocessor, + learning_rate=config.optimizer.learning_rate, + t_max=config.optimizer.t_max, ) + else: + module = TrainingModule.load_from_checkpoint(model_path) # type: ignore train_dataset = build_train_dataset( train_examples, - preprocessor, + preprocessor=module.preprocessor, config=config, ) - loss = build_loss(config.loss) - - module = TrainingModule( - detector=detector, - loss=loss, - targets=targets, - preprocessor=preprocessor, - postprocessor=postprocessor, - learning_rate=config.optimizer.learning_rate, - t_max=config.optimizer.t_max, - ) - trainer = Trainer( **config.trainer.model_dump(exclude_none=True), callbacks=callbacks,