Add num workers to cli

This commit is contained in:
mbsantiago 2025-04-30 23:40:26 +01:00
parent 3b5623ddca
commit fdab0860fd
5 changed files with 115 additions and 76 deletions

View File

@ -44,7 +44,7 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
default=DEFAULT_CONFIG_FILE,
)
@click.option(
"--train-field",
"--train-config-field",
type=str,
default="train",
)
@ -108,6 +108,16 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
type=str,
default="model",
)
@click.option(
"--train-workers",
type=int,
default=0,
)
@click.option(
"--val-workers",
type=int,
default=0,
)
def train_command(
train_examples: Path,
val_examples: Optional[Path] = None,
@ -122,6 +132,8 @@ def train_command(
postprocess_config_field: str = "postprocess",
model_config: Path = DEFAULT_CONFIG_FILE,
model_config_field: str = "model",
train_workers: int = 0,
val_workers: int = 0,
):
logger.info("Starting training!")
@ -176,15 +188,27 @@ def train_command(
targets=targets,
config=postprocess_config_loaded,
)
logger.debug(
"Loaded postprocessor from file {path}",
path=train_config,
)
except IOError:
logger.debug(
"Could not load postprocessor config from file. Using default"
)
postprocessor = build_postprocessor(targets=targets)
try:
train_config_loaded = load_train_config(
path=train_config, field=train_config_field
)
logger.debug(
"Loaded training config from file {path}",
path=train_config,
)
except IOError:
train_config_loaded = TrainingConfig()
logger.debug("Could not load training config from file. Using default")
train_files = list_preprocessed_files(train_examples)
@ -212,4 +236,6 @@ def train_command(
]
)
],
train_workers=train_workers,
val_workers=val_workers,
)

View File

@ -88,7 +88,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
KeyError: 'x'
"""
if "." not in current_key:
return obj[current_key]
return obj.get(current_key, {})
current_key, rest = current_key.split(".", 1)
subobj = obj[current_key]

View File

@ -24,6 +24,7 @@ object is via the `build_targets` or `load_targets` functions.
from typing import List, Optional
import numpy as np
from pydantic import Field
from soundevent import data
from batdetect2.configs import BaseConfig, load_config
@ -157,7 +158,9 @@ class TargetConfig(BaseConfig):
filtering: Optional[FilterConfig] = None
transforms: Optional[TransformConfig] = None
classes: ClassesConfig
classes: ClassesConfig = Field(
default_factory=lambda: DEFAULT_CLASSES_CONFIG
)
roi: Optional[ROIConfig] = None
@ -438,25 +441,7 @@ class Targets(TargetProtocol):
return self._roi_mapper.recover_roi(pos, dims)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
filtering=FilterConfig(
rules=[
FilterRule(
match_type="all",
tags=[TagInfo(key="event", value="Echolocation")],
),
FilterRule(
match_type="exclude",
tags=[
TagInfo(key="event", value="Feeding"),
TagInfo(key="event", value="Unknown"),
TagInfo(key="event", value="Not Bat"),
],
),
]
),
classes=ClassesConfig(
classes=[
DEFAULT_CLASSES = [
TargetClass(
tags=[TagInfo(value="Myotis mystacinus")],
name="myomys",
@ -525,9 +510,33 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
tags=[TagInfo(value="Plecotus austriacus")],
name="pleaus",
),
],
]
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
classes=DEFAULT_CLASSES,
generic_class=[TagInfo(value="Bat")],
)
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
filtering=FilterConfig(
rules=[
FilterRule(
match_type="all",
tags=[TagInfo(key="event", value="Echolocation")],
),
FilterRule(
match_type="exclude",
tags=[
TagInfo(key="event", value="Feeding"),
TagInfo(key="event", value="Unknown"),
TagInfo(key="event", value="Not Bat"),
],
),
]
),
classes=DEFAULT_CLASSES_CONFIG,
)

View File

@ -38,6 +38,8 @@ def train(
config: Optional[TrainingConfig] = None,
callbacks: Optional[List[Callback]] = None,
model_path: Optional[data.PathLike] = None,
train_workers: int = 0,
val_workers: int = 0,
**trainer_kwargs,
) -> None:
config = config or TrainingConfig()
@ -85,6 +87,7 @@ def train(
train_dataset,
batch_size=config.batch_size,
shuffle=True,
num_workers=train_workers,
)
val_dataloader = None
@ -97,6 +100,7 @@ def train(
val_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=val_workers,
)
trainer.fit(