Compare commits

...

3 Commits

Author SHA1 Message Date
mbsantiago
6d0a73dda6 Add deepmerge dependency 2025-10-14 18:21:07 +01:00
mbsantiago
5736421023 Add arguments to train cli 2025-10-14 18:19:33 +01:00
mbsantiago
3913d2d350 Add merge config option 2025-10-14 18:05:12 +01:00
9 changed files with 69 additions and 38 deletions

View File

@ -30,6 +30,7 @@ dependencies = [
"hydra-core>=1.3.2",
"numba>=0.60",
"loguru>=0.7.3",
"deepmerge>=2.0",
]
requires-python = ">=3.9,<3.13"
readme = "README.md"

View File

@ -8,6 +8,7 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
@ -61,6 +62,7 @@ class BatDetect2API:
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
@ -75,6 +77,7 @@ class BatDetect2API:
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
@ -243,7 +246,9 @@ class BatDetect2API:
):
model, stored_config = load_model_from_checkpoint(path)
config = config or stored_config
config = (
merge_configs(config, stored_config) if config else stored_config
)
targets = build_targets(config=config.targets)

View File

@ -1,5 +1,7 @@
"""BatDetect2 command line interface."""
import sys
import click
from loguru import logger
@ -19,8 +21,28 @@ BatDetect2 - Detection and Classification
@click.group()
def cli():
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def cli(
verbose: int = 0,
):
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.enable("batdetect2")
# click.echo(BATDETECT_ASCII_ART)

View File

@ -35,9 +35,11 @@ def summary(
from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config(
dataset_config,
field=field,
base_dir=base_dir,
)
print(f"Number of annotated clips: {len(dataset)}")

View File

@ -1,4 +1,3 @@
import sys
from pathlib import Path
from typing import Optional
@ -22,12 +21,6 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command(
model_path: Path,
test_dataset: Path,
@ -37,27 +30,18 @@ def evaluate_command(
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(
test_dataset,
base_dir=base_dir,
)
logger.debug(
"Loaded {num_annotations} test examples",
num_annotations=len(test_annotations),

View File

@ -1,4 +1,3 @@
import sys
from pathlib import Path
from typing import Optional
@ -21,15 +20,10 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--seed", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command(
train_dataset: Path,
val_dataset: Optional[Path] = None,
@ -40,11 +34,11 @@ def train_command(
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
num_epochs: Optional[int] = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import (
@ -54,14 +48,6 @@ def train_command(
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...")
logger.info("Loading configuration...")
@ -99,7 +85,10 @@ def train_command(
if model_path is None:
api = BatDetect2API.from_config(conf)
else:
api = BatDetect2API.from_checkpoint(model_path)
api = BatDetect2API.from_checkpoint(
model_path,
config=conf if config is not None else None,
)
return api.train(
train_annotations=train_annotations,
@ -108,6 +97,7 @@ def train_command(
val_workers=val_workers,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,

View File

@ -1,8 +1,9 @@
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
from batdetect2.core.registries import Registry
__all__ = [
"BaseConfig",
"load_config",
"Registry",
"merge_configs",
]

View File

@ -11,12 +11,14 @@ configuration sections.
from typing import Any, Optional, Type, TypeVar
import yaml
from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict
from soundevent.data import PathLike
__all__ = [
"BaseConfig",
"load_config",
"merge_configs",
]
@ -178,3 +180,19 @@ def load_config(
config = get_object_field(config, field)
return schema.model_validate(config or {})
default_merger = Merger(
[],
["override"],
["override"],
)
def merge_configs(config1: T, config2: T) -> T:
"""Merge two configuration objects."""
model = type(config1)
dict1 = config1.model_dump()
dict2 = config2.model_dump()
merged = default_merger.merge(dict1, dict2)
return model.model_validate(merged)

View File

@ -47,6 +47,7 @@ def train(
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
@ -107,6 +108,7 @@ def train(
targets=targets,
),
checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
@ -128,6 +130,7 @@ def build_trainer(
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
num_epochs: Optional[int] = None,
) -> Trainer:
trainer_conf = config.train.trainer
logger.opt(lazy=True).debug(
@ -149,8 +152,13 @@ def build_trainer(
)
)
train_config = trainer_conf.model_dump(exclude_none=True)
if num_epochs is not None:
train_config["max_epochs"] = num_epochs
return Trainer(
**trainer_conf.model_dump(exclude_none=True),
**train_config,
logger=train_logger,
callbacks=[
build_checkpoint_callback(