refactor: replace abstract model types with protocols

This commit is contained in:
mbsantiago 2026-05-06 12:50:32 +01:00
parent a27d1bbfd3
commit 2008c8000f
9 changed files with 94 additions and 54 deletions

View File

@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
from batdetect2.models import Model from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import OutputsConfig, build_output_transform from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def run_evaluate( def run_evaluate(
model: Model, model: ModelProtocol,
test_annotations: Sequence[data.ClipAnnotation], test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol, targets: TargetProtocol,
roi_mapper: ROIMapperProtocol, roi_mapper: ROIMapperProtocol,

View File

@ -7,14 +7,14 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models.types import ModelProtocol
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
class EvaluationModule(LightningModule): class EvaluationModule(LightningModule):
def __init__( def __init__(
self, self,
model: Model, model: ModelProtocol,
evaluator: EvaluatorProtocol, evaluator: EvaluatorProtocol,
): ):
super().__init__() super().__init__()

View File

@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig from batdetect2.inference.config import InferenceConfig
from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import ( from batdetect2.outputs import (
OutputsConfig, OutputsConfig,
OutputTransformProtocol, OutputTransformProtocol,
@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
def run_batch_inference( def run_batch_inference(
model: Model, model: ModelProtocol,
clips: Sequence[data.Clip], clips: Sequence[data.Clip],
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None, roi_mapper: ROIMapperProtocol | None = None,
@ -86,7 +86,7 @@ def run_batch_inference(
def process_file_list( def process_file_list(
model: Model, model: ModelProtocol,
paths: Sequence[data.PathLike], paths: Sequence[data.PathLike],
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None, roi_mapper: ROIMapperProtocol | None = None,

View File

@ -4,7 +4,7 @@ from lightning import LightningModule
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
class InferenceModule(LightningModule): class InferenceModule(LightningModule):
def __init__( def __init__(
self, self,
model: Model, model: ModelProtocol,
targets: TargetProtocol | None = None, targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None, roi_mapper: ROIMapperProtocol | None = None,
output_transform: OutputTransformProtocol | None = None, output_transform: OutputTransformProtocol | None = None,

View File

@ -27,6 +27,7 @@ from typing import Annotated, Literal
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from loguru import logger
from pydantic import Field, TypeAdapter from pydantic import Field, TypeAdapter
from soundevent import data from soundevent import data
@ -52,7 +53,7 @@ from batdetect2.models.encoder import (
build_encoder, build_encoder,
) )
from batdetect2.models.types import ( from batdetect2.models.types import (
BackboneModel, BackboneProtocol,
BottleneckProtocol, BottleneckProtocol,
DecoderProtocol, DecoderProtocol,
EncoderProtocol, EncoderProtocol,
@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig):
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
backbone_registry: Registry[BackboneModel, []] = Registry("backbone") backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone")
@add_import_config(backbone_registry) @add_import_config(backbone_registry)
@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig):
name: Literal["import"] = "import" name: Literal["import"] = "import"
class UNetBackbone(BackboneModel): class UNetBackbone(torch.nn.Module):
"""U-Net-style encoder-decoder backbone network. """U-Net-style encoder-decoder backbone network.
Combines an encoder, a bottleneck, and a decoder into a single module Combines an encoder, a bottleneck, and a decoder into a single module
@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel):
@backbone_registry.register(UNetBackboneConfig) @backbone_registry.register(UNetBackboneConfig)
@staticmethod @staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel: def from_config(config: UNetBackboneConfig) -> BackboneProtocol:
encoder = build_encoder( encoder = build_encoder(
in_channels=config.in_channels, in_channels=config.in_channels,
input_height=config.input_height, input_height=config.input_height,
@ -266,7 +267,7 @@ BackboneConfig = Annotated[
] ]
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol:
"""Build a backbone network from configuration. """Build a backbone network from configuration.
Looks up the backbone class corresponding to ``config.name`` in the Looks up the backbone class corresponding to ``config.name`` in the
@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
Returns Returns
------- -------
BackboneModel BackboneProtocol
An initialised backbone module. An initialised backbone module.
""" """
config = config or UNetBackboneConfig() config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model backbone with config: \n{}",
lambda: config.to_yaml_string(),
)
return backbone_registry.build(config) return backbone_registry.build(config)

View File

@ -1,21 +1,42 @@
from abc import ABC, abstractmethod from typing import Any, NamedTuple, Protocol
from typing import NamedTuple, Protocol
import torch import torch
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [ __all__ = [
"BackboneModel", "BackboneProtocol",
"BlockProtocol", "BlockProtocol",
"BottleneckProtocol", "BottleneckProtocol",
"ClassifierHeadProtocol",
"DecoderProtocol", "DecoderProtocol",
"DetectionModel", "DetectorProtocol",
"EncoderDecoderModel",
"EncoderProtocol", "EncoderProtocol",
"ModelOutput", "ModelOutput",
"ModelProtocol",
"ModuleProtocol",
"SizeHeadProtocol",
] ]
class BlockProtocol(Protocol): class ModuleProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def train(self, mode: bool = True) -> torch.nn.Module: ...
def eval(self) -> torch.nn.Module: ...
def state_dict(
self, *args: Any, **kwargs: Any
) -> dict[str, torch.Tensor]: ...
def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ...
def parameters(self) -> Any: ...
class BlockProtocol(ModuleProtocol, Protocol):
in_channels: int in_channels: int
out_channels: int out_channels: int
@ -24,7 +45,7 @@ class BlockProtocol(Protocol):
def get_output_height(self, input_height: int) -> int: ... def get_output_height(self, input_height: int) -> int: ...
class EncoderProtocol(Protocol): class EncoderProtocol(ModuleProtocol, Protocol):
in_channels: int in_channels: int
out_channels: int out_channels: int
input_height: int input_height: int
@ -33,7 +54,7 @@ class EncoderProtocol(Protocol):
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ... def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
class BottleneckProtocol(Protocol): class BottleneckProtocol(ModuleProtocol, Protocol):
in_channels: int in_channels: int
out_channels: int out_channels: int
input_height: int input_height: int
@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol):
def __call__(self, x: torch.Tensor) -> torch.Tensor: ... def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
class DecoderProtocol(Protocol): class DecoderProtocol(ModuleProtocol, Protocol):
in_channels: int in_channels: int
out_channels: int out_channels: int
input_height: int input_height: int
@ -62,29 +83,42 @@ class ModelOutput(NamedTuple):
features: torch.Tensor features: torch.Tensor
class BackboneModel(ABC, torch.nn.Module): class BackboneProtocol(ModuleProtocol, Protocol):
input_height: int input_height: int
out_channels: int out_channels: int
@abstractmethod def forward(self, spec: torch.Tensor) -> torch.Tensor: ...
def forward(self, spec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
class EncoderDecoderModel(BackboneModel): class ClassifierHeadProtocol(ModuleProtocol, Protocol):
bottleneck_channels: int num_classes: int
in_channels: int
class_names: list[str]
@abstractmethod def forward(self, features: torch.Tensor) -> torch.Tensor: ...
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
class DetectionModel(ABC, torch.nn.Module): class SizeHeadProtocol(ModuleProtocol, Protocol):
backbone: BackboneModel in_channels: int
classifier_head: torch.nn.Module num_sizes: int
bbox_head: torch.nn.Module dimension_names: list[str]
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
class DetectorProtocol(ModuleProtocol, Protocol):
backbone: BackboneProtocol
classifier_head: ClassifierHeadProtocol
size_head: SizeHeadProtocol
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput: ... def forward(self, spec: torch.Tensor) -> ModelOutput: ...
class ModelProtocol(ModuleProtocol, Protocol):
detector: DetectorProtocol
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
class_names: list[str]
dimension_names: list[str]
def get_config(self) -> dict[str, Any]: ...

View File

@ -4,8 +4,8 @@ import lightning as L
import torch import torch
from soundevent.data import PathLike from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import ModelConfig, build_model
from batdetect2.models.types import ModelOutput from batdetect2.models.types import ModelOutput, ModelProtocol
from batdetect2.targets import TargetConfig from batdetect2.targets import TargetConfig
from batdetect2.train.checkpoints import resolve_checkpoint_path from batdetect2.train.checkpoints import resolve_checkpoint_path
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
@ -21,7 +21,7 @@ __all__ = [
class TrainingModule(L.LightningModule): class TrainingModule(L.LightningModule):
model: Model model: ModelProtocol
loss: LossProtocol loss: LossProtocol
def __init__( def __init__(
@ -32,7 +32,7 @@ class TrainingModule(L.LightningModule):
dimension_names: list[str] | None = None, dimension_names: list[str] | None = None,
train_config: dict | None = None, train_config: dict | None = None,
loss: LossProtocol | None = None, loss: LossProtocol | None = None,
model: Model | None = None, model: ModelProtocol | None = None,
): ):
super().__init__() super().__init__()
@ -133,7 +133,7 @@ class StoredConfig:
def load_model_from_checkpoint( def load_model_from_checkpoint(
path: PathLike | str | None = None, path: PathLike | str | None = None,
) -> tuple[Model, StoredConfig]: ) -> tuple[ModelProtocol, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint. """Load a model and its configuration from a Lightning checkpoint.
Parameters Parameters
@ -144,7 +144,7 @@ def load_model_from_checkpoint(
Returns Returns
------- -------
tuple[Model, ModelConfig] tuple[ModelProtocol, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, and postprocessing. describes its architecture, preprocessing, and postprocessing.
""" """
@ -169,7 +169,7 @@ def build_training_module(
class_names: list[str] | None = None, class_names: list[str] | None = None,
dimension_names: list[str] | None = None, dimension_names: list[str] | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
model: Model | None = None, model: ModelProtocol | None = None,
) -> TrainingModule: ) -> TrainingModule:
if model_config is None: if model_config is None:
model_config = ModelConfig() model_config = ModelConfig()

View File

@ -15,7 +15,8 @@ from batdetect2.logging import (
TensorBoardLoggerConfig, TensorBoardLoggerConfig,
build_logger, build_logger,
) )
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import ModelConfig, build_model
from batdetect2.models.types import ModelProtocol
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import ( from batdetect2.targets import (
ROIMapperProtocol, ROIMapperProtocol,
@ -50,7 +51,7 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs"
def run_train( def run_train(
train_annotations: Sequence[data.ClipAnnotation], train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None, val_annotations: Sequence[data.ClipAnnotation] | None = None,
model: Model | None = None, model: ModelProtocol | None = None,
targets: Optional["TargetProtocol"] = None, targets: Optional["TargetProtocol"] = None,
roi_mapper: Optional["ROIMapperProtocol"] = None, roi_mapper: Optional["ROIMapperProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None,
@ -217,7 +218,7 @@ def run_train(
def _validate_model_compatibility( def _validate_model_compatibility(
model: Model, model: ModelProtocol,
model_config: ModelConfig, model_config: ModelConfig,
class_names: list[str], class_names: list[str],
dimension_names: list[str], dimension_names: list[str],

View File

@ -13,7 +13,6 @@ from batdetect2.models.backbones import (
build_backbone, build_backbone,
load_backbone_config, load_backbone_config,
) )
from batdetect2.models.types import BackboneModel
def test_unet_backbone_config_defaults(): def test_unet_backbone_config_defaults():
@ -61,10 +60,11 @@ def test_build_backbone_custom_config():
assert backbone.encoder.in_channels == 2 assert backbone.encoder.in_channels == 2
def test_build_backbone_returns_backbone_model(): def test_build_backbone_returns_unet_backbone():
"""build_backbone always returns a BackboneModel instance.""" """build_backbone returns the default UNet backbone."""
backbone = build_backbone() backbone = build_backbone()
assert isinstance(backbone, BackboneModel)
assert isinstance(backbone, UNetBackbone)
def test_registry_has_unet_backbone(): def test_registry_has_unet_backbone():