mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Compare commits
3 Commits
24dcc47e73
...
6d0a73dda6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d0a73dda6 | ||
|
|
5736421023 | ||
|
|
3913d2d350 |
@ -30,6 +30,7 @@ dependencies = [
|
||||
"hydra-core>=1.3.2",
|
||||
"numba>=0.60",
|
||||
"loguru>=0.7.3",
|
||||
"deepmerge>=2.0",
|
||||
]
|
||||
requires-python = ">=3.9,<3.13"
|
||||
readme = "README.md"
|
||||
|
||||
@ -8,6 +8,7 @@ from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
@ -61,6 +62,7 @@ class BatDetect2API:
|
||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
@ -75,6 +77,7 @@ class BatDetect2API:
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
num_epochs=num_epochs,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
@ -243,7 +246,9 @@ class BatDetect2API:
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
|
||||
config = config or stored_config
|
||||
config = (
|
||||
merge_configs(config, stored_config) if config else stored_config
|
||||
)
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
@ -19,8 +21,28 @@ BatDetect2 - Detection and Classification
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def cli(
|
||||
verbose: int = 0,
|
||||
):
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
logger.remove()
|
||||
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.enable("batdetect2")
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
||||
@ -35,9 +35,11 @@ def summary(
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
field=field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
print(f"Number of annotated clips: {len(dataset)}")
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -22,12 +21,6 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--workers", "num_workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
@ -37,27 +30,18 @@ def evaluate_command(
|
||||
num_workers: Optional[int] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
test_annotations = load_dataset_from_config(
|
||||
test_dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Loaded {num_annotations} test examples",
|
||||
num_annotations=len(test_annotations),
|
||||
|
||||
@ -1,4 +1,3 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -21,15 +20,10 @@ __all__ = ["train_command"]
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option("--num-epochs", type=int)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--seed", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Optional[Path] = None,
|
||||
@ -40,11 +34,11 @@ def train_command(
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
@ -54,14 +48,6 @@ def train_command(
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading configuration...")
|
||||
@ -99,7 +85,10 @@ def train_command(
|
||||
if model_path is None:
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(model_path)
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
config=conf if config is not None else None,
|
||||
)
|
||||
|
||||
return api.train(
|
||||
train_annotations=train_annotations,
|
||||
@ -108,6 +97,7 @@ def train_command(
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
log_dir=log_dir,
|
||||
num_epochs=num_epochs,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"Registry",
|
||||
"merge_configs",
|
||||
]
|
||||
|
||||
@ -11,12 +11,14 @@ configuration sections.
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
|
||||
import yaml
|
||||
from deepmerge.merger import Merger
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from soundevent.data import PathLike
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"merge_configs",
|
||||
]
|
||||
|
||||
|
||||
@ -178,3 +180,19 @@ def load_config(
|
||||
config = get_object_field(config, field)
|
||||
|
||||
return schema.model_validate(config or {})
|
||||
|
||||
|
||||
default_merger = Merger(
|
||||
[],
|
||||
["override"],
|
||||
["override"],
|
||||
)
|
||||
|
||||
|
||||
def merge_configs(config1: T, config2: T) -> T:
|
||||
"""Merge two configuration objects."""
|
||||
model = type(config1)
|
||||
dict1 = config1.model_dump()
|
||||
dict2 = config2.model_dump()
|
||||
merged = default_merger.merge(dict1, dict2)
|
||||
return model.model_validate(merged)
|
||||
|
||||
@ -47,6 +47,7 @@ def train(
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
@ -107,6 +108,7 @@ def train(
|
||||
targets=targets,
|
||||
),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
num_epochs=num_epochs,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
@ -128,6 +130,7 @@ def build_trainer(
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = config.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
@ -149,8 +152,13 @@ def build_trainer(
|
||||
)
|
||||
)
|
||||
|
||||
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||
|
||||
if num_epochs is not None:
|
||||
train_config["max_epochs"] = num_epochs
|
||||
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
**train_config,
|
||||
logger=train_logger,
|
||||
callbacks=[
|
||||
build_checkpoint_callback(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user