refactor: replace abstract model types with protocols

This commit is contained in:
mbsantiago 2026-05-06 12:50:32 +01:00
parent a27d1bbfd3
commit 2008c8000f
9 changed files with 94 additions and 54 deletions

View File

@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol
from batdetect2.postprocess.types import ClipDetections
@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
def run_evaluate(
model: Model,
model: ModelProtocol,
test_annotations: Sequence[data.ClipAnnotation],
targets: TargetProtocol,
roi_mapper: ROIMapperProtocol,

View File

@ -7,14 +7,14 @@ from torch.utils.data import DataLoader
from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.postprocess.types import ClipDetections
class EvaluationModule(LightningModule):
def __init__(
self,
model: Model,
model: ModelProtocol,
evaluator: EvaluatorProtocol,
):
super().__init__()

View File

@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import (
OutputsConfig,
OutputTransformProtocol,
@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
def run_batch_inference(
model: Model,
model: ModelProtocol,
clips: Sequence[data.Clip],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
@ -86,7 +86,7 @@ def run_batch_inference(
def process_file_list(
model: Model,
model: ModelProtocol,
paths: Sequence[data.PathLike],
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,

View File

@ -4,7 +4,7 @@ from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.models.types import ModelProtocol
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
from batdetect2.postprocess.types import ClipDetections
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
class InferenceModule(LightningModule):
def __init__(
self,
model: Model,
model: ModelProtocol,
targets: TargetProtocol | None = None,
roi_mapper: ROIMapperProtocol | None = None,
output_transform: OutputTransformProtocol | None = None,

View File

@ -27,6 +27,7 @@ from typing import Annotated, Literal
import torch
import torch.nn.functional as F
from loguru import logger
from pydantic import Field, TypeAdapter
from soundevent import data
@ -52,7 +53,7 @@ from batdetect2.models.encoder import (
build_encoder,
)
from batdetect2.models.types import (
BackboneModel,
BackboneProtocol,
BottleneckProtocol,
DecoderProtocol,
EncoderProtocol,
@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig):
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone")
@add_import_config(backbone_registry)
@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig):
name: Literal["import"] = "import"
class UNetBackbone(BackboneModel):
class UNetBackbone(torch.nn.Module):
"""U-Net-style encoder-decoder backbone network.
Combines an encoder, a bottleneck, and a decoder into a single module
@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel):
@backbone_registry.register(UNetBackboneConfig)
@staticmethod
def from_config(config: UNetBackboneConfig) -> BackboneModel:
def from_config(config: UNetBackboneConfig) -> BackboneProtocol:
encoder = build_encoder(
in_channels=config.in_channels,
input_height=config.input_height,
@ -266,7 +267,7 @@ BackboneConfig = Annotated[
]
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol:
"""Build a backbone network from configuration.
Looks up the backbone class corresponding to ``config.name`` in the
@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
Returns
-------
BackboneModel
BackboneProtocol
An initialised backbone module.
"""
config = config or UNetBackboneConfig()
logger.opt(lazy=True).debug(
"Building model backbone with config: \n{}",
lambda: config.to_yaml_string(),
)
return backbone_registry.build(config)

View File

@ -1,21 +1,42 @@
from abc import ABC, abstractmethod
from typing import NamedTuple, Protocol
from typing import Any, NamedTuple, Protocol
import torch
from batdetect2.postprocess.types import PostprocessorProtocol
from batdetect2.preprocess.types import PreprocessorProtocol
__all__ = [
"BackboneModel",
"BackboneProtocol",
"BlockProtocol",
"BottleneckProtocol",
"ClassifierHeadProtocol",
"DecoderProtocol",
"DetectionModel",
"EncoderDecoderModel",
"DetectorProtocol",
"EncoderProtocol",
"ModelOutput",
"ModelProtocol",
"ModuleProtocol",
"SizeHeadProtocol",
]
class BlockProtocol(Protocol):
class ModuleProtocol(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
def train(self, mode: bool = True) -> torch.nn.Module: ...
def eval(self) -> torch.nn.Module: ...
def state_dict(
self, *args: Any, **kwargs: Any
) -> dict[str, torch.Tensor]: ...
def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ...
def parameters(self) -> Any: ...
class BlockProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
@ -24,7 +45,7 @@ class BlockProtocol(Protocol):
def get_output_height(self, input_height: int) -> int: ...
class EncoderProtocol(Protocol):
class EncoderProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
input_height: int
@ -33,7 +54,7 @@ class EncoderProtocol(Protocol):
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
class BottleneckProtocol(Protocol):
class BottleneckProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
input_height: int
@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol):
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
class DecoderProtocol(Protocol):
class DecoderProtocol(ModuleProtocol, Protocol):
in_channels: int
out_channels: int
input_height: int
@ -62,29 +83,42 @@ class ModelOutput(NamedTuple):
features: torch.Tensor
class BackboneModel(ABC, torch.nn.Module):
class BackboneProtocol(ModuleProtocol, Protocol):
input_height: int
out_channels: int
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
raise NotImplementedError
def forward(self, spec: torch.Tensor) -> torch.Tensor: ...
class EncoderDecoderModel(BackboneModel):
bottleneck_channels: int
class ClassifierHeadProtocol(ModuleProtocol, Protocol):
num_classes: int
in_channels: int
class_names: list[str]
@abstractmethod
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
class DetectionModel(ABC, torch.nn.Module):
backbone: BackboneModel
classifier_head: torch.nn.Module
bbox_head: torch.nn.Module
class SizeHeadProtocol(ModuleProtocol, Protocol):
in_channels: int
num_sizes: int
dimension_names: list[str]
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
class DetectorProtocol(ModuleProtocol, Protocol):
backbone: BackboneProtocol
classifier_head: ClassifierHeadProtocol
size_head: SizeHeadProtocol
@abstractmethod
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
class ModelProtocol(ModuleProtocol, Protocol):
detector: DetectorProtocol
preprocessor: PreprocessorProtocol
postprocessor: PostprocessorProtocol
class_names: list[str]
dimension_names: list[str]
def get_config(self) -> dict[str, Any]: ...

View File

@ -4,8 +4,8 @@ import lightning as L
import torch
from soundevent.data import PathLike
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models.types import ModelOutput
from batdetect2.models import ModelConfig, build_model
from batdetect2.models.types import ModelOutput, ModelProtocol
from batdetect2.targets import TargetConfig
from batdetect2.train.checkpoints import resolve_checkpoint_path
from batdetect2.train.config import TrainingConfig
@ -21,7 +21,7 @@ __all__ = [
class TrainingModule(L.LightningModule):
model: Model
model: ModelProtocol
loss: LossProtocol
def __init__(
@ -32,7 +32,7 @@ class TrainingModule(L.LightningModule):
dimension_names: list[str] | None = None,
train_config: dict | None = None,
loss: LossProtocol | None = None,
model: Model | None = None,
model: ModelProtocol | None = None,
):
super().__init__()
@ -133,7 +133,7 @@ class StoredConfig:
def load_model_from_checkpoint(
path: PathLike | str | None = None,
) -> tuple[Model, StoredConfig]:
) -> tuple[ModelProtocol, StoredConfig]:
"""Load a model and its configuration from a Lightning checkpoint.
Parameters
@ -144,7 +144,7 @@ def load_model_from_checkpoint(
Returns
-------
tuple[Model, ModelConfig]
tuple[ModelProtocol, ModelConfig]
The restored ``Model`` instance and the ``ModelConfig`` that
describes its architecture, preprocessing, and postprocessing.
"""
@ -169,7 +169,7 @@ def build_training_module(
class_names: list[str] | None = None,
dimension_names: list[str] | None = None,
train_config: TrainingConfig | None = None,
model: Model | None = None,
model: ModelProtocol | None = None,
) -> TrainingModule:
if model_config is None:
model_config = ModelConfig()

View File

@ -15,7 +15,8 @@ from batdetect2.logging import (
TensorBoardLoggerConfig,
build_logger,
)
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.models import ModelConfig, build_model
from batdetect2.models.types import ModelProtocol
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import (
ROIMapperProtocol,
@ -50,7 +51,7 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs"
def run_train(
train_annotations: Sequence[data.ClipAnnotation],
val_annotations: Sequence[data.ClipAnnotation] | None = None,
model: Model | None = None,
model: ModelProtocol | None = None,
targets: Optional["TargetProtocol"] = None,
roi_mapper: Optional["ROIMapperProtocol"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
@ -217,7 +218,7 @@ def run_train(
def _validate_model_compatibility(
model: Model,
model: ModelProtocol,
model_config: ModelConfig,
class_names: list[str],
dimension_names: list[str],

View File

@ -13,7 +13,6 @@ from batdetect2.models.backbones import (
build_backbone,
load_backbone_config,
)
from batdetect2.models.types import BackboneModel
def test_unet_backbone_config_defaults():
@ -61,10 +60,11 @@ def test_build_backbone_custom_config():
assert backbone.encoder.in_channels == 2
def test_build_backbone_returns_backbone_model():
"""build_backbone always returns a BackboneModel instance."""
def test_build_backbone_returns_unet_backbone():
"""build_backbone returns the default UNet backbone."""
backbone = build_backbone()
assert isinstance(backbone, BackboneModel)
assert isinstance(backbone, UNetBackbone)
def test_registry_has_unet_backbone():