diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 5eecb4d..ddaa459 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger -from batdetect2.models import Model +from batdetect2.models.types import ModelProtocol from batdetect2.outputs import OutputsConfig, build_output_transform from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.postprocess.types import ClipDetections @@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations" def run_evaluate( - model: Model, + model: ModelProtocol, test_annotations: Sequence[data.ClipAnnotation], targets: TargetProtocol, roi_mapper: ROIMapperProtocol, diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index 48703d5..881b692 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -7,14 +7,14 @@ from torch.utils.data import DataLoader from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.logging import get_image_logger -from batdetect2.models import Model +from batdetect2.models.types import ModelProtocol from batdetect2.postprocess.types import ClipDetections class EvaluationModule(LightningModule): def __init__( self, - model: Model, + model: ModelProtocol, evaluator: EvaluatorProtocol, ): super().__init__() diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 4be9dab..d91df29 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files from batdetect2.inference.config import InferenceConfig from batdetect2.inference.dataset import build_inference_loader from batdetect2.inference.lightning import InferenceModule -from batdetect2.models import Model +from batdetect2.models.types import ModelProtocol from batdetect2.outputs import ( OutputsConfig, OutputTransformProtocol, @@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol def run_batch_inference( - model: Model, + model: ModelProtocol, clips: Sequence[data.Clip], targets: TargetProtocol | None = None, roi_mapper: ROIMapperProtocol | None = None, @@ -86,7 +86,7 @@ def run_batch_inference( def process_file_list( - model: Model, + model: ModelProtocol, paths: Sequence[data.PathLike], targets: TargetProtocol | None = None, roi_mapper: ROIMapperProtocol | None = None, diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py index 02cba66..d879d83 100644 --- a/src/batdetect2/inference/lightning.py +++ b/src/batdetect2/inference/lightning.py @@ -4,7 +4,7 @@ from lightning import LightningModule from torch.utils.data import DataLoader from batdetect2.inference.dataset import DatasetItem, InferenceDataset -from batdetect2.models import Model +from batdetect2.models.types import ModelProtocol from batdetect2.outputs import OutputTransformProtocol, build_output_transform from batdetect2.postprocess.types import ClipDetections from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol @@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol class InferenceModule(LightningModule): def __init__( self, - model: Model, + model: ModelProtocol, targets: TargetProtocol | None = None, roi_mapper: ROIMapperProtocol | None = None, output_transform: OutputTransformProtocol | None = None, diff --git a/src/batdetect2/models/backbones.py b/src/batdetect2/models/backbones.py index c3ef77c..03a62d3 100644 --- a/src/batdetect2/models/backbones.py +++ b/src/batdetect2/models/backbones.py @@ -27,6 +27,7 @@ from typing import Annotated, Literal import torch import torch.nn.functional as F +from loguru import logger from pydantic import Field, TypeAdapter from soundevent import data @@ -52,7 +53,7 @@ from batdetect2.models.encoder import ( build_encoder, ) from batdetect2.models.types import ( - BackboneModel, + BackboneProtocol, BottleneckProtocol, DecoderProtocol, EncoderProtocol, @@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig): decoder: DecoderConfig = DEFAULT_DECODER_CONFIG -backbone_registry: Registry[BackboneModel, []] = Registry("backbone") +backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone") @add_import_config(backbone_registry) @@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig): name: Literal["import"] = "import" -class UNetBackbone(BackboneModel): +class UNetBackbone(torch.nn.Module): """U-Net-style encoder-decoder backbone network. Combines an encoder, a bottleneck, and a decoder into a single module @@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel): @backbone_registry.register(UNetBackboneConfig) @staticmethod - def from_config(config: UNetBackboneConfig) -> BackboneModel: + def from_config(config: UNetBackboneConfig) -> BackboneProtocol: encoder = build_encoder( in_channels=config.in_channels, input_height=config.input_height, @@ -266,7 +267,7 @@ BackboneConfig = Annotated[ ] -def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: +def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol: """Build a backbone network from configuration. Looks up the backbone class corresponding to ``config.name`` in the @@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel: Returns ------- - BackboneModel + BackboneProtocol An initialised backbone module. """ config = config or UNetBackboneConfig() + logger.opt(lazy=True).debug( + "Building model backbone with config: \n{}", + lambda: config.to_yaml_string(), + ) return backbone_registry.build(config) diff --git a/src/batdetect2/models/types.py b/src/batdetect2/models/types.py index eb4302a..eb8a181 100644 --- a/src/batdetect2/models/types.py +++ b/src/batdetect2/models/types.py @@ -1,21 +1,42 @@ -from abc import ABC, abstractmethod -from typing import NamedTuple, Protocol +from typing import Any, NamedTuple, Protocol import torch +from batdetect2.postprocess.types import PostprocessorProtocol +from batdetect2.preprocess.types import PreprocessorProtocol + __all__ = [ - "BackboneModel", + "BackboneProtocol", "BlockProtocol", "BottleneckProtocol", + "ClassifierHeadProtocol", "DecoderProtocol", - "DetectionModel", - "EncoderDecoderModel", + "DetectorProtocol", "EncoderProtocol", "ModelOutput", + "ModelProtocol", + "ModuleProtocol", + "SizeHeadProtocol", ] -class BlockProtocol(Protocol): +class ModuleProtocol(Protocol): + def __call__(self, *args: Any, **kwargs: Any) -> Any: ... + + def train(self, mode: bool = True) -> torch.nn.Module: ... + + def eval(self) -> torch.nn.Module: ... + + def state_dict( + self, *args: Any, **kwargs: Any + ) -> dict[str, torch.Tensor]: ... + + def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ... + + def parameters(self) -> Any: ... + + +class BlockProtocol(ModuleProtocol, Protocol): in_channels: int out_channels: int @@ -24,7 +45,7 @@ class BlockProtocol(Protocol): def get_output_height(self, input_height: int) -> int: ... -class EncoderProtocol(Protocol): +class EncoderProtocol(ModuleProtocol, Protocol): in_channels: int out_channels: int input_height: int @@ -33,7 +54,7 @@ class EncoderProtocol(Protocol): def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ... -class BottleneckProtocol(Protocol): +class BottleneckProtocol(ModuleProtocol, Protocol): in_channels: int out_channels: int input_height: int @@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol): def __call__(self, x: torch.Tensor) -> torch.Tensor: ... -class DecoderProtocol(Protocol): +class DecoderProtocol(ModuleProtocol, Protocol): in_channels: int out_channels: int input_height: int @@ -62,29 +83,42 @@ class ModelOutput(NamedTuple): features: torch.Tensor -class BackboneModel(ABC, torch.nn.Module): +class BackboneProtocol(ModuleProtocol, Protocol): input_height: int out_channels: int - @abstractmethod - def forward(self, spec: torch.Tensor) -> torch.Tensor: - raise NotImplementedError + def forward(self, spec: torch.Tensor) -> torch.Tensor: ... -class EncoderDecoderModel(BackboneModel): - bottleneck_channels: int +class ClassifierHeadProtocol(ModuleProtocol, Protocol): + num_classes: int + in_channels: int + class_names: list[str] - @abstractmethod - def encode(self, spec: torch.Tensor) -> torch.Tensor: ... - - @abstractmethod - def decode(self, encoded: torch.Tensor) -> torch.Tensor: ... + def forward(self, features: torch.Tensor) -> torch.Tensor: ... -class DetectionModel(ABC, torch.nn.Module): - backbone: BackboneModel - classifier_head: torch.nn.Module - bbox_head: torch.nn.Module +class SizeHeadProtocol(ModuleProtocol, Protocol): + in_channels: int + num_sizes: int + dimension_names: list[str] + + def forward(self, features: torch.Tensor) -> torch.Tensor: ... + + +class DetectorProtocol(ModuleProtocol, Protocol): + backbone: BackboneProtocol + classifier_head: ClassifierHeadProtocol + size_head: SizeHeadProtocol - @abstractmethod def forward(self, spec: torch.Tensor) -> ModelOutput: ... + + +class ModelProtocol(ModuleProtocol, Protocol): + detector: DetectorProtocol + preprocessor: PreprocessorProtocol + postprocessor: PostprocessorProtocol + class_names: list[str] + dimension_names: list[str] + + def get_config(self) -> dict[str, Any]: ... diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index 30fe1af..3e7833e 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -4,8 +4,8 @@ import lightning as L import torch from soundevent.data import PathLike -from batdetect2.models import Model, ModelConfig, build_model -from batdetect2.models.types import ModelOutput +from batdetect2.models import ModelConfig, build_model +from batdetect2.models.types import ModelOutput, ModelProtocol from batdetect2.targets import TargetConfig from batdetect2.train.checkpoints import resolve_checkpoint_path from batdetect2.train.config import TrainingConfig @@ -21,7 +21,7 @@ __all__ = [ class TrainingModule(L.LightningModule): - model: Model + model: ModelProtocol loss: LossProtocol def __init__( @@ -32,7 +32,7 @@ class TrainingModule(L.LightningModule): dimension_names: list[str] | None = None, train_config: dict | None = None, loss: LossProtocol | None = None, - model: Model | None = None, + model: ModelProtocol | None = None, ): super().__init__() @@ -133,7 +133,7 @@ class StoredConfig: def load_model_from_checkpoint( path: PathLike | str | None = None, -) -> tuple[Model, StoredConfig]: +) -> tuple[ModelProtocol, StoredConfig]: """Load a model and its configuration from a Lightning checkpoint. Parameters @@ -144,7 +144,7 @@ def load_model_from_checkpoint( Returns ------- - tuple[Model, ModelConfig] + tuple[ModelProtocol, ModelConfig] The restored ``Model`` instance and the ``ModelConfig`` that describes its architecture, preprocessing, and postprocessing. """ @@ -169,7 +169,7 @@ def build_training_module( class_names: list[str] | None = None, dimension_names: list[str] | None = None, train_config: TrainingConfig | None = None, - model: Model | None = None, + model: ModelProtocol | None = None, ) -> TrainingModule: if model_config is None: model_config = ModelConfig() diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index c56d370..c1632d7 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -15,7 +15,8 @@ from batdetect2.logging import ( TensorBoardLoggerConfig, build_logger, ) -from batdetect2.models import Model, ModelConfig, build_model +from batdetect2.models import ModelConfig, build_model +from batdetect2.models.types import ModelProtocol from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor from batdetect2.targets import ( ROIMapperProtocol, @@ -50,7 +51,7 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs" def run_train( train_annotations: Sequence[data.ClipAnnotation], val_annotations: Sequence[data.ClipAnnotation] | None = None, - model: Model | None = None, + model: ModelProtocol | None = None, targets: Optional["TargetProtocol"] = None, roi_mapper: Optional["ROIMapperProtocol"] = None, preprocessor: Optional["PreprocessorProtocol"] = None, @@ -217,7 +218,7 @@ def run_train( def _validate_model_compatibility( - model: Model, + model: ModelProtocol, model_config: ModelConfig, class_names: list[str], dimension_names: list[str], diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index 9f121f9..99fc4e3 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -13,7 +13,6 @@ from batdetect2.models.backbones import ( build_backbone, load_backbone_config, ) -from batdetect2.models.types import BackboneModel def test_unet_backbone_config_defaults(): @@ -61,10 +60,11 @@ def test_build_backbone_custom_config(): assert backbone.encoder.in_channels == 2 -def test_build_backbone_returns_backbone_model(): - """build_backbone always returns a BackboneModel instance.""" +def test_build_backbone_returns_unet_backbone(): + """build_backbone returns the default UNet backbone.""" backbone = build_backbone() - assert isinstance(backbone, BackboneModel) + + assert isinstance(backbone, UNetBackbone) def test_registry_has_unet_backbone():