diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 7d7f35b..398adac 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -8,6 +8,7 @@ from soundevent.audio.files import get_audio_files from batdetect2.audio import build_audio_loader from batdetect2.config import BatDetect2Config +from batdetect2.core import merge_configs from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.logging import DEFAULT_LOGS_DIR @@ -243,7 +244,9 @@ class BatDetect2API: ): model, stored_config = load_model_from_checkpoint(path) - config = config or stored_config + config = ( + merge_configs(config, stored_config) if config else stored_config + ) targets = build_targets(config=config.targets) diff --git a/src/batdetect2/cli/base.py b/src/batdetect2/cli/base.py index 60ec376..8846685 100644 --- a/src/batdetect2/cli/base.py +++ b/src/batdetect2/cli/base.py @@ -1,5 +1,7 @@ """BatDetect2 command line interface.""" +import sys + import click from loguru import logger @@ -19,8 +21,28 @@ BatDetect2 - Detection and Classification @click.group() -def cli(): +@click.option( + "-v", + "--verbose", + count=True, + help="Increase verbosity. -v for INFO, -vv for DEBUG.", +) +def cli( + verbose: int = 0, +): """BatDetect2 - Bat Call Detection and Classification.""" click.echo(INFO_STR) + + logger.remove() + + if verbose == 0: + log_level = "WARNING" + elif verbose == 1: + log_level = "INFO" + else: + log_level = "DEBUG" + + logger.add(sys.stderr, level=log_level) + logger.enable("batdetect2") # click.echo(BATDETECT_ASCII_ART) diff --git a/src/batdetect2/cli/data.py b/src/batdetect2/cli/data.py index f824211..64f3757 100644 --- a/src/batdetect2/cli/data.py +++ b/src/batdetect2/cli/data.py @@ -35,9 +35,11 @@ def summary( from batdetect2.data import load_dataset_from_config base_dir = base_dir or Path.cwd() + dataset = load_dataset_from_config( dataset_config, field=field, base_dir=base_dir, ) + print(f"Number of annotated clips: {len(dataset)}") diff --git a/src/batdetect2/cli/evaluate.py b/src/batdetect2/cli/evaluate.py index 7fa2631..28c771f 100644 --- a/src/batdetect2/cli/evaluate.py +++ b/src/batdetect2/cli/evaluate.py @@ -1,4 +1,3 @@ -import sys from pathlib import Path from typing import Optional @@ -22,12 +21,6 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation" @click.option("--experiment-name", type=str) @click.option("--run-name", type=str) @click.option("--workers", "num_workers", type=int) -@click.option( - "-v", - "--verbose", - count=True, - help="Increase verbosity. -v for INFO, -vv for DEBUG.", -) def evaluate_command( model_path: Path, test_dataset: Path, @@ -37,27 +30,18 @@ def evaluate_command( num_workers: Optional[int] = None, experiment_name: Optional[str] = None, run_name: Optional[str] = None, - verbose: int = 0, ): from batdetect2.api_v2 import BatDetect2API from batdetect2.config import load_full_config from batdetect2.data import load_dataset_from_config - logger.remove() - if verbose == 0: - log_level = "WARNING" - elif verbose == 1: - log_level = "INFO" - else: - log_level = "DEBUG" - logger.add(sys.stderr, level=log_level) - logger.info("Initiating evaluation process...") test_annotations = load_dataset_from_config( test_dataset, base_dir=base_dir, ) + logger.debug( "Loaded {num_annotations} test examples", num_annotations=len(test_annotations), diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index f0168a3..76105f9 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -1,4 +1,3 @@ -import sys from pathlib import Path from typing import Optional @@ -24,12 +23,6 @@ __all__ = ["train_command"] @click.option("--experiment-name", type=str) @click.option("--run-name", type=str) @click.option("--seed", type=int) -@click.option( - "-v", - "--verbose", - count=True, - help="Increase verbosity. -v for INFO, -vv for DEBUG.", -) def train_command( train_dataset: Path, val_dataset: Optional[Path] = None, @@ -44,7 +37,6 @@ def train_command( val_workers: int = 0, experiment_name: Optional[str] = None, run_name: Optional[str] = None, - verbose: int = 0, ): from batdetect2.api_v2 import BatDetect2API from batdetect2.config import ( @@ -54,14 +46,6 @@ def train_command( from batdetect2.data import load_dataset_from_config from batdetect2.targets import load_target_config - logger.remove() - if verbose == 0: - log_level = "WARNING" - elif verbose == 1: - log_level = "INFO" - else: - log_level = "DEBUG" - logger.add(sys.stderr, level=log_level) logger.info("Initiating training process...") logger.info("Loading configuration...") @@ -99,7 +83,10 @@ def train_command( if model_path is None: api = BatDetect2API.from_config(conf) else: - api = BatDetect2API.from_checkpoint(model_path) + api = BatDetect2API.from_checkpoint( + model_path, + config=conf if config is not None else None, + ) return api.train( train_annotations=train_annotations, diff --git a/src/batdetect2/core/__init__.py b/src/batdetect2/core/__init__.py index 62730e8..19acaca 100644 --- a/src/batdetect2/core/__init__.py +++ b/src/batdetect2/core/__init__.py @@ -1,8 +1,9 @@ -from batdetect2.core.configs import BaseConfig, load_config +from batdetect2.core.configs import BaseConfig, load_config, merge_configs from batdetect2.core.registries import Registry __all__ = [ "BaseConfig", "load_config", "Registry", + "merge_configs", ] diff --git a/src/batdetect2/core/configs.py b/src/batdetect2/core/configs.py index 7513d73..e188935 100644 --- a/src/batdetect2/core/configs.py +++ b/src/batdetect2/core/configs.py @@ -11,12 +11,14 @@ configuration sections. from typing import Any, Optional, Type, TypeVar import yaml +from deepmerge.merger import Merger from pydantic import BaseModel, ConfigDict from soundevent.data import PathLike __all__ = [ "BaseConfig", "load_config", + "merge_configs", ] @@ -178,3 +180,19 @@ def load_config( config = get_object_field(config, field) return schema.model_validate(config or {}) + + +default_merger = Merger( + [], + ["override"], + ["override"], +) + + +def merge_configs(config1: T, config2: T) -> T: + """Merge two configuration objects.""" + model = type(config1) + dict1 = config1.model_dump() + dict2 = config2.model_dump() + merged = default_merger.merge(dict1, dict2) + return model.model_validate(merged)