Move the logging config out of the train/eval configs

This commit is contained in:
mbsantiago 2026-03-18 19:32:19 +00:00
parent bf5b88016a
commit 22a3d18d45
9 changed files with 81 additions and 10 deletions

View File

@ -22,7 +22,11 @@ from batdetect2.inference import (
process_file_list,
run_batch_inference,
)
from batdetect2.logging import DEFAULT_LOGS_DIR
from batdetect2.logging import (
DEFAULT_LOGS_DIR,
AppLoggingConfig,
LoggerConfig,
)
from batdetect2.models import (
Model,
ModelConfig,
@ -64,6 +68,7 @@ class BatDetect2API:
evaluation_config: EvaluationConfig,
inference_config: InferenceConfig,
outputs_config: OutputsConfig,
logging_config: AppLoggingConfig,
targets: TargetProtocol,
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
@ -79,6 +84,7 @@ class BatDetect2API:
self.evaluation_config = evaluation_config
self.inference_config = inference_config
self.outputs_config = outputs_config
self.logging_config = logging_config
self.targets = targets
self.audio_loader = audio_loader
self.preprocessor = preprocessor
@ -112,6 +118,7 @@ class BatDetect2API:
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
):
run_train(
train_annotations=train_annotations,
@ -131,6 +138,7 @@ class BatDetect2API:
seed=seed,
train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config,
logger_config=logger_config or self.logging_config.train,
)
return self
@ -152,6 +160,7 @@ class BatDetect2API:
model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
@ -175,6 +184,7 @@ class BatDetect2API:
seed=seed,
audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config,
logger_config=logger_config or self.logging_config.train,
)
return self
@ -189,6 +199,7 @@ class BatDetect2API:
audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
outputs_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]:
return run_evaluate(
self.model,
@ -199,6 +210,7 @@ class BatDetect2API:
audio_config=audio_config or self.audio_config,
evaluation_config=evaluation_config or self.evaluation_config,
output_config=outputs_config or self.outputs_config,
logger_config=logger_config or self.logging_config.evaluation,
num_workers=num_workers,
output_dir=output_dir,
experiment_name=experiment_name,
@ -486,6 +498,7 @@ class BatDetect2API:
evaluation_config=config.evaluation,
inference_config=config.inference,
outputs_config=config.outputs,
logging_config=config.logging,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
@ -506,6 +519,7 @@ class BatDetect2API:
evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
model, model_config = load_model_from_checkpoint(path)
@ -516,6 +530,7 @@ class BatDetect2API:
evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
if (
targets_config is not None
@ -571,6 +586,7 @@ class BatDetect2API:
evaluation_config=evaluation_config,
inference_config=inference_config,
outputs_config=outputs_config,
logging_config=logging_config,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,

View File

@ -19,6 +19,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str)
@ -33,6 +34,7 @@ def evaluate_command(
evaluation_config: Path | None,
inference_config: Path | None,
outputs_config: Path | None,
logging_config: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: int = 0,
experiment_name: str | None = None,
@ -43,6 +45,7 @@ def evaluate_command(
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig
from batdetect2.logging import load_logging_config
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config
@ -81,6 +84,11 @@ def evaluate_command(
if outputs_config is not None
else None
)
logging_conf = (
load_logging_config(logging_config)
if logging_config is not None
else None
)
api = BatDetect2API.from_checkpoint(
model_path,
@ -89,6 +97,7 @@ def evaluate_command(
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
api.evaluate(

View File

@ -19,6 +19,7 @@ __all__ = ["train_command"]
@click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True))
@click.option("--train-workers", type=int)
@ -40,6 +41,7 @@ def train_command(
evaluation_config: Path | None = None,
inference_config: Path | None = None,
outputs_config: Path | None = None,
logging_config: Path | None = None,
seed: int | None = None,
num_epochs: int | None = None,
train_workers: int = 0,
@ -53,6 +55,7 @@ def train_command(
from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig
from batdetect2.logging import load_logging_config
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config
@ -92,6 +95,11 @@ def train_command(
if outputs_config is not None
else None
)
logging_conf = (
load_logging_config(logging_config)
if logging_config is not None
else None
)
if target_conf is not None:
logger.info("Loaded targets configuration.")
@ -141,6 +149,8 @@ def train_command(
conf.inference = inference_conf
if outputs_conf is not None:
conf.outputs = outputs_conf
if logging_conf is not None:
conf.logging = logging_conf
api = BatDetect2API.from_config(conf)
else:
@ -152,6 +162,7 @@ def train_command(
evaluation_config=eval_conf,
inference_config=inference_conf,
outputs_config=outputs_conf,
logging_config=logging_conf,
)
return api.train(

View File

@ -10,6 +10,7 @@ from batdetect2.evaluate.config import (
get_default_eval_config,
)
from batdetect2.inference.config import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig
@ -32,6 +33,7 @@ class BatDetect2Config(BaseConfig):
audio: AudioConfig = Field(default_factory=AudioConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
def validate_config(config: dict | None) -> BatDetect2Config:

View File

@ -7,7 +7,6 @@ from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.tasks import TaskConfig
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [
"EvaluationConfig",
@ -22,7 +21,6 @@ class EvaluationConfig(BaseConfig):
ClassificationTaskConfig(),
]
)
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def get_default_eval_config() -> EvaluationConfig:

View File

@ -10,7 +10,7 @@ from batdetect2.evaluate import EvaluationConfig
from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
from batdetect2.models import Model
from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol
@ -30,6 +30,7 @@ def run_evaluate(
audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None,
output_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
formatter: OutputFormatterProtocol | None = None,
num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
@ -64,7 +65,7 @@ def run_evaluate(
)
logger = build_logger(
evaluation_config.logger,
logger_config or CSVLoggerConfig(),
log_dir=Path(output_dir),
experiment_name=experiment_name,
run_name=run_name,

View File

@ -26,10 +26,26 @@ from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.configs import BaseConfig, load_config
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
__all__ = [
"AppLoggingConfig",
"BaseLoggerConfig",
"CSVLoggerConfig",
"DEFAULT_LOGS_DIR",
"DVCLiveConfig",
"LoggerConfig",
"MLFlowLoggerConfig",
"TensorBoardLoggerConfig",
"build_logger",
"enable_logging",
"get_image_logger",
"get_table_logger",
"load_logging_config",
]
def enable_logging(level: int):
logger.remove()
@ -84,6 +100,19 @@ LoggerConfig = Annotated[
]
class AppLoggingConfig(BaseConfig):
train: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
evaluation: LoggerConfig = Field(default_factory=CSVLoggerConfig)
inference: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def load_logging_config(
path: data.PathLike,
field: str | None = None,
) -> AppLoggingConfig:
return load_config(path, schema=AppLoggingConfig, field=field)
T = TypeVar("T", bound=LoggerConfig, contravariant=True)

View File

@ -3,7 +3,6 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
from batdetect2.train.checkpoints import CheckpointConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig
@ -50,7 +49,6 @@ class TrainingConfig(BaseConfig):
)
loss: LossConfig = Field(default_factory=LossConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)

View File

@ -10,7 +10,11 @@ from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import build_evaluator
from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import build_logger
from batdetect2.logging import (
LoggerConfig,
TensorBoardLoggerConfig,
build_logger,
)
from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol
@ -41,6 +45,7 @@ def run_train(
audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None,
train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None,
trainer: Trainer | None = None,
train_workers: int = 0,
val_workers: int = 0,
@ -113,6 +118,7 @@ def run_train(
trainer = trainer or build_trainer(
train_config,
logger_config=logger_config,
evaluator=build_evaluator(
train_config.validation,
targets=targets,
@ -180,6 +186,7 @@ def _validate_model_compatibility(
def build_trainer(
config: TrainingConfig,
logger_config: LoggerConfig | None,
evaluator: "EvaluatorProtocol",
checkpoint_dir: Path | None = None,
log_dir: Path | None = None,
@ -194,7 +201,7 @@ def build_trainer(
)
train_logger = build_logger(
config.logger,
logger_config or TensorBoardLoggerConfig(),
log_dir=log_dir,
experiment_name=experiment_name,
run_name=run_name,