Move the logging config out of the train/eval configs

This commit is contained in:
mbsantiago 2026-03-18 19:32:19 +00:00
parent bf5b88016a
commit 22a3d18d45
9 changed files with 81 additions and 10 deletions

View File

@ -22,7 +22,11 @@ from batdetect2.inference import (
process_file_list, process_file_list,
run_batch_inference, run_batch_inference,
) )
from batdetect2.logging import DEFAULT_LOGS_DIR from batdetect2.logging import (
DEFAULT_LOGS_DIR,
AppLoggingConfig,
LoggerConfig,
)
from batdetect2.models import ( from batdetect2.models import (
Model, Model,
ModelConfig, ModelConfig,
@ -64,6 +68,7 @@ class BatDetect2API:
evaluation_config: EvaluationConfig, evaluation_config: EvaluationConfig,
inference_config: InferenceConfig, inference_config: InferenceConfig,
outputs_config: OutputsConfig, outputs_config: OutputsConfig,
logging_config: AppLoggingConfig,
targets: TargetProtocol, targets: TargetProtocol,
audio_loader: AudioLoader, audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol, preprocessor: PreprocessorProtocol,
@ -79,6 +84,7 @@ class BatDetect2API:
self.evaluation_config = evaluation_config self.evaluation_config = evaluation_config
self.inference_config = inference_config self.inference_config = inference_config
self.outputs_config = outputs_config self.outputs_config = outputs_config
self.logging_config = logging_config
self.targets = targets self.targets = targets
self.audio_loader = audio_loader self.audio_loader = audio_loader
self.preprocessor = preprocessor self.preprocessor = preprocessor
@ -112,6 +118,7 @@ class BatDetect2API:
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
): ):
run_train( run_train(
train_annotations=train_annotations, train_annotations=train_annotations,
@ -131,6 +138,7 @@ class BatDetect2API:
seed=seed, seed=seed,
train_config=train_config or self.train_config, train_config=train_config or self.train_config,
audio_config=audio_config or self.audio_config, audio_config=audio_config or self.audio_config,
logger_config=logger_config or self.logging_config.train,
) )
return self return self
@ -152,6 +160,7 @@ class BatDetect2API:
model_config: ModelConfig | None = None, model_config: ModelConfig | None = None,
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
train_config: TrainingConfig | None = None, train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection.""" """Fine-tune the model with trainable-parameter selection."""
@ -175,6 +184,7 @@ class BatDetect2API:
seed=seed, seed=seed,
audio_config=audio_config or self.audio_config, audio_config=audio_config or self.audio_config,
train_config=train_config or self.train_config, train_config=train_config or self.train_config,
logger_config=logger_config or self.logging_config.train,
) )
return self return self
@ -189,6 +199,7 @@ class BatDetect2API:
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None, evaluation_config: EvaluationConfig | None = None,
outputs_config: OutputsConfig | None = None, outputs_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]: ) -> tuple[dict[str, float], list[ClipDetections]]:
return run_evaluate( return run_evaluate(
self.model, self.model,
@ -199,6 +210,7 @@ class BatDetect2API:
audio_config=audio_config or self.audio_config, audio_config=audio_config or self.audio_config,
evaluation_config=evaluation_config or self.evaluation_config, evaluation_config=evaluation_config or self.evaluation_config,
output_config=outputs_config or self.outputs_config, output_config=outputs_config or self.outputs_config,
logger_config=logger_config or self.logging_config.evaluation,
num_workers=num_workers, num_workers=num_workers,
output_dir=output_dir, output_dir=output_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
@ -486,6 +498,7 @@ class BatDetect2API:
evaluation_config=config.evaluation, evaluation_config=config.evaluation,
inference_config=config.inference, inference_config=config.inference,
outputs_config=config.outputs, outputs_config=config.outputs,
logging_config=config.logging,
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,
@ -506,6 +519,7 @@ class BatDetect2API:
evaluation_config: EvaluationConfig | None = None, evaluation_config: EvaluationConfig | None = None,
inference_config: InferenceConfig | None = None, inference_config: InferenceConfig | None = None,
outputs_config: OutputsConfig | None = None, outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API": ) -> "BatDetect2API":
model, model_config = load_model_from_checkpoint(path) model, model_config = load_model_from_checkpoint(path)
@ -516,6 +530,7 @@ class BatDetect2API:
evaluation_config = evaluation_config or EvaluationConfig() evaluation_config = evaluation_config or EvaluationConfig()
inference_config = inference_config or InferenceConfig() inference_config = inference_config or InferenceConfig()
outputs_config = outputs_config or OutputsConfig() outputs_config = outputs_config or OutputsConfig()
logging_config = logging_config or AppLoggingConfig()
if ( if (
targets_config is not None targets_config is not None
@ -571,6 +586,7 @@ class BatDetect2API:
evaluation_config=evaluation_config, evaluation_config=evaluation_config,
inference_config=inference_config, inference_config=inference_config,
outputs_config=outputs_config, outputs_config=outputs_config,
logging_config=logging_config,
targets=targets, targets=targets,
audio_loader=audio_loader, audio_loader=audio_loader,
preprocessor=preprocessor, preprocessor=preprocessor,

View File

@ -19,6 +19,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
@click.option("--evaluation-config", type=click.Path(exists=True)) @click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True)) @click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True)) @click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--base-dir", type=click.Path(), default=Path.cwd()) @click.option("--base-dir", type=click.Path(), default=Path.cwd())
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR) @click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
@click.option("--experiment-name", type=str) @click.option("--experiment-name", type=str)
@ -33,6 +34,7 @@ def evaluate_command(
evaluation_config: Path | None, evaluation_config: Path | None,
inference_config: Path | None, inference_config: Path | None,
outputs_config: Path | None, outputs_config: Path | None,
logging_config: Path | None,
output_dir: Path = DEFAULT_OUTPUT_DIR, output_dir: Path = DEFAULT_OUTPUT_DIR,
num_workers: int = 0, num_workers: int = 0,
experiment_name: str | None = None, experiment_name: str | None = None,
@ -43,6 +45,7 @@ def evaluate_command(
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig from batdetect2.inference import InferenceConfig
from batdetect2.logging import load_logging_config
from batdetect2.outputs import OutputsConfig from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config from batdetect2.targets import load_target_config
@ -81,6 +84,11 @@ def evaluate_command(
if outputs_config is not None if outputs_config is not None
else None else None
) )
logging_conf = (
load_logging_config(logging_config)
if logging_config is not None
else None
)
api = BatDetect2API.from_checkpoint( api = BatDetect2API.from_checkpoint(
model_path, model_path,
@ -89,6 +97,7 @@ def evaluate_command(
evaluation_config=eval_conf, evaluation_config=eval_conf,
inference_config=inference_conf, inference_config=inference_conf,
outputs_config=outputs_conf, outputs_config=outputs_conf,
logging_config=logging_conf,
) )
api.evaluate( api.evaluate(

View File

@ -19,6 +19,7 @@ __all__ = ["train_command"]
@click.option("--evaluation-config", type=click.Path(exists=True)) @click.option("--evaluation-config", type=click.Path(exists=True))
@click.option("--inference-config", type=click.Path(exists=True)) @click.option("--inference-config", type=click.Path(exists=True))
@click.option("--outputs-config", type=click.Path(exists=True)) @click.option("--outputs-config", type=click.Path(exists=True))
@click.option("--logging-config", type=click.Path(exists=True))
@click.option("--ckpt-dir", type=click.Path(exists=True)) @click.option("--ckpt-dir", type=click.Path(exists=True))
@click.option("--log-dir", type=click.Path(exists=True)) @click.option("--log-dir", type=click.Path(exists=True))
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int)
@ -40,6 +41,7 @@ def train_command(
evaluation_config: Path | None = None, evaluation_config: Path | None = None,
inference_config: Path | None = None, inference_config: Path | None = None,
outputs_config: Path | None = None, outputs_config: Path | None = None,
logging_config: Path | None = None,
seed: int | None = None, seed: int | None = None,
num_epochs: int | None = None, num_epochs: int | None = None,
train_workers: int = 0, train_workers: int = 0,
@ -53,6 +55,7 @@ def train_command(
from batdetect2.data import load_dataset_from_config from batdetect2.data import load_dataset_from_config
from batdetect2.evaluate import load_evaluation_config from batdetect2.evaluate import load_evaluation_config
from batdetect2.inference import InferenceConfig from batdetect2.inference import InferenceConfig
from batdetect2.logging import load_logging_config
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig from batdetect2.outputs import OutputsConfig
from batdetect2.targets import load_target_config from batdetect2.targets import load_target_config
@ -92,6 +95,11 @@ def train_command(
if outputs_config is not None if outputs_config is not None
else None else None
) )
logging_conf = (
load_logging_config(logging_config)
if logging_config is not None
else None
)
if target_conf is not None: if target_conf is not None:
logger.info("Loaded targets configuration.") logger.info("Loaded targets configuration.")
@ -141,6 +149,8 @@ def train_command(
conf.inference = inference_conf conf.inference = inference_conf
if outputs_conf is not None: if outputs_conf is not None:
conf.outputs = outputs_conf conf.outputs = outputs_conf
if logging_conf is not None:
conf.logging = logging_conf
api = BatDetect2API.from_config(conf) api = BatDetect2API.from_config(conf)
else: else:
@ -152,6 +162,7 @@ def train_command(
evaluation_config=eval_conf, evaluation_config=eval_conf,
inference_config=inference_conf, inference_config=inference_conf,
outputs_config=outputs_conf, outputs_config=outputs_conf,
logging_config=logging_conf,
) )
return api.train( return api.train(

View File

@ -10,6 +10,7 @@ from batdetect2.evaluate.config import (
get_default_eval_config, get_default_eval_config,
) )
from batdetect2.inference.config import InferenceConfig from batdetect2.inference.config import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import ModelConfig from batdetect2.models import ModelConfig
from batdetect2.outputs import OutputsConfig from batdetect2.outputs import OutputsConfig
from batdetect2.train.config import TrainingConfig from batdetect2.train.config import TrainingConfig
@ -32,6 +33,7 @@ class BatDetect2Config(BaseConfig):
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig) inference: InferenceConfig = Field(default_factory=InferenceConfig)
outputs: OutputsConfig = Field(default_factory=OutputsConfig) outputs: OutputsConfig = Field(default_factory=OutputsConfig)
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
def validate_config(config: dict | None) -> BatDetect2Config: def validate_config(config: dict | None) -> BatDetect2Config:

View File

@ -7,7 +7,6 @@ from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.tasks import TaskConfig from batdetect2.evaluate.tasks import TaskConfig
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
__all__ = [ __all__ = [
"EvaluationConfig", "EvaluationConfig",
@ -22,7 +21,6 @@ class EvaluationConfig(BaseConfig):
ClassificationTaskConfig(), ClassificationTaskConfig(),
] ]
) )
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def get_default_eval_config() -> EvaluationConfig: def get_default_eval_config() -> EvaluationConfig:

View File

@ -10,7 +10,7 @@ from batdetect2.evaluate import EvaluationConfig
from batdetect2.evaluate.dataset import build_test_loader from batdetect2.evaluate.dataset import build_test_loader
from batdetect2.evaluate.evaluator import build_evaluator from batdetect2.evaluate.evaluator import build_evaluator
from batdetect2.evaluate.lightning import EvaluationModule from batdetect2.evaluate.lightning import EvaluationModule
from batdetect2.logging import build_logger from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.outputs import OutputsConfig, build_output_transform from batdetect2.outputs import OutputsConfig, build_output_transform
from batdetect2.outputs.types import OutputFormatterProtocol from batdetect2.outputs.types import OutputFormatterProtocol
@ -30,6 +30,7 @@ def run_evaluate(
audio_config: AudioConfig | None = None, audio_config: AudioConfig | None = None,
evaluation_config: EvaluationConfig | None = None, evaluation_config: EvaluationConfig | None = None,
output_config: OutputsConfig | None = None, output_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
formatter: OutputFormatterProtocol | None = None, formatter: OutputFormatterProtocol | None = None,
num_workers: int = 0, num_workers: int = 0,
output_dir: data.PathLike = DEFAULT_EVAL_DIR, output_dir: data.PathLike = DEFAULT_EVAL_DIR,
@ -64,7 +65,7 @@ def run_evaluate(
) )
logger = build_logger( logger = build_logger(
evaluation_config.logger, logger_config or CSVLoggerConfig(),
log_dir=Path(output_dir), log_dir=Path(output_dir),
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,

View File

@ -26,10 +26,26 @@ from matplotlib.figure import Figure
from pydantic import Field from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.core.configs import BaseConfig from batdetect2.core.configs import BaseConfig, load_config
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs" DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
__all__ = [
"AppLoggingConfig",
"BaseLoggerConfig",
"CSVLoggerConfig",
"DEFAULT_LOGS_DIR",
"DVCLiveConfig",
"LoggerConfig",
"MLFlowLoggerConfig",
"TensorBoardLoggerConfig",
"build_logger",
"enable_logging",
"get_image_logger",
"get_table_logger",
"load_logging_config",
]
def enable_logging(level: int): def enable_logging(level: int):
logger.remove() logger.remove()
@ -84,6 +100,19 @@ LoggerConfig = Annotated[
] ]
class AppLoggingConfig(BaseConfig):
train: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
evaluation: LoggerConfig = Field(default_factory=CSVLoggerConfig)
inference: LoggerConfig = Field(default_factory=CSVLoggerConfig)
def load_logging_config(
path: data.PathLike,
field: str | None = None,
) -> AppLoggingConfig:
return load_config(path, schema=AppLoggingConfig, field=field)
T = TypeVar("T", bound=LoggerConfig, contravariant=True) T = TypeVar("T", bound=LoggerConfig, contravariant=True)

View File

@ -3,7 +3,6 @@ from soundevent import data
from batdetect2.core.configs import BaseConfig, load_config from batdetect2.core.configs import BaseConfig, load_config
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
from batdetect2.train.checkpoints import CheckpointConfig from batdetect2.train.checkpoints import CheckpointConfig
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
from batdetect2.train.labels import LabelConfig from batdetect2.train.labels import LabelConfig
@ -50,7 +49,6 @@ class TrainingConfig(BaseConfig):
) )
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig) trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
labels: LabelConfig = Field(default_factory=LabelConfig) labels: LabelConfig = Field(default_factory=LabelConfig)
validation: EvaluationConfig = Field(default_factory=EvaluationConfig) validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig) checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)

View File

@ -10,7 +10,11 @@ from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.audio.types import AudioLoader from batdetect2.audio.types import AudioLoader
from batdetect2.evaluate import build_evaluator from batdetect2.evaluate import build_evaluator
from batdetect2.evaluate.types import EvaluatorProtocol from batdetect2.evaluate.types import EvaluatorProtocol
from batdetect2.logging import build_logger from batdetect2.logging import (
LoggerConfig,
TensorBoardLoggerConfig,
build_logger,
)
from batdetect2.models import Model, ModelConfig, build_model from batdetect2.models import Model, ModelConfig, build_model
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.preprocess.types import PreprocessorProtocol
@ -41,6 +45,7 @@ def run_train(
audio_config: Optional[AudioConfig] = None, audio_config: Optional[AudioConfig] = None,
model_config: Optional[ModelConfig] = None, model_config: Optional[ModelConfig] = None,
train_config: Optional[TrainingConfig] = None, train_config: Optional[TrainingConfig] = None,
logger_config: LoggerConfig | None = None,
trainer: Trainer | None = None, trainer: Trainer | None = None,
train_workers: int = 0, train_workers: int = 0,
val_workers: int = 0, val_workers: int = 0,
@ -113,6 +118,7 @@ def run_train(
trainer = trainer or build_trainer( trainer = trainer or build_trainer(
train_config, train_config,
logger_config=logger_config,
evaluator=build_evaluator( evaluator=build_evaluator(
train_config.validation, train_config.validation,
targets=targets, targets=targets,
@ -180,6 +186,7 @@ def _validate_model_compatibility(
def build_trainer( def build_trainer(
config: TrainingConfig, config: TrainingConfig,
logger_config: LoggerConfig | None,
evaluator: "EvaluatorProtocol", evaluator: "EvaluatorProtocol",
checkpoint_dir: Path | None = None, checkpoint_dir: Path | None = None,
log_dir: Path | None = None, log_dir: Path | None = None,
@ -194,7 +201,7 @@ def build_trainer(
) )
train_logger = build_logger( train_logger = build_logger(
config.logger, logger_config or TensorBoardLoggerConfig(),
log_dir=log_dir, log_dir=log_dir,
experiment_name=experiment_name, experiment_name=experiment_name,
run_name=run_name, run_name=run_name,