fix: align cli and helpers with model refactor

This commit is contained in:
mbsantiago 2026-05-04 21:20:02 +01:00
parent 57236fc82a
commit eec126a502
4 changed files with 17 additions and 30 deletions

View File

@ -143,7 +143,6 @@ def train_command(
"""
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import EvaluationConfig
from batdetect2.inference import InferenceConfig
@ -196,9 +195,6 @@ def train_command(
if target_conf is not None:
logger.info("Loaded targets configuration.")
if model_conf is not None and target_conf is not None:
model_conf = model_conf.model_copy(update={"targets": target_conf})
logger.info("Loading training dataset...")
train_annotations = load_dataset_from_config(
train_dataset,
@ -231,26 +227,16 @@ def train_command(
)
if model_path is None:
conf = BatDetect2Config()
if model_conf is not None:
conf.model = model_conf
elif target_conf is not None:
conf.model = conf.model.model_copy(update={"targets": target_conf})
if train_conf is not None:
conf.train = train_conf
if audio_conf is not None:
conf.audio = audio_conf
if eval_conf is not None:
conf.evaluation = eval_conf
if inference_conf is not None:
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
if logging_conf is not None:
conf.logging = logging_conf
api = BatDetect2API.from_config(conf)
api = BatDetect2API.from_config(
model_config=model_conf,
targets_config=target_conf,
train_config=train_conf,
audio_config=audio_conf,
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
else:
api = BatDetect2API.from_checkpoint(
model_path,

View File

@ -94,7 +94,7 @@ def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
return partial(operator.ge, value)
if op == "eq":
return partial(operator.eq, b=value)
return partial(operator.eq, value)
raise ValueError(f"Invalid operator {op}")

View File

@ -25,10 +25,7 @@ if TYPE_CHECKING:
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from matplotlib.figure import Figure
from soundevent import data

View File

@ -53,8 +53,12 @@ import torch.nn.functional as F
from pydantic import Field
from torch import nn
from batdetect2.core import ImportConfig, Registry, add_import_config
from batdetect2.core.configs import BaseConfig
from batdetect2.core import (
BaseConfig,
ImportConfig,
Registry,
add_import_config,
)
__all__ = [
"BlockImportConfig",