mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 00:59:34 +01:00
Add targets to train cli
This commit is contained in:
parent
115084fd2b
commit
16a0fa7b75
@ -7,6 +7,7 @@ from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
@ -20,6 +21,7 @@ __all__ = ["train_command"]
|
||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--targets", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@ -42,6 +44,7 @@ def train_command(
|
||||
ckpt_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
config: Optional[Path] = None,
|
||||
targets: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
@ -62,12 +65,18 @@ def train_command(
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading training configuration...")
|
||||
|
||||
conf = (
|
||||
load_full_training_config(config, field=config_field)
|
||||
if config is not None
|
||||
else FullTrainingConfig()
|
||||
)
|
||||
|
||||
if targets is not None:
|
||||
logger.info("Loading targets configuration...")
|
||||
targets_config = load_target_config(targets)
|
||||
conf = conf.model_copy(update=dict(targets=targets_config))
|
||||
|
||||
logger.info("Loading training dataset...")
|
||||
train_annotations = load_dataset_from_config(train_dataset)
|
||||
logger.debug(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user