mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Restructure model config
This commit is contained in:
parent
1a7c0b4b3a
commit
65bd0dc6ae
@ -282,18 +282,18 @@ class BatDetect2API:
|
|||||||
cls,
|
cls,
|
||||||
config: BatDetect2Config,
|
config: BatDetect2Config,
|
||||||
):
|
):
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.model.targets)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.preprocess,
|
config=config.model.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.postprocess,
|
config=config.model.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
@ -301,18 +301,7 @@ class BatDetect2API:
|
|||||||
# NOTE: Better to have a separate instance of
|
# NOTE: Better to have a separate instance of
|
||||||
# preprocessor and postprocessor as these may be moved
|
# preprocessor and postprocessor as these may be moved
|
||||||
# to another device.
|
# to another device.
|
||||||
model = build_model(
|
model = build_model(config=config.model)
|
||||||
config=config.model,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=build_preprocessor(
|
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
config=config.preprocess,
|
|
||||||
),
|
|
||||||
postprocessor=build_postprocessor(
|
|
||||||
preprocessor,
|
|
||||||
config=config.postprocess,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
formatter = build_output_formatter(targets, config=config.output)
|
formatter = build_output_formatter(targets, config=config.output)
|
||||||
|
|
||||||
@ -333,24 +322,30 @@ class BatDetect2API:
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
config: BatDetect2Config | None = None,
|
config: BatDetect2Config | None = None,
|
||||||
):
|
):
|
||||||
model, stored_config = load_model_from_checkpoint(path)
|
from batdetect2.audio import AudioConfig
|
||||||
|
|
||||||
config = (
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
merge_configs(stored_config, config) if config else stored_config
|
|
||||||
|
# Reconstruct a full BatDetect2Config from the checkpoint's
|
||||||
|
# ModelConfig, then overlay any caller-supplied overrides.
|
||||||
|
base = BatDetect2Config(
|
||||||
|
model=model_config,
|
||||||
|
audio=AudioConfig(samplerate=model_config.samplerate),
|
||||||
)
|
)
|
||||||
|
config = merge_configs(base, config) if config else base
|
||||||
|
|
||||||
targets = build_targets(config=config.targets)
|
targets = build_targets(config=config.model.targets)
|
||||||
|
|
||||||
audio_loader = build_audio_loader(config=config.audio)
|
audio_loader = build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
preprocessor = build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.preprocess,
|
config=config.model.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
postprocessor = build_postprocessor(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.postprocess,
|
config=config.model.postprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
evaluator = build_evaluator(config=config.evaluation, targets=targets)
|
||||||
|
|||||||
@ -12,10 +12,7 @@ from batdetect2.evaluate.config import (
|
|||||||
get_default_eval_config,
|
get_default_eval_config,
|
||||||
)
|
)
|
||||||
from batdetect2.inference.config import InferenceConfig
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.models.backbones import BackboneConfig, UNetBackboneConfig
|
from batdetect2.models import ModelConfig
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
|
||||||
from batdetect2.targets.config import TargetConfig
|
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -32,13 +29,8 @@ class BatDetect2Config(BaseConfig):
|
|||||||
evaluation: EvaluationConfig = Field(
|
evaluation: EvaluationConfig = Field(
|
||||||
default_factory=get_default_eval_config
|
default_factory=get_default_eval_config
|
||||||
)
|
)
|
||||||
model: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
model: ModelConfig = Field(default_factory=ModelConfig)
|
||||||
preprocess: PreprocessingConfig = Field(
|
|
||||||
default_factory=PreprocessingConfig
|
|
||||||
)
|
|
||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
|
||||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
output: OutputFormatConfig = Field(default_factory=RawOutputConfig)
|
||||||
|
|
||||||
|
|||||||
@ -45,12 +45,8 @@ def evaluate(
|
|||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or model.preprocessor
|
||||||
config=config.preprocess,
|
targets = targets or model.targets
|
||||||
input_samplerate=audio_loader.samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
targets = targets or build_targets(config=config.targets)
|
|
||||||
|
|
||||||
loader = build_test_loader(
|
loader = build_test_loader(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
|
|||||||
@ -27,7 +27,10 @@ is the ``build_model`` factory function exported from this module.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.audio.loader import TARGET_SAMPLERATE_HZ
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
from batdetect2.models.backbones import (
|
from batdetect2.models.backbones import (
|
||||||
BackboneConfig,
|
BackboneConfig,
|
||||||
UNetBackbone,
|
UNetBackbone,
|
||||||
@ -59,6 +62,9 @@ from batdetect2.models.encoder import (
|
|||||||
build_encoder,
|
build_encoder,
|
||||||
)
|
)
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead, DetectorHead
|
||||||
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
|
from batdetect2.targets.config import TargetConfig
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
ClipDetectionsTensor,
|
ClipDetectionsTensor,
|
||||||
DetectionModel,
|
DetectionModel,
|
||||||
@ -92,10 +98,50 @@ __all__ = [
|
|||||||
"build_detector",
|
"build_detector",
|
||||||
"load_backbone_config",
|
"load_backbone_config",
|
||||||
"Model",
|
"Model",
|
||||||
|
"ModelConfig",
|
||||||
"build_model",
|
"build_model",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseConfig):
|
||||||
|
"""Complete configuration describing a BatDetect2 model.
|
||||||
|
|
||||||
|
Bundles every parameter that defines a model's behaviour: the input
|
||||||
|
sample rate, backbone architecture, preprocessing pipeline,
|
||||||
|
postprocessing pipeline, and detection targets.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
samplerate : int
|
||||||
|
Expected input audio sample rate in Hz. Audio must be resampled
|
||||||
|
to this rate before being passed to the model. Defaults to
|
||||||
|
``TARGET_SAMPLERATE_HZ`` (256 000 Hz).
|
||||||
|
architecture : BackboneConfig
|
||||||
|
Configuration for the encoder-decoder backbone network. Defaults
|
||||||
|
to ``UNetBackboneConfig()``.
|
||||||
|
preprocess : PreprocessingConfig
|
||||||
|
Parameters for the audio-to-spectrogram preprocessing pipeline
|
||||||
|
(STFT, frequency crop, transforms, resize). Defaults to
|
||||||
|
``PreprocessingConfig()``.
|
||||||
|
postprocess : PostprocessConfig
|
||||||
|
Parameters for converting raw model outputs into detections (NMS
|
||||||
|
kernel, thresholds, top-k limit). Defaults to
|
||||||
|
``PostprocessConfig()``.
|
||||||
|
targets : TargetConfig
|
||||||
|
Detection and classification target definitions (class list,
|
||||||
|
detection target, bounding-box mapper). Defaults to
|
||||||
|
``TargetConfig()``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
architecture: BackboneConfig = Field(default_factory=UNetBackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
|
||||||
|
|
||||||
class Model(torch.nn.Module):
|
class Model(torch.nn.Module):
|
||||||
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
"""End-to-end BatDetect2 model wrapping preprocessing and postprocessing.
|
||||||
|
|
||||||
@ -166,55 +212,61 @@ class Model(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def build_model(
|
def build_model(
|
||||||
config: BackboneConfig | None = None,
|
config: ModelConfig | None = None,
|
||||||
targets: TargetProtocol | None = None,
|
targets: TargetProtocol | None = None,
|
||||||
preprocessor: PreprocessorProtocol | None = None,
|
preprocessor: PreprocessorProtocol | None = None,
|
||||||
postprocessor: PostprocessorProtocol | None = None,
|
postprocessor: PostprocessorProtocol | None = None,
|
||||||
) -> "Model":
|
) -> Model:
|
||||||
"""Build a complete, ready-to-use BatDetect2 model.
|
"""Build a complete, ready-to-use BatDetect2 model.
|
||||||
|
|
||||||
Assembles a ``Model`` instance from optional configuration and component
|
Assembles a ``Model`` instance from a ``ModelConfig`` and optional
|
||||||
overrides. Any argument left as ``None`` will be replaced by a sensible
|
component overrides. Any component argument left as ``None`` is built
|
||||||
default built with the project's own builder functions.
|
from the configuration. Passing a pre-built component overrides the
|
||||||
|
corresponding config fields for that component only.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
config : BackboneConfig, optional
|
config : ModelConfig, optional
|
||||||
Configuration describing the backbone architecture (encoder,
|
Full model configuration (samplerate, architecture, preprocessing,
|
||||||
bottleneck, decoder). Defaults to ``UNetBackboneConfig()`` if not
|
postprocessing, targets). Defaults to ``ModelConfig()`` if not
|
||||||
provided.
|
provided.
|
||||||
targets : TargetProtocol, optional
|
targets : TargetProtocol, optional
|
||||||
Describes the target bat species or call types to detect. Determines
|
Pre-built targets object. If given, overrides
|
||||||
the number of output classes. Defaults to the standard BatDetect2
|
``config.targets``.
|
||||||
target set.
|
|
||||||
preprocessor : PreprocessorProtocol, optional
|
preprocessor : PreprocessorProtocol, optional
|
||||||
Converts raw audio waveforms to spectrograms. Defaults to the
|
Pre-built preprocessor. If given, overrides
|
||||||
standard BatDetect2 preprocessor.
|
``config.preprocess`` and ``config.samplerate`` for the
|
||||||
|
preprocessing step.
|
||||||
postprocessor : PostprocessorProtocol, optional
|
postprocessor : PostprocessorProtocol, optional
|
||||||
Converts raw model outputs to detection tensors. Defaults to the
|
Pre-built postprocessor. If given, overrides
|
||||||
standard BatDetect2 postprocessor. If a custom ``preprocessor`` is
|
``config.postprocess``. When omitted and a custom
|
||||||
given without a matching ``postprocessor``, the default postprocessor
|
``preprocessor`` is supplied, the default postprocessor is built
|
||||||
will be built using the provided preprocessor so that frequency and
|
using that preprocessor so that frequency and time scaling remain
|
||||||
time scaling remain consistent.
|
consistent.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Model
|
Model
|
||||||
A fully assembled ``Model`` instance ready for inference or training.
|
A fully assembled ``Model`` instance ready for inference or
|
||||||
|
training.
|
||||||
"""
|
"""
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
|
|
||||||
config = config or UNetBackboneConfig()
|
config = config or ModelConfig()
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets(config=config.targets)
|
||||||
preprocessor = preprocessor or build_preprocessor()
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
|
config=config.preprocess,
|
||||||
|
input_samplerate=config.samplerate,
|
||||||
|
)
|
||||||
postprocessor = postprocessor or build_postprocessor(
|
postprocessor = postprocessor or build_postprocessor(
|
||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
|
config=config.postprocess,
|
||||||
)
|
)
|
||||||
detector = build_detector(
|
detector = build_detector(
|
||||||
num_classes=len(targets.class_names),
|
num_classes=len(targets.class_names),
|
||||||
config=config,
|
config=config.architecture,
|
||||||
)
|
)
|
||||||
return Model(
|
return Model(
|
||||||
detector=detector,
|
detector=detector,
|
||||||
|
|||||||
@ -921,20 +921,6 @@ class StandardConvUpBlock(Block):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LayerConfig = Annotated[
|
|
||||||
ConvConfig
|
|
||||||
| BlockImportConfig
|
|
||||||
| FreqCoordConvDownConfig
|
|
||||||
| StandardConvDownConfig
|
|
||||||
| FreqCoordConvUpConfig
|
|
||||||
| StandardConvUpConfig
|
|
||||||
| SelfAttentionConfig
|
|
||||||
| "LayerGroupConfig",
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
"""Type alias for the discriminated union of block configuration models."""
|
|
||||||
|
|
||||||
|
|
||||||
class LayerGroupConfig(BaseConfig):
|
class LayerGroupConfig(BaseConfig):
|
||||||
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
|
"""Configuration for a ``LayerGroup`` — a sequential chain of blocks.
|
||||||
|
|
||||||
@ -951,7 +937,20 @@ class LayerGroupConfig(BaseConfig):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: Literal["LayerGroup"] = "LayerGroup"
|
name: Literal["LayerGroup"] = "LayerGroup"
|
||||||
layers: list[LayerConfig]
|
layers: list["LayerConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
LayerConfig = Annotated[
|
||||||
|
ConvConfig
|
||||||
|
| FreqCoordConvDownConfig
|
||||||
|
| StandardConvDownConfig
|
||||||
|
| FreqCoordConvUpConfig
|
||||||
|
| StandardConvUpConfig
|
||||||
|
| SelfAttentionConfig
|
||||||
|
| LayerGroupConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
"""Type alias for the discriminated union of block configuration models."""
|
||||||
|
|
||||||
|
|
||||||
class LayerGroup(nn.Module):
|
class LayerGroup(nn.Module):
|
||||||
|
|||||||
@ -169,7 +169,7 @@ def build_detector(
|
|||||||
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building model with config: \n{}",
|
"Building model with config: \n{}",
|
||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(), # type: ignore
|
||||||
)
|
)
|
||||||
backbone = build_backbone(config=config)
|
backbone = build_backbone(config=config)
|
||||||
classifier_head = ClassifierHead(
|
classifier_head = ClassifierHead(
|
||||||
|
|||||||
@ -6,15 +6,12 @@ from soundevent.data import PathLike
|
|||||||
from torch.optim.adam import Adam
|
from torch.optim.adam import Adam
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, ModelConfig, build_model
|
||||||
from batdetect2.postprocess import build_postprocessor
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
from batdetect2.typing import ModelOutput, TrainExample
|
from batdetect2.typing import ModelOutput, TrainExample
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from batdetect2.config import BatDetect2Config
|
pass
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TrainingModule",
|
"TrainingModule",
|
||||||
@ -26,43 +23,25 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
t_max: int = 100,
|
t_max: int = 100,
|
||||||
model: Model | None = None,
|
learning_rate: float = 1e-3,
|
||||||
loss: torch.nn.Module | None = None,
|
loss: torch.nn.Module | None = None,
|
||||||
|
model: Model | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.config import validate_config
|
|
||||||
|
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.save_hyperparameters(logger=False)
|
self.save_hyperparameters(ignore=["model", "loss"], logger=False)
|
||||||
|
|
||||||
self.config = validate_config(config)
|
self.model_config = ModelConfig.model_validate(model_config or {})
|
||||||
self.input_samplerate = self.config.audio.samplerate
|
self.learning_rate = learning_rate
|
||||||
self.learning_rate = self.config.train.optimizer.learning_rate
|
|
||||||
self.t_max = t_max
|
self.t_max = t_max
|
||||||
|
|
||||||
if loss is None:
|
if loss is None:
|
||||||
loss = build_loss(self.config.train.loss)
|
loss = build_loss()
|
||||||
|
|
||||||
if model is None:
|
if model is None:
|
||||||
targets = build_targets(self.config.targets)
|
model = build_model(config=self.model_config)
|
||||||
|
|
||||||
preprocessor = build_preprocessor(
|
|
||||||
config=self.config.preprocess,
|
|
||||||
input_samplerate=self.input_samplerate,
|
|
||||||
)
|
|
||||||
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
preprocessor, config=self.config.postprocess
|
|
||||||
)
|
|
||||||
|
|
||||||
model = build_model(
|
|
||||||
config=self.config.model,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.model = model
|
self.model = model
|
||||||
@ -97,13 +76,39 @@ class TrainingModule(L.LightningModule):
|
|||||||
|
|
||||||
def load_model_from_checkpoint(
|
def load_model_from_checkpoint(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
) -> tuple[Model, "BatDetect2Config"]:
|
) -> tuple[Model, ModelConfig]:
|
||||||
|
"""Load a model and its configuration from a Lightning checkpoint.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path : PathLike
|
||||||
|
Path to a ``.ckpt`` file produced by the BatDetect2 training
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
tuple[Model, ModelConfig]
|
||||||
|
The restored ``Model`` instance and the ``ModelConfig`` that
|
||||||
|
describes its architecture, preprocessing, postprocessing, and
|
||||||
|
targets.
|
||||||
|
"""
|
||||||
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(path) # type: ignore
|
||||||
return module.model, module.config
|
return module.model, module.model_config
|
||||||
|
|
||||||
|
|
||||||
def build_training_module(
|
def build_training_module(
|
||||||
config: dict | None = None,
|
model_config: dict | None = None,
|
||||||
t_max: int = 200,
|
t_max: int = 200,
|
||||||
|
learning_rate: float = 1e-3,
|
||||||
|
loss_config: dict | None = None,
|
||||||
) -> TrainingModule:
|
) -> TrainingModule:
|
||||||
return TrainingModule(config=config, t_max=t_max)
|
from batdetect2.train.config import LossConfig
|
||||||
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
|
loss = build_loss(LossConfig.model_validate(loss_config or {}))
|
||||||
|
return TrainingModule(
|
||||||
|
model_config=model_config,
|
||||||
|
t_max=t_max,
|
||||||
|
learning_rate=learning_rate,
|
||||||
|
loss=loss,
|
||||||
|
)
|
||||||
|
|||||||
@ -58,13 +58,13 @@ def train(
|
|||||||
|
|
||||||
config = config or BatDetect2Config()
|
config = config or BatDetect2Config()
|
||||||
|
|
||||||
targets = targets or build_targets(config=config.targets)
|
targets = targets or build_targets(config=config.model.targets)
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
config=config.preprocess,
|
config=config.model.preprocess,
|
||||||
)
|
)
|
||||||
|
|
||||||
labeller = labeller or build_clip_labeler(
|
labeller = labeller or build_clip_labeler(
|
||||||
@ -97,8 +97,10 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
module = build_training_module(
|
module = build_training_module(
|
||||||
config.model_dump(mode="json"),
|
model_config=config.model.model_dump(mode="json"),
|
||||||
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
t_max=config.train.optimizer.t_max * len(train_dataloader),
|
||||||
|
learning_rate=config.train.optimizer.learning_rate,
|
||||||
|
loss_config=config.train.loss.model_dump(mode="json"),
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer = trainer or build_trainer(
|
trainer = trainer or build_trainer(
|
||||||
|
|||||||
@ -317,9 +317,11 @@ def convert_results(
|
|||||||
]
|
]
|
||||||
|
|
||||||
# combine into final results dictionary
|
# combine into final results dictionary
|
||||||
results: RunResults = RunResults({ # type: ignore
|
results: RunResults = RunResults( # type: ignore[missing-argument]
|
||||||
|
{
|
||||||
"pred_dict": pred_dict,
|
"pred_dict": pred_dict,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
|
|
||||||
# add spectrogram features if they exist
|
# add spectrogram features if they exist
|
||||||
if len(spec_feats) > 0 and params["spec_features"]:
|
if len(spec_feats) > 0 and params["spec_features"]:
|
||||||
|
|||||||
@ -2,8 +2,10 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.models import UNetBackbone
|
||||||
from batdetect2.models.backbones import UNetBackboneConfig
|
from batdetect2.models.backbones import UNetBackboneConfig
|
||||||
from batdetect2.models.detectors import Detector, build_detector
|
from batdetect2.models.detectors import Detector, build_detector
|
||||||
|
from batdetect2.models.encoder import Encoder
|
||||||
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
from batdetect2.models.heads import BBoxHead, ClassifierHead
|
||||||
from batdetect2.typing.models import ModelOutput
|
from batdetect2.typing.models import ModelOutput
|
||||||
|
|
||||||
@ -34,6 +36,8 @@ def test_build_detector_custom_config():
|
|||||||
|
|
||||||
assert isinstance(model, Detector)
|
assert isinstance(model, Detector)
|
||||||
assert model.backbone.input_height == 128
|
assert model.backbone.input_height == 128
|
||||||
|
|
||||||
|
assert isinstance(model.backbone.encoder, Encoder)
|
||||||
assert model.backbone.encoder.in_channels == 2
|
assert model.backbone.encoder.in_channels == 2
|
||||||
|
|
||||||
|
|
||||||
@ -80,6 +84,7 @@ def test_detector_forward_pass_shapes(dummy_spectrogram):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Check features shape: (B, out_channels, H, W)
|
# Check features shape: (B, out_channels, H, W)
|
||||||
|
assert isinstance(model.backbone, UNetBackbone)
|
||||||
out_channels = model.backbone.out_channels
|
out_channels = model.backbone.out_channels
|
||||||
assert output.features.shape == (
|
assert output.features.shape == (
|
||||||
batch_size,
|
batch_size,
|
||||||
|
|||||||
@ -12,7 +12,9 @@ from batdetect2.typing.preprocess import AudioLoader
|
|||||||
|
|
||||||
def build_default_module():
|
def build_default_module():
|
||||||
config = BatDetect2Config()
|
config = BatDetect2Config()
|
||||||
return build_training_module(config=config.model_dump())
|
return build_training_module(
|
||||||
|
model_config=config.model.model_dump(mode="json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_can_initialize_default_module():
|
def test_can_initialize_default_module():
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user