mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Change train to use full config
This commit is contained in:
parent
6d91153a56
commit
587742b41e
@ -2,13 +2,13 @@ from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
from batdetect2.cli.data import data
|
||||
from batdetect2.cli.preprocess import preprocess
|
||||
from batdetect2.cli.train import train
|
||||
from batdetect2.cli.train import train_detector
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
"detect",
|
||||
"data",
|
||||
"train",
|
||||
"train_detector",
|
||||
"preprocess",
|
||||
]
|
||||
|
||||
|
@ -5,236 +5,53 @@ import click
|
||||
from loguru import logger
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
from batdetect2.train import (
|
||||
FullTrainingConfig,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.models.backbones import load_backbone_config
|
||||
from batdetect2.postprocess import build_postprocessor, load_postprocess_config
|
||||
from batdetect2.preprocess import build_preprocessor, load_preprocessing_config
|
||||
from batdetect2.targets import build_targets, load_target_config
|
||||
from batdetect2.train import train
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.config import TrainingConfig, load_train_config
|
||||
from batdetect2.train.dataset import list_preprocessed_files
|
||||
|
||||
__all__ = [
|
||||
"train_command",
|
||||
]
|
||||
|
||||
DEFAULT_CONFIG_FILE = Path("config.yaml")
|
||||
|
||||
|
||||
@cli.command(name="train")
|
||||
@click.option(
|
||||
"--train-examples",
|
||||
type=click.Path(exists=True),
|
||||
required=True,
|
||||
)
|
||||
@click.option("--val-examples", type=click.Path(exists=True))
|
||||
@click.option(
|
||||
"--model-path",
|
||||
type=click.Path(exists=True),
|
||||
)
|
||||
@click.option(
|
||||
"--train-config",
|
||||
type=click.Path(exists=True),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--train-config-field",
|
||||
type=str,
|
||||
default="train",
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the preprocessing configuration file. This file tells "
|
||||
"the program how to prepare your audio data before training, such "
|
||||
"as resampling or applying filters."
|
||||
),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--preprocess-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the preprocessing settings are inside a nested dictionary "
|
||||
"within the preprocessing configuration file, specify the key "
|
||||
"here to access them. If the preprocessing settings are at the "
|
||||
"top level, you don't need to specify this."
|
||||
),
|
||||
default="preprocess",
|
||||
)
|
||||
@click.option(
|
||||
"--target-config",
|
||||
type=click.Path(exists=True),
|
||||
help=(
|
||||
"Path to the training target configuration file. This file "
|
||||
"specifies what sounds the model should learn to predict."
|
||||
),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--target-config-field",
|
||||
type=str,
|
||||
help=(
|
||||
"If the target settings are inside a nested dictionary "
|
||||
"within the target configuration file, specify the key here. "
|
||||
"If the settings are at the top level, you don't need to specify this."
|
||||
),
|
||||
default="targets",
|
||||
)
|
||||
@click.option(
|
||||
"--postprocess-config",
|
||||
type=click.Path(exists=True),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--postprocess-config-field",
|
||||
type=str,
|
||||
default="postprocess",
|
||||
)
|
||||
@click.option(
|
||||
"--model-config",
|
||||
type=click.Path(exists=True),
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--model-config-field",
|
||||
type=str,
|
||||
default="model",
|
||||
)
|
||||
@click.option(
|
||||
"--train-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
@click.option(
|
||||
"--val-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
@click.option("--train-dir", type=click.Path(exists=True), required=True)
|
||||
@click.option("--val-dir", type=click.Path(exists=True))
|
||||
@click.option("--model-path", type=click.Path(exists=True))
|
||||
@click.option("--config", type=click.Path(exists=True))
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int, default=0)
|
||||
@click.option("--val-workers", type=int, default=0)
|
||||
def train_command(
|
||||
train_examples: Path,
|
||||
val_examples: Optional[Path] = None,
|
||||
train_dir: Path,
|
||||
val_dir: Optional[Path] = None,
|
||||
model_path: Optional[Path] = None,
|
||||
train_config: Path = DEFAULT_CONFIG_FILE,
|
||||
train_config_field: str = "train",
|
||||
preprocess_config: Path = DEFAULT_CONFIG_FILE,
|
||||
preprocess_config_field: str = "preprocess",
|
||||
target_config: Path = DEFAULT_CONFIG_FILE,
|
||||
target_config_field: str = "targets",
|
||||
postprocess_config: Path = DEFAULT_CONFIG_FILE,
|
||||
postprocess_config_field: str = "postprocess",
|
||||
model_config: Path = DEFAULT_CONFIG_FILE,
|
||||
model_config_field: str = "model",
|
||||
config: Optional[Path] = None,
|
||||
config_field: Optional[str] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
):
|
||||
logger.info("Starting training!")
|
||||
|
||||
try:
|
||||
target_config_loaded = load_target_config(
|
||||
path=target_config,
|
||||
field=target_config_field,
|
||||
)
|
||||
targets = build_targets(config=target_config_loaded)
|
||||
logger.debug(
|
||||
"Loaded targets info from config file {path}", path=target_config
|
||||
)
|
||||
except IOError:
|
||||
logger.debug(
|
||||
"Could not load target info from config file, using default"
|
||||
)
|
||||
targets = build_targets()
|
||||
|
||||
try:
|
||||
preprocess_config_loaded = load_preprocessing_config(
|
||||
path=preprocess_config,
|
||||
field=preprocess_config_field,
|
||||
)
|
||||
preprocessor = build_preprocessor(preprocess_config_loaded)
|
||||
logger.debug(
|
||||
"Loaded preprocessor from config file {path}", path=target_config
|
||||
conf = (
|
||||
load_full_training_config(config, field=config_field)
|
||||
if config is not None
|
||||
else FullTrainingConfig()
|
||||
)
|
||||
|
||||
except IOError:
|
||||
logger.debug(
|
||||
"Could not load preprocessor from config file, using default"
|
||||
)
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
try:
|
||||
model_config_loaded = load_backbone_config(
|
||||
path=model_config, field=model_config_field
|
||||
)
|
||||
model = build_model(
|
||||
num_classes=len(targets.class_names),
|
||||
config=model_config_loaded,
|
||||
)
|
||||
except IOError:
|
||||
model = build_model(num_classes=len(targets.class_names))
|
||||
|
||||
try:
|
||||
postprocess_config_loaded = load_postprocess_config(
|
||||
path=postprocess_config,
|
||||
field=postprocess_config_field,
|
||||
)
|
||||
postprocessor = build_postprocessor(
|
||||
targets=targets,
|
||||
config=postprocess_config_loaded,
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded postprocessor from file {path}", path=postprocess_config
|
||||
)
|
||||
except IOError:
|
||||
logger.debug(
|
||||
"Could not load postprocessor config from file. Using default"
|
||||
)
|
||||
postprocessor = build_postprocessor(targets=targets)
|
||||
|
||||
try:
|
||||
train_config_loaded = load_train_config(
|
||||
path=train_config, field=train_config_field
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded training config from file {path}",
|
||||
path=train_config,
|
||||
)
|
||||
except IOError:
|
||||
train_config_loaded = TrainingConfig()
|
||||
logger.debug("Could not load training config from file. Using default")
|
||||
|
||||
train_files = list_preprocessed_files(train_examples)
|
||||
|
||||
val_files = (
|
||||
None if val_examples is None else list_preprocessed_files(val_examples)
|
||||
train_examples = list_preprocessed_files(train_dir)
|
||||
val_examples = (
|
||||
list_preprocessed_files(val_dir) if val_dir is not None else None
|
||||
)
|
||||
|
||||
return train(
|
||||
detector=model,
|
||||
train_examples=train_files, # type: ignore
|
||||
val_examples=val_files, # type: ignore
|
||||
train(
|
||||
train_examples=train_examples,
|
||||
val_examples=val_examples,
|
||||
config=conf,
|
||||
model_path=model_path,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
targets=targets,
|
||||
config=train_config_loaded,
|
||||
callbacks=[
|
||||
ValidationMetrics(
|
||||
metrics=[
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(
|
||||
class_names=targets.class_names,
|
||||
),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
)
|
||||
],
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
)
|
||||
|
@ -15,7 +15,7 @@ from batdetect2.train.augmentations import (
|
||||
)
|
||||
from batdetect2.train.clips import build_clipper, select_subclip
|
||||
from batdetect2.train.config import (
|
||||
TrainerConfig,
|
||||
PLTrainerConfig,
|
||||
TrainingConfig,
|
||||
load_train_config,
|
||||
)
|
||||
@ -39,8 +39,14 @@ from batdetect2.train.preprocess import (
|
||||
preprocess_annotations,
|
||||
)
|
||||
from batdetect2.train.train import (
|
||||
FullTrainingConfig,
|
||||
build_train_dataset,
|
||||
build_train_loader,
|
||||
build_trainer,
|
||||
build_training_module,
|
||||
build_val_dataset,
|
||||
build_val_loader,
|
||||
load_full_training_config,
|
||||
train,
|
||||
)
|
||||
|
||||
@ -50,14 +56,15 @@ __all__ = [
|
||||
"DetectionLossConfig",
|
||||
"EchoAugmentationConfig",
|
||||
"FrequencyMaskAugmentationConfig",
|
||||
"FullTrainingConfig",
|
||||
"LabeledDataset",
|
||||
"LossConfig",
|
||||
"LossFunction",
|
||||
"PLTrainerConfig",
|
||||
"RandomExampleSource",
|
||||
"SizeLossConfig",
|
||||
"TimeMaskAugmentationConfig",
|
||||
"TrainExample",
|
||||
"TrainerConfig",
|
||||
"TrainingConfig",
|
||||
"VolumeAugmentationConfig",
|
||||
"WarpAugmentationConfig",
|
||||
@ -67,9 +74,14 @@ __all__ = [
|
||||
"build_clipper",
|
||||
"build_loss",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_training_module",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"generate_train_example",
|
||||
"list_preprocessed_files",
|
||||
"load_full_training_config",
|
||||
"load_label_config",
|
||||
"load_train_config",
|
||||
"mask_frequency",
|
||||
@ -79,6 +91,5 @@ __all__ = [
|
||||
"scale_volume",
|
||||
"select_subclip",
|
||||
"train",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
]
|
||||
|
@ -13,18 +13,12 @@ from batdetect2.train.logging import CSVLoggerConfig, LoggerConfig
|
||||
from batdetect2.train.losses import LossConfig
|
||||
|
||||
__all__ = [
|
||||
"OptimizerConfig",
|
||||
"TrainingConfig",
|
||||
"load_train_config",
|
||||
]
|
||||
|
||||
|
||||
class OptimizerConfig(BaseConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
|
||||
|
||||
class TrainerConfig(BaseConfig):
|
||||
class PLTrainerConfig(BaseConfig):
|
||||
accelerator: str = "auto"
|
||||
accumulate_grad_batches: int = 1
|
||||
deterministic: bool = True
|
||||
@ -45,15 +39,16 @@ class TrainerConfig(BaseConfig):
|
||||
val_check_interval: Optional[Union[int, float]] = None
|
||||
|
||||
|
||||
class TrainingConfig(BaseConfig):
|
||||
class TrainingConfig(PLTrainerConfig):
|
||||
batch_size: int = 8
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
optimizer: OptimizerConfig = Field(default_factory=OptimizerConfig)
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||
)
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
trainer: TrainerConfig = Field(default_factory=TrainerConfig)
|
||||
trainer: PLTrainerConfig = Field(default_factory=PLTrainerConfig)
|
||||
logger: LoggerConfig = Field(default_factory=CSVLoggerConfig)
|
||||
|
||||
|
||||
|
@ -1,20 +1,28 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import List, Optional
|
||||
|
||||
from lightning import Trainer
|
||||
from lightning.pytorch.callbacks import Callback
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.models.types import DetectionModel
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.postprocess.types import PostprocessorProtocol
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.preprocess.types import PreprocessorProtocol
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.targets.types import TargetProtocol
|
||||
from batdetect2.train.augmentations import (
|
||||
build_augmentations,
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
from batdetect2.evaluate.metrics import (
|
||||
ClassificationAccuracy,
|
||||
ClassificationMeanAveragePrecision,
|
||||
DetectionAveragePrecision,
|
||||
)
|
||||
from batdetect2.models import BackboneConfig, build_model
|
||||
from batdetect2.postprocess import PostprocessConfig, build_postprocessor
|
||||
from batdetect2.preprocess import (
|
||||
PreprocessingConfig,
|
||||
PreprocessorProtocol,
|
||||
build_preprocessor,
|
||||
)
|
||||
from batdetect2.targets import TargetConfig, TargetProtocol, build_targets
|
||||
from batdetect2.train.augmentations import build_augmentations
|
||||
from batdetect2.train.callbacks import ValidationMetrics
|
||||
from batdetect2.train.clips import build_clipper
|
||||
from batdetect2.train.config import TrainingConfig
|
||||
from batdetect2.train.dataset import (
|
||||
@ -27,93 +35,70 @@ from batdetect2.train.logging import build_logger
|
||||
from batdetect2.train.losses import build_loss
|
||||
|
||||
__all__ = [
|
||||
"train",
|
||||
"build_val_dataset",
|
||||
"FullTrainingConfig",
|
||||
"build_train_dataset",
|
||||
"build_train_loader",
|
||||
"build_trainer",
|
||||
"build_training_module",
|
||||
"build_val_dataset",
|
||||
"build_val_loader",
|
||||
"load_full_training_config",
|
||||
"train",
|
||||
]
|
||||
|
||||
|
||||
class FullTrainingConfig(BaseConfig):
|
||||
"""Full training configuration."""
|
||||
|
||||
train: TrainingConfig = Field(default_factory=TrainingConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
model: BackboneConfig = Field(default_factory=BackboneConfig)
|
||||
preprocess: PreprocessingConfig = Field(
|
||||
default_factory=PreprocessingConfig
|
||||
)
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
|
||||
|
||||
def load_full_training_config(
|
||||
path: data.PathLike,
|
||||
field: Optional[str] = None,
|
||||
) -> FullTrainingConfig:
|
||||
"""Load the full training configuration."""
|
||||
return load_config(path, schema=FullTrainingConfig, field=field)
|
||||
|
||||
|
||||
def train(
|
||||
detector: DetectionModel,
|
||||
train_examples: List[data.PathLike],
|
||||
targets: Optional[TargetProtocol] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
postprocessor: Optional[PostprocessorProtocol] = None,
|
||||
val_examples: Optional[List[data.PathLike]] = None,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
callbacks: Optional[List[Callback]] = None,
|
||||
train_examples: Sequence[data.PathLike],
|
||||
val_examples: Optional[Sequence[data.PathLike]] = None,
|
||||
config: Optional[FullTrainingConfig] = None,
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
**trainer_kwargs,
|
||||
) -> None:
|
||||
config = config or TrainingConfig()
|
||||
):
|
||||
conf = config or FullTrainingConfig()
|
||||
|
||||
if model_path is None:
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
if targets is None:
|
||||
targets = build_targets()
|
||||
|
||||
if postprocessor is None:
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
loss = build_loss(config.loss)
|
||||
|
||||
module = TrainingModule(
|
||||
detector=detector,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
learning_rate=config.optimizer.learning_rate,
|
||||
t_max=config.optimizer.t_max,
|
||||
)
|
||||
else:
|
||||
if model_path is not None:
|
||||
module = TrainingModule.load_from_checkpoint(model_path) # type: ignore
|
||||
else:
|
||||
module = build_training_module(conf)
|
||||
|
||||
train_dataset = build_train_dataset(
|
||||
trainer = build_trainer(conf, targets=module.targets)
|
||||
|
||||
train_dataloader = build_train_loader(
|
||||
train_examples,
|
||||
preprocessor=module.preprocessor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
logger = build_logger(config.logger)
|
||||
if logger and hasattr(logger, 'log_hyperparams'):
|
||||
logger.log_hyperparams(config.model_dump(exclude_none=True))
|
||||
|
||||
trainer = Trainer(
|
||||
**config.trainer.model_dump(exclude_none=True, exclude={"logger"}),
|
||||
callbacks=callbacks,
|
||||
logger=logger,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
config=conf.train,
|
||||
num_workers=train_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
val_dataloader = None
|
||||
if val_examples:
|
||||
val_dataset = build_val_dataset(
|
||||
val_dataloader = (
|
||||
build_val_loader(
|
||||
val_examples,
|
||||
config=config,
|
||||
)
|
||||
val_dataloader = DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
config=conf.train,
|
||||
num_workers=val_workers,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
if val_examples is not None
|
||||
else None
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
@ -123,8 +108,106 @@ def train(
|
||||
)
|
||||
|
||||
|
||||
def build_training_module(conf: FullTrainingConfig) -> TrainingModule:
|
||||
preprocessor = build_preprocessor(conf.preprocess)
|
||||
|
||||
targets = build_targets(conf.targets)
|
||||
|
||||
postprocessor = build_postprocessor(
|
||||
targets,
|
||||
min_freq=preprocessor.min_freq,
|
||||
max_freq=preprocessor.max_freq,
|
||||
)
|
||||
|
||||
model = build_model(
|
||||
num_classes=len(targets.class_names),
|
||||
config=conf.model,
|
||||
)
|
||||
|
||||
loss = build_loss(conf.train.loss)
|
||||
|
||||
return TrainingModule(
|
||||
detector=model,
|
||||
loss=loss,
|
||||
targets=targets,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
learning_rate=conf.train.learning_rate,
|
||||
t_max=conf.train.t_max,
|
||||
)
|
||||
|
||||
|
||||
def build_trainer_callbacks(targets: TargetProtocol) -> List[Callback]:
|
||||
return [
|
||||
ValidationMetrics(
|
||||
metrics=[
|
||||
DetectionAveragePrecision(),
|
||||
ClassificationMeanAveragePrecision(
|
||||
class_names=targets.class_names
|
||||
),
|
||||
ClassificationAccuracy(class_names=targets.class_names),
|
||||
]
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def build_trainer(
|
||||
conf: FullTrainingConfig,
|
||||
targets: TargetProtocol,
|
||||
) -> Trainer:
|
||||
logger = build_logger(conf.train.logger)
|
||||
|
||||
if logger and hasattr(logger, "log_hyperparams"):
|
||||
logger.log_hyperparams(conf.model_dump(exclude_none=True))
|
||||
|
||||
return Trainer(
|
||||
accelerator=conf.train.accelerator,
|
||||
logger=logger,
|
||||
callbacks=build_trainer_callbacks(targets),
|
||||
)
|
||||
|
||||
|
||||
def build_train_loader(
|
||||
train_examples: Sequence[data.PathLike],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: TrainingConfig,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader:
|
||||
train_dataset = build_train_dataset(
|
||||
train_examples,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
)
|
||||
|
||||
return DataLoader(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=num_workers or 0,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_val_loader(
|
||||
val_examples: Sequence[data.PathLike],
|
||||
config: TrainingConfig,
|
||||
num_workers: Optional[int] = None,
|
||||
):
|
||||
val_dataset = build_val_dataset(
|
||||
val_examples,
|
||||
config=config,
|
||||
)
|
||||
return DataLoader(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=num_workers or 0,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_train_dataset(
|
||||
examples: List[data.PathLike],
|
||||
examples: Sequence[data.PathLike],
|
||||
preprocessor: PreprocessorProtocol,
|
||||
config: Optional[TrainingConfig] = None,
|
||||
) -> LabeledDataset:
|
||||
@ -133,7 +216,7 @@ def build_train_dataset(
|
||||
clipper = build_clipper(config.cliping, random=True)
|
||||
|
||||
random_example_source = RandomExampleSource(
|
||||
examples,
|
||||
list(examples),
|
||||
clipper=clipper,
|
||||
)
|
||||
|
||||
@ -151,7 +234,7 @@ def build_train_dataset(
|
||||
|
||||
|
||||
def build_val_dataset(
|
||||
examples: List[data.PathLike],
|
||||
examples: Sequence[data.PathLike],
|
||||
config: Optional[TrainingConfig] = None,
|
||||
train: bool = True,
|
||||
) -> LabeledDataset:
|
||||
|
Loading…
Reference in New Issue
Block a user