mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
No commits in common. "6d0a73dda62b2806e8059904f1ae97b83142d49f" and "24dcc47e73b1c578d4957bc60534c80707e1f0e5" have entirely different histories.
6d0a73dda6
...
24dcc47e73
@ -30,7 +30,6 @@ dependencies = [
|
||||
"hydra-core>=1.3.2",
|
||||
"numba>=0.60",
|
||||
"loguru>=0.7.3",
|
||||
"deepmerge>=2.0",
|
||||
]
|
||||
requires-python = ">=3.9,<3.13"
|
||||
readme = "README.md"
|
||||
|
||||
@ -8,7 +8,6 @@ from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.evaluate import DEFAULT_EVAL_DIR, build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
@ -62,7 +61,6 @@ class BatDetect2API:
|
||||
checkpoint_dir: Optional[Path] = DEFAULT_CHECKPOINT_DIR,
|
||||
log_dir: Optional[Path] = DEFAULT_LOGS_DIR,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
@ -77,7 +75,6 @@ class BatDetect2API:
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
log_dir=log_dir,
|
||||
num_epochs=num_epochs,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
@ -246,9 +243,7 @@ class BatDetect2API:
|
||||
):
|
||||
model, stored_config = load_model_from_checkpoint(path)
|
||||
|
||||
config = (
|
||||
merge_configs(config, stored_config) if config else stored_config
|
||||
)
|
||||
config = config or stored_config
|
||||
|
||||
targets = build_targets(config=config.targets)
|
||||
|
||||
|
||||
@ -1,7 +1,5 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import sys
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
|
||||
@ -21,28 +19,8 @@ BatDetect2 - Detection and Classification
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def cli(
|
||||
verbose: int = 0,
|
||||
):
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
logger.remove()
|
||||
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.enable("batdetect2")
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
||||
@ -35,11 +35,9 @@ def summary(
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
base_dir = base_dir or Path.cwd()
|
||||
|
||||
dataset = load_dataset_from_config(
|
||||
dataset_config,
|
||||
field=field,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
print(f"Number of annotated clips: {len(dataset)}")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -21,6 +22,12 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--workers", "num_workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
@ -30,18 +37,27 @@ def evaluate_command(
|
||||
num_workers: Optional[int] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
test_annotations = load_dataset_from_config(
|
||||
test_dataset,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
"Loaded {num_annotations} test examples",
|
||||
num_annotations=len(test_annotations),
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -20,10 +21,15 @@ __all__ = ["train_command"]
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option("--num-epochs", type=int)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@click.option("--run-name", type=str)
|
||||
@click.option("--seed", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def train_command(
|
||||
train_dataset: Path,
|
||||
val_dataset: Optional[Path] = None,
|
||||
@ -34,11 +40,11 @@ def train_command(
|
||||
targets_config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
verbose: int = 0,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
@ -48,6 +54,14 @@ def train_command(
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.targets import load_target_config
|
||||
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading configuration...")
|
||||
@ -85,10 +99,7 @@ def train_command(
|
||||
if model_path is None:
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
config=conf if config is not None else None,
|
||||
)
|
||||
api = BatDetect2API.from_checkpoint(model_path)
|
||||
|
||||
return api.train(
|
||||
train_annotations=train_annotations,
|
||||
@ -97,7 +108,6 @@ def train_command(
|
||||
val_workers=val_workers,
|
||||
checkpoint_dir=ckpt_dir,
|
||||
log_dir=log_dir,
|
||||
num_epochs=num_epochs,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
seed=seed,
|
||||
|
||||
@ -1,9 +1,8 @@
|
||||
from batdetect2.core.configs import BaseConfig, load_config, merge_configs
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"Registry",
|
||||
"merge_configs",
|
||||
]
|
||||
|
||||
@ -11,14 +11,12 @@ configuration sections.
|
||||
from typing import Any, Optional, Type, TypeVar
|
||||
|
||||
import yaml
|
||||
from deepmerge.merger import Merger
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from soundevent.data import PathLike
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"merge_configs",
|
||||
]
|
||||
|
||||
|
||||
@ -180,19 +178,3 @@ def load_config(
|
||||
config = get_object_field(config, field)
|
||||
|
||||
return schema.model_validate(config or {})
|
||||
|
||||
|
||||
default_merger = Merger(
|
||||
[],
|
||||
["override"],
|
||||
["override"],
|
||||
)
|
||||
|
||||
|
||||
def merge_configs(config1: T, config2: T) -> T:
|
||||
"""Merge two configuration objects."""
|
||||
model = type(config1)
|
||||
dict1 = config1.model_dump()
|
||||
dict2 = config2.model_dump()
|
||||
merged = default_merger.merge(dict1, dict2)
|
||||
return model.model_validate(merged)
|
||||
|
||||
@ -47,7 +47,6 @@ def train(
|
||||
checkpoint_dir: Optional[Path] = None,
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
run_name: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
):
|
||||
@ -108,7 +107,6 @@ def train(
|
||||
targets=targets,
|
||||
),
|
||||
checkpoint_dir=checkpoint_dir,
|
||||
num_epochs=num_epochs,
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
@ -130,7 +128,6 @@ def build_trainer(
|
||||
log_dir: Optional[Path] = None,
|
||||
experiment_name: Optional[str] = None,
|
||||
run_name: Optional[str] = None,
|
||||
num_epochs: Optional[int] = None,
|
||||
) -> Trainer:
|
||||
trainer_conf = config.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
@ -152,13 +149,8 @@ def build_trainer(
|
||||
)
|
||||
)
|
||||
|
||||
train_config = trainer_conf.model_dump(exclude_none=True)
|
||||
|
||||
if num_epochs is not None:
|
||||
train_config["max_epochs"] = num_epochs
|
||||
|
||||
return Trainer(
|
||||
**train_config,
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
logger=train_logger,
|
||||
callbacks=[
|
||||
build_checkpoint_callback(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user