fix: align cli and helpers with model refactor

This commit is contained in:
mbsantiago 2026-05-04 21:20:02 +01:00
parent 57236fc82a
commit eec126a502
4 changed files with 17 additions and 30 deletions

View File

@ -143,7 +143,6 @@ def train_command(
""" """
from batdetect2.api_v2 import BatDetect2API from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import EvaluationConfig from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig from batdetect2.inference import InferenceConfig
@ -196,9 +195,6 @@ def train_command(
if target_conf is not None: if target_conf is not None:
logger.info("Loaded targets configuration.") logger.info("Loaded targets configuration.")
if model_conf is not None and target_conf is not None:
model_conf = model_conf.model_copy(update={"targets": target_conf})
logger.info("Loading training dataset...") logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config( train_annotations = load_dataset_from_config(
train_dataset, train_dataset,
@ -231,26 +227,16 @@ def train_command(
) )
if model_path is None: if model_path is None:
conf = BatDetect2Config() api = BatDetect2API.from_config(
if model_conf is not None: model_config=model_conf,
conf.model = model_conf targets_config=target_conf,
elif target_conf is not None: train_config=train_conf,
conf.model = conf.model.model_copy(update={"targets": target_conf}) audio_config=audio_conf,
evaluation_config=eval_conf,
if train_conf is not None: inference_config=inference_conf,
conf.train = train_conf outputs_config=outputs_conf,
if audio_conf is not None: logging_config=logging_conf,
conf.audio = audio_conf )
if eval_conf is not None:
conf.evaluation = eval_conf
if inference_conf is not None:
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
if logging_conf is not None:
conf.logging = logging_conf
api = BatDetect2API.from_config(conf)
else: else:
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,

View File

@ -94,7 +94,7 @@ def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
return partial(operator.ge, value) return partial(operator.ge, value)
if op == "eq": if op == "eq":
return partial(operator.eq, b=value) return partial(operator.eq, value)
raise ValueError(f"Invalid operator {op}") raise ValueError(f"Invalid operator {op}")

View File

@ -25,10 +25,7 @@ if TYPE_CHECKING:
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from lightning.pytorch.loggers import ( from lightning.pytorch.loggers import (
CSVLogger,
Logger, Logger,
MLFlowLogger,
TensorBoardLogger,
) )
from matplotlib.figure import Figure from matplotlib.figure import Figure
from soundevent import data from soundevent import data

View File

@ -53,8 +53,12 @@ import torch.nn.functional as F
from pydantic import Field from pydantic import Field
from torch import nn from torch import nn
from batdetect2.core import ImportConfig, Registry, add_import_config from batdetect2.core import (
from batdetect2.core.configs import BaseConfig BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
__all__ = [ __all__ = [
"BlockImportConfig", "BlockImportConfig",