mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Add targets to train cli
This commit is contained in:
parent
115084fd2b
commit
16a0fa7b75
@ -7,6 +7,7 @@ from loguru import logger
|
|||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.targets import load_target_config
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
FullTrainingConfig,
|
FullTrainingConfig,
|
||||||
load_full_training_config,
|
load_full_training_config,
|
||||||
@ -20,6 +21,7 @@ __all__ = ["train_command"]
|
|||||||
@click.argument("train_dataset", type=click.Path(exists=True))
|
@click.argument("train_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
|
@click.option("--targets", type=click.Path(exists=True))
|
||||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||||
@click.option("--log-dir", type=click.Path(exists=True))
|
@click.option("--log-dir", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@ -42,6 +44,7 @@ def train_command(
|
|||||||
ckpt_dir: Optional[Path] = None,
|
ckpt_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
|
targets: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -62,12 +65,18 @@ def train_command(
|
|||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
logger.info("Loading training configuration...")
|
logger.info("Loading training configuration...")
|
||||||
|
|
||||||
conf = (
|
conf = (
|
||||||
load_full_training_config(config, field=config_field)
|
load_full_training_config(config, field=config_field)
|
||||||
if config is not None
|
if config is not None
|
||||||
else FullTrainingConfig()
|
else FullTrainingConfig()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if targets is not None:
|
||||||
|
logger.info("Loading targets configuration...")
|
||||||
|
targets_config = load_target_config(targets)
|
||||||
|
conf = conf.model_copy(update=dict(targets=targets_config))
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
logger.info("Loading training dataset...")
|
||||||
train_annotations = load_dataset_from_config(train_dataset)
|
train_annotations = load_dataset_from_config(train_dataset)
|
||||||
logger.debug(
|
logger.debug(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user