mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Compare commits
No commits in common. "9816985bb17ddff3cfd082a0f2dfc4f92b8023b8" and "839a632aa261c237faabd73de5cc4184e61dfe71" have entirely different histories.
9816985bb1
...
839a632aa2
@ -22,8 +22,8 @@ __all__ = ["train_command"]
|
|||||||
@click.option("--model-path", type=click.Path(exists=True))
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
@click.option("--config", type=click.Path(exists=True))
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int)
|
@click.option("--train-workers", type=int, default=0)
|
||||||
@click.option("--val-workers", type=int)
|
@click.option("--val-workers", type=int, default=0)
|
||||||
@click.option(
|
@click.option(
|
||||||
"-v",
|
"-v",
|
||||||
"--verbose",
|
"--verbose",
|
||||||
|
|||||||
@ -45,29 +45,10 @@ class PLTrainerConfig(BaseConfig):
|
|||||||
val_check_interval: Optional[Union[int, float]] = None
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
class DataLoaderConfig(BaseConfig):
|
class TrainingConfig(PLTrainerConfig):
|
||||||
batch_size: int
|
batch_size: int = 8
|
||||||
shuffle: bool
|
|
||||||
num_workers: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TRAIN_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=True)
|
|
||||||
DEFAULT_VAL_LOADER_CONFIG = DataLoaderConfig(batch_size=8, shuffle=False)
|
|
||||||
|
|
||||||
|
|
||||||
class LoadersConfig(BaseConfig):
|
|
||||||
train: DataLoaderConfig = Field(
|
|
||||||
default_factory=lambda: DEFAULT_TRAIN_LOADER_CONFIG.model_copy()
|
|
||||||
)
|
|
||||||
val: DataLoaderConfig = Field(
|
|
||||||
default_factory=lambda: DEFAULT_VAL_LOADER_CONFIG.model_copy()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
|
||||||
learning_rate: float = 1e-3
|
learning_rate: float = 1e-3
|
||||||
t_max: int = 100
|
t_max: int = 100
|
||||||
dataloaders: LoadersConfig = Field(default_factory=LoadersConfig)
|
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
augmentations: Optional[AugmentationsConfig] = Field(
|
augmentations: Optional[AugmentationsConfig] = Field(
|
||||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||||
|
|||||||
@ -93,6 +93,7 @@ def create_tensorboard_logger(config: TensorBoardLoggerConfig) -> Logger:
|
|||||||
name=config.name,
|
name=config.name,
|
||||||
version=config.version,
|
version=config.version,
|
||||||
log_graph=config.log_graph,
|
log_graph=config.log_graph,
|
||||||
|
flush_logs_every_n_steps=config.flush_logs_every_n_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import yaml
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
from lightning.pytorch.callbacks import Callback, ModelCheckpoint
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
@ -19,7 +20,11 @@ from batdetect2.targets import TargetProtocol
|
|||||||
from batdetect2.train.augmentations import build_augmentations
|
from batdetect2.train.augmentations import build_augmentations
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import FullTrainingConfig, TrainingConfig
|
from batdetect2.train.config import (
|
||||||
|
FullTrainingConfig,
|
||||||
|
PLTrainerConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
LabeledDataset,
|
LabeledDataset,
|
||||||
RandomExampleSource,
|
RandomExampleSource,
|
||||||
@ -43,8 +48,8 @@ def train(
|
|||||||
val_examples: Optional[Sequence[data.PathLike]] = None,
|
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||||
config: Optional[FullTrainingConfig] = None,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: Optional[int] = None,
|
train_workers: int = 0,
|
||||||
val_workers: Optional[int] = None,
|
val_workers: int = 0,
|
||||||
):
|
):
|
||||||
conf = config or FullTrainingConfig()
|
conf = config or FullTrainingConfig()
|
||||||
|
|
||||||
@ -105,13 +110,16 @@ def build_trainer(
|
|||||||
conf: FullTrainingConfig,
|
conf: FullTrainingConfig,
|
||||||
targets: TargetProtocol,
|
targets: TargetProtocol,
|
||||||
) -> Trainer:
|
) -> Trainer:
|
||||||
trainer_conf = conf.train.trainer
|
trainer_conf = PLTrainerConfig.model_validate(
|
||||||
|
conf.train.model_dump(mode="python")
|
||||||
|
)
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Building trainer with config: \n{config}",
|
"Building trainer with config: \n{config}",
|
||||||
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
config=lambda: trainer_conf.to_yaml_string(exclude_none=True),
|
||||||
)
|
)
|
||||||
return Trainer(
|
return Trainer(
|
||||||
**trainer_conf.model_dump(exclude_none=True),
|
**trainer_conf.model_dump(exclude_none=True),
|
||||||
|
val_check_interval=conf.train.val_check_interval,
|
||||||
logger=build_logger(conf.train.logger),
|
logger=build_logger(conf.train.logger),
|
||||||
callbacks=build_trainer_callbacks(targets),
|
callbacks=build_trainer_callbacks(targets),
|
||||||
)
|
)
|
||||||
@ -129,17 +137,22 @@ def build_train_loader(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
loader_conf = config.dataloaders.train
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Training data loader config: \n{config}",
|
"Training data loader config: \n{}",
|
||||||
config=loader_conf.to_yaml_string(exclude_none=True),
|
lambda: yaml.dump(
|
||||||
|
{
|
||||||
|
"batch_size": config.batch_size,
|
||||||
|
"shuffle": True,
|
||||||
|
"num_workers": num_workers or 0,
|
||||||
|
"collate_fn": str(collate_fn),
|
||||||
|
}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
num_workers = num_workers or loader_conf.num_workers
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=loader_conf.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=loader_conf.shuffle,
|
shuffle=True,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers or 0,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -154,17 +167,22 @@ def build_val_loader(
|
|||||||
val_examples,
|
val_examples,
|
||||||
config=config,
|
config=config,
|
||||||
)
|
)
|
||||||
loader_conf = config.dataloaders.val
|
|
||||||
logger.opt(lazy=True).debug(
|
logger.opt(lazy=True).debug(
|
||||||
"Validation data loader config: \n{config}",
|
"Validation data loader config: \n{}",
|
||||||
config=loader_conf.to_yaml_string(exclude_none=True),
|
lambda: yaml.dump(
|
||||||
|
{
|
||||||
|
"batch_size": config.batch_size,
|
||||||
|
"shuffle": False,
|
||||||
|
"num_workers": num_workers or 0,
|
||||||
|
"collate_fn": str(collate_fn),
|
||||||
|
}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
num_workers = num_workers or loader_conf.num_workers
|
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=loader_conf.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=loader_conf.shuffle,
|
shuffle=False,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers or 0,
|
||||||
collate_fn=collate_fn,
|
collate_fn=collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user