Compare commits

..

2 Commits

Author SHA1 Message Date
mbsantiago
9816985bb1 Create dataloader config 2025-07-29 12:10:51 +01:00
mbsantiago
62f0c5c397 Remove stale tensorboard logger argument 2025-07-29 11:47:07 +01:00
4 changed files with 41 additions and 41 deletions

View File

@ -22,8 +22,8 @@ __all__ = ["train_command"]
@click.option("--model-path", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@click.option("--train-workers", type=int, default=0) @click.option("--train-workers", type=int)
@click.option("--val-workers", type=int, default=0) @click.option("--val-workers", type=int)
@click.option( @click.option(
"-v", "-v",
"--verbose", "--verbose",

View File

@ -45,10 +45,29 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None val_check_interval: Optional[Union[int, float]] = None
class TrainingConfig(PLTrainerConfig): class DataLoaderConfig(BaseConfig):
batch_size: int = 8 batch_size: int
shuffle: bool
num_workers: int = 0
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False)
class LoadersConfig(BaseConfig):
train: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
)
val: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
)
class TrainingConfig(BaseConfig):
learning_rate: float = 1e-3 learning_rate: float = 1e-3
t_max: int = 100 t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
augmentations: Optional[AugmentationsConfig] = Field( augmentations: Optional[AugmentationsConfig] = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG

View File

@ -93,7 +93,6 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
name=config.name, name=config.name,
version=config.version, version=config.version,
log_graph=config.log_graph, log_graph=config.log_graph,
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
) )

View File

@ -1,7 +1,6 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import List, Optional from typing import List, Optional
import yaml
from lightning import Trainer from lightning import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger from loguru import logger
@ -20,11 +19,7 @@ from batdetect2.targets import TargetProtocol
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import ( from batdetect2.train.config import FullTrainingConfig, TrainingConfig
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
)
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
@ -48,8 +43,8 @@ def train(
val_examples: Optional[Sequence[data.PathLike]] = None, val_examples: Optional[Sequence[data.PathLike]] = None,
config: Optional[FullTrainingConfig] = None, config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: int = 0, train_workers: Optional[int] = None,
val_workers: int = 0, val_workers: Optional[int] = None,
): ):
conf = config or FullTrainingConfig() conf = config or FullTrainingConfig()
@ -110,16 +105,13 @@ def build_trainer(
conf: FullTrainingConfig, conf: FullTrainingConfig,
targets: TargetProtocol, targets: TargetProtocol,
) -> Trainer: ) -> Trainer:
trainer_conf = PLTrainerConfig.model_validate( trainer_conf = conf.train.trainer
conf.train.model_dump(mode="python")
)
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval,
logger=build_logger(conf.train.logger), logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets), callbacks=build_trainer_callbacks(targets),
) )
@ -137,22 +129,17 @@ def build_train_loader(
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, config=config,
) )
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Training data loader config: \n{}", "Training data loader config: \n{config}",
lambda: yaml.dump( config=loader_conf.to_yaml_string(exclude_none=True),
{
"batch_size": config.batch_size,
"shuffle": True,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
) )
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=loader_conf.batch_size,
shuffle=True, shuffle=loader_conf.shuffle,
num_workers=num_workers or 0, num_workers=num_workers,
collate_fn=collate_fn, collate_fn=collate_fn,
) )
@ -167,22 +154,17 @@ def build_val_loader(
val_examples, val_examples,
config=config, config=config,
) )
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Validation data loader config: \n{}", "Validation data loader config: \n{config}",
lambda: yaml.dump( config=loader_conf.to_yaml_string(exclude_none=True),
{
"batch_size": config.batch_size,
"shuffle": False,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
) )
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
val_dataset, val_dataset,
batch_size=config.batch_size, batch_size=loader_conf.batch_size,
shuffle=False, shuffle=loader_conf.shuffle,
num_workers=num_workers or 0, num_workers=num_workers,
collate_fn=collate_fn, collate_fn=collate_fn,
) )