Add merge config option

This commit is contained in:
mbsantiago 2025-10-14 18:05:12 +01:00
parent 24dcc47e73
commit 3913d2d350
7 changed files with 54 additions and 37 deletions

View File

@ -8,6 +8,7 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import DEFAULT_LOGS_DIR
@ -243,7 +244,9 @@ class BatDetect2API:
): ):
model, stored_config = load_model_from_checkpoint(path) model, stored_config = load_model_from_checkpoint(path)
config = config or stored_config config = (
merge_configs(config, stored_config) if config else stored_config
)
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)

View File

@ -1,5 +1,7 @@
"""BatDetect2 command line interface.""" """BatDetect2 command line interface."""
import sys
import click import click
from loguru import logger from loguru import logger
@ -19,8 +21,28 @@ BatDetect2 - Detection and Classification
@click.group() @click.group()
def cli(): @click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def cli(
verbose: int = 0,
):
"""BatDetect2 - Bat Call Detection and Classification.""" """BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR) click.echo(INFO_STR)
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.enable("batdetect2") logger.enable("batdetect2")
# click.echo(BATDETECT_ASCII_ART) # click.echo(BATDETECT_ASCII_ART)

View File

@ -35,9 +35,11 @@ def summary(
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd() base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config( dataset = load_dataset_from_config(
dataset_config, dataset_config,
field=field, field=field,
base_dir=base_dir, base_dir=base_dir,
) )
print(f"Number of annotated clips: {len(dataset)}") print(f"Number of annotated clips: {len(dataset)}")

View File

@ -1,4 +1,3 @@
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -22,12 +21,6 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--run-name", type=str) @click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int) @click.option("--workers", "num_workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command( def evaluate_command(
model_path: Path, model_path: Path,
test_dataset: Path, test_dataset: Path,
@ -37,27 +30,18 @@ def evaluate_command(
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
verbose: int = 0,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...") logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config( test_annotations = load_dataset_from_config(
test_dataset, test_dataset,
base_dir=base_dir, base_dir=base_dir,
) )
logger.debug( logger.debug(
"Loaded {num_annotations} test examples", "Loaded {num_annotations} test examples",
num_annotations=len(test_annotations), num_annotations=len(test_annotations),

View File

@ -1,4 +1,3 @@
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -24,12 +23,6 @@ __all__ = ["train_command"]
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--run-name", type=str) @click.option("--run-name", type=str)
@click.option("--seed", type=int) @click.option("--seed", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command( def train_command(
train_dataset: Path, train_dataset: Path,
val_dataset: Optional[Path] = None, val_dataset: Optional[Path] = None,
@ -44,7 +37,6 @@ def train_command(
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
verbose: int = 0,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import ( from batdetect2.config import (
@ -54,14 +46,6 @@ def train_command(
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config from batdetect2.targets import load_target_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...") logger.info("Initiating training process...")
logger.info("Loading configuration...") logger.info("Loading configuration...")
@ -99,7 +83,10 @@ def train_command(
if model_path is None: if model_path is None:
api = BatDetect2API.from_config(conf) api = BatDetect2API.from_config(conf)
else: else:
api = BatDetect2API.from_checkpoint(model_path) api = BatDetect2API.from_checkpoint(
model_path,
config=conf if config is not None else None,
)
return api.train( return api.train(
train_annotations=train_annotations, train_annotations=train_annotations,

View File

@ -1,8 +1,9 @@
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config, merge_configs
from batdetect2.core.registries import Registry from batdetect2.core.registries import Registry
__all__ = [ __all__ = [
"BaseConfig", "BaseConfig",
"load_config", "load_config",
"Registry", "Registry",
"merge_configs",
] ]

View File

@ -11,12 +11,14 @@ configuration sections.
from typing import Any, Optional, Type, TypeVar from typing import Any, Optional, Type, TypeVar
import yaml import yaml
from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from soundevent.data import PathLike from soundevent.data import PathLike
__all__ = [ __all__ = [
"BaseConfig", "BaseConfig",
"load_config", "load_config",
"merge_configs",
] ]
@ -178,3 +180,19 @@ def load_config(
config = get_object_field(config, field) config = get_object_field(config, field)
return schema.model_validate(config or {}) return schema.model_validate(config or {})
default_merger = Merger(
[],
["override"],
["override"],
)
def merge_configs(config1: T, config2: T) -> T:
"""Merge two configuration objects."""
model = type(config1)
dict1 = config1.model_dump()
dict2 = config2.model_dump()
merged = default_merger.merge(dict1, dict2)
return model.model_validate(merged)