diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index c88c7d2..f4f1959 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -40,8 +40,6 @@ class TrainingModule(L.LightningModule): self.learning_rate = learning_rate self.t_max = t_max - self.save_hyperparameters() - def forward(self, spec: torch.Tensor) -> ModelOutput: return self.detector(spec) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 8778498..d386e98 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -84,6 +84,8 @@ def train( ) logger = build_logger(config.logger) + if logger and hasattr(logger, 'log_hyperparams'): + logger.log_hyperparams(config.model_dump(exclude_none=True)) trainer = Trainer( **config.trainer.model_dump(exclude_none=True, exclude={"logger"}),