mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
fix: align cli and helpers with model refactor
This commit is contained in:
parent
57236fc82a
commit
eec126a502
@ -143,7 +143,6 @@ def train_command(
|
|||||||
"""
|
"""
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
from batdetect2.audio import AudioConfig
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.config import BatDetect2Config
|
|
||||||
from batdetect2.data import load_dataset_from_config
|
from batdetect2.data import load_dataset_from_config
|
||||||
from batdetect2.evaluate import EvaluationConfig
|
from batdetect2.evaluate import EvaluationConfig
|
||||||
from batdetect2.inference import InferenceConfig
|
from batdetect2.inference import InferenceConfig
|
||||||
@ -196,9 +195,6 @@ def train_command(
|
|||||||
if target_conf is not None:
|
if target_conf is not None:
|
||||||
logger.info("Loaded targets configuration.")
|
logger.info("Loaded targets configuration.")
|
||||||
|
|
||||||
if model_conf is not None and target_conf is not None:
|
|
||||||
model_conf = model_conf.model_copy(update={"targets": target_conf})
|
|
||||||
|
|
||||||
logger.info("Loading training dataset...")
|
logger.info("Loading training dataset...")
|
||||||
train_annotations = load_dataset_from_config(
|
train_annotations = load_dataset_from_config(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
@ -231,26 +227,16 @@ def train_command(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
conf = BatDetect2Config()
|
api = BatDetect2API.from_config(
|
||||||
if model_conf is not None:
|
model_config=model_conf,
|
||||||
conf.model = model_conf
|
targets_config=target_conf,
|
||||||
elif target_conf is not None:
|
train_config=train_conf,
|
||||||
conf.model = conf.model.model_copy(update={"targets": target_conf})
|
audio_config=audio_conf,
|
||||||
|
evaluation_config=eval_conf,
|
||||||
if train_conf is not None:
|
inference_config=inference_conf,
|
||||||
conf.train = train_conf
|
outputs_config=outputs_conf,
|
||||||
if audio_conf is not None:
|
logging_config=logging_conf,
|
||||||
conf.audio = audio_conf
|
)
|
||||||
if eval_conf is not None:
|
|
||||||
conf.evaluation = eval_conf
|
|
||||||
if inference_conf is not None:
|
|
||||||
conf.inference = inference_conf
|
|
||||||
if outputs_conf is not None:
|
|
||||||
conf.outputs = outputs_conf
|
|
||||||
if logging_conf is not None:
|
|
||||||
conf.logging = logging_conf
|
|
||||||
|
|
||||||
api = BatDetect2API.from_config(conf)
|
|
||||||
else:
|
else:
|
||||||
api = BatDetect2API.from_checkpoint(
|
api = BatDetect2API.from_checkpoint(
|
||||||
model_path,
|
model_path,
|
||||||
|
|||||||
@ -94,7 +94,7 @@ def _build_comparator(op: Operator, value: float) -> Callable[[float], bool]:
|
|||||||
return partial(operator.ge, value)
|
return partial(operator.ge, value)
|
||||||
|
|
||||||
if op == "eq":
|
if op == "eq":
|
||||||
return partial(operator.eq, b=value)
|
return partial(operator.eq, value)
|
||||||
|
|
||||||
raise ValueError(f"Invalid operator {op}")
|
raise ValueError(f"Invalid operator {op}")
|
||||||
|
|
||||||
|
|||||||
@ -25,10 +25,7 @@ if TYPE_CHECKING:
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from lightning.pytorch.loggers import (
|
from lightning.pytorch.loggers import (
|
||||||
CSVLogger,
|
|
||||||
Logger,
|
Logger,
|
||||||
MLFlowLogger,
|
|
||||||
TensorBoardLogger,
|
|
||||||
)
|
)
|
||||||
from matplotlib.figure import Figure
|
from matplotlib.figure import Figure
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|||||||
@ -53,8 +53,12 @@ import torch.nn.functional as F
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
from batdetect2.core import ImportConfig, Registry, add_import_config
|
from batdetect2.core import (
|
||||||
from batdetect2.core.configs import BaseConfig
|
BaseConfig,
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BlockImportConfig",
|
"BlockImportConfig",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user