mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
refactor: replace abstract model types with protocols
This commit is contained in:
parent
a27d1bbfd3
commit
2008c8000f
@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
def run_evaluate(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
test_annotations: Sequence[data.ClipAnnotation],
|
||||
targets: TargetProtocol,
|
||||
roi_mapper: ROIMapperProtocol,
|
||||
|
||||
@ -7,14 +7,14 @@ from torch.utils.data import DataLoader
|
||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
evaluator: EvaluatorProtocol,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
clips: Sequence[data.Clip],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
@ -86,7 +86,7 @@ def run_batch_inference(
|
||||
|
||||
|
||||
def process_file_list(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
paths: Sequence[data.PathLike],
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
|
||||
@ -4,7 +4,7 @@ from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||
from batdetect2.postprocess.types import ClipDetections
|
||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||
class InferenceModule(LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
targets: TargetProtocol | None = None,
|
||||
roi_mapper: ROIMapperProtocol | None = None,
|
||||
output_transform: OutputTransformProtocol | None = None,
|
||||
|
||||
@ -27,6 +27,7 @@ from typing import Annotated, Literal
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from loguru import logger
|
||||
from pydantic import Field, TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
@ -52,7 +53,7 @@ from batdetect2.models.encoder import (
|
||||
build_encoder,
|
||||
)
|
||||
from batdetect2.models.types import (
|
||||
BackboneModel,
|
||||
BackboneProtocol,
|
||||
BottleneckProtocol,
|
||||
DecoderProtocol,
|
||||
EncoderProtocol,
|
||||
@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig):
|
||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||
|
||||
|
||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
||||
backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone")
|
||||
|
||||
|
||||
@add_import_config(backbone_registry)
|
||||
@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig):
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class UNetBackbone(BackboneModel):
|
||||
class UNetBackbone(torch.nn.Module):
|
||||
"""U-Net-style encoder-decoder backbone network.
|
||||
|
||||
Combines an encoder, a bottleneck, and a decoder into a single module
|
||||
@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel):
|
||||
|
||||
@backbone_registry.register(UNetBackboneConfig)
|
||||
@staticmethod
|
||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
||||
def from_config(config: UNetBackboneConfig) -> BackboneProtocol:
|
||||
encoder = build_encoder(
|
||||
in_channels=config.in_channels,
|
||||
input_height=config.input_height,
|
||||
@ -266,7 +267,7 @@ BackboneConfig = Annotated[
|
||||
]
|
||||
|
||||
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol:
|
||||
"""Build a backbone network from configuration.
|
||||
|
||||
Looks up the backbone class corresponding to ``config.name`` in the
|
||||
@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
||||
|
||||
Returns
|
||||
-------
|
||||
BackboneModel
|
||||
BackboneProtocol
|
||||
An initialised backbone module.
|
||||
"""
|
||||
config = config or UNetBackboneConfig()
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building model backbone with config: \n{}",
|
||||
lambda: config.to_yaml_string(),
|
||||
)
|
||||
return backbone_registry.build(config)
|
||||
|
||||
|
||||
|
||||
@ -1,21 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple, Protocol
|
||||
from typing import Any, NamedTuple, Protocol
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"BackboneModel",
|
||||
"BackboneProtocol",
|
||||
"BlockProtocol",
|
||||
"BottleneckProtocol",
|
||||
"ClassifierHeadProtocol",
|
||||
"DecoderProtocol",
|
||||
"DetectionModel",
|
||||
"EncoderDecoderModel",
|
||||
"DetectorProtocol",
|
||||
"EncoderProtocol",
|
||||
"ModelOutput",
|
||||
"ModelProtocol",
|
||||
"ModuleProtocol",
|
||||
"SizeHeadProtocol",
|
||||
]
|
||||
|
||||
|
||||
class BlockProtocol(Protocol):
|
||||
class ModuleProtocol(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def train(self, mode: bool = True) -> torch.nn.Module: ...
|
||||
|
||||
def eval(self) -> torch.nn.Module: ...
|
||||
|
||||
def state_dict(
|
||||
self, *args: Any, **kwargs: Any
|
||||
) -> dict[str, torch.Tensor]: ...
|
||||
|
||||
def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
def parameters(self) -> Any: ...
|
||||
|
||||
|
||||
class BlockProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
|
||||
@ -24,7 +45,7 @@ class BlockProtocol(Protocol):
|
||||
def get_output_height(self, input_height: int) -> int: ...
|
||||
|
||||
|
||||
class EncoderProtocol(Protocol):
|
||||
class EncoderProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
@ -33,7 +54,7 @@ class EncoderProtocol(Protocol):
|
||||
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
||||
|
||||
|
||||
class BottleneckProtocol(Protocol):
|
||||
class BottleneckProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol):
|
||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DecoderProtocol(Protocol):
|
||||
class DecoderProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
out_channels: int
|
||||
input_height: int
|
||||
@ -62,29 +83,42 @@ class ModelOutput(NamedTuple):
|
||||
features: torch.Tensor
|
||||
|
||||
|
||||
class BackboneModel(ABC, torch.nn.Module):
|
||||
class BackboneProtocol(ModuleProtocol, Protocol):
|
||||
input_height: int
|
||||
out_channels: int
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class EncoderDecoderModel(BackboneModel):
|
||||
bottleneck_channels: int
|
||||
class ClassifierHeadProtocol(ModuleProtocol, Protocol):
|
||||
num_classes: int
|
||||
in_channels: int
|
||||
class_names: list[str]
|
||||
|
||||
@abstractmethod
|
||||
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
@abstractmethod
|
||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DetectionModel(ABC, torch.nn.Module):
|
||||
backbone: BackboneModel
|
||||
classifier_head: torch.nn.Module
|
||||
bbox_head: torch.nn.Module
|
||||
class SizeHeadProtocol(ModuleProtocol, Protocol):
|
||||
in_channels: int
|
||||
num_sizes: int
|
||||
dimension_names: list[str]
|
||||
|
||||
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||
|
||||
|
||||
class DetectorProtocol(ModuleProtocol, Protocol):
|
||||
backbone: BackboneProtocol
|
||||
classifier_head: ClassifierHeadProtocol
|
||||
size_head: SizeHeadProtocol
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||
|
||||
|
||||
class ModelProtocol(ModuleProtocol, Protocol):
|
||||
detector: DetectorProtocol
|
||||
preprocessor: PreprocessorProtocol
|
||||
postprocessor: PostprocessorProtocol
|
||||
class_names: list[str]
|
||||
dimension_names: list[str]
|
||||
|
||||
def get_config(self) -> dict[str, Any]: ...
|
||||
|
||||
@ -4,8 +4,8 @@ import lightning as L
|
||||
import torch
|
||||
from soundevent.data import PathLike
|
||||
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelOutput
|
||||
from batdetect2.models import ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelOutput, ModelProtocol
|
||||
from batdetect2.targets import TargetConfig
|
||||
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
@ -21,7 +21,7 @@ __all__ = [
|
||||
|
||||
|
||||
class TrainingModule(L.LightningModule):
|
||||
model: Model
|
||||
model: ModelProtocol
|
||||
loss: LossProtocol
|
||||
|
||||
def __init__(
|
||||
@ -32,7 +32,7 @@ class TrainingModule(L.LightningModule):
|
||||
dimension_names: list[str] | None = None,
|
||||
train_config: dict | None = None,
|
||||
loss: LossProtocol | None = None,
|
||||
model: Model | None = None,
|
||||
model: ModelProtocol | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -133,7 +133,7 @@ class StoredConfig:
|
||||
|
||||
def load_model_from_checkpoint(
|
||||
path: PathLike | str | None = None,
|
||||
) -> tuple[Model, StoredConfig]:
|
||||
) -> tuple[ModelProtocol, StoredConfig]:
|
||||
"""Load a model and its configuration from a Lightning checkpoint.
|
||||
|
||||
Parameters
|
||||
@ -144,7 +144,7 @@ def load_model_from_checkpoint(
|
||||
|
||||
Returns
|
||||
-------
|
||||
tuple[Model, ModelConfig]
|
||||
tuple[ModelProtocol, ModelConfig]
|
||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||
describes its architecture, preprocessing, and postprocessing.
|
||||
"""
|
||||
@ -169,7 +169,7 @@ def build_training_module(
|
||||
class_names: list[str] | None = None,
|
||||
dimension_names: list[str] | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
model: Model | None = None,
|
||||
model: ModelProtocol | None = None,
|
||||
) -> TrainingModule:
|
||||
if model_config is None:
|
||||
model_config = ModelConfig()
|
||||
|
||||
@ -15,7 +15,8 @@ from batdetect2.logging import (
|
||||
TensorBoardLoggerConfig,
|
||||
build_logger,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.models import ModelConfig, build_model
|
||||
from batdetect2.models.types import ModelProtocol
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
@ -50,7 +51,7 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
||||
def run_train(
|
||||
train_annotations: Sequence[data.ClipAnnotation],
|
||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||
model: Model | None = None,
|
||||
model: ModelProtocol | None = None,
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
@ -217,7 +218,7 @@ def run_train(
|
||||
|
||||
|
||||
def _validate_model_compatibility(
|
||||
model: Model,
|
||||
model: ModelProtocol,
|
||||
model_config: ModelConfig,
|
||||
class_names: list[str],
|
||||
dimension_names: list[str],
|
||||
|
||||
@ -13,7 +13,6 @@ from batdetect2.models.backbones import (
|
||||
build_backbone,
|
||||
load_backbone_config,
|
||||
)
|
||||
from batdetect2.models.types import BackboneModel
|
||||
|
||||
|
||||
def test_unet_backbone_config_defaults():
|
||||
@ -61,10 +60,11 @@ def test_build_backbone_custom_config():
|
||||
assert backbone.encoder.in_channels == 2
|
||||
|
||||
|
||||
def test_build_backbone_returns_backbone_model():
|
||||
"""build_backbone always returns a BackboneModel instance."""
|
||||
def test_build_backbone_returns_unet_backbone():
|
||||
"""build_backbone returns the default UNet backbone."""
|
||||
backbone = build_backbone()
|
||||
assert isinstance(backbone, BackboneModel)
|
||||
|
||||
assert isinstance(backbone, UNetBackbone)
|
||||
|
||||
|
||||
def test_registry_has_unet_backbone():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user