Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
6d0a73dda6 Add deepmerge dependency 2025-10-14 18:21:07 +01:00
mbsantiago
5736421023 Add arguments to train cli 2025-10-14 18:19:33 +01:00
mbsantiago
3913d2d350 Add merge config option 2025-10-14 18:05:12 +01:00
9 changed files with 69 additions and 38 deletions

View File

@ -30,6 +30,7 @@ dependencies = [
"hydra-core>=1.3.2", "hydra-core>=1.3.2",
"numba>=0.60", "numba>=0.60",
"loguru>=0.7.3", "loguru>=0.7.3",
"deepmerge>=2.0",
] ]
requires-python = ">=3.9,<3.13" requires-python = ">=3.9,<3.13"
readme = "README.md" readme = "README.md"

View File

@ -8,6 +8,7 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import DEFAULT_LOGS_DIR
@ -61,6 +62,7 @@ class BatDetect2API:
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR, checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR, log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
): ):
@ -75,6 +77,7 @@ class BatDetect2API:
val_workers=val_workers, val_workers=val_workers,
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
log_dir=log_dir, log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,
@ -243,7 +246,9 @@ class BatDetect2API:
): ):
model, stored_config = load_model_from_checkpoint(path) model, stored_config = load_model_from_checkpoint(path)
config = config or stored_config config = (
merge_configs(config, stored_config) if config else stored_config
)
targets = build_targets(config=config.targets) targets = build_targets(config=config.targets)

View File

@ -1,5 +1,7 @@
"""BatDetect2 command line interface.""" """BatDetect2 command line interface."""
import sys
import click import click
from loguru import logger from loguru import logger
@ -19,8 +21,28 @@ BatDetect2 - Detection and Classification
@click.group() @click.group()
def cli(): @click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def cli(
verbose: int = 0,
):
"""BatDetect2 - Bat Call Detection and Classification.""" """BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR) click.echo(INFO_STR)
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.enable("batdetect2") logger.enable("batdetect2")
# click.echo(BATDETECT_ASCII_ART) # click.echo(BATDETECT_ASCII_ART)

View File

@ -35,9 +35,11 @@ def summary(
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd() base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config( dataset = load_dataset_from_config(
dataset_config, dataset_config,
field=field, field=field,
base_dir=base_dir, base_dir=base_dir,
) )
print(f"Number of annotated clips: {len(dataset)}") print(f"Number of annotated clips: {len(dataset)}")

View File

@ -1,4 +1,3 @@
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -22,12 +21,6 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--run-name", type=str) @click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int) @click.option("--workers", "num_workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command( def evaluate_command(
model_path: Path, model_path: Path,
test_dataset: Path, test_dataset: Path,
@ -37,27 +30,18 @@ def evaluate_command(
num_workers: Optional[int] = None, num_workers: Optional[int] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
verbose: int = 0,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...") logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config( test_annotations = load_dataset_from_config(
test_dataset, test_dataset,
base_dir=base_dir, base_dir=base_dir,
) )
logger.debug( logger.debug(
"Loaded {num_annotations} test examples", "Loaded {num_annotations} test examples",
num_annotations=len(test_annotations), num_annotations=len(test_annotations),

View File

@ -1,4 +1,3 @@
import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
@ -21,15 +20,10 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@click.option("--run-name", type=str) @click.option("--run-name", type=str)
@click.option("--seed", type=int) @click.option("--seed", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command( def train_command(
train_dataset: Path, train_dataset: Path,
val_dataset: Optional[Path] = None, val_dataset: Optional[Path] = None,
@ -40,11 +34,11 @@ def train_command(
targets_config: Optional[Path] = None, targets_config: Optional[Path] = None,
config_field: Optional[str] = None, config_field: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
num_epochs: Optional[int] = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
verbose: int = 0,
): ):
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import ( from batdetect2.config import (
@ -54,14 +48,6 @@ def train_command(
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config from batdetect2.targets import load_target_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...") logger.info("Initiating training process...")
logger.info("Loading configuration...") logger.info("Loading configuration...")
@ -99,7 +85,10 @@ def train_command(
if model_path is None: if model_path is None:
api = BatDetect2API.from_config(conf) api = BatDetect2API.from_config(conf)
else: else:
api = BatDetect2API.from_checkpoint(model_path) api = BatDetect2API.from_checkpoint(
model_path,
config=conf if config is not None else None,
)
return api.train( return api.train(
train_annotations=train_annotations, train_annotations=train_annotations,
@ -108,6 +97,7 @@ def train_command(
val_workers=val_workers, val_workers=val_workers,
checkpoint_dir=ckpt_dir, checkpoint_dir=ckpt_dir,
log_dir=log_dir, log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
seed=seed, seed=seed,

View File

@ -1,8 +1,9 @@
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config, merge_configs
from batdetect2.core.registries import Registry from batdetect2.core.registries import Registry
__all__ = [ __all__ = [
"BaseConfig", "BaseConfig",
"load_config", "load_config",
"Registry", "Registry",
"merge_configs",
] ]

View File

@ -11,12 +11,14 @@ configuration sections.
from typing import Any, Optional, Type, TypeVar from typing import Any, Optional, Type, TypeVar
import yaml import yaml
from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict from pydantic import BaseModel, ConfigDict
from soundevent.data import PathLike from soundevent.data import PathLike
__all__ = [ __all__ = [
"BaseConfig", "BaseConfig",
"load_config", "load_config",
"merge_configs",
] ]
@ -178,3 +180,19 @@ def load_config(
config = get_object_field(config, field) config = get_object_field(config, field)
return schema.model_validate(config or {}) return schema.model_validate(config or {})
default_merger = Merger(
[],
["override"],
["override"],
)
def merge_configs(config1: T, config2: T) -> T:
"""Merge two configuration objects."""
model = type(config1)
dict1 = config1.model_dump()
dict2 = config2.model_dump()
merged = default_merger.merge(dict1, dict2)
return model.model_validate(merged)

View File

@ -47,6 +47,7 @@ def train(
checkpoint_dir: Optional[Path] = None, checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
): ):
@ -107,6 +108,7 @@ def train(
targets=targets, targets=targets,
), ),
checkpoint_dir=checkpoint_dir, checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,
@ -128,6 +130,7 @@ def build_trainer(
log_dir: Optional[Path] = None, log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None, experiment_name: Optional[str] = None,
run_name: Optional[str] = None, run_name: Optional[str] = None,
num_epochs: Optional[int] = None,
) -> Trainer: ) -> Trainer:
trainer_conf = config.train.trainer trainer_conf = config.train.trainer
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
@ -149,8 +152,13 @@ def build_trainer(
) )
) )
train_config = trainer_conf.model_dump(exclude_none=True)
if num_epochs is not None:
train_config["max_epochs"] = num_epochs
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **train_config,
logger=train_logger, logger=train_logger,
callbacks=[ callbacks=[
build_checkpoint_callback( build_checkpoint_callback(