Moved lightning module to root module

This commit is contained in:
mbsantiago 2025-04-03 16:48:31 +01:00
parent d9f7304a0f
commit 22cf47ed39

View File

@ -1,7 +1,7 @@
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import pytorch_lightning as L import lightning as L
import torch import torch
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
@ -10,11 +10,12 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig from batdetect2.configs import BaseConfig
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
from batdetect2.models import ( from batdetect2.models import (
BBoxHead, BBoxHead,
ClassifierHead, ClassifierHead,
ModelConfig, ModelConfig,
get_backbone, build_architecture,
) )
from batdetect2.models.typing import ModelOutput from batdetect2.models.typing import ModelOutput
from batdetect2.post_process import ( from batdetect2.post_process import (
@ -22,9 +23,9 @@ from batdetect2.post_process import (
postprocess_model_outputs, postprocess_model_outputs,
) )
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, TrainExample from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.evaluate import match_predictions_and_annotations from batdetect2.train.losses import compute_loss
from batdetect2.train.losses import LossConfig, compute_loss
from batdetect2.train.targets import ( from batdetect2.train.targets import (
TargetConfig, TargetConfig,
build_decoder, build_decoder,
@ -32,22 +33,15 @@ from batdetect2.train.targets import (
get_class_names, get_class_names,
) )
__all__ = [
class OptimizerConfig(BaseConfig): "DetectorModel",
learning_rate: float = 1e-3 ]
t_max: int = 100
class TrainingConfig(BaseConfig):
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
class ModuleConfig(BaseConfig): class ModuleConfig(BaseConfig):
train: TrainingConfig = Field(default_factory=TrainingConfig) train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig) targets: TargetConfig = Field(default_factory=TargetConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) architecture: ModelConfig = Field(default_factory=ModelConfig)
backbone: ModelConfig = Field(default_factory=ModelConfig)
preprocessing: PreprocessingConfig = Field( preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig default_factory=PreprocessingConfig
) )
@ -64,12 +58,11 @@ class DetectorModel(L.LightningModule):
config: Optional[ModuleConfig] = None, config: Optional[ModuleConfig] = None,
): ):
super().__init__() super().__init__()
self.config = config or ModuleConfig() self.config = config or ModuleConfig()
self.save_hyperparameters() self.save_hyperparameters()
size = self.config.preprocessing.spectrogram.size self.backbone = build_architecture(self.config.architecture)
assert size is not None
self.backbone = get_backbone(self.config.backbone)
self.classifier = ClassifierHead( self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes), num_classes=len(self.config.targets.classes),
@ -83,8 +76,7 @@ class DetectorModel(L.LightningModule):
torch.tensor(conf.class_weights) if conf.class_weights else None torch.tensor(conf.class_weights) if conf.class_weights else None
) )
self.validation_predictions = [] # Training targets
self.class_names = get_class_names(self.config.targets.classes) self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_encoder( self.encoder = build_encoder(
self.config.targets.classes, self.config.targets.classes,
@ -92,6 +84,10 @@ class DetectorModel(L.LightningModule):
) )
self.decoder = build_decoder(self.config.targets.classes) self.decoder = build_decoder(self.config.targets.classes)
self.validation_predictions = []
self.example_input_array = torch.randn([1, 1, 128, 512])
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.backbone(spec) features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features) detection_probs, classification_probs = self.classifier(features)