Moved lightning module to root module

This commit is contained in:
mbsantiago 2025-04-03 16:48:31 +01:00
parent d9f7304a0f
commit 22cf47ed39

View File

@ -1,7 +1,7 @@
from pathlib import Path
from typing import Optional
import pytorch_lightning as L
import lightning as L
import torch
from pydantic import Field
from soundevent import data
@ -10,11 +10,12 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.configs import BaseConfig
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
from batdetect2.models import (
BBoxHead,
ClassifierHead,
ModelConfig,
get_backbone,
build_architecture,
)
from batdetect2.models.typing import ModelOutput
from batdetect2.post_process import (
@ -22,9 +23,9 @@ from batdetect2.post_process import (
postprocess_model_outputs,
)
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
from batdetect2.train.config import TrainingConfig
from batdetect2.train.dataset import LabeledDataset, TrainExample
from batdetect2.train.evaluate import match_predictions_and_annotations
from batdetect2.train.losses import LossConfig, compute_loss
from batdetect2.train.losses import compute_loss
from batdetect2.train.targets import (
TargetConfig,
build_decoder,
@ -32,22 +33,15 @@ from batdetect2.train.targets import (
get_class_names,
)
class OptimizerConfig(BaseConfig):
learning_rate: float = 1e-3
t_max: int = 100
class TrainingConfig(BaseConfig):
loss: LossConfig = Field(default_factory=LossConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
__all__ = [
"DetectorModel",
]
class ModuleConfig(BaseConfig):
train: TrainingConfig = Field(default_factory=TrainingConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
backbone: ModelConfig = Field(default_factory=ModelConfig)
architecture: ModelConfig = Field(default_factory=ModelConfig)
preprocessing: PreprocessingConfig = Field(
default_factory=PreprocessingConfig
)
@ -64,12 +58,11 @@ class DetectorModel(L.LightningModule):
config: Optional[ModuleConfig] = None,
):
super().__init__()
self.config = config or ModuleConfig()
self.save_hyperparameters()
size = self.config.preprocessing.spectrogram.size
assert size is not None
self.backbone = get_backbone(self.config.backbone)
self.backbone = build_architecture(self.config.architecture)
self.classifier = ClassifierHead(
num_classes=len(self.config.targets.classes),
@ -83,8 +76,7 @@ class DetectorModel(L.LightningModule):
torch.tensor(conf.class_weights) if conf.class_weights else None
)
self.validation_predictions = []
# Training targets
self.class_names = get_class_names(self.config.targets.classes)
self.encoder = build_encoder(
self.config.targets.classes,
@ -92,6 +84,10 @@ class DetectorModel(L.LightningModule):
)
self.decoder = build_decoder(self.config.targets.classes)
self.validation_predictions = []
self.example_input_array = torch.randn([1, 1, 128, 512])
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.backbone(spec)
detection_probs, classification_probs = self.classifier(features)