mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
refactor: replace abstract model types with protocols
This commit is contained in:
parent
a27d1bbfd3
commit
2008c8000f
@ -11,7 +11,7 @@ from batdetect2.evaluate.dataset import build_test_loader
|
|||||||
from batdetect2.evaluate.evaluator import build_evaluator
|
from batdetect2.evaluate.evaluator import build_evaluator
|
||||||
from batdetect2.evaluate.lightning import EvaluationModule
|
from batdetect2.evaluate.lightning import EvaluationModule
|
||||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
@ -22,7 +22,7 @@ DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
|||||||
|
|
||||||
|
|
||||||
def run_evaluate(
|
def run_evaluate(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
test_annotations: Sequence[data.ClipAnnotation],
|
test_annotations: Sequence[data.ClipAnnotation],
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
roi_mapper: ROIMapperProtocol,
|
roi_mapper: ROIMapperProtocol,
|
||||||
|
|||||||
@ -7,14 +7,14 @@ from torch.utils.data import DataLoader
|
|||||||
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
evaluator: EvaluatorProtocol,
|
evaluator: EvaluatorProtocol,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from batdetect2.inference.clips import get_clips_from_files
|
|||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.inference.dataset import build_inference_loader
|
from batdetect2.inference.dataset import build_inference_loader
|
||||||
from batdetect2.inference.lightning import InferenceModule
|
from batdetect2.inference.lightning import InferenceModule
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import (
|
from batdetect2.outputs import (
|
||||||
OutputsConfig,
|
OutputsConfig,
|
||||||
OutputTransformProtocol,
|
OutputTransformProtocol,
|
||||||
@ -22,7 +22,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
|||||||
|
|
||||||
|
|
||||||
def run_batch_inference(
|
def run_batch_inference(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
@ -86,7 +86,7 @@ def run_batch_inference(
|
|||||||
|
|
||||||
|
|
||||||
def process_file_list(
|
def process_file_list(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
paths: Sequence[data.PathLike],
|
paths: Sequence[data.PathLike],
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from lightning import LightningModule
|
|||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||||
from batdetect2.models import Model
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
from batdetect2.outputs import OutputTransformProtocol, build_output_transform
|
||||||
from batdetect2.postprocess.types import ClipDetections
|
from batdetect2.postprocess.types import ClipDetections
|
||||||
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
||||||
@ -13,7 +13,7 @@ from batdetect2.targets.types import ROIMapperProtocol, TargetProtocol
|
|||||||
class InferenceModule(LightningModule):
|
class InferenceModule(LightningModule):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
roi_mapper: ROIMapperProtocol | None = None,
|
roi_mapper: ROIMapperProtocol | None = None,
|
||||||
output_transform: OutputTransformProtocol | None = None,
|
output_transform: OutputTransformProtocol | None = None,
|
||||||
|
|||||||
@ -27,6 +27,7 @@ from typing import Annotated, Literal
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from loguru import logger
|
||||||
from pydantic import Field, TypeAdapter
|
from pydantic import Field, TypeAdapter
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
@ -52,7 +53,7 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.types import (
|
from batdetect2.models.types import (
|
||||||
BackboneModel,
|
BackboneProtocol,
|
||||||
BottleneckProtocol,
|
BottleneckProtocol,
|
||||||
DecoderProtocol,
|
DecoderProtocol,
|
||||||
EncoderProtocol,
|
EncoderProtocol,
|
||||||
@ -104,7 +105,7 @@ class UNetBackboneConfig(BaseConfig):
|
|||||||
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
decoder: DecoderConfig = DEFAULT_DECODER_CONFIG
|
||||||
|
|
||||||
|
|
||||||
backbone_registry: Registry[BackboneModel, []] = Registry("backbone")
|
backbone_registry: Registry[BackboneProtocol, []] = Registry("backbone")
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(backbone_registry)
|
@add_import_config(backbone_registry)
|
||||||
@ -118,7 +119,7 @@ class BackboneImportConfig(ImportConfig):
|
|||||||
name: Literal["import"] = "import"
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
class UNetBackbone(BackboneModel):
|
class UNetBackbone(torch.nn.Module):
|
||||||
"""U-Net-style encoder-decoder backbone network.
|
"""U-Net-style encoder-decoder backbone network.
|
||||||
|
|
||||||
Combines an encoder, a bottleneck, and a decoder into a single module
|
Combines an encoder, a bottleneck, and a decoder into a single module
|
||||||
@ -225,7 +226,7 @@ class UNetBackbone(BackboneModel):
|
|||||||
|
|
||||||
@backbone_registry.register(UNetBackboneConfig)
|
@backbone_registry.register(UNetBackboneConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_config(config: UNetBackboneConfig) -> BackboneModel:
|
def from_config(config: UNetBackboneConfig) -> BackboneProtocol:
|
||||||
encoder = build_encoder(
|
encoder = build_encoder(
|
||||||
in_channels=config.in_channels,
|
in_channels=config.in_channels,
|
||||||
input_height=config.input_height,
|
input_height=config.input_height,
|
||||||
@ -266,7 +267,7 @@ BackboneConfig = Annotated[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
def build_backbone(config: BackboneConfig | None = None) -> BackboneProtocol:
|
||||||
"""Build a backbone network from configuration.
|
"""Build a backbone network from configuration.
|
||||||
|
|
||||||
Looks up the backbone class corresponding to ``config.name`` in the
|
Looks up the backbone class corresponding to ``config.name`` in the
|
||||||
@ -282,10 +283,14 @@ def build_backbone(config: BackboneConfig | None = None) -> BackboneModel:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
BackboneModel
|
BackboneProtocol
|
||||||
An initialised backbone module.
|
An initialised backbone module.
|
||||||
"""
|
"""
|
||||||
config = config or UNetBackboneConfig()
|
config = config or UNetBackboneConfig()
|
||||||
|
logger.opt(lazy=True).debug(
|
||||||
|
"Building model backbone with config: \n{}",
|
||||||
|
lambda: config.to_yaml_string(),
|
||||||
|
)
|
||||||
return backbone_registry.build(config)
|
return backbone_registry.build(config)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,21 +1,42 @@
|
|||||||
from abc import ABC, abstractmethod
|
from typing import Any, NamedTuple, Protocol
|
||||||
from typing import NamedTuple, Protocol
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||||
|
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BackboneModel",
|
"BackboneProtocol",
|
||||||
"BlockProtocol",
|
"BlockProtocol",
|
||||||
"BottleneckProtocol",
|
"BottleneckProtocol",
|
||||||
|
"ClassifierHeadProtocol",
|
||||||
"DecoderProtocol",
|
"DecoderProtocol",
|
||||||
"DetectionModel",
|
"DetectorProtocol",
|
||||||
"EncoderDecoderModel",
|
|
||||||
"EncoderProtocol",
|
"EncoderProtocol",
|
||||||
"ModelOutput",
|
"ModelOutput",
|
||||||
|
"ModelProtocol",
|
||||||
|
"ModuleProtocol",
|
||||||
|
"SizeHeadProtocol",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class BlockProtocol(Protocol):
|
class ModuleProtocol(Protocol):
|
||||||
|
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
|
def train(self, mode: bool = True) -> torch.nn.Module: ...
|
||||||
|
|
||||||
|
def eval(self) -> torch.nn.Module: ...
|
||||||
|
|
||||||
|
def state_dict(
|
||||||
|
self, *args: Any, **kwargs: Any
|
||||||
|
) -> dict[str, torch.Tensor]: ...
|
||||||
|
|
||||||
|
def load_state_dict(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
|
|
||||||
|
def parameters(self) -> Any: ...
|
||||||
|
|
||||||
|
|
||||||
|
class BlockProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
|
|
||||||
@ -24,7 +45,7 @@ class BlockProtocol(Protocol):
|
|||||||
def get_output_height(self, input_height: int) -> int: ...
|
def get_output_height(self, input_height: int) -> int: ...
|
||||||
|
|
||||||
|
|
||||||
class EncoderProtocol(Protocol):
|
class EncoderProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
input_height: int
|
input_height: int
|
||||||
@ -33,7 +54,7 @@ class EncoderProtocol(Protocol):
|
|||||||
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
def __call__(self, x: torch.Tensor) -> list[torch.Tensor]: ...
|
||||||
|
|
||||||
|
|
||||||
class BottleneckProtocol(Protocol):
|
class BottleneckProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
input_height: int
|
input_height: int
|
||||||
@ -41,7 +62,7 @@ class BottleneckProtocol(Protocol):
|
|||||||
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
def __call__(self, x: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
class DecoderProtocol(Protocol):
|
class DecoderProtocol(ModuleProtocol, Protocol):
|
||||||
in_channels: int
|
in_channels: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
input_height: int
|
input_height: int
|
||||||
@ -62,29 +83,42 @@ class ModelOutput(NamedTuple):
|
|||||||
features: torch.Tensor
|
features: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class BackboneModel(ABC, torch.nn.Module):
|
class BackboneProtocol(ModuleProtocol, Protocol):
|
||||||
input_height: int
|
input_height: int
|
||||||
out_channels: int
|
out_channels: int
|
||||||
|
|
||||||
@abstractmethod
|
def forward(self, spec: torch.Tensor) -> torch.Tensor: ...
|
||||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class EncoderDecoderModel(BackboneModel):
|
class ClassifierHeadProtocol(ModuleProtocol, Protocol):
|
||||||
bottleneck_channels: int
|
num_classes: int
|
||||||
|
in_channels: int
|
||||||
|
class_names: list[str]
|
||||||
|
|
||||||
@abstractmethod
|
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||||
def encode(self, spec: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def decode(self, encoded: torch.Tensor) -> torch.Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
class DetectionModel(ABC, torch.nn.Module):
|
class SizeHeadProtocol(ModuleProtocol, Protocol):
|
||||||
backbone: BackboneModel
|
in_channels: int
|
||||||
classifier_head: torch.nn.Module
|
num_sizes: int
|
||||||
bbox_head: torch.nn.Module
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def forward(self, features: torch.Tensor) -> torch.Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorProtocol(ModuleProtocol, Protocol):
|
||||||
|
backbone: BackboneProtocol
|
||||||
|
classifier_head: ClassifierHeadProtocol
|
||||||
|
size_head: SizeHeadProtocol
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
def forward(self, spec: torch.Tensor) -> ModelOutput: ...
|
||||||
|
|
||||||
|
|
||||||
|
class ModelProtocol(ModuleProtocol, Protocol):
|
||||||
|
detector: DetectorProtocol
|
||||||
|
preprocessor: PreprocessorProtocol
|
||||||
|
postprocessor: PostprocessorProtocol
|
||||||
|
class_names: list[str]
|
||||||
|
dimension_names: list[str]
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]: ...
|
||||||
|
|||||||
@ -4,8 +4,8 @@ import lightning as L
|
|||||||
import torch
|
import torch
|
||||||
from soundevent.data import PathLike
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import ModelConfig, build_model
|
||||||
from batdetect2.models.types import ModelOutput
|
from batdetect2.models.types import ModelOutput, ModelProtocol
|
||||||
from batdetect2.targets import TargetConfig
|
from batdetect2.targets import TargetConfig
|
||||||
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
from batdetect2.train.checkpoints import resolve_checkpoint_path
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
@ -21,7 +21,7 @@ __all__ = [
|
|||||||
|
|
||||||
|
|
||||||
class TrainingModule(L.LightningModule):
|
class TrainingModule(L.LightningModule):
|
||||||
model: Model
|
model: ModelProtocol
|
||||||
loss: LossProtocol
|
loss: LossProtocol
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -32,7 +32,7 @@ class TrainingModule(L.LightningModule):
|
|||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
train_config: dict | None = None,
|
train_config: dict | None = None,
|
||||||
loss: LossProtocol | None = None,
|
loss: LossProtocol | None = None,
|
||||||
model: Model | None = None,
|
model: ModelProtocol | None = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -133,7 +133,7 @@ class StoredConfig:
|
|||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike | str | None = None,
|
path: PathLike | str | None = None,
|
||||||
) -> tuple[Model, StoredConfig]:
|
) -> tuple[ModelProtocol, StoredConfig]:
|
||||||
"""Load a model and its configuration from a Lightning checkpoint.
|
"""Load a model and its configuration from a Lightning checkpoint.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -144,7 +144,7 @@ def load_model_from_checkpoint(
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
tuple[Model, ModelConfig]
|
tuple[ModelProtocol, ModelConfig]
|
||||||
The restored ``Model`` instance and the ``ModelConfig`` that
|
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||||
describes its architecture, preprocessing, and postprocessing.
|
describes its architecture, preprocessing, and postprocessing.
|
||||||
"""
|
"""
|
||||||
@ -169,7 +169,7 @@ def build_training_module(
|
|||||||
class_names: list[str] | None = None,
|
class_names: list[str] | None = None,
|
||||||
dimension_names: list[str] | None = None,
|
dimension_names: list[str] | None = None,
|
||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
model: Model | None = None,
|
model: ModelProtocol | None = None,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
if model_config is None:
|
if model_config is None:
|
||||||
model_config = ModelConfig()
|
model_config = ModelConfig()
|
||||||
|
|||||||
@ -15,7 +15,8 @@ from batdetect2.logging import (
|
|||||||
TensorBoardLoggerConfig,
|
TensorBoardLoggerConfig,
|
||||||
build_logger,
|
build_logger,
|
||||||
)
|
)
|
||||||
from batdetect2.models import Model, ModelConfig, build_model
|
from batdetect2.models import ModelConfig, build_model
|
||||||
|
from batdetect2.models.types import ModelProtocol
|
||||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
from batdetect2.targets import (
|
from batdetect2.targets import (
|
||||||
ROIMapperProtocol,
|
ROIMapperProtocol,
|
||||||
@ -50,7 +51,7 @@ DEFAULT_LOG_DIR = Path("outputs") / "logs"
|
|||||||
def run_train(
|
def run_train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
val_annotations: Sequence[data.ClipAnnotation] | None = None,
|
||||||
model: Model | None = None,
|
model: ModelProtocol | None = None,
|
||||||
targets: Optional["TargetProtocol"] = None,
|
targets: Optional["TargetProtocol"] = None,
|
||||||
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
roi_mapper: Optional["ROIMapperProtocol"] = None,
|
||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
@ -217,7 +218,7 @@ def run_train(
|
|||||||
|
|
||||||
|
|
||||||
def _validate_model_compatibility(
|
def _validate_model_compatibility(
|
||||||
model: Model,
|
model: ModelProtocol,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
class_names: list[str],
|
class_names: list[str],
|
||||||
dimension_names: list[str],
|
dimension_names: list[str],
|
||||||
|
|||||||
@ -13,7 +13,6 @@ from batdetect2.models.backbones import (
|
|||||||
build_backbone,
|
build_backbone,
|
||||||
load_backbone_config,
|
load_backbone_config,
|
||||||
)
|
)
|
||||||
from batdetect2.models.types import BackboneModel
|
|
||||||
|
|
||||||
|
|
||||||
def test_unet_backbone_config_defaults():
|
def test_unet_backbone_config_defaults():
|
||||||
@ -61,10 +60,11 @@ def test_build_backbone_custom_config():
|
|||||||
assert backbone.encoder.in_channels == 2
|
assert backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
def test_build_backbone_returns_backbone_model():
|
def test_build_backbone_returns_unet_backbone():
|
||||||
"""build_backbone always returns a BackboneModel instance."""
|
"""build_backbone returns the default UNet backbone."""
|
||||||
backbone = build_backbone()
|
backbone = build_backbone()
|
||||||
assert isinstance(backbone, BackboneModel)
|
|
||||||
|
assert isinstance(backbone, UNetBackbone)
|
||||||
|
|
||||||
|
|
||||||
def test_registry_has_unet_backbone():
|
def test_registry_has_unet_backbone():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user