mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Improve augmentations config and logging
This commit is contained in:
parent
ed67d8ceec
commit
19e873dd0b
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@ -25,6 +26,12 @@ __all__ = [
|
||||
@click.option("--config-field", type=str)
|
||||
@click.option("--train-workers", type=int, default=0)
|
||||
@click.option("--val-workers", type=int, default=0)
|
||||
@click.option(
|
||||
"-v",
|
||||
"--verbose",
|
||||
count=True,
|
||||
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||
)
|
||||
def train_command(
|
||||
train_dir: Path,
|
||||
val_dir: Optional[Path] = None,
|
||||
@ -33,7 +40,17 @@ def train_command(
|
||||
config_field: Optional[str] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
verbose: int = 0,
|
||||
):
|
||||
logger.remove()
|
||||
if verbose == 0:
|
||||
log_level = "WARNING"
|
||||
elif verbose == 1:
|
||||
log_level = "INFO"
|
||||
else:
|
||||
log_level = "DEBUG"
|
||||
logger.add(sys.stderr, level=log_level)
|
||||
|
||||
logger.info("Starting training!")
|
||||
|
||||
conf = (
|
||||
|
||||
@ -50,7 +50,7 @@ class TrainingConfig(PLTrainerConfig):
|
||||
learning_rate: float = 1e-3
|
||||
t_max: int = 100
|
||||
loss: LossConfig = Field(default_factory=LossConfig)
|
||||
augmentations: AugmentationsConfig = Field(
|
||||
augmentations: Optional[AugmentationsConfig] = Field(
|
||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||
)
|
||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
|
||||
@ -164,10 +164,17 @@ def build_train_dataset(
|
||||
clipper=clipper,
|
||||
)
|
||||
|
||||
augmentations = build_augmentations(
|
||||
preprocessor,
|
||||
config=config.augmentations,
|
||||
example_source=random_example_source,
|
||||
logger.debug(
|
||||
"Augmentations config: {}.", config.augmentations
|
||||
)
|
||||
augmentations = (
|
||||
build_augmentations(
|
||||
preprocessor,
|
||||
config=config.augmentations,
|
||||
example_source=random_example_source,
|
||||
)
|
||||
if config.augmentations
|
||||
else None
|
||||
)
|
||||
|
||||
return LabeledDataset(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user