mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 22:51:58 +02:00
Add logging config
This commit is contained in:
parent
a462beaeb8
commit
136949c4e7
@ -9,6 +9,7 @@ from batdetect2.train.augmentations import (
|
|||||||
AugmentationsConfig,
|
AugmentationsConfig,
|
||||||
)
|
)
|
||||||
from batdetect2.train.clips import ClipingConfig
|
from batdetect2.train.clips import ClipingConfig
|
||||||
|
from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -59,6 +60,8 @@ class TrainingConfig(BaseConfig):
|
|||||||
|
|
||||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
||||||
|
|
||||||
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_train_config(
|
def load_train_config(
|
||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
|
104
src/batdetect2/train/logging.py
Normal file
104
src/batdetect2/train/logging.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from typing import Annotated, Literal, Optional, Union
|
||||||
|
|
||||||
|
from lightning.pytorch.loggers import Logger
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent.data import PathLike
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
|
|
||||||
|
DEFAULT_LOGS_DIR: str = "logs"
|
||||||
|
|
||||||
|
|
||||||
|
class DVCLiveConfig(BaseConfig):
|
||||||
|
logger_type: Literal["dvclive"] = "dvclive"
|
||||||
|
dir: str = DEFAULT_LOGS_DIR
|
||||||
|
run_name: Optional[str] = None
|
||||||
|
prefix: str = ""
|
||||||
|
log_model: Union[bool, Literal["all"]] = False
|
||||||
|
monitor_system: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class CSVLoggerConfig(BaseConfig):
|
||||||
|
logger_type: Literal["csv"] = "csv"
|
||||||
|
save_dir: str = DEFAULT_LOGS_DIR
|
||||||
|
name: Optional[str] = "logs"
|
||||||
|
version: Optional[str] = None
|
||||||
|
flush_logs_every_n_steps: int = 100
|
||||||
|
|
||||||
|
|
||||||
|
class TensorBoardLoggerConfig(BaseConfig):
|
||||||
|
logger_type: Literal["tensorboard"] = "tensorboard"
|
||||||
|
save_dir: str = DEFAULT_LOGS_DIR
|
||||||
|
name: Optional[str] = "default"
|
||||||
|
version: Optional[str] = None
|
||||||
|
log_graph: bool = False
|
||||||
|
flush_logs_every_n_steps: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
LoggerConfig = Annotated[
|
||||||
|
Union[DVCLiveConfig, CSVLoggerConfig, TensorBoardLoggerConfig],
|
||||||
|
Field(discriminator="logger_type"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_dvclive_logger(config: DVCLiveConfig) -> Logger:
|
||||||
|
try:
|
||||||
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
|
except ImportError as error:
|
||||||
|
raise ValueError(
|
||||||
|
"DVCLive is not installed and cannot be used for logging"
|
||||||
|
"Make sure you have it installed by running `pip install dvclive`"
|
||||||
|
"or `uv add dvclive`"
|
||||||
|
) from error
|
||||||
|
|
||||||
|
return DVCLiveLogger(
|
||||||
|
dir=config.dir,
|
||||||
|
run_name=config.run_name,
|
||||||
|
prefix=config.prefix,
|
||||||
|
log_model=config.log_model,
|
||||||
|
monitor_system=config.monitor_system,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_csv_logger(config: CSVLoggerConfig) -> Logger:
|
||||||
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
|
return CSVLogger(
|
||||||
|
save_dir=config.save_dir,
|
||||||
|
name=config.name,
|
||||||
|
version=config.version,
|
||||||
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
||||||
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
return TensorBoardLogger(
|
||||||
|
save_dir=config.save_dir,
|
||||||
|
name=config.name,
|
||||||
|
version=config.version,
|
||||||
|
log_graph=config.log_graph,
|
||||||
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
LOGGER_FACTORY = {
|
||||||
|
"dvclive": create_dvclive_logger,
|
||||||
|
"csv": create_csv_logger,
|
||||||
|
"tensorboard": create_tensorboard_logger,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def build_logger(config: LoggerConfig) -> Logger:
|
||||||
|
"""
|
||||||
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
|
"""
|
||||||
|
logger_type = config.logger_type
|
||||||
|
|
||||||
|
if logger_type not in LOGGER_FACTORY:
|
||||||
|
raise ValueError(f"Unknown logger type: {logger_type}")
|
||||||
|
|
||||||
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
|
return creation_func(config)
|
@ -23,6 +23,7 @@ from batdetect2.train.dataset import (
|
|||||||
collate_fn,
|
collate_fn,
|
||||||
)
|
)
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import TrainingModule
|
||||||
|
from batdetect2.train.logging import build_logger
|
||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -47,6 +48,7 @@ def train(
|
|||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is None:
|
||||||
if preprocessor is None:
|
if preprocessor is None:
|
||||||
preprocessor = build_preprocessor()
|
preprocessor = build_preprocessor()
|
||||||
@ -81,9 +83,12 @@ def train(
|
|||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = build_logger(config.logger)
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
**config.trainer.model_dump(exclude_none=True),
|
**config.trainer.model_dump(exclude_none=True, exclude={"logger"}),
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
|
logger=logger,
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user