mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 23:30:21 +02:00
Polish evaluate and train CLI
This commit is contained in:
parent
f9056eb19a
commit
bf5b88016a
@ -8,7 +8,6 @@ from soundevent.audio.files import get_audio_files
|
|||||||
|
|
||||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.core import merge_configs
|
|
||||||
from batdetect2.data import Dataset, load_dataset_from_config
|
from batdetect2.data import Dataset, load_dataset_from_config
|
||||||
from batdetect2.evaluate import (
|
from batdetect2.evaluate import (
|
||||||
DEFAULT_EVAL_DIR,
|
DEFAULT_EVAL_DIR,
|
||||||
@ -47,7 +46,7 @@ from batdetect2.postprocess import (
|
|||||||
build_postprocessor,
|
build_postprocessor,
|
||||||
)
|
)
|
||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
from batdetect2.targets import TargetProtocol, build_targets
|
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
DEFAULT_CHECKPOINT_DIR,
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
@ -110,6 +109,7 @@ class BatDetect2API:
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
|
model_config: ModelConfig | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
):
|
):
|
||||||
@ -118,7 +118,7 @@ class BatDetect2API:
|
|||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
model_config=self.model_config,
|
model_config=model_config or self.model_config,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
@ -149,6 +149,7 @@ class BatDetect2API:
|
|||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
|
model_config: ModelConfig | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
@ -161,7 +162,7 @@ class BatDetect2API:
|
|||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
targets=self.targets,
|
targets=self.targets,
|
||||||
model_config=self.model_config,
|
model_config=model_config or self.model_config,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
@ -499,76 +500,77 @@ class BatDetect2API:
|
|||||||
def from_checkpoint(
|
def from_checkpoint(
|
||||||
cls,
|
cls,
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: BatDetect2Config | None = None,
|
targets_config: TargetConfig | None = None,
|
||||||
targets: TargetProtocol | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
|
train_config: TrainingConfig | None = None,
|
||||||
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
|
inference_config: InferenceConfig | None = None,
|
||||||
|
outputs_config: OutputsConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
from batdetect2.audio import AudioConfig
|
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
# Reconstruct a full BatDetect2Config from the checkpoint's
|
audio_config = audio_config or AudioConfig(
|
||||||
# ModelConfig, then overlay any caller-supplied overrides.
|
samplerate=model_config.samplerate,
|
||||||
base = BatDetect2Config(
|
|
||||||
model=model_config,
|
|
||||||
audio=AudioConfig(samplerate=model_config.samplerate),
|
|
||||||
)
|
)
|
||||||
config = merge_configs(base, config) if config else base
|
train_config = train_config or TrainingConfig()
|
||||||
|
evaluation_config = evaluation_config or EvaluationConfig()
|
||||||
|
inference_config = inference_config or InferenceConfig()
|
||||||
|
outputs_config = outputs_config or OutputsConfig()
|
||||||
|
|
||||||
if targets is None:
|
if (
|
||||||
targets = build_targets(config=config.model.targets)
|
targets_config is not None
|
||||||
else:
|
and targets_config != model_config.targets
|
||||||
target_config = getattr(targets, "config", None)
|
):
|
||||||
if target_config is not None:
|
targets = build_targets(config=targets_config)
|
||||||
config.model.targets = target_config
|
model = build_model_with_new_targets(
|
||||||
|
model=model,
|
||||||
|
targets=targets,
|
||||||
|
)
|
||||||
|
model_config = model_config.model_copy(
|
||||||
|
update={"targets": targets_config}
|
||||||
|
)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
targets = build_targets(config=model_config.targets)
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.model.preprocess,
|
config=model_config.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.model.postprocess,
|
config=model_config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatter = build_output_formatter(
|
formatter = build_output_formatter(
|
||||||
targets,
|
targets,
|
||||||
config=config.outputs.format,
|
config=outputs_config.format,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=config.outputs.transform,
|
config=outputs_config.transform,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(
|
||||||
config=config.evaluation,
|
config=evaluation_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
transform=output_transform,
|
transform=output_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
targets_changed = targets is not None or (
|
|
||||||
config.model.targets.model_dump(mode="json")
|
|
||||||
!= model_config.targets.model_dump(mode="json")
|
|
||||||
)
|
|
||||||
if targets_changed:
|
|
||||||
model = build_model_with_new_targets(
|
|
||||||
model=model,
|
|
||||||
targets=targets,
|
|
||||||
)
|
|
||||||
|
|
||||||
model.preprocessor = preprocessor
|
model.preprocessor = preprocessor
|
||||||
model.postprocessor = postprocessor
|
model.postprocessor = postprocessor
|
||||||
model.targets = targets
|
model.targets = targets
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
model_config=config.model,
|
model_config=model_config,
|
||||||
audio_config=config.audio,
|
audio_config=audio_config,
|
||||||
train_config=config.train,
|
train_config=train_config,
|
||||||
evaluation_config=config.evaluation,
|
evaluation_config=evaluation_config,
|
||||||
inference_config=config.inference,
|
inference_config=inference_config,
|
||||||
outputs_config=config.outputs,
|
outputs_config=outputs_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
|||||||
@ -12,9 +12,13 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
|||||||
|
|
||||||
|
|
||||||
@cli.command(name="evaluate")
|
@cli.command(name="evaluate")
|
||||||
@click.argument("model-path", type=click.Path(exists=True))
|
@click.argument("model_path", type=click.Path(exists=True))
|
||||||
@click.argument("test_dataset", type=click.Path(exists=True))
|
@click.argument("test_dataset", type=click.Path(exists=True))
|
||||||
@click.option("--config", "config_path", type=click.Path())
|
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||||
|
@click.option("--audio-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--inference-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||||
@click.option("--experiment-name", type=str)
|
@click.option("--experiment-name", type=str)
|
||||||
@ -24,15 +28,23 @@ def evaluate_command(
|
|||||||
model_path: Path,
|
model_path: Path,
|
||||||
test_dataset: Path,
|
test_dataset: Path,
|
||||||
base_dir: Path,
|
base_dir: Path,
|
||||||
config_path: Path | None,
|
targets_config: Path | None,
|
||||||
|
audio_config: Path | None,
|
||||||
|
evaluation_config: Path | None,
|
||||||
|
inference_config: Path | None,
|
||||||
|
outputs_config: Path | None,
|
||||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||||
num_workers: int = 0,
|
num_workers: int = 0,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import load_full_config
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.evaluate import load_evaluation_config
|
||||||
|
from batdetect2.inference import InferenceConfig
|
||||||
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
from batdetect2.targets import load_target_config
|
||||||
|
|
||||||
logger.info("Initiating evaluation process...")
|
logger.info("Initiating evaluation process...")
|
||||||
|
|
||||||
@ -46,11 +58,38 @@ def evaluate_command(
|
|||||||
num_annotations=len(test_annotations),
|
num_annotations=len(test_annotations),
|
||||||
)
|
)
|
||||||
|
|
||||||
config = None
|
target_conf = (
|
||||||
if config_path is not None:
|
load_target_config(targets_config)
|
||||||
config = load_full_config(config_path)
|
if targets_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
audio_conf = (
|
||||||
|
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||||
|
)
|
||||||
|
eval_conf = (
|
||||||
|
load_evaluation_config(evaluation_config)
|
||||||
|
if evaluation_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
inference_conf = (
|
||||||
|
InferenceConfig.load(inference_config)
|
||||||
|
if inference_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
outputs_conf = (
|
||||||
|
OutputsConfig.load(outputs_config)
|
||||||
|
if outputs_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(model_path, config=config)
|
api = BatDetect2API.from_checkpoint(
|
||||||
|
model_path,
|
||||||
|
targets_config=target_conf,
|
||||||
|
audio_config=audio_conf,
|
||||||
|
evaluation_config=eval_conf,
|
||||||
|
inference_config=inference_conf,
|
||||||
|
outputs_config=outputs_conf,
|
||||||
|
)
|
||||||
|
|
||||||
api.evaluate(
|
api.evaluate(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
|
|||||||
@ -13,10 +13,14 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--val-dataset", type=click.Path(exists=True))
|
@click.option("--val-dataset", type=click.Path(exists=True))
|
||||||
@click.option("--model", "model_path", type=click.Path(exists=True))
|
@click.option("--model", "model_path", type=click.Path(exists=True))
|
||||||
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
@click.option("--targets", "targets_config", type=click.Path(exists=True))
|
||||||
|
@click.option("--model-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--training-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--audio-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--inference-config", type=click.Path(exists=True))
|
||||||
|
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||||
@click.option("--log-dir", type=click.Path(exists=True))
|
@click.option("--log-dir", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
|
||||||
@click.option("--config-field", type=str)
|
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int)
|
||||||
@click.option("--num-epochs", type=int)
|
@click.option("--num-epochs", type=int)
|
||||||
@ -29,9 +33,13 @@ def train_command(
|
|||||||
model_path: Path | None = None,
|
model_path: Path | None = None,
|
||||||
ckpt_dir: Path | None = None,
|
ckpt_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
config: Path | None = None,
|
|
||||||
targets_config: Path | None = None,
|
targets_config: Path | None = None,
|
||||||
config_field: str | None = None,
|
model_config: Path | None = None,
|
||||||
|
training_config: Path | None = None,
|
||||||
|
audio_config: Path | None = None,
|
||||||
|
evaluation_config: Path | None = None,
|
||||||
|
inference_config: Path | None = None,
|
||||||
|
outputs_config: Path | None = None,
|
||||||
seed: int | None = None,
|
seed: int | None = None,
|
||||||
num_epochs: int | None = None,
|
num_epochs: int | None = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
@ -40,27 +48,56 @@ def train_command(
|
|||||||
run_name: str | None = None,
|
run_name: str | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import (
|
from batdetect2.audio import AudioConfig
|
||||||
BatDetect2Config,
|
from batdetect2.config import BatDetect2Config
|
||||||
load_full_config,
|
|
||||||
)
|
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
|
from batdetect2.evaluate import load_evaluation_config
|
||||||
|
from batdetect2.inference import InferenceConfig
|
||||||
|
from batdetect2.models import ModelConfig
|
||||||
|
from batdetect2.outputs import OutputsConfig
|
||||||
from batdetect2.targets import load_target_config
|
from batdetect2.targets import load_target_config
|
||||||
|
from batdetect2.train import load_train_config
|
||||||
|
|
||||||
logger.info("Initiating training process...")
|
logger.info("Initiating training process...")
|
||||||
|
|
||||||
logger.info("Loading configuration...")
|
logger.info("Loading configuration...")
|
||||||
conf = (
|
target_conf = (
|
||||||
load_full_config(config, field=config_field)
|
load_target_config(targets_config)
|
||||||
if config is not None
|
if targets_config is not None
|
||||||
else BatDetect2Config()
|
else None
|
||||||
|
)
|
||||||
|
model_conf = (
|
||||||
|
ModelConfig.load(model_config) if model_config is not None else None
|
||||||
|
)
|
||||||
|
train_conf = (
|
||||||
|
load_train_config(training_config)
|
||||||
|
if training_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
audio_conf = (
|
||||||
|
AudioConfig.load(audio_config) if audio_config is not None else None
|
||||||
|
)
|
||||||
|
eval_conf = (
|
||||||
|
load_evaluation_config(evaluation_config)
|
||||||
|
if evaluation_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
inference_conf = (
|
||||||
|
InferenceConfig.load(inference_config)
|
||||||
|
if inference_config is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
outputs_conf = (
|
||||||
|
OutputsConfig.load(outputs_config)
|
||||||
|
if outputs_config is not None
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
|
|
||||||
if targets_config is not None:
|
if target_conf is not None:
|
||||||
logger.info("Loading targets configuration...")
|
logger.info("Loaded targets configuration.")
|
||||||
conf = conf.model_copy(
|
|
||||||
update=dict(targets=load_target_config(targets_config))
|
if model_conf is not None and target_conf is not None:
|
||||||
)
|
model_conf = model_conf.model_copy(update={"targets": target_conf})
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
logger.info("Loading training dataset...")
|
||||||
train_annotations = load_dataset_from_config(train_dataset)
|
train_annotations = load_dataset_from_config(train_dataset)
|
||||||
@ -81,12 +118,40 @@ def train_command(
|
|||||||
|
|
||||||
logger.info("Configuration and data loaded. Starting training...")
|
logger.info("Configuration and data loaded. Starting training...")
|
||||||
|
|
||||||
|
if model_path is not None and model_conf is not None:
|
||||||
|
raise click.UsageError(
|
||||||
|
"--model-config cannot be used with --model. "
|
||||||
|
"Checkpoint model configuration is loaded from the checkpoint."
|
||||||
|
)
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
|
conf = BatDetect2Config()
|
||||||
|
if model_conf is not None:
|
||||||
|
conf.model = model_conf
|
||||||
|
elif target_conf is not None:
|
||||||
|
conf.model = conf.model.model_copy(update={"targets": target_conf})
|
||||||
|
|
||||||
|
if train_conf is not None:
|
||||||
|
conf.train = train_conf
|
||||||
|
if audio_conf is not None:
|
||||||
|
conf.audio = audio_conf
|
||||||
|
if eval_conf is not None:
|
||||||
|
conf.evaluation = eval_conf
|
||||||
|
if inference_conf is not None:
|
||||||
|
conf.inference = inference_conf
|
||||||
|
if outputs_conf is not None:
|
||||||
|
conf.outputs = outputs_conf
|
||||||
|
|
||||||
api = BatDetect2API.from_config(conf)
|
api = BatDetect2API.from_config(conf)
|
||||||
else:
|
else:
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
config=conf if config is not None else None,
|
targets_config=target_conf,
|
||||||
|
train_config=train_conf,
|
||||||
|
audio_config=audio_conf,
|
||||||
|
evaluation_config=eval_conf,
|
||||||
|
inference_config=inference_conf,
|
||||||
|
outputs_config=outputs_conf,
|
||||||
)
|
)
|
||||||
|
|
||||||
return api.train(
|
return api.train(
|
||||||
|
|||||||
@ -8,7 +8,6 @@ configuration data from files, with optional support for accessing nested
|
|||||||
configuration sections.
|
configuration sections.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import sys
|
|
||||||
from typing import Any, Type, TypeVar, overload
|
from typing import Any, Type, TypeVar, overload
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
@ -16,17 +15,14 @@ from deepmerge.merger import Merger
|
|||||||
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
from pydantic import BaseModel, ConfigDict, TypeAdapter
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
if sys.version_info < (3, 11):
|
|
||||||
from typing_extensions import Self
|
|
||||||
else:
|
|
||||||
from typing import Self
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseConfig",
|
"BaseConfig",
|
||||||
"load_config",
|
"load_config",
|
||||||
"merge_configs",
|
"merge_configs",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
C = TypeVar("C", bound="BaseConfig")
|
||||||
|
|
||||||
|
|
||||||
class BaseConfig(BaseModel):
|
class BaseConfig(BaseModel):
|
||||||
"""Base class for all configuration models in BatDetect2.
|
"""Base class for all configuration models in BatDetect2.
|
||||||
@ -73,8 +69,8 @@ class BaseConfig(BaseModel):
|
|||||||
return cls.model_validate(yaml.safe_load(yaml_str))
|
return cls.model_validate(yaml.safe_load(yaml_str))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls: Self, path: PathLike, field: str | None = None) -> Self:
|
def load(cls: Type[C], path: PathLike, field: str | None = None) -> C:
|
||||||
return load_config(path, schema=cls, field=field) # type: ignore
|
return load_config(path, schema=cls, field=field)
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|||||||
@ -201,7 +201,10 @@ def test_user_can_load_checkpoint_and_finetune(
|
|||||||
config.train.train_loader.batch_size = 1
|
config.train.train_loader.batch_size = 1
|
||||||
config.train.train_loader.augmentations.enabled = False
|
config.train.train_loader.augmentations.enabled = False
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(checkpoint_path, config=config)
|
api = BatDetect2API.from_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
train_config=config.train,
|
||||||
|
)
|
||||||
finetune_dir = tmp_path / "finetuned"
|
finetune_dir = tmp_path / "finetuned"
|
||||||
|
|
||||||
api.train(
|
api.train(
|
||||||
@ -234,13 +237,13 @@ def test_user_can_load_checkpoint_with_new_targets(
|
|||||||
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
targets=sample_targets,
|
targets_config=sample_targets.config,
|
||||||
)
|
)
|
||||||
source_detector = cast(Detector, source_model.detector)
|
source_detector = cast(Detector, source_model.detector)
|
||||||
detector = cast(Detector, api.model.detector)
|
detector = cast(Detector, api.model.detector)
|
||||||
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
||||||
|
|
||||||
assert api.targets is sample_targets
|
assert api.targets.config == sample_targets.config
|
||||||
assert detector.num_classes == len(sample_targets.class_names)
|
assert detector.num_classes == len(sample_targets.class_names)
|
||||||
assert (
|
assert (
|
||||||
classifier_head.classifier.out_channels
|
classifier_head.classifier.out_channels
|
||||||
@ -255,6 +258,43 @@ def test_user_can_load_checkpoint_with_new_targets(
|
|||||||
torch.testing.assert_close(target_backbone[key], value)
|
torch.testing.assert_close(target_backbone[key], value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> None:
|
||||||
|
"""User story: same targets config does not rebuild prediction heads."""
|
||||||
|
|
||||||
|
module = build_training_module(model_config=BatDetect2Config().model)
|
||||||
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
|
checkpoint_path = tmp_path / "same_targets.ckpt"
|
||||||
|
trainer.strategy.connect(module)
|
||||||
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
|
source_model, source_model_config = load_model_from_checkpoint(
|
||||||
|
checkpoint_path
|
||||||
|
)
|
||||||
|
source_detector = cast(Detector, source_model.detector)
|
||||||
|
|
||||||
|
api = BatDetect2API.from_checkpoint(
|
||||||
|
checkpoint_path,
|
||||||
|
targets_config=source_model_config.targets,
|
||||||
|
)
|
||||||
|
detector = cast(Detector, api.model.detector)
|
||||||
|
|
||||||
|
for key, value in source_detector.classifier_head.state_dict().items():
|
||||||
|
assert key in detector.classifier_head.state_dict()
|
||||||
|
torch.testing.assert_close(
|
||||||
|
detector.classifier_head.state_dict()[key],
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
for key, value in source_detector.bbox_head.state_dict().items():
|
||||||
|
assert key in detector.bbox_head.state_dict()
|
||||||
|
torch.testing.assert_close(
|
||||||
|
detector.bbox_head.state_dict()[key],
|
||||||
|
value,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_user_can_finetune_only_heads(
|
def test_user_can_finetune_only_heads(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
example_annotations,
|
example_annotations,
|
||||||
|
|||||||
@ -176,10 +176,10 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
|||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(path)
|
api = BatDetect2API.from_checkpoint(path)
|
||||||
|
|
||||||
assert api.config.model.model_dump(
|
assert api.model_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == module.model_config.model_dump(mode="json")
|
) == module.model_config.model_dump(mode="json")
|
||||||
assert api.config.audio.samplerate == module.model_config.samplerate
|
assert api.audio_config.samplerate == module.model_config.samplerate
|
||||||
|
|
||||||
|
|
||||||
def test_train_smoke_produces_loadable_checkpoint(
|
def test_train_smoke_produces_loadable_checkpoint(
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user