mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Change train to use full config
This commit is contained in:
parent
6d91153a56
commit
587742b41e
@ -2,13 +2,13 @@ from batdetect2.cli.base import cli
|
|||||||
from batdetect2.cli.compat import detect
|
from batdetect2.cli.compat import detect
|
||||||
from batdetect2.cli.data import data
|
from batdetect2.cli.data import data
|
||||||
from batdetect2.cli.preprocess import preprocess
|
from batdetect2.cli.preprocess import preprocess
|
||||||
from batdetect2.cli.train import train
|
from batdetect2.cli.train import train_detector
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cli",
|
"cli",
|
||||||
"detect",
|
"detect",
|
||||||
"data",
|
"data",
|
||||||
"train",
|
"train_detector",
|
||||||
"preprocess",
|
"preprocess",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -5,236 +5,53 @@ import click
|
|||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
from batdetect2.evaluate.metrics import (
|
from batdetect2.train import (
|
||||||
ClassificationAccuracy,
|
FullTrainingConfig,
|
||||||
ClassificationMeanAveragePrecision,
|
load_full_training_config,
|
||||||
DetectionAveragePrecision,
|
train,
|
||||||
)
|
)
|
||||||
from batdetect2.models import build_model
|
|
||||||
from batdetect2.models.backbones import load_backbone_config
|
|
||||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
|
||||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
|
||||||
from batdetect2.targets import build_targets, load_target_config
|
|
||||||
from batdetect2.train import train
|
|
||||||
from batdetect2.train.callbacks import ValidationMetrics
|
|
||||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
|
||||||
from batdetect2.train.dataset import list_preprocessed_files
|
from batdetect2.train.dataset import list_preprocessed_files
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"train_command",
|
"train_command",
|
||||||
]
|
]
|
||||||
|
|
||||||
DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command(name="train")
|
@cli.command(name="train")
|
||||||
@click.option(
|
@click.option("--train-dir", type=click.Path(exists=True), required=True)
|
||||||
"--train-examples",
|
@click.option("--val-dir", type=click.Path(exists=True))
|
||||||
type=click.Path(exists=True),
|
@click.option("--model-path", type=click.Path(exists=True))
|
||||||
required=True,
|
@click.option("--config", type=click.Path(exists=True))
|
||||||
)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--val-examples", type=click.Path(exists=True))
|
@click.option("--train-workers", type=int, default=0)
|
||||||
@click.option(
|
@click.option("--val-workers", type=int, default=0)
|
||||||
"--model-path",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-config-field",
|
|
||||||
type=str,
|
|
||||||
default="train",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--preprocess-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the preprocessing configuration file. This file tells "
|
|
||||||
"the program how to prepare your audio data before training, such "
|
|
||||||
"as resampling or applying filters."
|
|
||||||
),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--preprocess-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the preprocessing settings are inside a nested dictionary "
|
|
||||||
"within the preprocessing configuration file, specify the key "
|
|
||||||
"here to access them. If the preprocessing settings are at the "
|
|
||||||
"top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
default="preprocess",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
help=(
|
|
||||||
"Path to the training target configuration file. This file "
|
|
||||||
"specifies what sounds the model should learn to predict."
|
|
||||||
),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--target-config-field",
|
|
||||||
type=str,
|
|
||||||
help=(
|
|
||||||
"If the target settings are inside a nested dictionary "
|
|
||||||
"within the target configuration file, specify the key here. "
|
|
||||||
"If the settings are at the top level, you don't need to specify this."
|
|
||||||
),
|
|
||||||
default="targets",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--postprocess-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--postprocess-config-field",
|
|
||||||
type=str,
|
|
||||||
default="postprocess",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--model-config",
|
|
||||||
type=click.Path(exists=True),
|
|
||||||
default=DEFAULT_CONFIG_FILE,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--model-config-field",
|
|
||||||
type=str,
|
|
||||||
default="model",
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--train-workers",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
@click.option(
|
|
||||||
"--val-workers",
|
|
||||||
type=int,
|
|
||||||
default=0,
|
|
||||||
)
|
|
||||||
def train_command(
|
def train_command(
|
||||||
train_examples: Path,
|
train_dir: Path,
|
||||||
val_examples: Optional[Path] = None,
|
val_dir: Optional[Path] = None,
|
||||||
model_path: Optional[Path] = None,
|
model_path: Optional[Path] = None,
|
||||||
train_config: Path = DEFAULT_CONFIG_FILE,
|
config: Optional[Path] = None,
|
||||||
train_config_field: str = "train",
|
config_field: Optional[str] = None,
|
||||||
preprocess_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
preprocess_config_field: str = "preprocess",
|
|
||||||
target_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
target_config_field: str = "targets",
|
|
||||||
postprocess_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
postprocess_config_field: str = "postprocess",
|
|
||||||
model_config: Path = DEFAULT_CONFIG_FILE,
|
|
||||||
model_config_field: str = "model",
|
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
):
|
):
|
||||||
logger.info("Starting training!")
|
logger.info("Starting training!")
|
||||||
|
|
||||||
try:
|
conf = (
|
||||||
target_config_loaded = load_target_config(
|
load_full_training_config(config, field=config_field)
|
||||||
path=target_config,
|
if config is not None
|
||||||
field=target_config_field,
|
else FullTrainingConfig()
|
||||||
)
|
|
||||||
targets = build_targets(config=target_config_loaded)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded targets info from config file {path}", path=target_config
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load target info from config file, using default"
|
|
||||||
)
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
try:
|
|
||||||
preprocess_config_loaded = load_preprocessing_config(
|
|
||||||
path=preprocess_config,
|
|
||||||
field=preprocess_config_field,
|
|
||||||
)
|
|
||||||
preprocessor = build_preprocessor(preprocess_config_loaded)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded preprocessor from config file {path}", path=target_config
|
|
||||||
)
|
|
||||||
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load preprocessor from config file, using default"
|
|
||||||
)
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
try:
|
|
||||||
model_config_loaded = load_backbone_config(
|
|
||||||
path=model_config, field=model_config_field
|
|
||||||
)
|
|
||||||
model = build_model(
|
|
||||||
num_classes=len(targets.class_names),
|
|
||||||
config=model_config_loaded,
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
model = build_model(num_classes=len(targets.class_names))
|
|
||||||
|
|
||||||
try:
|
|
||||||
postprocess_config_loaded = load_postprocess_config(
|
|
||||||
path=postprocess_config,
|
|
||||||
field=postprocess_config_field,
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets=targets,
|
|
||||||
config=postprocess_config_loaded,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded postprocessor from file {path}", path=postprocess_config
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
logger.debug(
|
|
||||||
"Could not load postprocessor config from file. Using default"
|
|
||||||
)
|
|
||||||
postprocessor = build_postprocessor(targets=targets)
|
|
||||||
|
|
||||||
try:
|
|
||||||
train_config_loaded = load_train_config(
|
|
||||||
path=train_config, field=train_config_field
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
"Loaded training config from file {path}",
|
|
||||||
path=train_config,
|
|
||||||
)
|
|
||||||
except IOError:
|
|
||||||
train_config_loaded = TrainingConfig()
|
|
||||||
logger.debug("Could not load training config from file. Using default")
|
|
||||||
|
|
||||||
train_files = list_preprocessed_files(train_examples)
|
|
||||||
|
|
||||||
val_files = (
|
|
||||||
None if val_examples is None else list_preprocessed_files(val_examples)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return train(
|
train_examples = list_preprocessed_files(train_dir)
|
||||||
detector=model,
|
val_examples = (
|
||||||
train_examples=train_files, # type: ignore
|
list_preprocessed_files(val_dir) if val_dir is not None else None
|
||||||
val_examples=val_files, # type: ignore
|
)
|
||||||
|
|
||||||
|
train(
|
||||||
|
train_examples=train_examples,
|
||||||
|
val_examples=val_examples,
|
||||||
|
config=conf,
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
targets=targets,
|
|
||||||
config=train_config_loaded,
|
|
||||||
callbacks=[
|
|
||||||
ValidationMetrics(
|
|
||||||
metrics=[
|
|
||||||
DetectionAveragePrecision(),
|
|
||||||
ClassificationMeanAveragePrecision(
|
|
||||||
class_names=targets.class_names,
|
|
||||||
),
|
|
||||||
ClassificationAccuracy(class_names=targets.class_names),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
],
|
|
||||||
train_workers=train_workers,
|
train_workers=train_workers,
|
||||||
val_workers=val_workers,
|
val_workers=val_workers,
|
||||||
)
|
)
|
||||||
|
@ -15,7 +15,7 @@ from batdetect2.train.augmentations import (
|
|||||||
)
|
)
|
||||||
from batdetect2.train.clips import build_clipper, select_subclip
|
from batdetect2.train.clips import build_clipper, select_subclip
|
||||||
from batdetect2.train.config import (
|
from batdetect2.train.config import (
|
||||||
TrainerConfig,
|
PLTrainerConfig,
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
load_train_config,
|
load_train_config,
|
||||||
)
|
)
|
||||||
@ -39,8 +39,14 @@ from batdetect2.train.preprocess import (
|
|||||||
preprocess_annotations,
|
preprocess_annotations,
|
||||||
)
|
)
|
||||||
from batdetect2.train.train import (
|
from batdetect2.train.train import (
|
||||||
|
FullTrainingConfig,
|
||||||
build_train_dataset,
|
build_train_dataset,
|
||||||
|
build_train_loader,
|
||||||
|
build_trainer,
|
||||||
|
build_training_module,
|
||||||
build_val_dataset,
|
build_val_dataset,
|
||||||
|
build_val_loader,
|
||||||
|
load_full_training_config,
|
||||||
train,
|
train,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -50,14 +56,15 @@ __all__ = [
|
|||||||
"DetectionLossConfig",
|
"DetectionLossConfig",
|
||||||
"EchoAugmentationConfig",
|
"EchoAugmentationConfig",
|
||||||
"FrequencyMaskAugmentationConfig",
|
"FrequencyMaskAugmentationConfig",
|
||||||
|
"FullTrainingConfig",
|
||||||
"LabeledDataset",
|
"LabeledDataset",
|
||||||
"LossConfig",
|
"LossConfig",
|
||||||
"LossFunction",
|
"LossFunction",
|
||||||
|
"PLTrainerConfig",
|
||||||
"RandomExampleSource",
|
"RandomExampleSource",
|
||||||
"SizeLossConfig",
|
"SizeLossConfig",
|
||||||
"TimeMaskAugmentationConfig",
|
"TimeMaskAugmentationConfig",
|
||||||
"TrainExample",
|
"TrainExample",
|
||||||
"TrainerConfig",
|
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"VolumeAugmentationConfig",
|
"VolumeAugmentationConfig",
|
||||||
"WarpAugmentationConfig",
|
"WarpAugmentationConfig",
|
||||||
@ -67,9 +74,14 @@ __all__ = [
|
|||||||
"build_clipper",
|
"build_clipper",
|
||||||
"build_loss",
|
"build_loss",
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
|
"build_train_loader",
|
||||||
|
"build_trainer",
|
||||||
|
"build_training_module",
|
||||||
"build_val_dataset",
|
"build_val_dataset",
|
||||||
|
"build_val_loader",
|
||||||
"generate_train_example",
|
"generate_train_example",
|
||||||
"list_preprocessed_files",
|
"list_preprocessed_files",
|
||||||
|
"load_full_training_config",
|
||||||
"load_label_config",
|
"load_label_config",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
"mask_frequency",
|
"mask_frequency",
|
||||||
@ -79,6 +91,5 @@ __all__ = [
|
|||||||
"scale_volume",
|
"scale_volume",
|
||||||
"select_subclip",
|
"select_subclip",
|
||||||
"train",
|
"train",
|
||||||
"train",
|
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
]
|
]
|
||||||
|
@ -13,18 +13,12 @@ from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
|||||||
from batdetect2.train.losses import LossConfig
|
from batdetect2.train.losses import LossConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"OptimizerConfig",
|
|
||||||
"TrainingConfig",
|
"TrainingConfig",
|
||||||
"load_train_config",
|
"load_train_config",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class OptimizerConfig(BaseConfig):
|
class PLTrainerConfig(BaseConfig):
|
||||||
learning_rate: float = 1e-3
|
|
||||||
t_max: int = 100
|
|
||||||
|
|
||||||
|
|
||||||
class TrainerConfig(BaseConfig):
|
|
||||||
accelerator: str = "auto"
|
accelerator: str = "auto"
|
||||||
accumulate_grad_batches: int = 1
|
accumulate_grad_batches: int = 1
|
||||||
deterministic: bool = True
|
deterministic: bool = True
|
||||||
@ -45,15 +39,16 @@ class TrainerConfig(BaseConfig):
|
|||||||
val_check_interval: Optional[Union[int, float]] = None
|
val_check_interval: Optional[Union[int, float]] = None
|
||||||
|
|
||||||
|
|
||||||
class TrainingConfig(BaseConfig):
|
class TrainingConfig(PLTrainerConfig):
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
learning_rate: float = 1e-3
|
||||||
|
t_max: int = 100
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
|
||||||
augmentations: AugmentationsConfig = Field(
|
augmentations: AugmentationsConfig = Field(
|
||||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||||
)
|
)
|
||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,20 +1,28 @@
|
|||||||
|
from collections.abc import Sequence
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from lightning import Trainer
|
from lightning import Trainer
|
||||||
from lightning.pytorch.callbacks import Callback
|
from lightning.pytorch.callbacks import Callback
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from batdetect2.models.types import DetectionModel
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.evaluate.metrics import (
|
||||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
ClassificationAccuracy,
|
||||||
from batdetect2.preprocess import build_preprocessor
|
ClassificationMeanAveragePrecision,
|
||||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
DetectionAveragePrecision,
|
||||||
from batdetect2.targets import build_targets
|
|
||||||
from batdetect2.targets.types import TargetProtocol
|
|
||||||
from batdetect2.train.augmentations import (
|
|
||||||
build_augmentations,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.models import BackboneConfig, build_model
|
||||||
|
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||||
|
from batdetect2.preprocess import (
|
||||||
|
PreprocessingConfig,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
build_preprocessor,
|
||||||
|
)
|
||||||
|
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||||
|
from batdetect2.train.augmentations import build_augmentations
|
||||||
|
from batdetect2.train.callbacks import ValidationMetrics
|
||||||
from batdetect2.train.clips import build_clipper
|
from batdetect2.train.clips import build_clipper
|
||||||
from batdetect2.train.config import TrainingConfig
|
from batdetect2.train.config import TrainingConfig
|
||||||
from batdetect2.train.dataset import (
|
from batdetect2.train.dataset import (
|
||||||
@ -27,94 +35,71 @@ from batdetect2.train.logging import build_logger
|
|||||||
from batdetect2.train.losses import build_loss
|
from batdetect2.train.losses import build_loss
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"train",
|
"FullTrainingConfig",
|
||||||
"build_val_dataset",
|
|
||||||
"build_train_dataset",
|
"build_train_dataset",
|
||||||
|
"build_train_loader",
|
||||||
|
"build_trainer",
|
||||||
|
"build_training_module",
|
||||||
|
"build_val_dataset",
|
||||||
|
"build_val_loader",
|
||||||
|
"load_full_training_config",
|
||||||
|
"train",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class FullTrainingConfig(BaseConfig):
|
||||||
|
"""Full training configuration."""
|
||||||
|
|
||||||
|
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||||
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||||
|
preprocess: PreprocessingConfig = Field(
|
||||||
|
default_factory=PreprocessingConfig
|
||||||
|
)
|
||||||
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
|
|
||||||
|
|
||||||
|
def load_full_training_config(
|
||||||
|
path: data.PathLike,
|
||||||
|
field: Optional[str] = None,
|
||||||
|
) -> FullTrainingConfig:
|
||||||
|
"""Load the full training configuration."""
|
||||||
|
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
detector: DetectionModel,
|
train_examples: Sequence[data.PathLike],
|
||||||
train_examples: List[data.PathLike],
|
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||||
targets: Optional[TargetProtocol] = None,
|
config: Optional[FullTrainingConfig] = None,
|
||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
|
||||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
|
||||||
val_examples: Optional[List[data.PathLike]] = None,
|
|
||||||
config: Optional[TrainingConfig] = None,
|
|
||||||
callbacks: Optional[List[Callback]] = None,
|
|
||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
**trainer_kwargs,
|
):
|
||||||
) -> None:
|
conf = config or FullTrainingConfig()
|
||||||
config = config or TrainingConfig()
|
|
||||||
|
|
||||||
if model_path is None:
|
if model_path is not None:
|
||||||
if preprocessor is None:
|
|
||||||
preprocessor = build_preprocessor()
|
|
||||||
|
|
||||||
if targets is None:
|
|
||||||
targets = build_targets()
|
|
||||||
|
|
||||||
if postprocessor is None:
|
|
||||||
postprocessor = build_postprocessor(
|
|
||||||
targets,
|
|
||||||
min_freq=preprocessor.min_freq,
|
|
||||||
max_freq=preprocessor.max_freq,
|
|
||||||
)
|
|
||||||
|
|
||||||
loss = build_loss(config.loss)
|
|
||||||
|
|
||||||
module = TrainingModule(
|
|
||||||
detector=detector,
|
|
||||||
loss=loss,
|
|
||||||
targets=targets,
|
|
||||||
preprocessor=preprocessor,
|
|
||||||
postprocessor=postprocessor,
|
|
||||||
learning_rate=config.optimizer.learning_rate,
|
|
||||||
t_max=config.optimizer.t_max,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||||
|
else:
|
||||||
|
module = build_training_module(conf)
|
||||||
|
|
||||||
train_dataset = build_train_dataset(
|
trainer = build_trainer(conf, targets=module.targets)
|
||||||
|
|
||||||
|
train_dataloader = build_train_loader(
|
||||||
train_examples,
|
train_examples,
|
||||||
preprocessor=module.preprocessor,
|
preprocessor=module.preprocessor,
|
||||||
config=config,
|
config=conf.train,
|
||||||
)
|
|
||||||
|
|
||||||
logger = build_logger(config.logger)
|
|
||||||
if logger and hasattr(logger, 'log_hyperparams'):
|
|
||||||
logger.log_hyperparams(config.model_dump(exclude_none=True))
|
|
||||||
|
|
||||||
trainer = Trainer(
|
|
||||||
**config.trainer.model_dump(exclude_none=True, exclude={"logger"}),
|
|
||||||
callbacks=callbacks,
|
|
||||||
logger=logger,
|
|
||||||
**trainer_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
train_dataloader = DataLoader(
|
|
||||||
train_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=True,
|
|
||||||
num_workers=train_workers,
|
num_workers=train_workers,
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = None
|
val_dataloader = (
|
||||||
if val_examples:
|
build_val_loader(
|
||||||
val_dataset = build_val_dataset(
|
|
||||||
val_examples,
|
val_examples,
|
||||||
config=config,
|
config=conf.train,
|
||||||
)
|
|
||||||
val_dataloader = DataLoader(
|
|
||||||
val_dataset,
|
|
||||||
batch_size=config.batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=val_workers,
|
num_workers=val_workers,
|
||||||
collate_fn=collate_fn,
|
|
||||||
)
|
)
|
||||||
|
if val_examples is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
module,
|
module,
|
||||||
@ -123,8 +108,106 @@ def train(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
|
||||||
|
preprocessor = build_preprocessor(conf.preprocess)
|
||||||
|
|
||||||
|
targets = build_targets(conf.targets)
|
||||||
|
|
||||||
|
postprocessor = build_postprocessor(
|
||||||
|
targets,
|
||||||
|
min_freq=preprocessor.min_freq,
|
||||||
|
max_freq=preprocessor.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
model = build_model(
|
||||||
|
num_classes=len(targets.class_names),
|
||||||
|
config=conf.model,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss = build_loss(conf.train.loss)
|
||||||
|
|
||||||
|
return TrainingModule(
|
||||||
|
detector=model,
|
||||||
|
loss=loss,
|
||||||
|
targets=targets,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
learning_rate=conf.train.learning_rate,
|
||||||
|
t_max=conf.train.t_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||||
|
return [
|
||||||
|
ValidationMetrics(
|
||||||
|
metrics=[
|
||||||
|
DetectionAveragePrecision(),
|
||||||
|
ClassificationMeanAveragePrecision(
|
||||||
|
class_names=targets.class_names
|
||||||
|
),
|
||||||
|
ClassificationAccuracy(class_names=targets.class_names),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_trainer(
|
||||||
|
conf: FullTrainingConfig,
|
||||||
|
targets: TargetProtocol,
|
||||||
|
) -> Trainer:
|
||||||
|
logger = build_logger(conf.train.logger)
|
||||||
|
|
||||||
|
if logger and hasattr(logger, "log_hyperparams"):
|
||||||
|
logger.log_hyperparams(conf.model_dump(exclude_none=True))
|
||||||
|
|
||||||
|
return Trainer(
|
||||||
|
accelerator=conf.train.accelerator,
|
||||||
|
logger=logger,
|
||||||
|
callbacks=build_trainer_callbacks(targets),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_train_loader(
|
||||||
|
train_examples: Sequence[data.PathLike],
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
config: TrainingConfig,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> DataLoader:
|
||||||
|
train_dataset = build_train_dataset(
|
||||||
|
train_examples,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
return DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=num_workers or 0,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_val_loader(
|
||||||
|
val_examples: Sequence[data.PathLike],
|
||||||
|
config: TrainingConfig,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
):
|
||||||
|
val_dataset = build_val_dataset(
|
||||||
|
val_examples,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
return DataLoader(
|
||||||
|
val_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=num_workers or 0,
|
||||||
|
collate_fn=collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_train_dataset(
|
def build_train_dataset(
|
||||||
examples: List[data.PathLike],
|
examples: Sequence[data.PathLike],
|
||||||
preprocessor: PreprocessorProtocol,
|
preprocessor: PreprocessorProtocol,
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
) -> LabeledDataset:
|
) -> LabeledDataset:
|
||||||
@ -133,7 +216,7 @@ def build_train_dataset(
|
|||||||
clipper = build_clipper(config.cliping, random=True)
|
clipper = build_clipper(config.cliping, random=True)
|
||||||
|
|
||||||
random_example_source = RandomExampleSource(
|
random_example_source = RandomExampleSource(
|
||||||
examples,
|
list(examples),
|
||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -151,7 +234,7 @@ def build_train_dataset(
|
|||||||
|
|
||||||
|
|
||||||
def build_val_dataset(
|
def build_val_dataset(
|
||||||
examples: List[data.PathLike],
|
examples: Sequence[data.PathLike],
|
||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
train: bool = True,
|
train: bool = True,
|
||||||
) -> LabeledDataset:
|
) -> LabeledDataset:
|
||||||
|
Loading…
Reference in New Issue
Block a user