mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Moved lightning module to root module
This commit is contained in:
parent
d9f7304a0f
commit
22cf47ed39
@ -1,7 +1,7 @@
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytorch_lightning as L
|
||||
import lightning as L
|
||||
import torch
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
@ -10,11 +10,12 @@ from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.configs import BaseConfig
|
||||
from batdetect2.evaluate.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.models import (
|
||||
BBoxHead,
|
||||
ClassifierHead,
|
||||
ModelConfig,
|
||||
get_backbone,
|
||||
build_architecture,
|
||||
)
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
from batdetect2.post_process import (
|
||||
@ -22,9 +23,9 @@ from batdetect2.post_process import (
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import LabeledDataset, TrainExample
|
||||
from batdetect2.train.evaluate import match_predictions_and_annotations
|
||||
from batdetect2.train.losses import LossConfig, compute_loss
|
||||
from batdetect2.train.losses import compute_loss
|
||||
from batdetect2.train.targets import (
|
||||
TargetConfig,
|
||||
build_decoder,
|
||||
@ -32,22 +33,15 @@ from batdetect2.train.targets import (
|
||||
get_class_names,
|
||||
)
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
__all__ = [
|
||||
"DetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class ModuleConfig(BaseConfig):
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
backbone: ModelConfig = Field(default_factory=ModelConfig)
|
||||
architecture: ModelConfig = Field(default_factory=ModelConfig)
|
||||
preprocessing: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
@ -64,12 +58,11 @@ class DetectorModel(L.LightningModule):
|
||||
config: Optional[ModuleConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.config = config or ModuleConfig()
|
||||
self.save_hyperparameters()
|
||||
|
||||
size = self.config.preprocessing.spectrogram.size
|
||||
assert size is not None
|
||||
self.backbone = get_backbone(self.config.backbone)
|
||||
self.backbone = build_architecture(self.config.architecture)
|
||||
|
||||
self.classifier = ClassifierHead(
|
||||
num_classes=len(self.config.targets.classes),
|
||||
@ -83,8 +76,7 @@ class DetectorModel(L.LightningModule):
|
||||
torch.tensor(conf.class_weights) if conf.class_weights else None
|
||||
)
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
# Training targets
|
||||
self.class_names = get_class_names(self.config.targets.classes)
|
||||
self.encoder = build_encoder(
|
||||
self.config.targets.classes,
|
||||
@ -92,6 +84,10 @@ class DetectorModel(L.LightningModule):
|
||||
)
|
||||
self.decoder = build_decoder(self.config.targets.classes)
|
||||
|
||||
self.validation_predictions = []
|
||||
|
||||
self.example_input_array = torch.randn([1, 1, 128, 512])
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.backbone(spec)
|
||||
detection_probs, classification_probs = self.classifier(features)
|
Loading…
Reference in New Issue
Block a user