mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Add num workers to cli
This commit is contained in:
parent
3b5623ddca
commit
fdab0860fd
@ -44,7 +44,7 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
||||
default=DEFAULT_CONFIG_FILE,
|
||||
)
|
||||
@click.option(
|
||||
"--train-field",
|
||||
"--train-config-field",
|
||||
type=str,
|
||||
default="train",
|
||||
)
|
||||
@ -108,6 +108,16 @@ DEFAULT_CONFIG_FILE = Path("config.yaml")
|
||||
type=str,
|
||||
default="model",
|
||||
)
|
||||
@click.option(
|
||||
"--train-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
@click.option(
|
||||
"--val-workers",
|
||||
type=int,
|
||||
default=0,
|
||||
)
|
||||
def train_command(
|
||||
train_examples: Path,
|
||||
val_examples: Optional[Path] = None,
|
||||
@ -122,6 +132,8 @@ def train_command(
|
||||
postprocess_config_field: str = "postprocess",
|
||||
model_config: Path = DEFAULT_CONFIG_FILE,
|
||||
model_config_field: str = "model",
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
):
|
||||
logger.info("Starting training!")
|
||||
|
||||
@ -176,15 +188,27 @@ def train_command(
|
||||
targets=targets,
|
||||
config=postprocess_config_loaded,
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded postprocessor from file {path}",
|
||||
path=train_config,
|
||||
)
|
||||
except IOError:
|
||||
logger.debug(
|
||||
"Could not load postprocessor config from file. Using default"
|
||||
)
|
||||
postprocessor = build_postprocessor(targets=targets)
|
||||
|
||||
try:
|
||||
train_config_loaded = load_train_config(
|
||||
path=train_config, field=train_config_field
|
||||
)
|
||||
logger.debug(
|
||||
"Loaded training config from file {path}",
|
||||
path=train_config,
|
||||
)
|
||||
except IOError:
|
||||
train_config_loaded = TrainingConfig()
|
||||
logger.debug("Could not load training config from file. Using default")
|
||||
|
||||
train_files = list_preprocessed_files(train_examples)
|
||||
|
||||
@ -212,4 +236,6 @@ def train_command(
|
||||
]
|
||||
)
|
||||
],
|
||||
train_workers=train_workers,
|
||||
val_workers=val_workers,
|
||||
)
|
||||
|
@ -88,7 +88,7 @@ def get_object_field(obj: dict, current_key: str) -> Any:
|
||||
KeyError: 'x'
|
||||
"""
|
||||
if "." not in current_key:
|
||||
return obj[current_key]
|
||||
return obj.get(current_key, {})
|
||||
|
||||
current_key, rest = current_key.split(".", 1)
|
||||
subobj = obj[current_key]
|
||||
|
@ -24,6 +24,7 @@ object is via the `build_targets` or `load_targets` functions.
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.configs import BaseConfig, load_config
|
||||
@ -157,7 +158,9 @@ class TargetConfig(BaseConfig):
|
||||
|
||||
filtering: Optional[FilterConfig] = None
|
||||
transforms: Optional[TransformConfig] = None
|
||||
classes: ClassesConfig
|
||||
classes: ClassesConfig = Field(
|
||||
default_factory=lambda: DEFAULT_CLASSES_CONFIG
|
||||
)
|
||||
roi: Optional[ROIConfig] = None
|
||||
|
||||
|
||||
@ -438,25 +441,7 @@ class Targets(TargetProtocol):
|
||||
return self._roi_mapper.recover_roi(pos, dims)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
filtering=FilterConfig(
|
||||
rules=[
|
||||
FilterRule(
|
||||
match_type="all",
|
||||
tags=[TagInfo(key="event", value="Echolocation")],
|
||||
),
|
||||
FilterRule(
|
||||
match_type="exclude",
|
||||
tags=[
|
||||
TagInfo(key="event", value="Feeding"),
|
||||
TagInfo(key="event", value="Unknown"),
|
||||
TagInfo(key="event", value="Not Bat"),
|
||||
],
|
||||
),
|
||||
]
|
||||
),
|
||||
classes=ClassesConfig(
|
||||
classes=[
|
||||
DEFAULT_CLASSES = [
|
||||
TargetClass(
|
||||
tags=[TagInfo(value="Myotis mystacinus")],
|
||||
name="myomys",
|
||||
@ -525,9 +510,33 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
tags=[TagInfo(value="Plecotus austriacus")],
|
||||
name="pleaus",
|
||||
),
|
||||
],
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig(
|
||||
classes=DEFAULT_CLASSES,
|
||||
generic_class=[TagInfo(value="Bat")],
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig(
|
||||
filtering=FilterConfig(
|
||||
rules=[
|
||||
FilterRule(
|
||||
match_type="all",
|
||||
tags=[TagInfo(key="event", value="Echolocation")],
|
||||
),
|
||||
FilterRule(
|
||||
match_type="exclude",
|
||||
tags=[
|
||||
TagInfo(key="event", value="Feeding"),
|
||||
TagInfo(key="event", value="Unknown"),
|
||||
TagInfo(key="event", value="Not Bat"),
|
||||
],
|
||||
),
|
||||
]
|
||||
),
|
||||
classes=DEFAULT_CLASSES_CONFIG,
|
||||
)
|
||||
|
||||
|
||||
|
@ -38,6 +38,8 @@ def train(
|
||||
config: Optional[TrainingConfig] = None,
|
||||
callbacks: Optional[List[Callback]] = None,
|
||||
model_path: Optional[data.PathLike] = None,
|
||||
train_workers: int = 0,
|
||||
val_workers: int = 0,
|
||||
**trainer_kwargs,
|
||||
) -> None:
|
||||
config = config or TrainingConfig()
|
||||
@ -85,6 +87,7 @@ def train(
|
||||
train_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=train_workers,
|
||||
)
|
||||
|
||||
val_dataloader = None
|
||||
@ -97,6 +100,7 @@ def train(
|
||||
val_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=val_workers,
|
||||
)
|
||||
|
||||
trainer.fit(
|
||||
|
Loading…
Reference in New Issue
Block a user