From 4b6acd5e6e6883bce415e3eeff165d2e0cc7a6d9 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Thu, 26 Jun 2025 11:59:33 -0600 Subject: [PATCH] Add manual logging of hyperparams --- src/batdetect2/train/lightning.py | 2 -- src/batdetect2/train/train.py | 2 ++ 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/batdetect2/train/lightning.py b/src/batdetect2/train/lightning.py index c88c7d2..f4f1959 100644 --- a/src/batdetect2/train/lightning.py +++ b/src/batdetect2/train/lightning.py @@ -40,8 +40,6 @@ class TrainingModule(L.LightningModule): self.learning_rate = learning_rate self.t_max = t_max - self.save_hyperparameters() - def forward(self, spec: torch.Tensor) -> ModelOutput: return self.detector(spec) diff --git a/src/batdetect2/train/train.py b/src/batdetect2/train/train.py index 8778498..d386e98 100644 --- a/src/batdetect2/train/train.py +++ b/src/batdetect2/train/train.py @@ -84,6 +84,8 @@ def train( ) logger = build_logger(config.logger) + if logger and hasattr(logger, 'log_hyperparams'): + logger.log_hyperparams(config.model_dump(exclude_none=True)) trainer = Trainer( **config.trainer.model_dump(exclude_none=True, exclude={"logger"}),