mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-11 17:29:34 +01:00
Improve augmentations config and logging
This commit is contained in:
parent
ed67d8ceec
commit
19e873dd0b
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -25,6 +26,12 @@ __all__ = [
|
|||||||
@click.option("--config-field", type=str)
|
@click.option("--config-field", type=str)
|
||||||
@click.option("--train-workers", type=int, default=0)
|
@click.option("--train-workers", type=int, default=0)
|
||||||
@click.option("--val-workers", type=int, default=0)
|
@click.option("--val-workers", type=int, default=0)
|
||||||
|
@click.option(
|
||||||
|
"-v",
|
||||||
|
"--verbose",
|
||||||
|
count=True,
|
||||||
|
help="Increase verbosity. -v for INFO, -vv for DEBUG.",
|
||||||
|
)
|
||||||
def train_command(
|
def train_command(
|
||||||
train_dir: Path,
|
train_dir: Path,
|
||||||
val_dir: Optional[Path] = None,
|
val_dir: Optional[Path] = None,
|
||||||
@ -33,7 +40,17 @@ def train_command(
|
|||||||
config_field: Optional[str] = None,
|
config_field: Optional[str] = None,
|
||||||
train_workers: int = 0,
|
train_workers: int = 0,
|
||||||
val_workers: int = 0,
|
val_workers: int = 0,
|
||||||
|
verbose: int = 0,
|
||||||
):
|
):
|
||||||
|
logger.remove()
|
||||||
|
if verbose == 0:
|
||||||
|
log_level = "WARNING"
|
||||||
|
elif verbose == 1:
|
||||||
|
log_level = "INFO"
|
||||||
|
else:
|
||||||
|
log_level = "DEBUG"
|
||||||
|
logger.add(sys.stderr, level=log_level)
|
||||||
|
|
||||||
logger.info("Starting training!")
|
logger.info("Starting training!")
|
||||||
|
|
||||||
conf = (
|
conf = (
|
||||||
|
|||||||
@ -50,7 +50,7 @@ class TrainingConfig(PLTrainerConfig):
|
|||||||
learning_rate: float = 1e-3
|
learning_rate: float = 1e-3
|
||||||
t_max: int = 100
|
t_max: int = 100
|
||||||
loss: LossConfig = Field(default_factory=LossConfig)
|
loss: LossConfig = Field(default_factory=LossConfig)
|
||||||
augmentations: AugmentationsConfig = Field(
|
augmentations: Optional[AugmentationsConfig] = Field(
|
||||||
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
default_factory=lambda: DEFAULT_AUGMENTATION_CONFIG
|
||||||
)
|
)
|
||||||
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
cliping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
|
|||||||
@ -164,11 +164,18 @@ def build_train_dataset(
|
|||||||
clipper=clipper,
|
clipper=clipper,
|
||||||
)
|
)
|
||||||
|
|
||||||
augmentations = build_augmentations(
|
logger.debug(
|
||||||
|
"Augmentations config: {}.", config.augmentations
|
||||||
|
)
|
||||||
|
augmentations = (
|
||||||
|
build_augmentations(
|
||||||
preprocessor,
|
preprocessor,
|
||||||
config=config.augmentations,
|
config=config.augmentations,
|
||||||
example_source=random_example_source,
|
example_source=random_example_source,
|
||||||
)
|
)
|
||||||
|
if config.augmentations
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
return LabeledDataset(
|
return LabeledDataset(
|
||||||
examples,
|
examples,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user