mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
refactor: decouple model metadata from target configs
This commit is contained in:
parent
e33053614a
commit
57236fc82a
@ -12,7 +12,6 @@ if TYPE_CHECKING:
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from batdetect2.audio import AudioConfig, AudioLoader
|
from batdetect2.audio import AudioConfig, AudioLoader
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.data import Dataset
|
from batdetect2.data import Dataset
|
||||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
@ -483,46 +482,70 @@ class BatDetect2API:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_config(
|
def from_config(
|
||||||
cls,
|
cls,
|
||||||
config: BatDetect2Config,
|
model_config: ModelConfig | None = None,
|
||||||
|
targets_config: TargetConfig | None = None,
|
||||||
|
audio_config: AudioConfig | None = None,
|
||||||
|
train_config: TrainingConfig | None = None,
|
||||||
|
evaluation_config: EvaluationConfig | None = None,
|
||||||
|
inference_config: InferenceConfig | None = None,
|
||||||
|
outputs_config: OutputsConfig | None = None,
|
||||||
|
logging_config: AppLoggingConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||||
from batdetect2.evaluate import build_evaluator
|
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||||
from batdetect2.models import build_model
|
from batdetect2.inference import InferenceConfig
|
||||||
|
from batdetect2.logging import AppLoggingConfig
|
||||||
|
from batdetect2.models import ModelConfig, build_model
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
|
OutputsConfig,
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
build_output_transform,
|
build_output_transform,
|
||||||
)
|
)
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_roi_mapping, build_targets
|
from batdetect2.targets import (
|
||||||
|
TargetConfig,
|
||||||
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
)
|
||||||
|
from batdetect2.train import TrainingConfig
|
||||||
|
|
||||||
targets = build_targets(config=config.model.targets)
|
model_config = model_config or ModelConfig()
|
||||||
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
targets_config = targets_config or TargetConfig()
|
||||||
|
audio_config = audio_config or AudioConfig()
|
||||||
|
train_config = train_config or TrainingConfig()
|
||||||
|
evaluation_config = evaluation_config or EvaluationConfig()
|
||||||
|
inference_config = inference_config or InferenceConfig()
|
||||||
|
outputs_config = outputs_config or OutputsConfig()
|
||||||
|
logging_config = logging_config or AppLoggingConfig()
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
targets = build_targets(config=targets_config)
|
||||||
|
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||||
|
|
||||||
|
audio_loader = build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.model.preprocess,
|
config=model_config.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.model.postprocess,
|
config=model_config.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
formatter = build_output_formatter(
|
formatter = build_output_formatter(
|
||||||
targets,
|
targets,
|
||||||
config=config.outputs.format,
|
config=outputs_config.format,
|
||||||
)
|
)
|
||||||
output_transform = build_output_transform(
|
output_transform = build_output_transform(
|
||||||
config=config.outputs.transform,
|
config=outputs_config.transform,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(
|
evaluator = build_evaluator(
|
||||||
config=config.evaluation,
|
config=evaluation_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
transform=output_transform,
|
transform=output_transform,
|
||||||
@ -531,27 +554,27 @@ class BatDetect2API:
|
|||||||
# NOTE: Build separate instances of preprocessor and postprocessor
|
# NOTE: Build separate instances of preprocessor and postprocessor
|
||||||
# to avoid device mismatch errors
|
# to avoid device mismatch errors
|
||||||
model = build_model(
|
model = build_model(
|
||||||
config=config.model,
|
config=model_config,
|
||||||
targets=targets,
|
class_names=targets.class_names,
|
||||||
roi_mapper=roi_mapper,
|
dimension_names=roi_mapper.dimension_names,
|
||||||
preprocessor=build_preprocessor(
|
preprocessor=build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.model.preprocess,
|
config=model_config.preprocess,
|
||||||
),
|
),
|
||||||
postprocessor=build_postprocessor(
|
postprocessor=build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.model.postprocess,
|
config=model_config.postprocess,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return cls(
|
return cls(
|
||||||
model_config=config.model,
|
model_config=model_config,
|
||||||
audio_config=config.audio,
|
audio_config=audio_config,
|
||||||
train_config=config.train,
|
train_config=train_config,
|
||||||
evaluation_config=config.evaluation,
|
evaluation_config=evaluation_config,
|
||||||
inference_config=config.inference,
|
inference_config=inference_config,
|
||||||
outputs_config=config.outputs,
|
outputs_config=outputs_config,
|
||||||
logging_config=config.logging,
|
logging_config=logging_config,
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
audio_loader=audio_loader,
|
audio_loader=audio_loader,
|
||||||
@ -579,7 +602,6 @@ class BatDetect2API:
|
|||||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
from batdetect2.models import build_model_with_new_targets
|
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputsConfig,
|
OutputsConfig,
|
||||||
build_output_formatter,
|
build_output_formatter,
|
||||||
@ -587,8 +609,16 @@ class BatDetect2API:
|
|||||||
)
|
)
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_roi_mapping, build_targets
|
from batdetect2.targets import (
|
||||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
build_default_target_config,
|
||||||
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
check_target_compatibility,
|
||||||
|
)
|
||||||
|
from batdetect2.train import (
|
||||||
|
TrainingConfig,
|
||||||
|
load_model_from_checkpoint,
|
||||||
|
)
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
@ -600,24 +630,24 @@ class BatDetect2API:
|
|||||||
inference_config = inference_config or InferenceConfig()
|
inference_config = inference_config or InferenceConfig()
|
||||||
outputs_config = outputs_config or OutputsConfig()
|
outputs_config = outputs_config or OutputsConfig()
|
||||||
logging_config = logging_config or AppLoggingConfig()
|
logging_config = logging_config or AppLoggingConfig()
|
||||||
|
targets_config = targets_config or build_default_target_config(
|
||||||
|
class_names=model.class_names
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
|
||||||
targets_config is not None
|
|
||||||
and targets_config != model_config.targets
|
|
||||||
):
|
|
||||||
targets = build_targets(config=targets_config)
|
targets = build_targets(config=targets_config)
|
||||||
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
roi_mapper = build_roi_mapping(config=targets_config.roi)
|
||||||
model = build_model_with_new_targets(
|
|
||||||
model=model,
|
if not check_target_compatibility(targets, model.class_names):
|
||||||
targets=targets,
|
raise ValueError(
|
||||||
roi_mapper=roi_mapper,
|
"Provided targets_config is incompatible with the "
|
||||||
)
|
"checkpoint model: missing one or more model classes."
|
||||||
model_config = model_config.model_copy(
|
|
||||||
update={"targets": targets_config}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = build_targets(config=model_config.targets)
|
if model.dimension_names != roi_mapper.dimension_names:
|
||||||
roi_mapper = build_roi_mapping(config=model_config.targets.roi)
|
raise ValueError(
|
||||||
|
"Provided targets_config is incompatible with the "
|
||||||
|
"checkpoint model: mismatched dimension names."
|
||||||
|
)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=audio_config)
|
audio_loader = build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from batdetect2.inference.config import InferenceConfig
|
|||||||
from batdetect2.logging import AppLoggingConfig
|
from batdetect2.logging import AppLoggingConfig
|
||||||
from batdetect2.models import ModelConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.outputs import OutputsConfig
|
from batdetect2.outputs import OutputsConfig
|
||||||
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
__all__ = ["BatDetect2Config"]
|
__all__ = ["BatDetect2Config"]
|
||||||
@ -25,6 +26,7 @@ class BatDetect2Config(BaseConfig):
|
|||||||
default_factory=get_default_eval_config
|
default_factory=get_default_eval_config
|
||||||
)
|
)
|
||||||
model: ModelConfig = Field(default_factory=ModelConfig)
|
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||||
|
|||||||
@ -24,8 +24,8 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
|||||||
def run_evaluate(
|
def run_evaluate(
|
||||||
model: Model,
|
model: Model,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
roi_mapper: ROIMapperProtocol,
|
||||||
audio_loader: AudioLoader | None = None,
|
audio_loader: AudioLoader | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
audio_config: AudioConfig | None = None,
|
audio_config: AudioConfig | None = None,
|
||||||
@ -46,8 +46,6 @@ def run_evaluate(
|
|||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or model.preprocessor
|
||||||
targets = targets or model.targets
|
|
||||||
roi_mapper = roi_mapper or model.roi_mapper
|
|
||||||
|
|
||||||
loader = build_test_loader(
|
loader = build_test_loader(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
|
|||||||
@ -45,8 +45,16 @@ def run_batch_inference(
|
|||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
|
|
||||||
preprocessor = preprocessor or model.preprocessor
|
preprocessor = preprocessor or model.preprocessor
|
||||||
targets = targets or model.targets
|
|
||||||
roi_mapper = roi_mapper or model.roi_mapper
|
if targets is None:
|
||||||
|
raise ValueError(
|
||||||
|
"targets must be provided when running batch inference."
|
||||||
|
)
|
||||||
|
|
||||||
|
if roi_mapper is None:
|
||||||
|
raise ValueError(
|
||||||
|
"roi_mapper must be provided when running batch inference."
|
||||||
|
)
|
||||||
|
|
||||||
output_transform = output_transform or build_output_transform(
|
output_transform = output_transform or build_output_transform(
|
||||||
config=output_config.transform,
|
config=output_config.transform,
|
||||||
|
|||||||
@ -7,21 +7,37 @@ from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
|||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: Model,
|
||||||
|
targets: TargetProtocol | None = None,
|
||||||
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model = model
|
self.model = model
|
||||||
self.detection_threshold = detection_threshold
|
self.detection_threshold = detection_threshold
|
||||||
|
|
||||||
|
if output_transform is None and targets is None:
|
||||||
|
raise ValueError(
|
||||||
|
"targets must be provided when building inference output "
|
||||||
|
"transforms."
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_transform is None and roi_mapper is None:
|
||||||
|
raise ValueError(
|
||||||
|
"roi_mapper must be provided when building inference output "
|
||||||
|
"transforms."
|
||||||
|
)
|
||||||
|
|
||||||
self.output_transform = output_transform or build_output_transform(
|
self.output_transform = output_transform or build_output_transform(
|
||||||
targets=model.targets,
|
targets=targets,
|
||||||
roi_mapper=model.roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
def predict_step(
|
def predict_step(
|
||||||
|
|||||||
@ -26,11 +26,8 @@ The primary entry point for building a full, ready-to-use BatDetect2 model
|
|||||||
is the ``build_model`` factory function exported from this module.
|
is the ``build_model`` factory function exported from this module.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from soundevent.data import PathLike
|
|
||||||
|
|
||||||
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
@ -73,7 +70,6 @@ from batdetect2.postprocess.types import (
|
|||||||
)
|
)
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets.config import TargetConfig
|
|
||||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -131,10 +127,6 @@ class ModelConfig(BaseConfig):
|
|||||||
Parameters for converting raw model outputs into detections (NMS
|
Parameters for converting raw model outputs into detections (NMS
|
||||||
kernel, thresholds, top-k limit). Defaults to
|
kernel, thresholds, top-k limit). Defaults to
|
||||||
``PostprocessConfig()``.
|
``PostprocessConfig()``.
|
||||||
targets : TargetConfig
|
|
||||||
Detection and classification target definitions (class list,
|
|
||||||
detection target, bounding-box mapper). Defaults to
|
|
||||||
``TargetConfig()``.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
@ -143,23 +135,6 @@ class ModelConfig(BaseConfig):
|
|||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(
|
|
||||||
cls,
|
|
||||||
path: PathLike,
|
|
||||||
field: str | None = None,
|
|
||||||
extra: Literal["ignore", "allow", "forbid"] | None = None,
|
|
||||||
strict: bool | None = None,
|
|
||||||
targets: TargetConfig | None = None,
|
|
||||||
) -> "ModelConfig":
|
|
||||||
config = super().load(path, field, extra, strict)
|
|
||||||
|
|
||||||
if targets is None:
|
|
||||||
return config
|
|
||||||
|
|
||||||
return config.model_copy(update={"targets": targets})
|
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
@ -183,33 +158,32 @@ class Model(torch.nn.Module):
|
|||||||
postprocessor : PostprocessorProtocol
|
postprocessor : PostprocessorProtocol
|
||||||
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
Converts the raw ``ModelOutput`` from ``detector`` into a list of
|
||||||
per-clip detection tensors.
|
per-clip detection tensors.
|
||||||
targets : TargetProtocol
|
class_names : list[str]
|
||||||
Describes the set of target classes; used when building heads and
|
Class names corresponding to the model classification outputs.
|
||||||
during training target construction.
|
dimension_names : list[str]
|
||||||
roi_mapper : ROIMapperProtocol
|
Size-dimension names corresponding to the model size outputs.
|
||||||
Maps geometries to target-size channels and back.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
detector: DetectionModel
|
detector: DetectionModel
|
||||||
preprocessor: PreprocessorProtocol
|
preprocessor: PreprocessorProtocol
|
||||||
postprocessor: PostprocessorProtocol
|
postprocessor: PostprocessorProtocol
|
||||||
targets: TargetProtocol
|
class_names: list[str]
|
||||||
roi_mapper: ROIMapperProtocol
|
dimension_names: list[str]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
detector: DetectionModel,
|
detector: DetectionModel,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
postprocessor: PostprocessorProtocol,
|
postprocessor: PostprocessorProtocol,
|
||||||
targets: TargetProtocol,
|
class_names: list[str],
|
||||||
roi_mapper: ROIMapperProtocol,
|
dimension_names: list[str],
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.detector = detector
|
self.detector = detector
|
||||||
self.preprocessor = preprocessor
|
self.preprocessor = preprocessor
|
||||||
self.postprocessor = postprocessor
|
self.postprocessor = postprocessor
|
||||||
self.targets = targets
|
self.class_names = class_names
|
||||||
self.roi_mapper = roi_mapper
|
self.dimension_names = dimension_names
|
||||||
|
|
||||||
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
def forward(self, wav: torch.Tensor) -> list[ClipDetectionsTensor]:
|
||||||
"""Run the full detection pipeline on a waveform tensor.
|
"""Run the full detection pipeline on a waveform tensor.
|
||||||
@ -238,8 +212,8 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
def build_model(
|
def build_model(
|
||||||
config: ModelConfig | None = None,
|
config: ModelConfig | None = None,
|
||||||
targets: TargetProtocol | None = None,
|
class_names: list[str] | None = None,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
@ -254,11 +228,13 @@ def build_model(
|
|||||||
----------
|
----------
|
||||||
config : ModelConfig, optional
|
config : ModelConfig, optional
|
||||||
Full model configuration (samplerate, architecture, preprocessing,
|
Full model configuration (samplerate, architecture, preprocessing,
|
||||||
postprocessing, targets). Defaults to ``ModelConfig()`` if not
|
postprocessing). Defaults to ``ModelConfig()`` if not provided.
|
||||||
provided.
|
class_names : list[str], optional
|
||||||
targets : TargetProtocol, optional
|
Class names used to size the classifier head. Required when building
|
||||||
Pre-built targets object. If given, overrides
|
a new model.
|
||||||
``config.targets``.
|
dimension_names : list[str], optional
|
||||||
|
Dimension names used to size the bbox head. Required when building a
|
||||||
|
new model.
|
||||||
preprocessor : PreprocessorProtocol, optional
|
preprocessor : PreprocessorProtocol, optional
|
||||||
Pre-built preprocessor. If given, overrides
|
Pre-built preprocessor. If given, overrides
|
||||||
``config.preprocess`` and ``config.samplerate`` for the
|
``config.preprocess`` and ``config.samplerate`` for the
|
||||||
@ -278,19 +254,17 @@ def build_model(
|
|||||||
"""
|
"""
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_roi_mapping, build_targets
|
|
||||||
|
|
||||||
config = config or ModelConfig()
|
config = config or ModelConfig()
|
||||||
targets = targets or build_targets(config=config.targets)
|
|
||||||
|
|
||||||
targets_config = getattr(targets, "config", None)
|
if class_names is None:
|
||||||
roi_config = (
|
raise ValueError("class_names must be provided when building a model.")
|
||||||
targets_config.roi
|
|
||||||
if isinstance(targets_config, TargetConfig)
|
if dimension_names is None:
|
||||||
else config.targets.roi
|
raise ValueError(
|
||||||
|
"dimension_names must be provided when building a model."
|
||||||
)
|
)
|
||||||
|
|
||||||
roi_mapper = roi_mapper or build_roi_mapping(config=roi_config)
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
config=config.preprocess,
|
config=config.preprocess,
|
||||||
input_samplerate=config.samplerate,
|
input_samplerate=config.samplerate,
|
||||||
@ -300,16 +274,16 @@ def build_model(
|
|||||||
config=config.postprocess,
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(class_names),
|
||||||
num_sizes=len(roi_mapper.dimension_names),
|
num_sizes=len(dimension_names),
|
||||||
config=config.architecture,
|
config=config.architecture,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
targets=targets,
|
class_names=class_names,
|
||||||
roi_mapper=roi_mapper,
|
dimension_names=dimension_names,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -329,6 +303,6 @@ def build_model_with_new_targets(
|
|||||||
detector=detector,
|
detector=detector,
|
||||||
postprocessor=model.postprocessor,
|
postprocessor=model.postprocessor,
|
||||||
preprocessor=model.preprocessor,
|
preprocessor=model.preprocessor,
|
||||||
targets=targets,
|
class_names=targets.class_names,
|
||||||
roi_mapper=roi_mapper,
|
dimension_names=roi_mapper.dimension_names,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from batdetect2.targets.classes import (
|
|||||||
build_sound_event_encoder,
|
build_sound_event_encoder,
|
||||||
get_class_names_from_config,
|
get_class_names_from_config,
|
||||||
)
|
)
|
||||||
from batdetect2.targets.config import TargetConfig
|
from batdetect2.targets.config import TargetConfig, build_default_target_config
|
||||||
from batdetect2.targets.rois import (
|
from batdetect2.targets.rois import (
|
||||||
AnchorBBoxMapperConfig,
|
AnchorBBoxMapperConfig,
|
||||||
ROIMapperConfig,
|
ROIMapperConfig,
|
||||||
@ -36,13 +36,14 @@ from batdetect2.targets.types import (
|
|||||||
SoundEventFilter,
|
SoundEventFilter,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
|
from batdetect2.targets.utils import check_target_compatibility
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AnchorBBoxMapperConfig",
|
"AnchorBBoxMapperConfig",
|
||||||
"Position",
|
"Position",
|
||||||
"ROIMappingConfig",
|
|
||||||
"ROIMapperProtocol",
|
|
||||||
"ROIMapperConfig",
|
"ROIMapperConfig",
|
||||||
|
"ROIMapperProtocol",
|
||||||
|
"ROIMappingConfig",
|
||||||
"ROITargetMapper",
|
"ROITargetMapper",
|
||||||
"Size",
|
"Size",
|
||||||
"SoundEventDecoder",
|
"SoundEventDecoder",
|
||||||
@ -52,12 +53,14 @@ __all__ = [
|
|||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
"TargetProtocol",
|
"TargetProtocol",
|
||||||
"Targets",
|
"Targets",
|
||||||
"build_roi_mapping",
|
"build_default_target_config",
|
||||||
"build_roi_mapper",
|
"build_roi_mapper",
|
||||||
|
"build_roi_mapping",
|
||||||
"build_sound_event_decoder",
|
"build_sound_event_decoder",
|
||||||
"build_sound_event_encoder",
|
"build_sound_event_encoder",
|
||||||
"build_targets",
|
"build_targets",
|
||||||
"call_type",
|
"call_type",
|
||||||
|
"check_target_compatibility",
|
||||||
"data_source",
|
"data_source",
|
||||||
"generic_class",
|
"generic_class",
|
||||||
"get_class_names_from_config",
|
"get_class_names_from_config",
|
||||||
|
|||||||
@ -2,6 +2,7 @@ from collections import Counter
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.targets.classes import (
|
from batdetect2.targets.classes import (
|
||||||
@ -13,6 +14,7 @@ from batdetect2.targets.rois import ROIMappingConfig
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TargetConfig",
|
"TargetConfig",
|
||||||
|
"build_default_target_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@ -42,3 +44,20 @@ class TargetConfig(BaseConfig):
|
|||||||
f"{', '.join(duplicates)}"
|
f"{', '.join(duplicates)}"
|
||||||
)
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
def build_default_target_config(class_names: list[str]) -> TargetConfig:
|
||||||
|
"""Build a default target configuration object."""
|
||||||
|
return TargetConfig(
|
||||||
|
detection_target=DEFAULT_DETECTION_CLASS,
|
||||||
|
classification_targets=[
|
||||||
|
TargetClassConfig(
|
||||||
|
name=class_name,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="class", value=class_name),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
for class_name in class_names
|
||||||
|
],
|
||||||
|
roi=ROIMappingConfig(),
|
||||||
|
)
|
||||||
|
|||||||
29
src/batdetect2/targets/utils.py
Normal file
29
src/batdetect2/targets/utils.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.targets.types import TargetProtocol
|
||||||
|
|
||||||
|
|
||||||
|
def check_target_compatibility(
|
||||||
|
targets: "TargetProtocol",
|
||||||
|
class_names: list[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Check if a target definition can decode a model's outputs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
targets : TargetProtocol
|
||||||
|
Target definition that would be used with the model outputs.
|
||||||
|
class_names : list[str]
|
||||||
|
Class names produced by the model checkpoint.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
bool
|
||||||
|
True when every model class name exists in the provided targets,
|
||||||
|
False otherwise.
|
||||||
|
"""
|
||||||
|
target_class_names = set(targets.class_names)
|
||||||
|
model_class_names = set(class_names)
|
||||||
|
|
||||||
|
return model_class_names.issubset(target_class_names)
|
||||||
@ -10,6 +10,7 @@ from batdetect2.logging import get_image_logger
|
|||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
from batdetect2.train.dataset import ValidationDataset
|
from batdetect2.train.dataset import ValidationDataset
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
from batdetect2.train.types import TrainExample
|
from batdetect2.train.types import TrainExample
|
||||||
@ -19,11 +20,15 @@ class ValidationMetrics(Callback):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
roi_mapper: ROIMapperProtocol,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.evaluator = evaluator
|
self.evaluator = evaluator
|
||||||
|
self.targets = targets
|
||||||
|
self.roi_mapper = roi_mapper
|
||||||
self.output_transform = output_transform
|
self.output_transform = output_transform
|
||||||
|
|
||||||
self._clip_annotations: List[data.ClipAnnotation] = []
|
self._clip_annotations: List[data.ClipAnnotation] = []
|
||||||
@ -93,8 +98,8 @@ class ValidationMetrics(Callback):
|
|||||||
model = pl_module.model
|
model = pl_module.model
|
||||||
if self.output_transform is None:
|
if self.output_transform is None:
|
||||||
self.output_transform = build_output_transform(
|
self.output_transform = build_output_transform(
|
||||||
targets=model.targets,
|
targets=self.targets,
|
||||||
roi_mapper=model.roi_mapper,
|
roi_mapper=self.roi_mapper,
|
||||||
)
|
)
|
||||||
|
|
||||||
output_transform = self.output_transform
|
output_transform = self.output_transform
|
||||||
|
|||||||
@ -9,9 +9,7 @@ from batdetect2.train.optimizers import build_optimizer
|
|||||||
from batdetect2.train.schedulers import build_scheduler
|
from batdetect2.train.schedulers import build_scheduler
|
||||||
from batdetect2.train.types import LossProtocol, TrainExample
|
from batdetect2.train.types import LossProtocol, TrainExample
|
||||||
|
|
||||||
__all__ = [
|
__all__ = ["TrainingModule"]
|
||||||
"TrainingModule",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
@ -21,6 +19,8 @@ class TrainingModule(L.LightningModule):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
|
class_names: list[str] | None = None,
|
||||||
|
dimension_names: list[str] | None = None,
|
||||||
train_config: dict | None = None,
|
train_config: dict | None = None,
|
||||||
loss: LossProtocol | None = None,
|
loss: LossProtocol | None = None,
|
||||||
model: Model | None = None,
|
model: Model | None = None,
|
||||||
@ -30,13 +30,31 @@ class TrainingModule(L.LightningModule):
|
|||||||
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||||
|
|
||||||
self.model_config = ModelConfig.model_validate(model_config or {})
|
self.model_config = ModelConfig.model_validate(model_config or {})
|
||||||
|
self.class_names = list(class_names or [])
|
||||||
|
self.dimension_names = list(dimension_names or [])
|
||||||
self.train_config = TrainingConfig.model_validate(train_config or {})
|
self.train_config = TrainingConfig.model_validate(train_config or {})
|
||||||
|
|
||||||
if loss is None:
|
if loss is None:
|
||||||
loss = build_loss(config=self.train_config.loss)
|
loss = build_loss(config=self.train_config.loss)
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
model = build_model(config=self.model_config)
|
if not self.class_names:
|
||||||
|
raise ValueError(
|
||||||
|
"class_names must be provided when rebuilding a training "
|
||||||
|
"module without a model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.dimension_names:
|
||||||
|
raise ValueError(
|
||||||
|
"dimension_names must be provided when rebuilding a "
|
||||||
|
"training module without a model."
|
||||||
|
)
|
||||||
|
|
||||||
|
model = build_model(
|
||||||
|
config=self.model_config,
|
||||||
|
class_names=self.class_names,
|
||||||
|
dimension_names=self.dimension_names,
|
||||||
|
)
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -110,8 +128,7 @@ def load_model_from_checkpoint(
|
|||||||
-------
|
-------
|
||||||
tuple[Model, ModelConfig]
|
tuple[Model, ModelConfig]
|
||||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||||
describes its architecture, preprocessing, postprocessing, and
|
describes its architecture, preprocessing, and postprocessing.
|
||||||
targets.
|
|
||||||
"""
|
"""
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
return module.model, module.model_config
|
return module.model, module.model_config
|
||||||
@ -119,6 +136,8 @@ def load_model_from_checkpoint(
|
|||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
model_config: ModelConfig | None = None,
|
model_config: ModelConfig | None = None,
|
||||||
|
class_names: list[str] | None = None,
|
||||||
|
dimension_names: list[str] | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
model: Model | None = None,
|
model: Model | None = None,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
@ -130,6 +149,8 @@ def build_training_module(
|
|||||||
|
|
||||||
return TrainingModule(
|
return TrainingModule(
|
||||||
model_config=model_config.model_dump(mode="json"),
|
model_config=model_config.model_dump(mode="json"),
|
||||||
|
class_names=class_names,
|
||||||
|
dimension_names=dimension_names,
|
||||||
train_config=train_config.model_dump(mode="json"),
|
train_config=train_config.model_dump(mode="json"),
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -17,6 +17,7 @@ from batdetect2.models import Model, ModelConfig, build_model
|
|||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
ROIMapperProtocol,
|
ROIMapperProtocol,
|
||||||
|
TargetConfig,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
build_roi_mapping,
|
build_roi_mapping,
|
||||||
build_targets,
|
build_targets,
|
||||||
@ -46,6 +47,7 @@ def run_train(
|
|||||||
labeller: Optional["ClipLabeller"] = None,
|
labeller: Optional["ClipLabeller"] = None,
|
||||||
audio_config: Optional[AudioConfig] = None,
|
audio_config: Optional[AudioConfig] = None,
|
||||||
model_config: Optional[ModelConfig] = None,
|
model_config: Optional[ModelConfig] = None,
|
||||||
|
targets_config: TargetConfig | None = None,
|
||||||
train_config: Optional[TrainingConfig] = None,
|
train_config: Optional[TrainingConfig] = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
trainer: Trainer | None = None,
|
trainer: Trainer | None = None,
|
||||||
@ -62,22 +64,33 @@ def run_train(
|
|||||||
seed_everything(seed)
|
seed_everything(seed)
|
||||||
|
|
||||||
model_config = model_config or ModelConfig()
|
model_config = model_config or ModelConfig()
|
||||||
|
targets_config = targets_config or TargetConfig()
|
||||||
audio_config = audio_config or AudioConfig()
|
audio_config = audio_config or AudioConfig()
|
||||||
train_config = train_config or TrainingConfig()
|
train_config = train_config or TrainingConfig()
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
_validate_model_compatibility(model=model, model_config=model_config)
|
if targets is None:
|
||||||
|
raise ValueError(
|
||||||
|
"targets must be provided when training with an existing "
|
||||||
|
"model."
|
||||||
|
)
|
||||||
|
|
||||||
|
if roi_mapper is None:
|
||||||
|
raise ValueError(
|
||||||
|
"roi_mapper must be provided when training with an existing "
|
||||||
|
"model."
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = targets or build_targets(config=targets_config)
|
||||||
|
|
||||||
|
roi_mapper = roi_mapper or build_roi_mapping(config=targets_config.roi)
|
||||||
|
|
||||||
if model is not None:
|
if model is not None:
|
||||||
targets = targets or model.targets
|
_validate_model_compatibility(
|
||||||
|
model=model,
|
||||||
if roi_mapper is None and targets is model.targets:
|
model_config=model_config,
|
||||||
roi_mapper = model.roi_mapper
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
targets = targets or build_targets(config=model_config.targets)
|
|
||||||
|
|
||||||
roi_mapper = roi_mapper or build_roi_mapping(
|
|
||||||
config=model_config.targets.roi
|
|
||||||
)
|
)
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
audio_loader = audio_loader or build_audio_loader(config=audio_config)
|
||||||
@ -119,18 +132,24 @@ def run_train(
|
|||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
evaluator = build_evaluator(
|
||||||
|
train_config.validation,
|
||||||
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
|
)
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
train_config,
|
train_config,
|
||||||
logger_config=logger_config,
|
logger_config=logger_config,
|
||||||
evaluator=build_evaluator(
|
evaluator=evaluator,
|
||||||
train_config.validation,
|
|
||||||
targets=targets,
|
targets=targets,
|
||||||
roi_mapper=roi_mapper,
|
roi_mapper=roi_mapper,
|
||||||
),
|
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
num_epochs=num_epochs,
|
num_epochs=num_epochs,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
@ -152,8 +171,14 @@ def run_train(
|
|||||||
def _validate_model_compatibility(
|
def _validate_model_compatibility(
|
||||||
model: Model,
|
model: Model,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
|
class_names: list[str],
|
||||||
|
dimension_names: list[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
reference_model = build_model(config=model_config)
|
reference_model = build_model(
|
||||||
|
config=model_config,
|
||||||
|
class_names=class_names,
|
||||||
|
dimension_names=dimension_names,
|
||||||
|
)
|
||||||
|
|
||||||
expected_shapes = {
|
expected_shapes = {
|
||||||
key: tuple(value.shape)
|
key: tuple(value.shape)
|
||||||
@ -196,6 +221,8 @@ def build_trainer(
|
|||||||
config: TrainingConfig,
|
config: TrainingConfig,
|
||||||
logger_config: LoggerConfig | None,
|
logger_config: LoggerConfig | None,
|
||||||
evaluator: "EvaluatorProtocol",
|
evaluator: "EvaluatorProtocol",
|
||||||
|
targets: "TargetProtocol",
|
||||||
|
roi_mapper: "ROIMapperProtocol",
|
||||||
checkpoint_dir: Path | None = None,
|
checkpoint_dir: Path | None = None,
|
||||||
log_dir: Path | None = None,
|
log_dir: Path | None = None,
|
||||||
experiment_name: str | None = None,
|
experiment_name: str | None = None,
|
||||||
@ -234,6 +261,6 @@ def build_trainer(
|
|||||||
experiment_name=experiment_name,
|
experiment_name=experiment_name,
|
||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
),
|
),
|
||||||
ValidationMetrics(evaluator),
|
ValidationMetrics(evaluator, targets, roi_mapper),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|||||||
@ -13,13 +13,14 @@ from soundevent import data, terms
|
|||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.audio.clips import build_clipper
|
from batdetect2.audio.clips import build_clipper
|
||||||
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
from batdetect2.audio.types import AudioLoader, ClipperProtocol
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
from batdetect2.data import DatasetConfig, load_dataset
|
||||||
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
|
ROIMapperProtocol,
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
|
build_roi_mapping,
|
||||||
build_targets,
|
build_targets,
|
||||||
call_type,
|
call_type,
|
||||||
)
|
)
|
||||||
@ -404,6 +405,13 @@ def sample_targets(
|
|||||||
return build_targets(sample_target_config)
|
return build_targets(sample_target_config)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_roi_mapper(
|
||||||
|
sample_target_config: TargetConfig,
|
||||||
|
) -> ROIMapperProtocol:
|
||||||
|
return build_roi_mapping(sample_target_config.roi)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_labeller(
|
def sample_labeller(
|
||||||
sample_targets: TargetProtocol,
|
sample_targets: TargetProtocol,
|
||||||
@ -458,8 +466,15 @@ def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tiny_checkpoint_path(tmp_path: Path) -> Path:
|
def tiny_checkpoint_path(
|
||||||
module = build_training_module(model_config=BatDetect2Config().model)
|
sample_targets: TargetProtocol,
|
||||||
|
sample_roi_mapper: ROIMapperProtocol,
|
||||||
|
tmp_path: Path,
|
||||||
|
) -> Path:
|
||||||
|
module = build_training_module(
|
||||||
|
class_names=sample_targets.class_names,
|
||||||
|
dimension_names=sample_roi_mapper.dimension_names,
|
||||||
|
)
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
checkpoint_path = tmp_path / "model.ckpt"
|
checkpoint_path = tmp_path / "model.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
|
|||||||
@ -8,20 +8,43 @@ import torch
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.models.detectors import Detector
|
from batdetect2.models.detectors import Detector
|
||||||
from batdetect2.models.heads import ClassifierHead
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train import load_model_from_checkpoint
|
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def api_v2() -> BatDetect2API:
|
def train_config() -> TrainingConfig:
|
||||||
|
"""Train config with a small batch size for testing."""
|
||||||
|
return TrainingConfig.model_validate({"train_loader": {"batch_size": 2}})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def inference_config() -> InferenceConfig:
|
||||||
|
"""Inference config with a small batch size for testing."""
|
||||||
|
return InferenceConfig.model_validate({"loader": {"batch_size": 2}})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_targets_config(example_data_dir: Path) -> TargetConfig:
|
||||||
|
return TargetConfig.load(example_data_dir / "targets.yaml")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_v2(
|
||||||
|
train_config: TrainingConfig,
|
||||||
|
inference_config: InferenceConfig,
|
||||||
|
) -> BatDetect2API:
|
||||||
"""User story: users can create a ready-to-use API from config."""
|
"""User story: users can create a ready-to-use API from config."""
|
||||||
|
|
||||||
config = BatDetect2Config()
|
api = BatDetect2API.from_config(
|
||||||
config.inference.loader.batch_size = 2
|
train_config=train_config,
|
||||||
return BatDetect2API.from_config(config)
|
inference_config=inference_config,
|
||||||
|
)
|
||||||
|
assert api.inference_config.loader.batch_size == 2
|
||||||
|
return api
|
||||||
|
|
||||||
|
|
||||||
def test_process_file_returns_recording_level_predictions(
|
def test_process_file_returns_recording_level_predictions(
|
||||||
@ -30,8 +53,10 @@ def test_process_file_returns_recording_level_predictions(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: process a file and get detections in recording time."""
|
"""User story: process a file and get detections in recording time."""
|
||||||
|
|
||||||
|
# When
|
||||||
prediction = api_v2.process_file(example_audio_files[0])
|
prediction = api_v2.process_file(example_audio_files[0])
|
||||||
|
|
||||||
|
# Then
|
||||||
assert prediction.clip.recording.path == example_audio_files[0]
|
assert prediction.clip.recording.path == example_audio_files[0]
|
||||||
assert prediction.clip.start_time == 0
|
assert prediction.clip.start_time == 0
|
||||||
assert prediction.clip.end_time == prediction.clip.recording.duration
|
assert prediction.clip.end_time == prediction.clip.recording.duration
|
||||||
@ -53,9 +78,11 @@ def test_process_files_is_batch_size_invariant(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: changing batch size should not change predictions."""
|
"""User story: changing batch size should not change predictions."""
|
||||||
|
|
||||||
|
# When
|
||||||
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
|
preds_batch_1 = api_v2.process_files(example_audio_files, batch_size=1)
|
||||||
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
|
preds_batch_3 = api_v2.process_files(example_audio_files, batch_size=3)
|
||||||
|
|
||||||
|
# Then
|
||||||
assert len(preds_batch_1) == len(preds_batch_3)
|
assert len(preds_batch_1) == len(preds_batch_3)
|
||||||
|
|
||||||
by_key_1 = {
|
by_key_1 = {
|
||||||
@ -91,12 +118,14 @@ def test_process_audio_matches_process_spectrogram(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: users can call either audio or spectrogram entrypoint."""
|
"""User story: users can call either audio or spectrogram entrypoint."""
|
||||||
|
|
||||||
|
# When
|
||||||
audio = api_v2.load_audio(example_audio_files[0])
|
audio = api_v2.load_audio(example_audio_files[0])
|
||||||
from_audio = api_v2.process_audio(audio)
|
from_audio = api_v2.process_audio(audio)
|
||||||
|
|
||||||
spec = api_v2.generate_spectrogram(audio)
|
spec = api_v2.generate_spectrogram(audio)
|
||||||
from_spec = api_v2.process_spectrogram(spec)
|
from_spec = api_v2.process_spectrogram(spec)
|
||||||
|
|
||||||
|
# Then
|
||||||
assert len(from_audio) == len(from_spec)
|
assert len(from_audio) == len(from_spec)
|
||||||
|
|
||||||
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
|
for det_audio, det_spec in zip(from_audio, from_spec, strict=True):
|
||||||
@ -116,8 +145,10 @@ def test_process_spectrogram_rejects_batched_input(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: invalid batched input gives a clear error."""
|
"""User story: invalid batched input gives a clear error."""
|
||||||
|
|
||||||
|
# Given
|
||||||
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
|
spec = torch.zeros((2, 1, 128, 64), dtype=torch.float32)
|
||||||
|
|
||||||
|
# When/Then
|
||||||
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
with pytest.raises(ValueError, match="Batched spectrograms not supported"):
|
||||||
api_v2.process_spectrogram(spec)
|
api_v2.process_spectrogram(spec)
|
||||||
|
|
||||||
@ -184,26 +215,34 @@ def test_user_can_read_extracted_features_per_detection(
|
|||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
def test_user_can_load_checkpoint_and_finetune(
|
def test_user_can_load_checkpoint_and_finetune(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
|
example_targets_config: TargetConfig,
|
||||||
example_annotations,
|
example_annotations,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: load a checkpoint and continue training from it."""
|
"""User story: load a checkpoint and continue training from it."""
|
||||||
|
|
||||||
module = build_training_module(model_config=BatDetect2Config().model)
|
api = BatDetect2API.from_config(
|
||||||
|
targets_config=example_targets_config,
|
||||||
|
)
|
||||||
|
module = build_training_module(
|
||||||
|
model_config=api.model_config,
|
||||||
|
class_names=api.targets.class_names,
|
||||||
|
dimension_names=api.roi_mapper.dimension_names,
|
||||||
|
)
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
checkpoint_path = tmp_path / "base.ckpt"
|
checkpoint_path = tmp_path / "base.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(checkpoint_path)
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
config = BatDetect2Config()
|
train_config = api.train_config.model_copy(deep=True)
|
||||||
config.train.trainer.limit_train_batches = 1
|
train_config.trainer.limit_train_batches = 1
|
||||||
config.train.trainer.limit_val_batches = 1
|
train_config.trainer.limit_val_batches = 1
|
||||||
config.train.trainer.log_every_n_steps = 1
|
train_config.trainer.log_every_n_steps = 1
|
||||||
config.train.train_loader.batch_size = 1
|
train_config.train_loader.batch_size = 1
|
||||||
config.train.train_loader.augmentations.enabled = False
|
train_config.train_loader.augmentations.enabled = False
|
||||||
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
train_config=config.train,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
finetune_dir = tmp_path / "finetuned"
|
finetune_dir = tmp_path / "finetuned"
|
||||||
|
|
||||||
@ -222,62 +261,36 @@ def test_user_can_load_checkpoint_and_finetune(
|
|||||||
assert checkpoints
|
assert checkpoints
|
||||||
|
|
||||||
|
|
||||||
def test_user_can_load_checkpoint_with_new_targets(
|
|
||||||
tmp_path: Path,
|
|
||||||
sample_targets,
|
|
||||||
) -> None:
|
|
||||||
"""User story: start from checkpoint with a new target definition."""
|
|
||||||
|
|
||||||
module = build_training_module(model_config=BatDetect2Config().model)
|
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
|
||||||
checkpoint_path = tmp_path / "base_transfer.ckpt"
|
|
||||||
trainer.strategy.connect(module)
|
|
||||||
trainer.save_checkpoint(checkpoint_path)
|
|
||||||
|
|
||||||
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
|
||||||
api = BatDetect2API.from_checkpoint(
|
|
||||||
checkpoint_path,
|
|
||||||
targets_config=sample_targets.config,
|
|
||||||
)
|
|
||||||
source_detector = cast(Detector, source_model.detector)
|
|
||||||
detector = cast(Detector, api.model.detector)
|
|
||||||
classifier_head = cast(ClassifierHead, detector.classifier_head)
|
|
||||||
|
|
||||||
assert api.targets.config == sample_targets.config # type: ignore
|
|
||||||
assert detector.num_classes == len(sample_targets.class_names)
|
|
||||||
assert (
|
|
||||||
classifier_head.classifier.out_channels
|
|
||||||
== len(sample_targets.class_names) + 1
|
|
||||||
)
|
|
||||||
|
|
||||||
source_backbone = source_detector.backbone.state_dict()
|
|
||||||
target_backbone = detector.backbone.state_dict()
|
|
||||||
assert source_backbone
|
|
||||||
for key, value in source_backbone.items():
|
|
||||||
assert key in target_backbone
|
|
||||||
torch.testing.assert_close(target_backbone[key], value)
|
|
||||||
|
|
||||||
|
|
||||||
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||||
|
example_targets_config: TargetConfig,
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""User story: same targets config does not rebuild prediction heads."""
|
"""User story: same targets config does not rebuild prediction heads."""
|
||||||
|
|
||||||
module = build_training_module(model_config=BatDetect2Config().model)
|
# Given
|
||||||
|
source_api = BatDetect2API.from_config(
|
||||||
|
targets_config=example_targets_config
|
||||||
|
)
|
||||||
|
module = build_training_module(
|
||||||
|
model_config=source_api.model_config,
|
||||||
|
class_names=source_api.targets.class_names,
|
||||||
|
dimension_names=source_api.roi_mapper.dimension_names,
|
||||||
|
)
|
||||||
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
trainer = L.Trainer(enable_checkpointing=False, logger=False)
|
||||||
checkpoint_path = tmp_path / "same_targets.ckpt"
|
checkpoint_path = tmp_path / "same_targets.ckpt"
|
||||||
trainer.strategy.connect(module)
|
trainer.strategy.connect(module)
|
||||||
trainer.save_checkpoint(checkpoint_path)
|
trainer.save_checkpoint(checkpoint_path)
|
||||||
|
|
||||||
source_model, source_model_config = load_model_from_checkpoint(
|
source_model, _ = load_model_from_checkpoint(checkpoint_path)
|
||||||
checkpoint_path
|
|
||||||
)
|
|
||||||
source_detector = cast(Detector, source_model.detector)
|
source_detector = cast(Detector, source_model.detector)
|
||||||
|
|
||||||
|
# When
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
targets_config=source_model_config.targets,
|
targets_config=example_targets_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Then
|
||||||
detector = cast(Detector, api.model.detector)
|
detector = cast(Detector, api.model.detector)
|
||||||
|
|
||||||
for key, value in source_detector.classifier_head.state_dict().items():
|
for key, value in source_detector.classifier_head.state_dict().items():
|
||||||
@ -302,7 +315,7 @@ def test_user_can_finetune_only_heads(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""User story: fine-tune only prediction heads."""
|
"""User story: fine-tune only prediction heads."""
|
||||||
|
|
||||||
api = BatDetect2API.from_config(BatDetect2Config())
|
api = BatDetect2API.from_config()
|
||||||
finetune_dir = tmp_path / "heads_only"
|
finetune_dir = tmp_path / "heads_only"
|
||||||
|
|
||||||
api.finetune(
|
api.finetune(
|
||||||
@ -348,8 +361,6 @@ def test_user_can_evaluate_small_dataset_and_get_metrics(
|
|||||||
|
|
||||||
assert isinstance(metrics, list)
|
assert isinstance(metrics, list)
|
||||||
assert len(metrics) == 1
|
assert len(metrics) == 1
|
||||||
assert isinstance(metrics[0], dict)
|
|
||||||
assert len(metrics[0]) > 0
|
|
||||||
assert isinstance(predictions, list)
|
assert isinstance(predictions, list)
|
||||||
assert len(predictions) == 1
|
assert len(predictions) == 1
|
||||||
|
|
||||||
@ -450,8 +461,17 @@ def test_detection_threshold_override_changes_spectrogram_results(
|
|||||||
spec = api_v2.generate_spectrogram(audio)
|
spec = api_v2.generate_spectrogram(audio)
|
||||||
default_detections = api_v2.process_spectrogram(spec)
|
default_detections = api_v2.process_spectrogram(spec)
|
||||||
strict_detections = api_v2.process_spectrogram(
|
strict_detections = api_v2.process_spectrogram(
|
||||||
spec,
|
spec, detection_threshold=1.0
|
||||||
detection_threshold=1.0,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(strict_detections) <= len(default_detections)
|
assert len(strict_detections) <= len(default_detections)
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_can_create_api_with_custom_targets_and_model_metadata_matches(
|
||||||
|
sample_targets,
|
||||||
|
) -> None:
|
||||||
|
"""User story: custom targets define model output names for a new API."""
|
||||||
|
|
||||||
|
api = BatDetect2API.from_config(targets_config=sample_targets.config)
|
||||||
|
|
||||||
|
assert api.model.class_names == sample_targets.class_names
|
||||||
|
|||||||
@ -5,7 +5,6 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.outputs import build_output_formatter
|
from batdetect2.outputs import build_output_formatter
|
||||||
from batdetect2.outputs.formats import (
|
from batdetect2.outputs.formats import (
|
||||||
BatDetect2OutputConfig,
|
BatDetect2OutputConfig,
|
||||||
@ -18,7 +17,7 @@ from batdetect2.postprocess.types import ClipDetections
|
|||||||
def api_v2() -> BatDetect2API:
|
def api_v2() -> BatDetect2API:
|
||||||
"""User story: API object manages prediction IO formats."""
|
"""User story: API object manages prediction IO formats."""
|
||||||
|
|
||||||
return BatDetect2API.from_config(BatDetect2Config())
|
return BatDetect2API.from_config()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
40
tests/test_targets/test_utils.py
Normal file
40
tests/test_targets/test_utils.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.targets import (
|
||||||
|
TargetClassConfig,
|
||||||
|
TargetConfig,
|
||||||
|
build_targets,
|
||||||
|
check_target_compatibility,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _target_class(name: str) -> TargetClassConfig:
|
||||||
|
return TargetClassConfig(
|
||||||
|
name=name,
|
||||||
|
tags=[data.Tag(key="class", value=name)],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_target_compatibility_accepts_superset_targets() -> None:
|
||||||
|
config = TargetConfig(
|
||||||
|
classification_targets=[
|
||||||
|
_target_class("pip35"),
|
||||||
|
_target_class("myo"),
|
||||||
|
_target_class("extra"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
targets = build_targets(config)
|
||||||
|
|
||||||
|
assert check_target_compatibility(targets, ["pip35", "myo"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_check_target_compatibility_rejects_missing_model_classes() -> None:
|
||||||
|
config = TargetConfig(
|
||||||
|
classification_targets=[
|
||||||
|
_target_class("pip35"),
|
||||||
|
_target_class("myo"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
targets = build_targets(config)
|
||||||
|
|
||||||
|
assert not check_target_compatibility(targets, ["pip35", "nyc"])
|
||||||
@ -10,9 +10,8 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
|||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio.types import AudioLoader
|
from batdetect2.audio.types import AudioLoader
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.models import ModelConfig, build_model
|
from batdetect2.models import ModelConfig, build_model
|
||||||
from batdetect2.targets.classes import TargetClassConfig
|
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||||
from batdetect2.train import (
|
from batdetect2.train import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
TrainingModule,
|
TrainingModule,
|
||||||
@ -24,11 +23,21 @@ from batdetect2.train.schedulers import CosineAnnealingSchedulerConfig
|
|||||||
from batdetect2.train.train import build_training_module
|
from batdetect2.train.train import build_training_module
|
||||||
|
|
||||||
|
|
||||||
def build_default_module(config: BatDetect2Config | None = None):
|
def build_default_module(
|
||||||
config = config or BatDetect2Config()
|
target_config: TargetConfig | None = None,
|
||||||
|
model_config: ModelConfig | None = None,
|
||||||
|
train_config: TrainingConfig | None = None,
|
||||||
|
):
|
||||||
|
target_config = target_config or TargetConfig()
|
||||||
|
model_config = model_config or ModelConfig()
|
||||||
|
train_config = train_config or TrainingConfig()
|
||||||
|
targets = build_targets(target_config)
|
||||||
|
roi_mapper = build_roi_mapping(target_config.roi)
|
||||||
return build_training_module(
|
return build_training_module(
|
||||||
model_config=config.model,
|
model_config=model_config,
|
||||||
train_config=config.train,
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -72,8 +81,13 @@ def test_load_model_from_checkpoint_returns_model_and_config(
|
|||||||
input_model_config.model_dump(mode="json")
|
input_model_config.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
|
targets_config = TargetConfig()
|
||||||
|
targets = build_targets(targets_config)
|
||||||
|
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=input_model_config,
|
model_config=input_model_config,
|
||||||
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
trainer = L.Trainer()
|
trainer = L.Trainer()
|
||||||
@ -87,6 +101,8 @@ def test_load_model_from_checkpoint_returns_model_and_config(
|
|||||||
assert loaded_model_config.model_dump(
|
assert loaded_model_config.model_dump(
|
||||||
mode="json"
|
mode="json"
|
||||||
) == expected_model_config.model_dump(mode="json")
|
) == expected_model_config.model_dump(mode="json")
|
||||||
|
assert model.class_names == targets.class_names
|
||||||
|
assert model.dimension_names == roi_mapper.dimension_names
|
||||||
|
|
||||||
recovered = TrainingModule.load_from_checkpoint(path)
|
recovered = TrainingModule.load_from_checkpoint(path)
|
||||||
assert recovered.train_config.model_dump(
|
assert recovered.train_config.model_dump(
|
||||||
@ -100,6 +116,9 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
|||||||
model_config.model_dump(mode="json")
|
model_config.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
|
targets_config = TargetConfig()
|
||||||
|
targets = build_targets(targets_config)
|
||||||
|
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
|
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=123)
|
||||||
train_config.trainer.max_epochs = 3
|
train_config.trainer.max_epochs = 3
|
||||||
@ -107,6 +126,8 @@ def test_checkpoint_stores_train_config_hyperparameters(tmp_path: Path):
|
|||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
trainer = L.Trainer()
|
trainer = L.Trainer()
|
||||||
@ -131,11 +152,16 @@ def test_configure_optimizers_uses_train_config_values(tmp_path: Path):
|
|||||||
model_config.model_dump(mode="json")
|
model_config.model_dump(mode="json")
|
||||||
)
|
)
|
||||||
train_config = TrainingConfig()
|
train_config = TrainingConfig()
|
||||||
|
targets_config = TargetConfig()
|
||||||
|
targets = build_targets(targets_config)
|
||||||
|
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||||
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
train_config.optimizer = AdamOptimizerConfig(learning_rate=5e-4)
|
||||||
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
|
train_config.scheduler = CosineAnnealingSchedulerConfig(t_max=321)
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
train_config=train_config,
|
train_config=train_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -189,19 +215,26 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
sample_audio_loader: AudioLoader,
|
sample_audio_loader: AudioLoader,
|
||||||
):
|
):
|
||||||
config = BatDetect2Config()
|
# Given
|
||||||
config.train.trainer.limit_train_batches = 1
|
train_config = TrainingConfig.model_validate(
|
||||||
config.train.trainer.limit_val_batches = 1
|
{
|
||||||
config.train.trainer.log_every_n_steps = 1
|
"trainer": {
|
||||||
config.train.train_loader.batch_size = 1
|
"limit_train_batches": 1,
|
||||||
config.train.train_loader.augmentations.enabled = False
|
"limit_val_batches": 1,
|
||||||
|
"log_every_n_steps": 1,
|
||||||
|
},
|
||||||
|
"train_loader": {
|
||||||
|
"batch_size": 1,
|
||||||
|
"augmentations": {"enabled": False},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# When
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=example_annotations[:1],
|
val_annotations=example_annotations[:1],
|
||||||
train_config=config.train,
|
train_config=train_config,
|
||||||
model_config=config.model,
|
|
||||||
audio_config=config.audio,
|
|
||||||
num_epochs=1,
|
num_epochs=1,
|
||||||
train_workers=0,
|
train_workers=0,
|
||||||
val_workers=0,
|
val_workers=0,
|
||||||
@ -209,18 +242,11 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
seed=0,
|
seed=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Then
|
||||||
checkpoints = list(tmp_path.rglob("*.ckpt"))
|
checkpoints = list(tmp_path.rglob("*.ckpt"))
|
||||||
assert checkpoints
|
assert checkpoints
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(checkpoints[0])
|
model, model_config = load_model_from_checkpoint(checkpoints[0])
|
||||||
assert model_config.samplerate == config.model.samplerate
|
|
||||||
assert model_config.architecture.name == config.model.architecture.name
|
|
||||||
assert model_config.preprocess.model_dump(
|
|
||||||
mode="json"
|
|
||||||
) == config.model.preprocess.model_dump(mode="json")
|
|
||||||
assert model_config.postprocess.model_dump(
|
|
||||||
mode="json"
|
|
||||||
) == config.model.postprocess.model_dump(mode="json")
|
|
||||||
|
|
||||||
wav = torch.tensor(
|
wav = torch.tensor(
|
||||||
sample_audio_loader.load_clip(example_annotations[0].clip)
|
sample_audio_loader.load_clip(example_annotations[0].clip)
|
||||||
@ -230,10 +256,18 @@ def test_train_smoke_produces_loadable_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
def test_build_training_module_uses_provided_model() -> None:
|
def test_build_training_module_uses_provided_model() -> None:
|
||||||
model = build_model(ModelConfig())
|
targets = build_targets(TargetConfig())
|
||||||
|
roi_mapper = build_roi_mapping(TargetConfig().roi)
|
||||||
|
model = build_model(
|
||||||
|
ModelConfig(),
|
||||||
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
|
)
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
model_config=ModelConfig(),
|
model_config=ModelConfig(),
|
||||||
|
class_names=targets.class_names,
|
||||||
|
dimension_names=roi_mapper.dimension_names,
|
||||||
train_config=TrainingConfig(),
|
train_config=TrainingConfig(),
|
||||||
model=model,
|
model=model,
|
||||||
)
|
)
|
||||||
@ -244,15 +278,18 @@ def test_build_training_module_uses_provided_model() -> None:
|
|||||||
def test_run_train_rejects_incompatible_model_config(
|
def test_run_train_rejects_incompatible_model_config(
|
||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = build_model(ModelConfig())
|
# Given
|
||||||
|
targets_config = TargetConfig()
|
||||||
|
targets = build_targets(targets_config)
|
||||||
|
roi_mapper = build_roi_mapping(targets_config.roi)
|
||||||
incompatible_config = ModelConfig()
|
incompatible_config = ModelConfig()
|
||||||
incompatible_config.targets.classification_targets.append(
|
incompatible_model = build_model(
|
||||||
TargetClassConfig(
|
incompatible_config,
|
||||||
name="dummy_class",
|
class_names=targets.class_names,
|
||||||
tags=[data.Tag(key="class", value="Dummy class")],
|
dimension_names=[*roi_mapper.dimension_names, "extra_dim"],
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# When/Then
|
||||||
with pytest.raises(
|
with pytest.raises(
|
||||||
ValueError,
|
ValueError,
|
||||||
match="Provided model is incompatible with model_config",
|
match="Provided model is incompatible with model_config",
|
||||||
@ -260,7 +297,10 @@ def test_run_train_rejects_incompatible_model_config(
|
|||||||
run_train(
|
run_train(
|
||||||
train_annotations=example_annotations[:1],
|
train_annotations=example_annotations[:1],
|
||||||
val_annotations=None,
|
val_annotations=None,
|
||||||
model=model,
|
model=incompatible_model,
|
||||||
|
targets=targets,
|
||||||
|
roi_mapper=roi_mapper,
|
||||||
model_config=incompatible_config,
|
model_config=incompatible_config,
|
||||||
|
targets_config=targets_config,
|
||||||
train_config=TrainingConfig(),
|
train_config=TrainingConfig(),
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user