mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
create train cli command
This commit is contained in:
parent
2913fa59a4
commit
3b5623ddca
196
batdetect2/cli/preprocess.py
Normal file
196
batdetect2/cli/preprocess.py
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
|
from batdetect2.train import load_label_config, preprocess_annotations
|
||||||
|
from batdetect2.train.labels import build_clip_labeler
|
||||||
|
|
||||||
|
__all__ = ["preprocess"]
|
||||||
|
|
||||||
|
|
||||||
|
@cli.command()
|
||||||
|
@click.argument(
|
||||||
|
"dataset_config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
)
|
||||||
|
@click.argument(
|
||||||
|
"output",
|
||||||
|
type=click.Path(),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--dataset-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"Specifies the key to access the dataset information within the "
|
||||||
|
"dataset configuration file, if the information is nested inside a "
|
||||||
|
"dictionary. If the dataset information is at the top level of the "
|
||||||
|
"config file, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--base-dir",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"The main directory where your audio recordings and annotation "
|
||||||
|
"files are stored. This helps the program find your data, "
|
||||||
|
"especially if the paths in your dataset configuration file "
|
||||||
|
"are relative."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--preprocess-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the preprocessing configuration file. This file tells "
|
||||||
|
"the program how to prepare your audio data before training, such "
|
||||||
|
"as resampling or applying filters."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--preprocess-config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the preprocessing settings are inside a nested dictionary "
|
||||||
|
"within the preprocessing configuration file, specify the key "
|
||||||
|
"here to access them. If the preprocessing settings are at the "
|
||||||
|
"top level, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--label-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the label generation configuration file. This file "
|
||||||
|
"contains settings for how to create labels from your "
|
||||||
|
"annotations, which the model uses to learn."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--label-config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the label generation settings are inside a nested dictionary "
|
||||||
|
"within the label configuration file, specify the key here. If "
|
||||||
|
"the settings are at the top level, leave this blank."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--target-config",
|
||||||
|
type=click.Path(exists=True),
|
||||||
|
help=(
|
||||||
|
"Path to the training target configuration file. This file "
|
||||||
|
"specifies what sounds the model should learn to predict."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--target-config-field",
|
||||||
|
type=str,
|
||||||
|
help=(
|
||||||
|
"If the target settings are inside a nested dictionary "
|
||||||
|
"within the target configuration file, specify the key here. "
|
||||||
|
"If the settings are at the top level, you don't need to specify this."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--force",
|
||||||
|
is_flag=True,
|
||||||
|
help=(
|
||||||
|
"If a preprocessed file already exists, this option tells the "
|
||||||
|
"program to overwrite it with the new preprocessed data. Use "
|
||||||
|
"this if you want to re-do the preprocessing even if the files "
|
||||||
|
"already exist."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--num-workers",
|
||||||
|
type=int,
|
||||||
|
help=(
|
||||||
|
"The maximum number of computer cores to use when processing "
|
||||||
|
"your audio data. Using more cores can speed up the preprocessing, "
|
||||||
|
"but don't use more than your computer has available. By default, "
|
||||||
|
"the program will use all available cores."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
def preprocess(
|
||||||
|
dataset_config: Path,
|
||||||
|
output: Path,
|
||||||
|
target_config: Optional[Path] = None,
|
||||||
|
base_dir: Optional[Path] = None,
|
||||||
|
preprocess_config: Optional[Path] = None,
|
||||||
|
label_config: Optional[Path] = None,
|
||||||
|
force: bool = False,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
target_config_field: Optional[str] = None,
|
||||||
|
preprocess_config_field: Optional[str] = None,
|
||||||
|
label_config_field: Optional[str] = None,
|
||||||
|
dataset_field: Optional[str] = None,
|
||||||
|
):
|
||||||
|
logger.info("Starting preprocessing.")
|
||||||
|
|
||||||
|
output = Path(output)
|
||||||
|
logger.info("Will save outputs to {output}", output=output)
|
||||||
|
|
||||||
|
base_dir = base_dir or Path.cwd()
|
||||||
|
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
||||||
|
|
||||||
|
preprocess = (
|
||||||
|
load_preprocessing_config(
|
||||||
|
preprocess_config,
|
||||||
|
field=preprocess_config_field,
|
||||||
|
)
|
||||||
|
if preprocess_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
target = (
|
||||||
|
load_target_config(
|
||||||
|
target_config,
|
||||||
|
field=target_config_field,
|
||||||
|
)
|
||||||
|
if target_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
label = (
|
||||||
|
load_label_config(
|
||||||
|
label_config,
|
||||||
|
field=label_config_field,
|
||||||
|
)
|
||||||
|
if label_config
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = load_dataset_from_config(
|
||||||
|
dataset_config,
|
||||||
|
field=dataset_field,
|
||||||
|
base_dir=base_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Loaded {num_examples} annotated clips from the configured dataset",
|
||||||
|
num_examples=len(dataset),
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = build_targets(config=target)
|
||||||
|
preprocessor = build_preprocessor(config=preprocess)
|
||||||
|
labeller = build_clip_labeler(targets, config=label)
|
||||||
|
|
||||||
|
if not output.exists():
|
||||||
|
logger.debug("Creating directory {directory}", directory=output)
|
||||||
|
output.mkdir(parents=True)
|
||||||
|
|
||||||
|
logger.info("Will start preprocessing")
|
||||||
|
preprocess_annotations(
|
||||||
|
dataset,
|
||||||
|
output_dir=output,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
labeller=labeller,
|
||||||
|
replace=force,
|
||||||
|
max_workers=num_workers,
|
||||||
|
)
|
@ -5,47 +5,48 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.evaluate.metrics import (
|
||||||
|
ClassificationAccuracy,
|
||||||
|
ClassificationMeanAveragePrecision,
|
||||||
|
DetectionAveragePrecision,
|
||||||
|
)
|
||||||
|
from batdetect2.models import build_model
|
||||||
|
from batdetect2.models.backbones import load_backbone_config
|
||||||
|
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
from batdetect2.targets import build_targets, load_target_config
|
||||||
from batdetect2.train import load_label_config, preprocess_annotations
|
from batdetect2.train import train
|
||||||
from batdetect2.train.labels import build_clip_labeler
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
|
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||||
|
from batdetect2.train.dataset import list_preprocessed_files
|
||||||
|
|
||||||
__all__ = ["train"]
|
__all__ = [
|
||||||
|
"train_command",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_CONFIG_FILE = Path("config.yaml")
|
||||||
|
|
||||||
|
|
||||||
@cli.group()
|
@cli.command(name="train")
|
||||||
def train(): ...
|
@click.option(
|
||||||
|
"--train-examples",
|
||||||
|
type=click.Path(exists=True),
|
||||||
@train.command()
|
required=True,
|
||||||
@click.argument(
|
)
|
||||||
"dataset_config",
|
@click.option("--val-examples", type=click.Path(exists=True))
|
||||||
|
@click.option(
|
||||||
|
"--model-path",
|
||||||
type=click.Path(exists=True),
|
type=click.Path(exists=True),
|
||||||
)
|
)
|
||||||
@click.argument(
|
@click.option(
|
||||||
"output",
|
"--train-config",
|
||||||
type=click.Path(),
|
type=click.Path(exists=True),
|
||||||
|
default=DEFAULT_CONFIG_FILE,
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--dataset-field",
|
"--train-field",
|
||||||
type=str,
|
type=str,
|
||||||
help=(
|
default="train",
|
||||||
"Specifies the key to access the dataset information within the "
|
|
||||||
"dataset configuration file, if the information is nested inside a "
|
|
||||||
"dictionary. If the dataset information is at the top level of the "
|
|
||||||
"config file, you don't need to specify this."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--base-dir",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"The main directory where your audio recordings and annotation "
|
|
||||||
"files are stored. This helps the program find your data, "
|
|
||||||
"especially if the paths in your dataset configuration file "
|
|
||||||
"are relative."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config",
|
"--preprocess-config",
|
||||||
@ -55,6 +56,7 @@ def train(): ...
|
|||||||
"the program how to prepare your audio data before training, such "
|
"the program how to prepare your audio data before training, such "
|
||||||
"as resampling or applying filters."
|
"as resampling or applying filters."
|
||||||
),
|
),
|
||||||
|
default=DEFAULT_CONFIG_FILE,
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--preprocess-config-field",
|
"--preprocess-config-field",
|
||||||
@ -65,24 +67,7 @@ def train(): ...
|
|||||||
"here to access them. If the preprocessing settings are at the "
|
"here to access them. If the preprocessing settings are at the "
|
||||||
"top level, you don't need to specify this."
|
"top level, you don't need to specify this."
|
||||||
),
|
),
|
||||||
)
|
default="preprocess",
|
||||||
@click.option(
|
|
||||||
"--label-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the label generation configuration file. This file "
|
|
||||||
"contains settings for how to create labels from your "
|
|
||||||
"annotations, which the model uses to learn."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--label-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the label generation settings are inside a nested dictionary "
|
|
||||||
"within the label configuration file, specify the key here. If "
|
|
||||||
"the settings are at the top level, leave this blank."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--target-config",
|
"--target-config",
|
||||||
@ -91,6 +76,7 @@ def train(): ...
|
|||||||
"Path to the training target configuration file. This file "
|
"Path to the training target configuration file. This file "
|
||||||
"specifies what sounds the model should learn to predict."
|
"specifies what sounds the model should learn to predict."
|
||||||
),
|
),
|
||||||
|
default=DEFAULT_CONFIG_FILE,
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--target-config-field",
|
"--target-config-field",
|
||||||
@ -100,101 +86,130 @@ def train(): ...
|
|||||||
"within the target configuration file, specify the key here. "
|
"within the target configuration file, specify the key here. "
|
||||||
"If the settings are at the top level, you don't need to specify this."
|
"If the settings are at the top level, you don't need to specify this."
|
||||||
),
|
),
|
||||||
|
default="targets",
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--force",
|
"--postprocess-config",
|
||||||
is_flag=True,
|
type=click.Path(exists=True),
|
||||||
help=(
|
default=DEFAULT_CONFIG_FILE,
|
||||||
"If a preprocessed file already exists, this option tells the "
|
|
||||||
"program to overwrite it with the new preprocessed data. Use "
|
|
||||||
"this if you want to re-do the preprocessing even if the files "
|
|
||||||
"already exist."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--num-workers",
|
"--postprocess-config-field",
|
||||||
type=int,
|
type=str,
|
||||||
help=(
|
default="postprocess",
|
||||||
"The maximum number of computer cores to use when processing "
|
|
||||||
"your audio data. Using more cores can speed up the preprocessing, "
|
|
||||||
"but don't use more than your computer has available. By default, "
|
|
||||||
"the program will use all available cores."
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
def preprocess(
|
@click.option(
|
||||||
dataset_config: Path,
|
"--model-config",
|
||||||
output: Path,
|
type=click.Path(exists=True),
|
||||||
target_config: Optional[Path] = None,
|
default=DEFAULT_CONFIG_FILE,
|
||||||
base_dir: Optional[Path] = None,
|
)
|
||||||
preprocess_config: Optional[Path] = None,
|
@click.option(
|
||||||
label_config: Optional[Path] = None,
|
"--model-config-field",
|
||||||
force: bool = False,
|
type=str,
|
||||||
num_workers: Optional[int] = None,
|
default="model",
|
||||||
target_config_field: Optional[str] = None,
|
)
|
||||||
preprocess_config_field: Optional[str] = None,
|
def train_command(
|
||||||
label_config_field: Optional[str] = None,
|
train_examples: Path,
|
||||||
dataset_field: Optional[str] = None,
|
val_examples: Optional[Path] = None,
|
||||||
|
model_path: Optional[Path] = None,
|
||||||
|
train_config: Path = DEFAULT_CONFIG_FILE,
|
||||||
|
train_config_field: str = "train",
|
||||||
|
preprocess_config: Path = DEFAULT_CONFIG_FILE,
|
||||||
|
preprocess_config_field: str = "preprocess",
|
||||||
|
target_config: Path = DEFAULT_CONFIG_FILE,
|
||||||
|
target_config_field: str = "targets",
|
||||||
|
postprocess_config: Path = DEFAULT_CONFIG_FILE,
|
||||||
|
postprocess_config_field: str = "postprocess",
|
||||||
|
model_config: Path = DEFAULT_CONFIG_FILE,
|
||||||
|
model_config_field: str = "model",
|
||||||
):
|
):
|
||||||
logger.info("Starting preprocessing.")
|
logger.info("Starting training!")
|
||||||
|
|
||||||
output = Path(output)
|
try:
|
||||||
logger.info("Will save outputs to {output}", output=output)
|
target_config_loaded = load_target_config(
|
||||||
|
path=target_config,
|
||||||
base_dir = base_dir or Path.cwd()
|
|
||||||
logger.debug("Current working directory: {base_dir}", base_dir=base_dir)
|
|
||||||
|
|
||||||
preprocess = (
|
|
||||||
load_preprocessing_config(
|
|
||||||
preprocess_config,
|
|
||||||
field=preprocess_config_field,
|
|
||||||
)
|
|
||||||
if preprocess_config
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
target = (
|
|
||||||
load_target_config(
|
|
||||||
target_config,
|
|
||||||
field=target_config_field,
|
field=target_config_field,
|
||||||
)
|
)
|
||||||
if target_config
|
targets = build_targets(config=target_config_loaded)
|
||||||
else None
|
logger.debug(
|
||||||
)
|
"Loaded targets info from config file {path}", path=target_config
|
||||||
|
|
||||||
label = (
|
|
||||||
load_label_config(
|
|
||||||
label_config,
|
|
||||||
field=label_config_field,
|
|
||||||
)
|
)
|
||||||
if label_config
|
except IOError:
|
||||||
else None
|
logger.debug(
|
||||||
|
"Could not load target info from config file, using default"
|
||||||
|
)
|
||||||
|
targets = build_targets()
|
||||||
|
|
||||||
|
try:
|
||||||
|
preprocess_config_loaded = load_preprocessing_config(
|
||||||
|
path=preprocess_config,
|
||||||
|
field=preprocess_config_field,
|
||||||
|
)
|
||||||
|
preprocessor = build_preprocessor(preprocess_config_loaded)
|
||||||
|
logger.debug(
|
||||||
|
"Loaded preprocessor from config file {path}", path=target_config
|
||||||
|
)
|
||||||
|
|
||||||
|
except IOError:
|
||||||
|
logger.debug(
|
||||||
|
"Could not load preprocessor from config file, using default"
|
||||||
|
)
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
try:
|
||||||
|
model_config_loaded = load_backbone_config(
|
||||||
|
path=model_config, field=model_config_field
|
||||||
|
)
|
||||||
|
model = build_model(
|
||||||
|
num_classes=len(targets.class_names),
|
||||||
|
config=model_config_loaded,
|
||||||
|
)
|
||||||
|
except IOError:
|
||||||
|
model = build_model(num_classes=len(targets.class_names))
|
||||||
|
|
||||||
|
try:
|
||||||
|
postprocess_config_loaded = load_postprocess_config(
|
||||||
|
path=postprocess_config,
|
||||||
|
field=postprocess_config_field,
|
||||||
|
)
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
targets=targets,
|
||||||
|
config=postprocess_config_loaded,
|
||||||
|
)
|
||||||
|
except IOError:
|
||||||
|
postprocessor = build_postprocessor(targets=targets)
|
||||||
|
|
||||||
|
try:
|
||||||
|
train_config_loaded = load_train_config(
|
||||||
|
path=train_config, field=train_config_field
|
||||||
|
)
|
||||||
|
except IOError:
|
||||||
|
train_config_loaded = TrainingConfig()
|
||||||
|
|
||||||
|
train_files = list_preprocessed_files(train_examples)
|
||||||
|
|
||||||
|
val_files = (
|
||||||
|
None if val_examples is None else list_preprocessed_files(val_examples)
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = load_dataset_from_config(
|
return train(
|
||||||
dataset_config,
|
detector=model,
|
||||||
field=dataset_field,
|
train_examples=train_files, # type: ignore
|
||||||
base_dir=base_dir,
|
val_examples=val_files, # type: ignore
|
||||||
)
|
model_path=model_path,
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Loaded {num_examples} annotated clips from the configured dataset",
|
|
||||||
num_examples=len(dataset),
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = build_targets(config=target)
|
|
||||||
preprocessor = build_preprocessor(config=preprocess)
|
|
||||||
labeller = build_clip_labeler(targets, config=label)
|
|
||||||
|
|
||||||
if not output.exists():
|
|
||||||
logger.debug("Creating directory {directory}", directory=output)
|
|
||||||
output.mkdir(parents=True)
|
|
||||||
|
|
||||||
logger.info("Will start preprocessing")
|
|
||||||
preprocess_annotations(
|
|
||||||
dataset,
|
|
||||||
output_dir=output,
|
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
labeller=labeller,
|
postprocessor=postprocessor,
|
||||||
replace=force,
|
targets=targets,
|
||||||
max_workers=num_workers,
|
config=train_config_loaded,
|
||||||
|
callbacks=[
|
||||||
|
ValidationMetrics(
|
||||||
|
metrics=[
|
||||||
|
DetectionAveragePrecision(),
|
||||||
|
ClassificationMeanAveragePrecision(
|
||||||
|
class_names=targets.class_names,
|
||||||
|
),
|
||||||
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
@ -89,7 +89,7 @@ class LabeledDataset(Dataset):
|
|||||||
|
|
||||||
def list_preprocessed_files(
|
def list_preprocessed_files(
|
||||||
directory: data.PathLike, extension: str = ".nc"
|
directory: data.PathLike, extension: str = ".nc"
|
||||||
) -> Sequence[Path]:
|
) -> List[Path]:
|
||||||
return list(Path(directory).glob(f"*{extension}"))
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,41 +37,44 @@ def train(
|
|||||||
val_examples: Optional[List[data.PathLike]] = None,
|
val_examples: Optional[List[data.PathLike]] = None,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
callbacks: Optional[List[Callback]] = None,
|
callbacks: Optional[List[Callback]] = None,
|
||||||
|
model_path: Optional[data.PathLike] = None,
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
if model_path is None:
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
if preprocessor is None:
|
if targets is None:
|
||||||
preprocessor = build_preprocessor()
|
targets = build_targets()
|
||||||
|
|
||||||
if targets is None:
|
if postprocessor is None:
|
||||||
targets = build_targets()
|
postprocessor = build_postprocessor(
|
||||||
|
targets,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
if postprocessor is None:
|
loss = build_loss(config.loss)
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets,
|
module = TrainingModule(
|
||||||
min_freq=preprocessor.min_freq,
|
detector=detector,
|
||||||
max_freq=preprocessor.max_freq,
|
loss=loss,
|
||||||
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
learning_rate=config.optimizer.learning_rate,
|
||||||
|
t_max=config.optimizer.t_max,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
|
|
||||||
train_dataset = build_train_dataset(
|
train_dataset = build_train_dataset(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor,
|
preprocessor=module.preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = build_loss(config.loss)
|
|
||||||
|
|
||||||
module = TrainingModule(
|
|
||||||
detector=detector,
|
|
||||||
loss=loss,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
learning_rate=config.optimizer.learning_rate,
|
|
||||||
t_max=config.optimizer.t_max,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
**config.trainer.model_dump(exclude_none=True),
|
**config.trainer.model_dump(exclude_none=True),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
Loading…
Reference in New Issue
Block a user