mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add num workers to cli
This commit is contained in:
parent
3b5623ddca
commit
fdab0860fd
@ -44,7 +44,7 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|||||||
default=DEFAULT_CONFIG_FILE,
|
default=DEFAULT_CONFIG_FILE,
|
||||||
)
|
)
|
||||||
@click.option(
|
@click.option(
|
||||||
"--train-field",
|
"--train-config-field",
|
||||||
type=str,
|
type=str,
|
||||||
default="train",
|
default="train",
|
||||||
)
|
)
|
||||||
@ -108,6 +108,16 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
|||||||
type=str,
|
type=str,
|
||||||
default="model",
|
default="model",
|
||||||
)
|
)
|
||||||
|
@click.option(
|
||||||
|
"--train-workers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
|
@click.option(
|
||||||
|
"--val-workers",
|
||||||
|
type=int,
|
||||||
|
default=0,
|
||||||
|
)
|
||||||
def train_command(
|
def train_command(
|
||||||
train_examples: Path,
|
train_examples: Path,
|
||||||
val_examples: Optional[Path] = None,
|
val_examples: Optional[Path] = None,
|
||||||
@ -122,6 +132,8 @@ def train_command(
|
|||||||
postprocess_config_field: str = "postprocess",
|
postprocess_config_field: str = "postprocess",
|
||||||
model_config: Path = DEFAULT_CONFIG_FILE,
|
model_config: Path = DEFAULT_CONFIG_FILE,
|
||||||
model_config_field: str = "model",
|
model_config_field: str = "model",
|
||||||
|
train_workers: int = 0,
|
||||||
|
val_workers: int = 0,
|
||||||
):
|
):
|
||||||
logger.info("Starting training!")
|
logger.info("Starting training!")
|
||||||
|
|
||||||
@ -176,15 +188,27 @@ def train_command(
|
|||||||
targets=targets,
|
targets=targets,
|
||||||
config=postprocess_config_loaded,
|
config=postprocess_config_loaded,
|
||||||
)
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Loaded postprocessor from file {path}",
|
||||||
|
path=train_config,
|
||||||
|
)
|
||||||
except IOError:
|
except IOError:
|
||||||
|
logger.debug(
|
||||||
|
"Could not load postprocessor config from file. Using default"
|
||||||
|
)
|
||||||
postprocessor = build_postprocessor(targets=targets)
|
postprocessor = build_postprocessor(targets=targets)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
train_config_loaded = load_train_config(
|
train_config_loaded = load_train_config(
|
||||||
path=train_config, field=train_config_field
|
path=train_config, field=train_config_field
|
||||||
)
|
)
|
||||||
|
logger.debug(
|
||||||
|
"Loaded training config from file {path}",
|
||||||
|
path=train_config,
|
||||||
|
)
|
||||||
except IOError:
|
except IOError:
|
||||||
train_config_loaded = TrainingConfig()
|
train_config_loaded = TrainingConfig()
|
||||||
|
logger.debug("Could not load training config from file. Using default")
|
||||||
|
|
||||||
train_files = list_preprocessed_files(train_examples)
|
train_files = list_preprocessed_files(train_examples)
|
||||||
|
|
||||||
@ -212,4 +236,6 @@ def train_command(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
|
train_workers=train_workers,
|
||||||
|
val_workers=val_workers,
|
||||||
)
|
)
|
||||||
|
@ -88,7 +88,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
|||||||
KeyError: 'x'
|
KeyError: 'x'
|
||||||
"""
|
"""
|
||||||
if "." not in current_key:
|
if "." not in current_key:
|
||||||
return obj[current_key]
|
return obj.get(current_key, {})
|
||||||
|
|
||||||
current_key, rest = current_key.split(".", 1)
|
current_key, rest = current_key.split(".", 1)
|
||||||
subobj = obj[current_key]
|
subobj = obj[current_key]
|
||||||
|
@ -24,6 +24,7 @@ object is via the `build_targets` or `load_targets` functions.
|
|||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from pydantic import Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig, load_config
|
from batdetect2.configs import BaseConfig, load_config
|
||||||
@ -157,7 +158,9 @@ class TargetConfig(BaseConfig):
|
|||||||
|
|
||||||
filtering: Optional[FilterConfig] = None
|
filtering: Optional[FilterConfig] = None
|
||||||
transforms: Optional[TransformConfig] = None
|
transforms: Optional[TransformConfig] = None
|
||||||
classes: ClassesConfig
|
classes: ClassesConfig = Field(
|
||||||
|
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||||
|
)
|
||||||
roi: Optional[ROIConfig] = None
|
roi: Optional[ROIConfig] = None
|
||||||
|
|
||||||
|
|
||||||
@ -438,6 +441,84 @@ class Targets(TargetProtocol):
|
|||||||
return self._roi_mapper.recover_roi(pos, dims)
|
return self._roi_mapper.recover_roi(pos, dims)
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CLASSES = [
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis mystacinus")],
|
||||||
|
name="myomys",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis alcathoe")],
|
||||||
|
name="myoalc",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Eptesicus serotinus")],
|
||||||
|
name="eptser",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus nathusii")],
|
||||||
|
name="pipnat",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Barbastellus barbastellus")],
|
||||||
|
name="barbar",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis nattereri")],
|
||||||
|
name="myonat",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis daubentonii")],
|
||||||
|
name="myodau",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis brandtii")],
|
||||||
|
name="myobra",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
||||||
|
name="pippip",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Myotis bechsteinii")],
|
||||||
|
name="myobec",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
||||||
|
name="pippyg",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
||||||
|
name="rhihip",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Nyctalus leisleri")],
|
||||||
|
name="nyclei",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
||||||
|
name="rhifer",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Plecotus auritus")],
|
||||||
|
name="pleaur",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Nyctalus noctula")],
|
||||||
|
name="nycnoc",
|
||||||
|
),
|
||||||
|
TargetClass(
|
||||||
|
tags=[TagInfo(value="Plecotus austriacus")],
|
||||||
|
name="pleaus",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
|
||||||
|
classes=DEFAULT_CLASSES,
|
||||||
|
generic_class=[TagInfo(value="Bat")],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||||
filtering=FilterConfig(
|
filtering=FilterConfig(
|
||||||
rules=[
|
rules=[
|
||||||
@ -455,79 +536,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
),
|
),
|
||||||
classes=ClassesConfig(
|
classes=DEFAULT_CLASSES_CONFIG,
|
||||||
classes=[
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis mystacinus")],
|
|
||||||
name="myomys",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis alcathoe")],
|
|
||||||
name="myoalc",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Eptesicus serotinus")],
|
|
||||||
name="eptser",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus nathusii")],
|
|
||||||
name="pipnat",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Barbastellus barbastellus")],
|
|
||||||
name="barbar",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis nattereri")],
|
|
||||||
name="myonat",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis daubentonii")],
|
|
||||||
name="myodau",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis brandtii")],
|
|
||||||
name="myobra",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus pipistrellus")],
|
|
||||||
name="pippip",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Myotis bechsteinii")],
|
|
||||||
name="myobec",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Pipistrellus pygmaeus")],
|
|
||||||
name="pippyg",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Rhinolophus hipposideros")],
|
|
||||||
name="rhihip",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Nyctalus leisleri")],
|
|
||||||
name="nyclei",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
|
|
||||||
name="rhifer",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Plecotus auritus")],
|
|
||||||
name="pleaur",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Nyctalus noctula")],
|
|
||||||
name="nycnoc",
|
|
||||||
),
|
|
||||||
TargetClass(
|
|
||||||
tags=[TagInfo(value="Plecotus austriacus")],
|
|
||||||
name="pleaus",
|
|
||||||
),
|
|
||||||
],
|
|
||||||
generic_class=[TagInfo(value="Bat")],
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,6 +38,8 @@ def train(
|
|||||||
config: Optional[TrainingConfig] = None,
|
config: Optional[TrainingConfig] = None,
|
||||||
callbacks: Optional[List[Callback]] = None,
|
callbacks: Optional[List[Callback]] = None,
|
||||||
model_path: Optional[data.PathLike] = None,
|
model_path: Optional[data.PathLike] = None,
|
||||||
|
train_workers: int = 0,
|
||||||
|
val_workers: int = 0,
|
||||||
**trainer_kwargs,
|
**trainer_kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
config = config or TrainingConfig()
|
config = config or TrainingConfig()
|
||||||
@ -85,6 +87,7 @@ def train(
|
|||||||
train_dataset,
|
train_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
|
num_workers=train_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
val_dataloader = None
|
val_dataloader = None
|
||||||
@ -97,6 +100,7 @@ def train(
|
|||||||
val_dataset,
|
val_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=config.batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
|
num_workers=val_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
|
Loading…
Reference in New Issue
Block a user