Compare commits

..

No commits in common. "6d0a73dda62b2806e8059904f1ae97b83142d49f" and "24dcc47e73b1c578d4957bc60534c80707e1f0e5" have entirely different histories.

9 changed files with 38 additions and 69 deletions

View File

@ -30,7 +30,6 @@ dependencies = [
"hydra-core>=1.3.2",
"numba>=0.60",
"loguru>=0.7.3",
"deepmerge>=2.0",
]
requires-python = ">=3.9,<3.13"
readme = "README.md"

View File

@ -8,7 +8,6 @@ from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.core import merge_configs
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.logging import DEFAULT_LOGS_DIR
@ -62,7 +61,6 @@ class BatDetect2API:
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
@ -77,7 +75,6 @@ class BatDetect2API:
val_workers=val_workers,
checkpoint_dir=checkpoint_dir,
log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,
@ -246,9 +243,7 @@ class BatDetect2API:
):
model, stored_config = load_model_from_checkpoint(path)
config = (
merge_configs(config, stored_config) if config else stored_config
)
config = config or stored_config
targets = build_targets(config=config.targets)

View File

@ -1,7 +1,5 @@
"""BatDetect2 command line interface."""
import sys
import click
from loguru import logger
@ -21,28 +19,8 @@ BatDetect2 - Detection and Classification
@click.group()
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def cli(
verbose: int = 0,
):
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.enable("batdetect2")
# click.echo(BATDETECT_ASCII_ART)

View File

@ -35,11 +35,9 @@ def summary(
from batdetect2.data import load_dataset_from_config
base_dir = base_dir or Path.cwd()
dataset = load_dataset_from_config(
dataset_config,
field=field,
base_dir=base_dir,
)
print(f"Number of annotated clips: {len(dataset)}")

View File

@ -1,3 +1,4 @@
import sys
from pathlib import Path
from typing import Optional
@ -21,6 +22,12 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--workers", "num_workers", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def evaluate_command(
model_path: Path,
test_dataset: Path,
@ -30,18 +37,27 @@ def evaluate_command(
num_workers: Optional[int] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import load_full_config
from batdetect2.data import load_dataset_from_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating evaluation process...")
test_annotations = load_dataset_from_config(
test_dataset,
base_dir=base_dir,
)
logger.debug(
"Loaded {num_annotations} test examples",
num_annotations=len(test_annotations),

View File

@ -1,3 +1,4 @@
import sys
from pathlib import Path
from typing import Optional
@ -20,10 +21,15 @@ __all__ = ["train_command"]
@click.option("--config-field", type=str)
@click.option("--train-workers", type=int)
@click.option("--val-workers", type=int)
@click.option("--num-epochs", type=int)
@click.option("--experiment-name", type=str)
@click.option("--run-name", type=str)
@click.option("--seed", type=int)
@click.option(
"-v",
"--verbose",
count=True,
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
)
def train_command(
train_dataset: Path,
val_dataset: Optional[Path] = None,
@ -34,11 +40,11 @@ def train_command(
targets_config: Optional[Path] = None,
config_field: Optional[str] = None,
seed: Optional[int] = None,
num_epochs: Optional[int] = None,
train_workers: int = 0,
val_workers: int = 0,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
verbose: int = 0,
):
from batdetect2.api_v2 import BatDetect2API
from batdetect2.config import (
@ -48,6 +54,14 @@ def train_command(
from batdetect2.data import load_dataset_from_config
from batdetect2.targets import load_target_config
logger.remove()
if verbose == 0:
log_level = "WARNING"
elif verbose == 1:
log_level = "INFO"
else:
log_level = "DEBUG"
logger.add(sys.stderr, level=log_level)
logger.info("Initiating training process...")
logger.info("Loading configuration...")
@ -85,10 +99,7 @@ def train_command(
if model_path is None:
api = BatDetect2API.from_config(conf)
else:
api = BatDetect2API.from_checkpoint(
model_path,
config=conf if config is not None else None,
)
api = BatDetect2API.from_checkpoint(model_path)
return api.train(
train_annotations=train_annotations,
@ -97,7 +108,6 @@ def train_command(
val_workers=val_workers,
checkpoint_dir=ckpt_dir,
log_dir=log_dir,
num_epochs=num_epochs,
experiment_name=experiment_name,
run_name=run_name,
seed=seed,

View File

@ -1,9 +1,8 @@
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.core.registries import Registry
__all__ = [
"BaseConfig",
"load_config",
"Registry",
"merge_configs",
]

View File

@ -11,14 +11,12 @@ configuration sections.
from typing import Any, Optional, Type, TypeVar
import yaml
from deepmerge.merger import Merger
from pydantic import BaseModel, ConfigDict
from soundevent.data import PathLike
__all__ = [
"BaseConfig",
"load_config",
"merge_configs",
]
@ -180,19 +178,3 @@ def load_config(
config = get_object_field(config, field)
return schema.model_validate(config or {})
default_merger = Merger(
[],
["override"],
["override"],
)
def merge_configs(config1: T, config2: T) -> T:
"""Merge two configuration objects."""
model = type(config1)
dict1 = config1.model_dump()
dict2 = config2.model_dump()
merged = default_merger.merge(dict1, dict2)
return model.model_validate(merged)

View File

@ -47,7 +47,6 @@ def train(
checkpoint_dir: Optional[Path] = None,
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
num_epochs: Optional[int] = None,
run_name: Optional[str] = None,
seed: Optional[int] = None,
):
@ -108,7 +107,6 @@ def train(
targets=targets,
),
checkpoint_dir=checkpoint_dir,
num_epochs=num_epochs,
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,
@ -130,7 +128,6 @@ def build_trainer(
log_dir: Optional[Path] = None,
experiment_name: Optional[str] = None,
run_name: Optional[str] = None,
num_epochs: Optional[int] = None,
) -> Trainer:
trainer_conf = config.train.trainer
logger.opt(lazy=True).debug(
@ -152,13 +149,8 @@ def build_trainer(
)
)
train_config = trainer_conf.model_dump(exclude_none=True)
if num_epochs is not None:
train_config["max_epochs"] = num_epochs
return Trainer(
**train_config,
**trainer_conf.model_dump(exclude_none=True),
logger=train_logger,
callbacks=[
build_checkpoint_callback(