mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Compare commits
2 Commits
839a632aa2
...
9816985bb1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9816985bb1 | ||
|
|
62f0c5c397 |
@ -22,8 +22,8 @@ __all__ = ["train_command"]
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int, default=0)
|
||||
@click.option("--val-workers", type=int, default=0)
|
||||
@click.option("--train-workers", type=int)
|
||||
@click.option("--val-workers", type=int)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
|
||||
@ -45,10 +45,29 @@ class PLTrainerConfig(BaseConfig):
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
class TrainingConfig(PLTrainerConfig):
|
||||
batch_size: int = 8
|
||||
class DataLoaderConfig(BaseConfig):
|
||||
batch_size: int
|
||||
shuffle: bool
|
||||
num_workers: int = 0
|
||||
|
||||
|
||||
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
|
||||
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False)
|
||||
|
||||
|
||||
class LoadersConfig(BaseConfig):
|
||||
train: DataLoaderConfig = Field(
|
||||
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
|
||||
)
|
||||
val: DataLoaderConfig = Field(
|
||||
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
|
||||
)
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
augmentations: Optional[AugmentationsConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||
|
||||
@ -93,7 +93,6 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
||||
name=config.name,
|
||||
version=config.version,
|
||||
log_graph=config.log_graph,
|
||||
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
import yaml
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||
from loguru import logger
|
||||
@ -20,11 +19,7 @@ from batdetect2.targets import TargetProtocol
|
||||
from batdetect2.train.augmentations import build_augmentations
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import (
|
||||
FullTrainingConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
)
|
||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
||||
from batdetect2.train.dataset import (
|
||||
LabeledDataset,
|
||||
RandomExampleSource,
|
||||
@ -48,8 +43,8 @@ def train(
|
||||
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
train_workers: Optional[int] = None,
|
||||
val_workers: Optional[int] = None,
|
||||
):
|
||||
conf = config or FullTrainingConfig()
|
||||
|
||||
@ -110,16 +105,13 @@ def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> Trainer:
|
||||
trainer_conf = PLTrainerConfig.model_validate(
|
||||
conf.train.model_dump(mode="python")
|
||||
)
|
||||
trainer_conf = conf.train.trainer
|
||||
logger.opt(lazy=True).debug(
|
||||
"Building trainer with config: \n{config}",
|
||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
return Trainer(
|
||||
**trainer_conf.model_dump(exclude_none=True),
|
||||
val_check_interval=conf.train.val_check_interval,
|
||||
logger=build_logger(conf.train.logger),
|
||||
callbacks=build_trainer_callbacks(targets),
|
||||
)
|
||||
@ -137,22 +129,17 @@ def build_train_loader(
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
loader_conf = config.dataloaders.train
|
||||
logger.opt(lazy=True).debug(
|
||||
"Training data loader config: \n{}",
|
||||
lambda: yaml.dump(
|
||||
{
|
||||
"batch_size": config.batch_size,
|
||||
"shuffle": True,
|
||||
"num_workers": num_workers or 0,
|
||||
"collate_fn": str(collate_fn),
|
||||
}
|
||||
),
|
||||
"Training data loader config: \n{config}",
|
||||
config=loader_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
num_workers = num_workers or loader_conf.num_workers
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers or 0,
|
||||
batch_size=loader_conf.batch_size,
|
||||
shuffle=loader_conf.shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
@ -167,22 +154,17 @@ def build_val_loader(
|
||||
val_examples,
|
||||
config=config,
|
||||
)
|
||||
loader_conf = config.dataloaders.val
|
||||
logger.opt(lazy=True).debug(
|
||||
"Validation data loader config: \n{}",
|
||||
lambda: yaml.dump(
|
||||
{
|
||||
"batch_size": config.batch_size,
|
||||
"shuffle": False,
|
||||
"num_workers": num_workers or 0,
|
||||
"collate_fn": str(collate_fn),
|
||||
}
|
||||
),
|
||||
"Validation data loader config: \n{config}",
|
||||
config=loader_conf.to_yaml_string(exclude_none=True),
|
||||
)
|
||||
num_workers = num_workers or loader_conf.num_workers
|
||||
return DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers or 0,
|
||||
batch_size=loader_conf.batch_size,
|
||||
shuffle=loader_conf.shuffle,
|
||||
num_workers=num_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user