diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 9f872b5..0513d63 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -109,6 +109,7 @@ class BatDetect2API: ): from batdetect2.train import run_train + self.model.train() run_train( train_annotations=train_annotations, val_annotations=val_annotations, @@ -130,6 +131,7 @@ class BatDetect2API: audio_config=audio_config or self.audio_config, logger_config=logger_config or self.logging_config.train, ) + self.model.eval() return self def finetune(