mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
3 Commits
cd4955d4f3
...
115084fd2b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
115084fd2b | ||
|
|
951dc59718 | ||
|
|
3376be06a4 |
@ -109,7 +109,7 @@ train:
|
|||||||
sigma: 3
|
sigma: 3
|
||||||
|
|
||||||
trainer:
|
trainer:
|
||||||
max_epochs: 40
|
max_epochs: 5
|
||||||
|
|
||||||
dataloaders:
|
dataloaders:
|
||||||
train:
|
train:
|
||||||
@ -136,9 +136,9 @@ train:
|
|||||||
weight: 0.1
|
weight: 0.1
|
||||||
|
|
||||||
logger:
|
logger:
|
||||||
logger_type: csv
|
name: mlflow
|
||||||
# save_dir: outputs/log/
|
tracking_uri: http://10.20.20.211:9000
|
||||||
# name: logs
|
log_model: true
|
||||||
|
|
||||||
augmentations:
|
augmentations:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|||||||
@ -23,7 +23,7 @@ dependencies = [
|
|||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
"cf-xarray>=0.9.0",
|
"cf-xarray>=0.9.0",
|
||||||
"onnx>=1.16.0",
|
"onnx>=1.16.0",
|
||||||
"lightning[extra]>=2.2.2",
|
"lightning[extra]==2.5.0",
|
||||||
"tensorboard>=2.16.2",
|
"tensorboard>=2.16.2",
|
||||||
"omegaconf>=2.3.0",
|
"omegaconf>=2.3.0",
|
||||||
"pyyaml>=6.0.2",
|
"pyyaml>=6.0.2",
|
||||||
|
|||||||
@ -26,6 +26,9 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int)
|
||||||
|
@click.option("--experiment-name", type=str)
|
||||||
|
@click.option("--run-name", type=str)
|
||||||
|
@click.option("--seed", type=int)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
@ -40,8 +43,11 @@ def train_command(
|
|||||||
log_dir: Optional[Path] = None,
|
log_dir: Optional[Path] = None,
|
||||||
config: Optional[Path] = None,
|
config: Optional[Path] = None,
|
||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
verbose: int = 0,
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
logger.remove()
|
logger.remove()
|
||||||
@ -87,6 +93,9 @@ def train_command(
|
|||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
|
experiment_name=experiment_name,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
checkpoint_dir=ckpt_dir,
|
checkpoint_dir=ckpt_dir,
|
||||||
|
seed=seed,
|
||||||
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -68,7 +68,7 @@ class ValidationMetrics(Callback):
|
|||||||
n_examples=4,
|
n_examples=4,
|
||||||
):
|
):
|
||||||
plotter(
|
plotter(
|
||||||
f"images/{class_name}_examples",
|
f"examples/{class_name}",
|
||||||
fig,
|
fig,
|
||||||
pl_module.global_step,
|
pl_module.global_step,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,5 +1,16 @@
|
|||||||
import io
|
import io
|
||||||
from typing import Annotated, Any, Literal, Optional, Union
|
from pathlib import Path
|
||||||
|
from typing import (
|
||||||
|
Annotated,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
from lightning.pytorch.loggers import Logger, MLFlowLogger, TensorBoardLogger
|
||||||
@ -9,39 +20,34 @@ from soundevent import data
|
|||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: str = "outputs"
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
|
|
||||||
|
|
||||||
class DVCLiveConfig(BaseConfig):
|
class BaseLoggerConfig(BaseConfig):
|
||||||
logger_type: Literal["dvclive"] = "dvclive"
|
log_dir: Path = DEFAULT_LOGS_DIR
|
||||||
dir: str = DEFAULT_LOGS_DIR
|
experiment_name: Optional[str] = None
|
||||||
run_name: Optional[str] = None
|
run_name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DVCLiveConfig(BaseLoggerConfig):
|
||||||
|
name: Literal["dvclive"] = "dvclive"
|
||||||
prefix: str = ""
|
prefix: str = ""
|
||||||
log_model: Union[bool, Literal["all"]] = False
|
log_model: Union[bool, Literal["all"]] = False
|
||||||
monitor_system: bool = False
|
monitor_system: bool = False
|
||||||
|
|
||||||
|
|
||||||
class CSVLoggerConfig(BaseConfig):
|
class CSVLoggerConfig(BaseLoggerConfig):
|
||||||
logger_type: Literal["csv"] = "csv"
|
name: Literal["csv"] = "csv"
|
||||||
save_dir: str = DEFAULT_LOGS_DIR
|
|
||||||
name: Optional[str] = "logs"
|
|
||||||
version: Optional[str] = None
|
|
||||||
flush_logs_every_n_steps: int = 100
|
flush_logs_every_n_steps: int = 100
|
||||||
|
|
||||||
|
|
||||||
class TensorBoardLoggerConfig(BaseConfig):
|
class TensorBoardLoggerConfig(BaseLoggerConfig):
|
||||||
logger_type: Literal["tensorboard"] = "tensorboard"
|
name: Literal["tensorboard"] = "tensorboard"
|
||||||
save_dir: str = DEFAULT_LOGS_DIR
|
|
||||||
name: Optional[str] = "logs"
|
|
||||||
version: Optional[str] = None
|
|
||||||
log_graph: bool = False
|
log_graph: bool = False
|
||||||
|
|
||||||
|
|
||||||
class MLFlowLoggerConfig(BaseConfig):
|
class MLFlowLoggerConfig(BaseLoggerConfig):
|
||||||
logger_type: Literal["mlflow"] = "mlflow"
|
name: Literal["mlflow"] = "mlflow"
|
||||||
experiment_name: str = "default"
|
|
||||||
run_name: Optional[str] = None
|
|
||||||
save_dir: Optional[str] = "./mlruns"
|
|
||||||
tracking_uri: Optional[str] = None
|
tracking_uri: Optional[str] = None
|
||||||
tags: Optional[dict[str, Any]] = None
|
tags: Optional[dict[str, Any]] = None
|
||||||
log_model: bool = False
|
log_model: bool = False
|
||||||
@ -54,13 +60,28 @@ LoggerConfig = Annotated[
|
|||||||
TensorBoardLoggerConfig,
|
TensorBoardLoggerConfig,
|
||||||
MLFlowLoggerConfig,
|
MLFlowLoggerConfig,
|
||||||
],
|
],
|
||||||
Field(discriminator="logger_type"),
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=LoggerConfig, contravariant=True)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggerBuilder(Protocol, Generic[T]):
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
config: T,
|
||||||
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
) -> Logger: ...
|
||||||
|
|
||||||
|
|
||||||
def create_dvclive_logger(
|
def create_dvclive_logger(
|
||||||
config: DVCLiveConfig,
|
config: DVCLiveConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from dvclive.lightning import DVCLiveLogger # type: ignore
|
from dvclive.lightning import DVCLiveLogger # type: ignore
|
||||||
@ -72,8 +93,11 @@ def create_dvclive_logger(
|
|||||||
) from error
|
) from error
|
||||||
|
|
||||||
return DVCLiveLogger(
|
return DVCLiveLogger(
|
||||||
dir=log_dir if log_dir is not None else config.dir,
|
dir=log_dir if log_dir is not None else config.log_dir,
|
||||||
run_name=config.run_name,
|
run_name=run_name if run_name is not None else config.run_name,
|
||||||
|
experiment=experiment_name
|
||||||
|
if experiment_name is not None
|
||||||
|
else config.experiment_name,
|
||||||
prefix=config.prefix,
|
prefix=config.prefix,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
monitor_system=config.monitor_system,
|
monitor_system=config.monitor_system,
|
||||||
@ -82,28 +106,58 @@ def create_dvclive_logger(
|
|||||||
|
|
||||||
def create_csv_logger(
|
def create_csv_logger(
|
||||||
config: CSVLoggerConfig,
|
config: CSVLoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import CSVLogger
|
from lightning.pytorch.loggers import CSVLogger
|
||||||
|
|
||||||
|
if log_dir is None:
|
||||||
|
log_dir = Path(config.log_dir)
|
||||||
|
|
||||||
|
if run_name is None:
|
||||||
|
run_name = config.run_name
|
||||||
|
|
||||||
|
if experiment_name is None:
|
||||||
|
experiment_name = config.experiment_name
|
||||||
|
|
||||||
|
name = run_name
|
||||||
|
|
||||||
|
if run_name is not None and experiment_name is not None:
|
||||||
|
name = str(Path(experiment_name) / run_name)
|
||||||
|
|
||||||
return CSVLogger(
|
return CSVLogger(
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=str(log_dir),
|
||||||
name=config.name,
|
name=name,
|
||||||
version=config.version,
|
|
||||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_tensorboard_logger(
|
def create_tensorboard_logger(
|
||||||
config: TensorBoardLoggerConfig,
|
config: TensorBoardLoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
from lightning.pytorch.loggers import TensorBoardLogger
|
from lightning.pytorch.loggers import TensorBoardLogger
|
||||||
|
|
||||||
|
if log_dir is None:
|
||||||
|
log_dir = Path(config.log_dir)
|
||||||
|
|
||||||
|
if run_name is None:
|
||||||
|
run_name = config.run_name
|
||||||
|
|
||||||
|
if experiment_name is None:
|
||||||
|
experiment_name = config.experiment_name
|
||||||
|
|
||||||
|
name = run_name
|
||||||
|
|
||||||
|
if run_name is not None and experiment_name is not None:
|
||||||
|
name = str(Path(experiment_name) / run_name)
|
||||||
|
|
||||||
return TensorBoardLogger(
|
return TensorBoardLogger(
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
save_dir=str(log_dir),
|
||||||
name=config.name,
|
name=name,
|
||||||
version=config.version,
|
|
||||||
log_graph=config.log_graph,
|
log_graph=config.log_graph,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -111,6 +165,8 @@ def create_tensorboard_logger(
|
|||||||
def create_mlflow_logger(
|
def create_mlflow_logger(
|
||||||
config: MLFlowLoggerConfig,
|
config: MLFlowLoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[data.PathLike] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
try:
|
try:
|
||||||
from lightning.pytorch.loggers import MLFlowLogger
|
from lightning.pytorch.loggers import MLFlowLogger
|
||||||
@ -121,17 +177,25 @@ def create_mlflow_logger(
|
|||||||
"or `uv add mlflow`"
|
"or `uv add mlflow`"
|
||||||
) from error
|
) from error
|
||||||
|
|
||||||
|
if experiment_name is None:
|
||||||
|
experiment_name = config.experiment_name or "Default"
|
||||||
|
|
||||||
|
if log_dir is None:
|
||||||
|
log_dir = config.log_dir
|
||||||
|
|
||||||
return MLFlowLogger(
|
return MLFlowLogger(
|
||||||
experiment_name=config.experiment_name,
|
experiment_name=experiment_name
|
||||||
run_name=config.run_name,
|
if experiment_name is not None
|
||||||
save_dir=str(log_dir) if log_dir is not None else config.save_dir,
|
else config.experiment_name,
|
||||||
|
run_name=run_name if run_name is not None else config.run_name,
|
||||||
|
save_dir=str(log_dir),
|
||||||
tracking_uri=config.tracking_uri,
|
tracking_uri=config.tracking_uri,
|
||||||
tags=config.tags,
|
tags=config.tags,
|
||||||
log_model=config.log_model,
|
log_model=config.log_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
LOGGER_FACTORY = {
|
LOGGER_FACTORY: Dict[str, LoggerBuilder] = {
|
||||||
"dvclive": create_dvclive_logger,
|
"dvclive": create_dvclive_logger,
|
||||||
"csv": create_csv_logger,
|
"csv": create_csv_logger,
|
||||||
"tensorboard": create_tensorboard_logger,
|
"tensorboard": create_tensorboard_logger,
|
||||||
@ -141,7 +205,9 @@ LOGGER_FACTORY = {
|
|||||||
|
|
||||||
def build_logger(
|
def build_logger(
|
||||||
config: LoggerConfig,
|
config: LoggerConfig,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> Logger:
|
) -> Logger:
|
||||||
"""
|
"""
|
||||||
Creates a logger instance from a validated Pydantic config object.
|
Creates a logger instance from a validated Pydantic config object.
|
||||||
@ -150,14 +216,19 @@ def build_logger(
|
|||||||
"Building logger with config: \n{}",
|
"Building logger with config: \n{}",
|
||||||
lambda: config.to_yaml_string(),
|
lambda: config.to_yaml_string(),
|
||||||
)
|
)
|
||||||
logger_type = config.logger_type
|
logger_type = config.name
|
||||||
|
|
||||||
if logger_type not in LOGGER_FACTORY:
|
if logger_type not in LOGGER_FACTORY:
|
||||||
raise ValueError(f"Unknown logger type: {logger_type}")
|
raise ValueError(f"Unknown logger type: {logger_type}")
|
||||||
|
|
||||||
creation_func = LOGGER_FACTORY[logger_type]
|
creation_func = LOGGER_FACTORY[logger_type]
|
||||||
|
|
||||||
return creation_func(config, log_dir=log_dir)
|
return creation_func(
|
||||||
|
config,
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_image_plotter(logger: Logger):
|
def get_image_plotter(logger: Logger):
|
||||||
@ -173,8 +244,8 @@ def get_image_plotter(logger: Logger):
|
|||||||
def plot_figure(name, figure, step):
|
def plot_figure(name, figure, step):
|
||||||
image = _convert_figure_to_image(figure)
|
image = _convert_figure_to_image(figure)
|
||||||
return logger.experiment.log_image(
|
return logger.experiment.log_image(
|
||||||
run_id=logger.run_id,
|
logger.run_id,
|
||||||
image=image,
|
image,
|
||||||
key=name,
|
key=name,
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -1,8 +1,9 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from lightning import Trainer
|
from lightning import Trainer, seed_everything
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -45,6 +46,8 @@ __all__ = [
|
|||||||
"train",
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
train_annotations: Sequence[data.ClipAnnotation],
|
train_annotations: Sequence[data.ClipAnnotation],
|
||||||
@ -53,9 +56,15 @@ def train(
|
|||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: Optional[int] = None,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: Optional[int] = None,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
if seed is not None:
|
||||||
|
seed_everything(seed)
|
||||||
|
|
||||||
config = config or FullTrainingConfig()
|
config = config or FullTrainingConfig()
|
||||||
|
|
||||||
targets = build_targets(config.targets)
|
targets = build_targets(config.targets)
|
||||||
@ -107,6 +116,8 @@ def train(
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("Starting main training loop...")
|
logger.info("Starting main training loop...")
|
||||||
@ -134,17 +145,32 @@ def build_trainer_callbacks(
|
|||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: EvaluationConfig,
|
config: EvaluationConfig,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> List[Callback]:
|
) -> List[Callback]:
|
||||||
if checkpoint_dir is None:
|
if checkpoint_dir is None:
|
||||||
checkpoint_dir = "outputs/checkpoints"
|
checkpoint_dir = DEFAULT_CHECKPOINT_DIR
|
||||||
|
|
||||||
|
filename = "best-{epoch:02d}-{val_loss:.0f}"
|
||||||
|
|
||||||
|
if run_name is not None:
|
||||||
|
filename = f"run_{run_name}_{filename}"
|
||||||
|
|
||||||
|
if experiment_name is not None:
|
||||||
|
filename = f"experiment_{experiment_name}_{filename}"
|
||||||
|
|
||||||
|
model_checkpoint = ModelCheckpoint(
|
||||||
|
dirpath=str(checkpoint_dir),
|
||||||
|
save_top_k=1,
|
||||||
|
filename=filename,
|
||||||
|
monitor="total_loss/val",
|
||||||
|
)
|
||||||
|
|
||||||
|
model_checkpoint.CHECKPOINT_EQUALS_CHAR = "_" # type: ignore
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ModelCheckpoint(
|
model_checkpoint,
|
||||||
dirpath=str(checkpoint_dir),
|
|
||||||
save_top_k=1,
|
|
||||||
monitor="total_loss/val",
|
|
||||||
),
|
|
||||||
ValidationMetrics(
|
ValidationMetrics(
|
||||||
metrics=[
|
metrics=[
|
||||||
DetectionAveragePrecision(),
|
DetectionAveragePrecision(),
|
||||||
@ -162,15 +188,22 @@ def build_trainer_callbacks(
|
|||||||
def build_trainer(
|
def build_trainer(
|
||||||
conf: FullTrainingConfig,
|
conf: FullTrainingConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
checkpoint_dir: Optional[data.PathLike] = None,
|
checkpoint_dir: Optional[Path] = None,
|
||||||
log_dir: Optional[data.PathLike] = None,
|
log_dir: Optional[Path] = None,
|
||||||
|
experiment_name: Optional[str] = None,
|
||||||
|
run_name: Optional[str] = None,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.train.trainer
|
trainer_conf = conf.train.trainer
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
train_logger = build_logger(conf.train.logger, log_dir=log_dir)
|
train_logger = build_logger(
|
||||||
|
conf.train.logger,
|
||||||
|
log_dir=log_dir,
|
||||||
|
experiment_name=experiment_name,
|
||||||
|
run_name=run_name,
|
||||||
|
)
|
||||||
|
|
||||||
train_logger.log_hyperparams(
|
train_logger.log_hyperparams(
|
||||||
conf.model_dump(
|
conf.model_dump(
|
||||||
@ -187,6 +220,7 @@ def build_trainer(
|
|||||||
config=conf.evaluation,
|
config=conf.evaluation,
|
||||||
preprocessor=build_preprocessor(conf.preprocess),
|
preprocessor=build_preprocessor(conf.preprocess),
|
||||||
checkpoint_dir=checkpoint_dir,
|
checkpoint_dir=checkpoint_dir,
|
||||||
|
experiment_name=train_logger.name,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user