diff --git a/batdetect2/train/modules.py b/batdetect2/modules.py similarity index 88% rename from batdetect2/train/modules.py rename to batdetect2/modules.py index 6a6bb85..ae9fd68 100644 --- a/batdetect2/train/modules.py +++ b/batdetect2/modules.py @@ -1,7 +1,7 @@ from pathlib import Path from typing import Optional -import pytorch_lightning as L +import lightning as L import torch from pydantic import Field from soundevent import data @@ -10,11 +10,12 @@ from torch.optim.lr_scheduler import CosineAnnealingLR from torch.utils.data import DataLoader from batdetect2.configs import BaseConfig +from batdetect2.evaluate.evaluate import match_predictions_and_annotations from batdetect2.models import ( BBoxHead, ClassifierHead, ModelConfig, - get_backbone, + build_architecture, ) from batdetect2.models.typing import ModelOutput from batdetect2.post_process import ( @@ -22,9 +23,9 @@ from batdetect2.post_process import ( postprocess_model_outputs, ) from batdetect2.preprocess import PreprocessingConfig, preprocess_audio_clip +from batdetect2.train.config import TrainingConfig from batdetect2.train.dataset import LabeledDataset, TrainExample -from batdetect2.train.evaluate import match_predictions_and_annotations -from batdetect2.train.losses import LossConfig, compute_loss +from batdetect2.train.losses import compute_loss from batdetect2.train.targets import ( TargetConfig, build_decoder, @@ -32,22 +33,15 @@ from batdetect2.train.targets import ( get_class_names, ) - -class OptimizerConfig(BaseConfig): - learning_rate: float = 1e-3 - t_max: int = 100 - - -class TrainingConfig(BaseConfig): - loss: LossConfig = Field(default_factory=LossConfig) - optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) +__all__ = [ + "DetectorModel", +] class ModuleConfig(BaseConfig): train: TrainingConfig = Field(default_factory=TrainingConfig) targets: TargetConfig = Field(default_factory=TargetConfig) - optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig) - backbone: ModelConfig = Field(default_factory=ModelConfig) + architecture: ModelConfig = Field(default_factory=ModelConfig) preprocessing: PreprocessingConfig = Field( default_factory=PreprocessingConfig ) @@ -64,12 +58,11 @@ class DetectorModel(L.LightningModule): config: Optional[ModuleConfig] = None, ): super().__init__() + self.config = config or ModuleConfig() self.save_hyperparameters() - size = self.config.preprocessing.spectrogram.size - assert size is not None - self.backbone = get_backbone(self.config.backbone) + self.backbone = build_architecture(self.config.architecture) self.classifier = ClassifierHead( num_classes=len(self.config.targets.classes), @@ -83,8 +76,7 @@ class DetectorModel(L.LightningModule): torch.tensor(conf.class_weights) if conf.class_weights else None ) - self.validation_predictions = [] - + # Training targets self.class_names = get_class_names(self.config.targets.classes) self.encoder = build_encoder( self.config.targets.classes, @@ -92,6 +84,10 @@ class DetectorModel(L.LightningModule): ) self.decoder = build_decoder(self.config.targets.classes) + self.validation_predictions = [] + + self.example_input_array = torch.randn([1, 1, 128, 512]) + def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore features = self.backbone(spec) detection_probs, classification_probs = self.classifier(features)