From eec126a5029763f7ced81cbd71ff222b45cfe00e Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Mon, 4 May 2026 21:20:02 +0100 Subject: [PATCH] fix: align cli and helpers with model refactor --- src/batdetect2/cli/train.py | 34 ++++++------------- .../data/conditions/sound_events.py | 2 +- src/batdetect2/logging.py | 3 -- src/batdetect2/models/blocks.py | 8 +++-- 4 files changed, 17 insertions(+), 30 deletions(-) diff --git a/src/batdetect2/cli/train.py b/src/batdetect2/cli/train.py index e169e3e..ac2321e 100644 --- a/src/batdetect2/cli/train.py +++ b/src/batdetect2/cli/train.py @@ -143,7 +143,6 @@ def train_command( """ from batdetect2.api_v2 import BatDetect2API from batdetect2.audio import AudioConfig - from batdetect2.config import BatDetect2Config from batdetect2.data import load_dataset_from_config from batdetect2.evaluate import EvaluationConfig from batdetect2.inference import InferenceConfig @@ -196,9 +195,6 @@ def train_command( if target_conf is not None: logger.info("Loaded targets configuration.") - if model_conf is not None and target_conf is not None: - model_conf = model_conf.model_copy(update={"targets": target_conf}) - logger.info("Loading training dataset...") train_annotations = load_dataset_from_config( train_dataset, @@ -231,26 +227,16 @@ def train_command( ) if model_path is None: - conf = BatDetect2Config() - if model_conf is not None: - conf.model = model_conf - elif target_conf is not None: - conf.model = conf.model.model_copy(update={"targets": target_conf}) - - if train_conf is not None: - conf.train = train_conf - if audio_conf is not None: - conf.audio = audio_conf - if eval_conf is not None: - conf.evaluation = eval_conf - if inference_conf is not None: - conf.inference = inference_conf - if outputs_conf is not None: - conf.outputs = outputs_conf - if logging_conf is not None: - conf.logging = logging_conf - - api = BatDetect2API.from_config(conf) + api = BatDetect2API.from_config( + model_config=model_conf, + targets_config=target_conf, + train_config=train_conf, + audio_config=audio_conf, + evaluation_config=eval_conf, + inference_config=inference_conf, + outputs_config=outputs_conf, + logging_config=logging_conf, + ) else: api = BatDetect2API.from_checkpoint( model_path, diff --git a/src/batdetect2/data/conditions/sound_events.py b/src/batdetect2/data/conditions/sound_events.py index 0597906..fe47710 100644 --- a/src/batdetect2/data/conditions/sound_events.py +++ b/src/batdetect2/data/conditions/sound_events.py @@ -94,7 +94,7 @@ def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]: return partial(operator.ge, value) if op == "eq": - return partial(operator.eq, b=value) + return partial(operator.eq, value) raise ValueError(f"Invalid operator {op}") diff --git a/src/batdetect2/logging.py b/src/batdetect2/logging.py index 0080bef..6d14980 100644 --- a/src/batdetect2/logging.py +++ b/src/batdetect2/logging.py @@ -25,10 +25,7 @@ if TYPE_CHECKING: import numpy as np import pandas as pd from lightning.pytorch.loggers import ( - CSVLogger, Logger, - MLFlowLogger, - TensorBoardLogger, ) from matplotlib.figure import Figure from soundevent import data diff --git a/src/batdetect2/models/blocks.py b/src/batdetect2/models/blocks.py index 8bf1c69..ccce61c 100644 --- a/src/batdetect2/models/blocks.py +++ b/src/batdetect2/models/blocks.py @@ -53,8 +53,12 @@ import torch.nn.functional as F from pydantic import Field from torch import nn -from batdetect2.core import ImportConfig, Registry, add_import_config -from batdetect2.core.configs import BaseConfig +from batdetect2.core import ( + BaseConfig, + ImportConfig, + Registry, + add_import_config, +) __all__ = [ "BlockImportConfig",