mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Polish evaluate and train CLI
This commit is contained in:
parent
f9056eb19a
commit
bf5b88016a
@ -8,7 +8,6 @@ from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.core import merge_configs
|
||||
from batdetect2.data import Dataset, load_dataset_from_config
|
||||
from batdetect2.evaluate import (
|
||||
DEFAULT_EVAL_DIR,
|
||||
@ -47,7 +46,7 @@ from batdetect2.postprocess import (
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import TargetProtocol, build_targets
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
TrainingConfig,
|
||||
@ -110,6 +109,7 @@ class BatDetect2API:
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
):
|
||||
@ -118,7 +118,7 @@ class BatDetect2API:
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
model_config=self.model_config,
|
||||
model_config=model_config or self.model_config,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
train_workers=train_workers,
|
||||
@ -149,6 +149,7 @@ class BatDetect2API:
|
||||
num_epochs: int | None = None,
|
||||
run_name: str | None = None,
|
||||
seed: int | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
@ -161,7 +162,7 @@ class BatDetect2API:
|
||||
val_annotations=val_annotations,
|
||||
model=self.model,
|
||||
targets=self.targets,
|
||||
model_config=self.model_config,
|
||||
model_config=model_config or self.model_config,
|
||||
preprocessor=self.preprocessor,
|
||||
audio_loader=self.audio_loader,
|
||||
train_workers=train_workers,
|
||||
@ -499,76 +500,77 @@ class BatDetect2API:
|
||||
def from_checkpoint(
|
||||
cls,
|
||||
path: data.PathLike,
|
||||
config: BatDetect2Config | None = None,
|
||||
targets: TargetProtocol | None = None,
|
||||
targets_config: TargetConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import AudioConfig
|
||||
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
|
||||
# Reconstruct a full BatDetect2Config from the checkpoint's
|
||||
# ModelConfig, then overlay any caller-supplied overrides.
|
||||
base = BatDetect2Config(
|
||||
model=model_config,
|
||||
audio=AudioConfig(samplerate=model_config.samplerate),
|
||||
audio_config = audio_config or AudioConfig(
|
||||
samplerate=model_config.samplerate,
|
||||
)
|
||||
config = merge_configs(base, config) if config else base
|
||||
train_config = train_config or TrainingConfig()
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
outputs_config = outputs_config or OutputsConfig()
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets(config=config.model.targets)
|
||||
else:
|
||||
target_config = getattr(targets, "config", None)
|
||||
if target_config is not None:
|
||||
config.model.targets = target_config
|
||||
if (
|
||||
targets_config is not None
|
||||
and targets_config != model_config.targets
|
||||
):
|
||||
targets = build_targets(config=targets_config)
|
||||
model = build_model_with_new_targets(
|
||||
model=model,
|
||||
targets=targets,
|
||||
)
|
||||
model_config = model_config.model_copy(
|
||||
update={"targets": targets_config}
|
||||
)
|
||||
|
||||
audio_loader = build_audio_loader(config=config.audio)
|
||||
targets = build_targets(config=model_config.targets)
|
||||
|
||||
audio_loader = build_audio_loader(config=audio_config)
|
||||
|
||||
preprocessor = build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
config=config.model.preprocess,
|
||||
config=model_config.preprocess,
|
||||
)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
preprocessor,
|
||||
config=config.model.postprocess,
|
||||
config=model_config.postprocess,
|
||||
)
|
||||
|
||||
formatter = build_output_formatter(
|
||||
targets,
|
||||
config=config.outputs.format,
|
||||
config=outputs_config.format,
|
||||
)
|
||||
|
||||
output_transform = build_output_transform(
|
||||
config=config.outputs.transform,
|
||||
config=outputs_config.transform,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
evaluator = build_evaluator(
|
||||
config=config.evaluation,
|
||||
config=evaluation_config,
|
||||
targets=targets,
|
||||
transform=output_transform,
|
||||
)
|
||||
|
||||
targets_changed = targets is not None or (
|
||||
config.model.targets.model_dump(mode="json")
|
||||
!= model_config.targets.model_dump(mode="json")
|
||||
)
|
||||
if targets_changed:
|
||||
model = build_model_with_new_targets(
|
||||
model=model,
|
||||
targets=targets,
|
||||
)
|
||||
|
||||
model.preprocessor = preprocessor
|
||||
model.postprocessor = postprocessor
|
||||
model.targets = targets
|
||||
|
||||
return cls(
|
||||
model_config=config.model,
|
||||
audio_config=config.audio,
|
||||
train_config=config.train,
|
||||
evaluation_config=config.evaluation,
|
||||
inference_config=config.inference,
|
||||
outputs_config=config.outputs,
|
||||
model_config=model_config,
|
||||
audio_config=audio_config,
|
||||
train_config=train_config,
|
||||
evaluation_config=evaluation_config,
|
||||
inference_config=inference_config,
|
||||
outputs_config=outputs_config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@ -12,9 +12,13 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
|
||||
|
||||
@cli.command(name="evaluate")
|
||||
@click.argument("model-path", type=click.Path(exists=True))
|
||||
@click.argument("model_path", type=click.Path(exists=True))
|
||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||
@click.option("--config", "config_path", type=click.Path())
|
||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||
@click.option("--audio-config", type=click.Path(exists=True))
|
||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@ -24,15 +28,23 @@ def evaluate_command(
|
||||
model_path: Path,
|
||||
test_dataset: Path,
|
||||
base_dir: Path,
|
||||
config_path: Path | None,
|
||||
targets_config: Path | None,
|
||||
audio_config: Path | None,
|
||||
evaluation_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: int = 0,
|
||||
experiment_name: str | None = None,
|
||||
run_name: str | None = None,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import load_full_config
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate import load_evaluation_config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import load_target_config
|
||||
|
||||
logger.info("Initiating evaluation process...")
|
||||
|
||||
@ -46,11 +58,38 @@ def evaluate_command(
|
||||
num_annotations=len(test_annotations),
|
||||
)
|
||||
|
||||
config = None
|
||||
if config_path is not None:
|
||||
config = load_full_config(config_path)
|
||||
target_conf = (
|
||||
load_target_config(targets_config)
|
||||
if targets_config is not None
|
||||
else None
|
||||
)
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
eval_conf = (
|
||||
load_evaluation_config(evaluation_config)
|
||||
if evaluation_config is not None
|
||||
else None
|
||||
)
|
||||
inference_conf = (
|
||||
InferenceConfig.load(inference_config)
|
||||
if inference_config is not None
|
||||
else None
|
||||
)
|
||||
outputs_conf = (
|
||||
OutputsConfig.load(outputs_config)
|
||||
if outputs_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
targets_config=target_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
)
|
||||
|
||||
api.evaluate(
|
||||
test_annotations,
|
||||
|
||||
@ -13,10 +13,14 @@ __all__ = ["train_command"]
|
||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||
@click.option("--model", "model_path", type=click.Path(exists=True))
|
||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||
@click.option("--model-config", type=click.Path(exists=True))
|
||||
@click.option("--training-config", type=click.Path(exists=True))
|
||||
@click.option("--audio-config", type=click.Path(exists=True))
|
||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option("--num-epochs", type=int)
|
||||
@ -29,9 +33,13 @@ def train_command(
|
||||
model_path: Path | None = None,
|
||||
ckpt_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
config: Path | None = None,
|
||||
targets_config: Path | None = None,
|
||||
config_field: str | None = None,
|
||||
model_config: Path | None = None,
|
||||
training_config: Path | None = None,
|
||||
audio_config: Path | None = None,
|
||||
evaluation_config: Path | None = None,
|
||||
inference_config: Path | None = None,
|
||||
outputs_config: Path | None = None,
|
||||
seed: int | None = None,
|
||||
num_epochs: int | None = None,
|
||||
train_workers: int = 0,
|
||||
@ -40,27 +48,56 @@ def train_command(
|
||||
run_name: str | None = None,
|
||||
):
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.config import (
|
||||
BatDetect2Config,
|
||||
load_full_config,
|
||||
)
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate import load_evaluation_config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import load_target_config
|
||||
from batdetect2.train import load_train_config
|
||||
|
||||
logger.info("Initiating training process...")
|
||||
|
||||
logger.info("Loading configuration...")
|
||||
conf = (
|
||||
load_full_config(config, field=config_field)
|
||||
if config is not None
|
||||
else BatDetect2Config()
|
||||
target_conf = (
|
||||
load_target_config(targets_config)
|
||||
if targets_config is not None
|
||||
else None
|
||||
)
|
||||
model_conf = (
|
||||
ModelConfig.load(model_config) if model_config is not None else None
|
||||
)
|
||||
train_conf = (
|
||||
load_train_config(training_config)
|
||||
if training_config is not None
|
||||
else None
|
||||
)
|
||||
audio_conf = (
|
||||
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||
)
|
||||
eval_conf = (
|
||||
load_evaluation_config(evaluation_config)
|
||||
if evaluation_config is not None
|
||||
else None
|
||||
)
|
||||
inference_conf = (
|
||||
InferenceConfig.load(inference_config)
|
||||
if inference_config is not None
|
||||
else None
|
||||
)
|
||||
outputs_conf = (
|
||||
OutputsConfig.load(outputs_config)
|
||||
if outputs_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if targets_config is not None:
|
||||
logger.info("Loading targets configuration...")
|
||||
conf = conf.model_copy(
|
||||
update=dict(targets=load_target_config(targets_config))
|
||||
)
|
||||
if target_conf is not None:
|
||||
logger.info("Loaded targets configuration.")
|
||||
|
||||
if model_conf is not None and target_conf is not None:
|
||||
model_conf = model_conf.model_copy(update={"targets": target_conf})
|
||||
|
||||
logger.info("Loading training dataset...")
|
||||
train_annotations = load_dataset_from_config(train_dataset)
|
||||
@ -81,12 +118,40 @@ def train_command(
|
||||
|
||||
logger.info("Configuration and data loaded. Starting training...")
|
||||
|
||||
if model_path is not None and model_conf is not None:
|
||||
raise click.UsageError(
|
||||
"--model-config cannot be used with --model. "
|
||||
"Checkpoint model configuration is loaded from the checkpoint."
|
||||
)
|
||||
|
||||
if model_path is None:
|
||||
conf = BatDetect2Config()
|
||||
if model_conf is not None:
|
||||
conf.model = model_conf
|
||||
elif target_conf is not None:
|
||||
conf.model = conf.model.model_copy(update={"targets": target_conf})
|
||||
|
||||
if train_conf is not None:
|
||||
conf.train = train_conf
|
||||
if audio_conf is not None:
|
||||
conf.audio = audio_conf
|
||||
if eval_conf is not None:
|
||||
conf.evaluation = eval_conf
|
||||
if inference_conf is not None:
|
||||
conf.inference = inference_conf
|
||||
if outputs_conf is not None:
|
||||
conf.outputs = outputs_conf
|
||||
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
config=conf if config is not None else None,
|
||||
targets_config=target_conf,
|
||||
train_config=train_conf,
|
||||
audio_config=audio_conf,
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
)
|
||||
|
||||
return api.train(
|
||||
|
||||
@ -8,7 +8,6 @@ configuration data from files, with optional support for accessing nested
|
||||
configuration sections.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from typing import Any, Type, TypeVar, overload
|
||||
|
||||
import yaml
|
||||
@ -16,17 +15,14 @@ from deepmerge.merger import Merger
|
||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||
from soundevent.data import PathLike
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
from typing_extensions import Self
|
||||
else:
|
||||
from typing import Self
|
||||
|
||||
__all__ = [
|
||||
"BaseConfig",
|
||||
"load_config",
|
||||
"merge_configs",
|
||||
]
|
||||
|
||||
C = TypeVar("C", bound="BaseConfig")
|
||||
|
||||
|
||||
class BaseConfig(BaseModel):
|
||||
"""Base class for all configuration models in BatDetect2.
|
||||
@ -73,8 +69,8 @@ class BaseConfig(BaseModel):
|
||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||
|
||||
@classmethod
|
||||
def load(cls: Self, path: PathLike, field: str | None = None) -> Self:
|
||||
return load_config(path, schema=cls, field=field) # type: ignore
|
||||
def load(cls: Type[C], path: PathLike, field: str | None = None) -> C:
|
||||
return load_config(path, schema=cls, field=field)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@ -201,7 +201,10 @@ def test_user_can_load_checkpoint_and_finetune(
|
||||
config.train.train_loader.batch_size = 1
|
||||
config.train.train_loader.augmentations.enabled = False
|
||||
|
||||
api = BatDetect2API.from_checkpoint(checkpoint_path, config=config)
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
train_config=config.train,
|
||||
)
|
||||
finetune_dir = tmp_path / "finetuned"
|
||||
|
||||
api.train(
|
||||
@ -234,13 +237,13 @@ def test_user_can_load_checkpoint_with_new_targets(
|
||||
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
targets=sample_targets,
|
||||
targets_config=sample_targets.config,
|
||||
)
|
||||
source_detector = cast(Detector, source_model.detector)
|
||||
detector = cast(Detector, api.model.detector)
|
||||
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||
|
||||
assert api.targets is sample_targets
|
||||
assert api.targets.config == sample_targets.config
|
||||
assert detector.num_classes == len(sample_targets.class_names)
|
||||
assert (
|
||||
classifier_head.classifier.out_channels
|
||||
@ -255,6 +258,43 @@ def test_user_can_load_checkpoint_with_new_targets(
|
||||
torch.testing.assert_close(target_backbone[key], value)
|
||||
|
||||
|
||||
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""User story: same targets config does not rebuild prediction heads."""
|
||||
|
||||
module = build_training_module(model_config=BatDetect2Config().model)
|
||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||
checkpoint_path = tmp_path / "same_targets.ckpt"
|
||||
trainer.strategy.connect(module)
|
||||
trainer.save_checkpoint(checkpoint_path)
|
||||
|
||||
source_model, source_model_config = load_model_from_checkpoint(
|
||||
checkpoint_path
|
||||
)
|
||||
source_detector = cast(Detector, source_model.detector)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
checkpoint_path,
|
||||
targets_config=source_model_config.targets,
|
||||
)
|
||||
detector = cast(Detector, api.model.detector)
|
||||
|
||||
for key, value in source_detector.classifier_head.state_dict().items():
|
||||
assert key in detector.classifier_head.state_dict()
|
||||
torch.testing.assert_close(
|
||||
detector.classifier_head.state_dict()[key],
|
||||
value,
|
||||
)
|
||||
|
||||
for key, value in source_detector.bbox_head.state_dict().items():
|
||||
assert key in detector.bbox_head.state_dict()
|
||||
torch.testing.assert_close(
|
||||
detector.bbox_head.state_dict()[key],
|
||||
value,
|
||||
)
|
||||
|
||||
|
||||
def test_user_can_finetune_only_heads(
|
||||
tmp_path: Path,
|
||||
example_annotations,
|
||||
|
||||
@ -176,10 +176,10 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
||||
|
||||
api = BatDetect2API.from_checkpoint(path)
|
||||
|
||||
assert api.config.model.model_dump(
|
||||
assert api.model_config.model_dump(
|
||||
mode="json"
|
||||
) == module.model_config.model_dump(mode="json")
|
||||
assert api.config.audio.samplerate == module.model_config.samplerate
|
||||
assert api.audio_config.samplerate == module.model_config.samplerate
|
||||
|
||||
|
||||
def test_train_smoke_produces_loadable_checkpoint(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user