Compare commits

..

No commits in common. "9816985bb17ddff3cfd082a0f2dfc4f92b8023b8" and "839a632aa261c237faabd73de5cc4184e61dfe71" have entirely different histories.

4 changed files with 41 additions and 41 deletions

View File

@ -22,8 +22,8 @@ __all__ = ["train_command"]
@click.option("--model-path", type=click.Path(exists=True)) @click.option("--model-path", type=click.Path(exists=True))
@click.option("--config", type=click.Path(exists=True)) @click.option("--config", type=click.Path(exists=True))
@click.option("--config-field", type=str) @click.option("--config-field", type=str)
@click.option("--train-workers", type=int) @click.option("--train-workers", type=int, default=0)
@click.option("--val-workers", type=int) @click.option("--val-workers", type=int, default=0)
@click.option( @click.option(
"-v", "-v",
"--verbose", "--verbose",

View File

@ -45,29 +45,10 @@ class PLTrainerConfig(BaseConfig):
val_check_interval: Optional[Union[int, float]] = None val_check_interval: Optional[Union[int, float]] = None
class DataLoaderConfig(BaseConfig): class TrainingConfig(PLTrainerConfig):
batch_size: int batch_size: int = 8
shuffle: bool
num_workers: int = 0
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False)
class LoadersConfig(BaseConfig):
train: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
)
val: DataLoaderConfig = Field(
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
)
class TrainingConfig(BaseConfig):
learning_rate: float = 1e-3 learning_rate: float = 1e-3
t_max: int = 100 t_max: int = 100
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
loss: LossConfig = Field(default_factory=LossConfig) loss: LossConfig = Field(default_factory=LossConfig)
augmentations: Optional[AugmentationsConfig] = Field( augmentations: Optional[AugmentationsConfig] = Field(
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG

View File

@ -93,6 +93,7 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
name=config.name, name=config.name,
version=config.version, version=config.version,
log_graph=config.log_graph, log_graph=config.log_graph,
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
) )

View File

@ -1,6 +1,7 @@
from collections.abc import Sequence from collections.abc import Sequence
from typing import List, Optional from typing import List, Optional
import yaml
from lightning import Trainer from lightning import Trainer
from lightning.pytorch.callbacks import Callback, ModelCheckpoint from lightning.pytorch.callbacks import Callback, ModelCheckpoint
from loguru import logger from loguru import logger
@ -19,7 +20,11 @@ from batdetect2.targets import TargetProtocol
from batdetect2.train.augmentations import build_augmentations from batdetect2.train.augmentations import build_augmentations
from batdetect2.train.callbacks import ValidationMetrics from batdetect2.train.callbacks import ValidationMetrics
from batdetect2.train.clips import build_clipper from batdetect2.train.clips import build_clipper
from batdetect2.train.config import FullTrainingConfig, TrainingConfig from batdetect2.train.config import (
FullTrainingConfig,
PLTrainerConfig,
TrainingConfig,
)
from batdetect2.train.dataset import ( from batdetect2.train.dataset import (
LabeledDataset, LabeledDataset,
RandomExampleSource, RandomExampleSource,
@ -43,8 +48,8 @@ def train(
val_examples: Optional[Sequence[data.PathLike]] = None, val_examples: Optional[Sequence[data.PathLike]] = None,
config: Optional[FullTrainingConfig] = None, config: Optional[FullTrainingConfig] = None,
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: Optional[int] = None, train_workers: int = 0,
val_workers: Optional[int] = None, val_workers: int = 0,
): ):
conf = config or FullTrainingConfig() conf = config or FullTrainingConfig()
@ -105,13 +110,16 @@ def build_trainer(
conf: FullTrainingConfig, conf: FullTrainingConfig,
targets: TargetProtocol, targets: TargetProtocol,
) -> Trainer: ) -> Trainer:
trainer_conf = conf.train.trainer trainer_conf = PLTrainerConfig.model_validate(
conf.train.model_dump(mode="python")
)
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Building trainer with config: \n{config}", "Building trainer with config: \n{config}",
config=lambda: trainer_conf.to_yaml_string(exclude_none=True), config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
) )
return Trainer( return Trainer(
**trainer_conf.model_dump(exclude_none=True), **trainer_conf.model_dump(exclude_none=True),
val_check_interval=conf.train.val_check_interval,
logger=build_logger(conf.train.logger), logger=build_logger(conf.train.logger),
callbacks=build_trainer_callbacks(targets), callbacks=build_trainer_callbacks(targets),
) )
@ -129,17 +137,22 @@ def build_train_loader(
preprocessor=preprocessor, preprocessor=preprocessor,
config=config, config=config,
) )
loader_conf = config.dataloaders.train
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Training data loader config: \n{config}", "Training data loader config: \n{}",
config=loader_conf.to_yaml_string(exclude_none=True), lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": True,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
) )
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
train_dataset, train_dataset,
batch_size=loader_conf.batch_size, batch_size=config.batch_size,
shuffle=loader_conf.shuffle, shuffle=True,
num_workers=num_workers, num_workers=num_workers or 0,
collate_fn=collate_fn, collate_fn=collate_fn,
) )
@ -154,17 +167,22 @@ def build_val_loader(
val_examples, val_examples,
config=config, config=config,
) )
loader_conf = config.dataloaders.val
logger.opt(lazy=True).debug( logger.opt(lazy=True).debug(
"Validation data loader config: \n{config}", "Validation data loader config: \n{}",
config=loader_conf.to_yaml_string(exclude_none=True), lambda: yaml.dump(
{
"batch_size": config.batch_size,
"shuffle": False,
"num_workers": num_workers or 0,
"collate_fn": str(collate_fn),
}
),
) )
num_workers = num_workers or loader_conf.num_workers
return DataLoader( return DataLoader(
val_dataset, val_dataset,
batch_size=loader_conf.batch_size, batch_size=config.batch_size,
shuffle=loader_conf.shuffle, shuffle=False,
num_workers=num_workers, num_workers=num_workers or 0,
collate_fn=collate_fn, collate_fn=collate_fn,
) )