mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-04-04 15:20:19 +02:00
Move the logging config out of the train/eval configs
This commit is contained in:
parent
bf5b88016a
commit
22a3d18d45
@ -22,7 +22,11 @@ from batdetect2.inference import (
|
||||
process_file_list,
|
||||
run_batch_inference,
|
||||
)
|
||||
from batdetect2.logging import DEFAULT_LOGS_DIR
|
||||
from batdetect2.logging import (
|
||||
DEFAULT_LOGS_DIR,
|
||||
AppLoggingConfig,
|
||||
LoggerConfig,
|
||||
)
|
||||
from batdetect2.models import (
|
||||
Model,
|
||||
ModelConfig,
|
||||
@ -64,6 +68,7 @@ class BatDetect2API:
|
||||
evaluation_config: EvaluationConfig,
|
||||
inference_config: InferenceConfig,
|
||||
outputs_config: OutputsConfig,
|
||||
logging_config: AppLoggingConfig,
|
||||
targets: TargetProtocol,
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
@ -79,6 +84,7 @@ class BatDetect2API:
|
||||
self.evaluation_config = evaluation_config
|
||||
self.inference_config = inference_config
|
||||
self.outputs_config = outputs_config
|
||||
self.logging_config = logging_config
|
||||
self.targets = targets
|
||||
self.audio_loader = audio_loader
|
||||
self.preprocessor = preprocessor
|
||||
@ -112,6 +118,7 @@ class BatDetect2API:
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
):
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
@ -131,6 +138,7 @@ class BatDetect2API:
|
||||
seed=seed,
|
||||
train_config=train_config or self.train_config,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
logger_config=logger_config or self.logging_config.train,
|
||||
)
|
||||
return self
|
||||
|
||||
@ -152,6 +160,7 @@ class BatDetect2API:
|
||||
model_config: ModelConfig | None = None,
|
||||
audio_config: AudioConfig | None = None,
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune the model with trainable-parameter selection."""
|
||||
|
||||
@ -175,6 +184,7 @@ class BatDetect2API:
|
||||
seed=seed,
|
||||
audio_config=audio_config or self.audio_config,
|
||||
train_config=train_config or self.train_config,
|
||||
logger_config=logger_config or self.logging_config.train,
|
||||
)
|
||||
return self
|
||||
|
||||
@ -189,6 +199,7 @@ class BatDetect2API:
|
||||
audio_config: AudioConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
return run_evaluate(
|
||||
self.model,
|
||||
@ -199,6 +210,7 @@ class BatDetect2API:
|
||||
audio_config=audio_config or self.audio_config,
|
||||
evaluation_config=evaluation_config or self.evaluation_config,
|
||||
output_config=outputs_config or self.outputs_config,
|
||||
logger_config=logger_config or self.logging_config.evaluation,
|
||||
num_workers=num_workers,
|
||||
output_dir=output_dir,
|
||||
experiment_name=experiment_name,
|
||||
@ -486,6 +498,7 @@ class BatDetect2API:
|
||||
evaluation_config=config.evaluation,
|
||||
inference_config=config.inference,
|
||||
outputs_config=config.outputs,
|
||||
logging_config=config.logging,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
@ -506,6 +519,7 @@ class BatDetect2API:
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
inference_config: InferenceConfig | None = None,
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logging_config: AppLoggingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
|
||||
@ -516,6 +530,7 @@ class BatDetect2API:
|
||||
evaluation_config = evaluation_config or EvaluationConfig()
|
||||
inference_config = inference_config or InferenceConfig()
|
||||
outputs_config = outputs_config or OutputsConfig()
|
||||
logging_config = logging_config or AppLoggingConfig()
|
||||
|
||||
if (
|
||||
targets_config is not None
|
||||
@ -571,6 +586,7 @@ class BatDetect2API:
|
||||
evaluation_config=evaluation_config,
|
||||
inference_config=inference_config,
|
||||
outputs_config=outputs_config,
|
||||
logging_config=logging_config,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
|
||||
@ -19,6 +19,7 @@ DEFAULT_OUTPUT_DIR = Path("outputs") / "evaluation"
|
||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--logging-config", type=click.Path(exists=True))
|
||||
@click.option("--base-dir", type=click.Path(), default=Path.cwd())
|
||||
@click.option("--output-dir", type=click.Path(), default=DEFAULT_OUTPUT_DIR)
|
||||
@click.option("--experiment-name", type=str)
|
||||
@ -33,6 +34,7 @@ def evaluate_command(
|
||||
evaluation_config: Path | None,
|
||||
inference_config: Path | None,
|
||||
outputs_config: Path | None,
|
||||
logging_config: Path | None,
|
||||
output_dir: Path = DEFAULT_OUTPUT_DIR,
|
||||
num_workers: int = 0,
|
||||
experiment_name: str | None = None,
|
||||
@ -43,6 +45,7 @@ def evaluate_command(
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate import load_evaluation_config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import load_logging_config
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import load_target_config
|
||||
|
||||
@ -81,6 +84,11 @@ def evaluate_command(
|
||||
if outputs_config is not None
|
||||
else None
|
||||
)
|
||||
logging_conf = (
|
||||
load_logging_config(logging_config)
|
||||
if logging_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
api = BatDetect2API.from_checkpoint(
|
||||
model_path,
|
||||
@ -89,6 +97,7 @@ def evaluate_command(
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
logging_config=logging_conf,
|
||||
)
|
||||
|
||||
api.evaluate(
|
||||
|
||||
@ -19,6 +19,7 @@ __all__ = ["train_command"]
|
||||
@click.option("--evaluation-config", type=click.Path(exists=True))
|
||||
@click.option("--inference-config", type=click.Path(exists=True))
|
||||
@click.option("--outputs-config", type=click.Path(exists=True))
|
||||
@click.option("--logging-config", type=click.Path(exists=True))
|
||||
@click.option("--ckpt-dir", type=click.Path(exists=True))
|
||||
@click.option("--log-dir", type=click.Path(exists=True))
|
||||
@click.option("--train-workers", type=int)
|
||||
@ -40,6 +41,7 @@ def train_command(
|
||||
evaluation_config: Path | None = None,
|
||||
inference_config: Path | None = None,
|
||||
outputs_config: Path | None = None,
|
||||
logging_config: Path | None = None,
|
||||
seed: int | None = None,
|
||||
num_epochs: int | None = None,
|
||||
train_workers: int = 0,
|
||||
@ -53,6 +55,7 @@ def train_command(
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
from batdetect2.evaluate import load_evaluation_config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import load_logging_config
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.targets import load_target_config
|
||||
@ -92,6 +95,11 @@ def train_command(
|
||||
if outputs_config is not None
|
||||
else None
|
||||
)
|
||||
logging_conf = (
|
||||
load_logging_config(logging_config)
|
||||
if logging_config is not None
|
||||
else None
|
||||
)
|
||||
|
||||
if target_conf is not None:
|
||||
logger.info("Loaded targets configuration.")
|
||||
@ -141,6 +149,8 @@ def train_command(
|
||||
conf.inference = inference_conf
|
||||
if outputs_conf is not None:
|
||||
conf.outputs = outputs_conf
|
||||
if logging_conf is not None:
|
||||
conf.logging = logging_conf
|
||||
|
||||
api = BatDetect2API.from_config(conf)
|
||||
else:
|
||||
@ -152,6 +162,7 @@ def train_command(
|
||||
evaluation_config=eval_conf,
|
||||
inference_config=inference_conf,
|
||||
outputs_config=outputs_conf,
|
||||
logging_config=logging_conf,
|
||||
)
|
||||
|
||||
return api.train(
|
||||
|
||||
@ -10,6 +10,7 @@ from batdetect2.evaluate.config import (
|
||||
get_default_eval_config,
|
||||
)
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import ModelConfig
|
||||
from batdetect2.outputs import OutputsConfig
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
@ -32,6 +33,7 @@ class BatDetect2Config(BaseConfig):
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
outputs: OutputsConfig = Field(default_factory=OutputsConfig)
|
||||
logging: AppLoggingConfig = Field(default_factory=AppLoggingConfig)
|
||||
|
||||
|
||||
def validate_config(config: dict | None) -> BatDetect2Config:
|
||||
|
||||
@ -7,7 +7,6 @@ from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.tasks import TaskConfig
|
||||
from batdetect2.evaluate.tasks.classification import ClassificationTaskConfig
|
||||
from batdetect2.evaluate.tasks.detection import DetectionTaskConfig
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig
|
||||
|
||||
__all__ = [
|
||||
"EvaluationConfig",
|
||||
@ -22,7 +21,6 @@ class EvaluationConfig(BaseConfig):
|
||||
ClassificationTaskConfig(),
|
||||
]
|
||||
)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def get_default_eval_config() -> EvaluationConfig:
|
||||
|
||||
@ -10,7 +10,7 @@ from batdetect2.evaluate import EvaluationConfig
|
||||
from batdetect2.evaluate.dataset import build_test_loader
|
||||
from batdetect2.evaluate.evaluator import build_evaluator
|
||||
from batdetect2.evaluate.lightning import EvaluationModule
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.logging import CSVLoggerConfig, LoggerConfig, build_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.outputs import OutputsConfig, build_output_transform
|
||||
from batdetect2.outputs.types import OutputFormatterProtocol
|
||||
@ -30,6 +30,7 @@ def run_evaluate(
|
||||
audio_config: AudioConfig | None = None,
|
||||
evaluation_config: EvaluationConfig | None = None,
|
||||
output_config: OutputsConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
formatter: OutputFormatterProtocol | None = None,
|
||||
num_workers: int = 0,
|
||||
output_dir: data.PathLike = DEFAULT_EVAL_DIR,
|
||||
@ -64,7 +65,7 @@ def run_evaluate(
|
||||
)
|
||||
|
||||
logger = build_logger(
|
||||
evaluation_config.logger,
|
||||
logger_config or CSVLoggerConfig(),
|
||||
log_dir=Path(output_dir),
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
|
||||
@ -26,10 +26,26 @@ from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
|
||||
__all__ = [
|
||||
"AppLoggingConfig",
|
||||
"BaseLoggerConfig",
|
||||
"CSVLoggerConfig",
|
||||
"DEFAULT_LOGS_DIR",
|
||||
"DVCLiveConfig",
|
||||
"LoggerConfig",
|
||||
"MLFlowLoggerConfig",
|
||||
"TensorBoardLoggerConfig",
|
||||
"build_logger",
|
||||
"enable_logging",
|
||||
"get_image_logger",
|
||||
"get_table_logger",
|
||||
"load_logging_config",
|
||||
]
|
||||
|
||||
|
||||
def enable_logging(level: int):
|
||||
logger.remove()
|
||||
@ -84,6 +100,19 @@ LoggerConfig = Annotated[
|
||||
]
|
||||
|
||||
|
||||
class AppLoggingConfig(BaseConfig):
|
||||
train: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
evaluation: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
inference: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
def load_logging_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
) -> AppLoggingConfig:
|
||||
return load_config(path, schema=AppLoggingConfig, field=field)
|
||||
|
||||
|
||||
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,6 @@ from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.logging import LoggerConfig, TensorBoardLoggerConfig
|
||||
from batdetect2.train.checkpoints import CheckpointConfig
|
||||
from batdetect2.train.dataset import TrainLoaderConfig, ValLoaderConfig
|
||||
from batdetect2.train.labels import LabelConfig
|
||||
@ -50,7 +49,6 @@ class TrainingConfig(BaseConfig):
|
||||
)
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=TensorBoardLoggerConfig)
|
||||
labels: LabelConfig = Field(default_factory=LabelConfig)
|
||||
validation: EvaluationConfig = Field(default_factory=EvaluationConfig)
|
||||
checkpoints: CheckpointConfig = Field(default_factory=CheckpointConfig)
|
||||
|
||||
@ -10,7 +10,11 @@ from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.audio.types import AudioLoader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.evaluate.types import EvaluatorProtocol
|
||||
from batdetect2.logging import build_logger
|
||||
from batdetect2.logging import (
|
||||
LoggerConfig,
|
||||
TensorBoardLoggerConfig,
|
||||
build_logger,
|
||||
)
|
||||
from batdetect2.models import Model, ModelConfig, build_model
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
@ -41,6 +45,7 @@ def run_train(
|
||||
audio_config: Optional[AudioConfig] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
train_config: Optional[TrainingConfig] = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
trainer: Trainer | None = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
@ -113,6 +118,7 @@ def run_train(
|
||||
|
||||
trainer = trainer or build_trainer(
|
||||
train_config,
|
||||
logger_config=logger_config,
|
||||
evaluator=build_evaluator(
|
||||
train_config.validation,
|
||||
targets=targets,
|
||||
@ -180,6 +186,7 @@ def _validate_model_compatibility(
|
||||
|
||||
def build_trainer(
|
||||
config: TrainingConfig,
|
||||
logger_config: LoggerConfig | None,
|
||||
evaluator: "EvaluatorProtocol",
|
||||
checkpoint_dir: Path | None = None,
|
||||
log_dir: Path | None = None,
|
||||
@ -194,7 +201,7 @@ def build_trainer(
|
||||
)
|
||||
|
||||
train_logger = build_logger(
|
||||
config.logger,
|
||||
logger_config or TensorBoardLoggerConfig(),
|
||||
log_dir=log_dir,
|
||||
experiment_name=experiment_name,
|
||||
run_name=run_name,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user