Add num workers to cli

This commit is contained in:
mbsantiago 2025-04-30 23:40:26 +01:00
parent 3b5623ddca
commit fdab0860fd
5 changed files with 115 additions and 76 deletions

View File

@ -44,7 +44,7 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
default=DEFAULT_CONFIG_FILE, default=DEFAULT_CONFIG_FILE,
) )
@click.option( @click.option(
"--train-field", "--train-config-field",
type=str, type=str,
default="train", default="train",
) )
@ -108,6 +108,16 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
type=str, type=str,
default="model", default="model",
) )
@click.option(
"--train-workers",
type=int,
default=0,
)
@click.option(
"--val-workers",
type=int,
default=0,
)
def train_command( def train_command(
train_examples: Path, train_examples: Path,
val_examples: Optional[Path] = None, val_examples: Optional[Path] = None,
@ -122,6 +132,8 @@ def train_command(
postprocess_config_field: str = "postprocess", postprocess_config_field: str = "postprocess",
model_config: Path = DEFAULT_CONFIG_FILE, model_config: Path = DEFAULT_CONFIG_FILE,
model_config_field: str = "model", model_config_field: str = "model",
train_workers: int = 0,
val_workers: int = 0,
): ):
logger.info("Starting training!") logger.info("Starting training!")
@ -176,15 +188,27 @@ def train_command(
targets=targets, targets=targets,
config=postprocess_config_loaded, config=postprocess_config_loaded,
) )
logger.debug(
"Loaded postprocessor from file {path}",
path=train_config,
)
except IOError: except IOError:
logger.debug(
"Could not load postprocessor config from file. Using default"
)
postprocessor = build_postprocessor(targets=targets) postprocessor = build_postprocessor(targets=targets)
try: try:
train_config_loaded = load_train_config( train_config_loaded = load_train_config(
path=train_config, field=train_config_field path=train_config, field=train_config_field
) )
logger.debug(
"Loaded training config from file {path}",
path=train_config,
)
except IOError: except IOError:
train_config_loaded = TrainingConfig() train_config_loaded = TrainingConfig()
logger.debug("Could not load training config from file. Using default")
train_files = list_preprocessed_files(train_examples) train_files = list_preprocessed_files(train_examples)
@ -212,4 +236,6 @@ def train_command(
] ]
) )
], ],
train_workers=train_workers,
val_workers=val_workers,
) )

View File

@ -88,7 +88,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
KeyError: 'x' KeyError: 'x'
""" """
if "." not in current_key: if "." not in current_key:
return obj[current_key] return obj.get(current_key, {})
current_key, rest = current_key.split(".", 1) current_key, rest = current_key.split(".", 1)
subobj = obj[current_key] subobj = obj[current_key]

View File

@ -24,6 +24,7 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional from typing import List, Optional
import numpy as np import numpy as np
from pydantic import Field
from soundevent import data from soundevent import data
from batdetect2.configs import BaseConfig, load_config from batdetect2.configs import BaseConfig, load_config
@ -157,7 +158,9 @@ class TargetConfig(BaseConfig):
filtering: Optional[FilterConfig] = None filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None transforms: Optional[TransformConfig] = None
classes: ClassesConfig classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG
)
roi: Optional[ROIConfig] = None roi: Optional[ROIConfig] = None
@ -438,6 +441,84 @@ class Targets(TargetProtocol):
return self._roi_mapper.recover_roi(pos, dims) return self._roi_mapper.recover_roi(pos, dims)
DEFAULT_CLASSES = [
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
),
TargetClass(
tags=[TagInfo(value="Myotis alcathoe")],
name="myoalc",
),
TargetClass(
tags=[TagInfo(value="Eptesicus serotinus")],
name="eptser",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus nathusii")],
name="pipnat",
),
TargetClass(
tags=[TagInfo(value="Barbastellus barbastellus")],
name="barbar",
),
TargetClass(
tags=[TagInfo(value="Myotis nattereri")],
name="myonat",
),
TargetClass(
tags=[TagInfo(value="Myotis daubentonii")],
name="myodau",
),
TargetClass(
tags=[TagInfo(value="Myotis brandtii")],
name="myobra",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pipistrellus")],
name="pippip",
),
TargetClass(
tags=[TagInfo(value="Myotis bechsteinii")],
name="myobec",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pygmaeus")],
name="pippyg",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus hipposideros")],
name="rhihip",
),
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
name="pleaur",
),
TargetClass(
tags=[TagInfo(value="Nyctalus noctula")],
name="nycnoc",
),
TargetClass(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
]
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
classes=DEFAULT_CLASSES,
generic_class=[TagInfo(value="Bat")],
)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
filtering=FilterConfig( filtering=FilterConfig(
rules=[ rules=[
@ -455,79 +536,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
), ),
] ]
), ),
classes=ClassesConfig( classes=DEFAULT_CLASSES_CONFIG,
classes=[
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
),
TargetClass(
tags=[TagInfo(value="Myotis alcathoe")],
name="myoalc",
),
TargetClass(
tags=[TagInfo(value="Eptesicus serotinus")],
name="eptser",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus nathusii")],
name="pipnat",
),
TargetClass(
tags=[TagInfo(value="Barbastellus barbastellus")],
name="barbar",
),
TargetClass(
tags=[TagInfo(value="Myotis nattereri")],
name="myonat",
),
TargetClass(
tags=[TagInfo(value="Myotis daubentonii")],
name="myodau",
),
TargetClass(
tags=[TagInfo(value="Myotis brandtii")],
name="myobra",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pipistrellus")],
name="pippip",
),
TargetClass(
tags=[TagInfo(value="Myotis bechsteinii")],
name="myobec",
),
TargetClass(
tags=[TagInfo(value="Pipistrellus pygmaeus")],
name="pippyg",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus hipposideros")],
name="rhihip",
),
TargetClass(
tags=[TagInfo(value="Nyctalus leisleri")],
name="nyclei",
),
TargetClass(
tags=[TagInfo(value="Rhinolophus ferrumequinum")],
name="rhifer",
),
TargetClass(
tags=[TagInfo(value="Plecotus auritus")],
name="pleaur",
),
TargetClass(
tags=[TagInfo(value="Nyctalus noctula")],
name="nycnoc",
),
TargetClass(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
],
generic_class=[TagInfo(value="Bat")],
),
) )

View File

@ -38,6 +38,8 @@ def train(
config: Optional[TrainingConfig] = None, config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None, callbacks: Optional[List[Callback]] = None,
model_path: Optional[data.PathLike] = None, model_path: Optional[data.PathLike] = None,
train_workers: int = 0,
val_workers: int = 0,
**trainer_kwargs, **trainer_kwargs,
) -> None: ) -> None:
config = config or TrainingConfig() config = config or TrainingConfig()
@ -85,6 +87,7 @@ def train(
train_dataset, train_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=True, shuffle=True,
num_workers=train_workers,
) )
val_dataloader = None val_dataloader = None
@ -97,6 +100,7 @@ def train(
val_dataset, val_dataset,
batch_size=config.batch_size, batch_size=config.batch_size,
shuffle=False, shuffle=False,
num_workers=val_workers,
) )
trainer.fit( trainer.fit(