diff --git a/batdetect2/cli/train.py b/batdetect2/cli/train.py index cf6954a..b0477c0 100644 --- a/batdetect2/cli/train.py +++ b/batdetect2/cli/train.py @@ -44,7 +44,7 @@ DEFAULT_CONFIG_FILE = Path("config.yaml") default=DEFAULT_CONFIG_FILE, ) @click.option( - "--train-field", + "--train-config-field", type=str, default="train", ) @@ -108,6 +108,16 @@ DEFAULT_CONFIG_FILE = Path("config.yaml") type=str, default="model", ) +@click.option( + "--train-workers", + type=int, + default=0, +) +@click.option( + "--val-workers", + type=int, + default=0, +) def train_command( train_examples: Path, val_examples: Optional[Path] = None, @@ -122,6 +132,8 @@ def train_command( postprocess_config_field: str = "postprocess", model_config: Path = DEFAULT_CONFIG_FILE, model_config_field: str = "model", + train_workers: int = 0, + val_workers: int = 0, ): logger.info("Starting training!") @@ -176,15 +188,27 @@ def train_command( targets=targets, config=postprocess_config_loaded, ) + logger.debug( + "Loaded postprocessor from file {path}", + path=train_config, + ) except IOError: + logger.debug( + "Could not load postprocessor config from file. Using default" + ) postprocessor = build_postprocessor(targets=targets) try: train_config_loaded = load_train_config( path=train_config, field=train_config_field ) + logger.debug( + "Loaded training config from file {path}", + path=train_config, + ) except IOError: train_config_loaded = TrainingConfig() + logger.debug("Could not load training config from file. Using default") train_files = list_preprocessed_files(train_examples) @@ -212,4 +236,6 @@ def train_command( ] ) ], + train_workers=train_workers, + val_workers=val_workers, ) diff --git a/batdetect2/configs.py b/batdetect2/configs.py index 83a982a..e5549cd 100644 --- a/batdetect2/configs.py +++ b/batdetect2/configs.py @@ -88,7 +88,7 @@ def get_object_field(obj: dict, current_key: str) -> Any: KeyError: 'x' """ if "." not in current_key: - return obj[current_key] + return obj.get(current_key, {}) current_key, rest = current_key.split(".", 1) subobj = obj[current_key] diff --git a/batdetect2/targets/__init__.py b/batdetect2/targets/__init__.py index 452907a..9a5da1e 100644 --- a/batdetect2/targets/__init__.py +++ b/batdetect2/targets/__init__.py @@ -24,6 +24,7 @@ object is via the `build_targets` or `load_targets` functions. from typing import List, Optional import numpy as np +from pydantic import Field from soundevent import data from batdetect2.configs import BaseConfig, load_config @@ -157,7 +158,9 @@ class TargetConfig(BaseConfig): filtering: Optional[FilterConfig] = None transforms: Optional[TransformConfig] = None - classes: ClassesConfig + classes: ClassesConfig = Field( + default_factory=lambda: DEFAULT_CLASSES_CONFIG + ) roi: Optional[ROIConfig] = None @@ -438,6 +441,84 @@ class Targets(TargetProtocol): return self._roi_mapper.recover_roi(pos, dims) +DEFAULT_CLASSES = [ + TargetClass( + tags=[TagInfo(value="Myotis mystacinus")], + name="myomys", + ), + TargetClass( + tags=[TagInfo(value="Myotis alcathoe")], + name="myoalc", + ), + TargetClass( + tags=[TagInfo(value="Eptesicus serotinus")], + name="eptser", + ), + TargetClass( + tags=[TagInfo(value="Pipistrellus nathusii")], + name="pipnat", + ), + TargetClass( + tags=[TagInfo(value="Barbastellus barbastellus")], + name="barbar", + ), + TargetClass( + tags=[TagInfo(value="Myotis nattereri")], + name="myonat", + ), + TargetClass( + tags=[TagInfo(value="Myotis daubentonii")], + name="myodau", + ), + TargetClass( + tags=[TagInfo(value="Myotis brandtii")], + name="myobra", + ), + TargetClass( + tags=[TagInfo(value="Pipistrellus pipistrellus")], + name="pippip", + ), + TargetClass( + tags=[TagInfo(value="Myotis bechsteinii")], + name="myobec", + ), + TargetClass( + tags=[TagInfo(value="Pipistrellus pygmaeus")], + name="pippyg", + ), + TargetClass( + tags=[TagInfo(value="Rhinolophus hipposideros")], + name="rhihip", + ), + TargetClass( + tags=[TagInfo(value="Nyctalus leisleri")], + name="nyclei", + ), + TargetClass( + tags=[TagInfo(value="Rhinolophus ferrumequinum")], + name="rhifer", + ), + TargetClass( + tags=[TagInfo(value="Plecotus auritus")], + name="pleaur", + ), + TargetClass( + tags=[TagInfo(value="Nyctalus noctula")], + name="nycnoc", + ), + TargetClass( + tags=[TagInfo(value="Plecotus austriacus")], + name="pleaus", + ), +] + + +DEFAULT_CLASSES_CONFIG: ClassesConfig = ClassesConfig( + classes=DEFAULT_CLASSES, + generic_class=[TagInfo(value="Bat")], +) + + DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( filtering=FilterConfig( rules=[ @@ -455,79 +536,7 @@ DEFAULT_TARGET_CONFIG: TargetConfig = TargetConfig( ), ] ), - classes=ClassesConfig( - classes=[ - TargetClass( - tags=[TagInfo(value="Myotis mystacinus")], - name="myomys", - ), - TargetClass( - tags=[TagInfo(value="Myotis alcathoe")], - name="myoalc", - ), - TargetClass( - tags=[TagInfo(value="Eptesicus serotinus")], - name="eptser", - ), - TargetClass( - tags=[TagInfo(value="Pipistrellus nathusii")], - name="pipnat", - ), - TargetClass( - tags=[TagInfo(value="Barbastellus barbastellus")], - name="barbar", - ), - TargetClass( - tags=[TagInfo(value="Myotis nattereri")], - name="myonat", - ), - TargetClass( - tags=[TagInfo(value="Myotis daubentonii")], - name="myodau", - ), - TargetClass( - tags=[TagInfo(value="Myotis brandtii")], - name="myobra", - ), - TargetClass( - tags=[TagInfo(value="Pipistrellus pipistrellus")], - name="pippip", - ), - TargetClass( - tags=[TagInfo(value="Myotis bechsteinii")], - name="myobec", - ), - TargetClass( - tags=[TagInfo(value="Pipistrellus pygmaeus")], - name="pippyg", - ), - TargetClass( - tags=[TagInfo(value="Rhinolophus hipposideros")], - name="rhihip", - ), - TargetClass( - tags=[TagInfo(value="Nyctalus leisleri")], - name="nyclei", - ), - TargetClass( - tags=[TagInfo(value="Rhinolophus ferrumequinum")], - name="rhifer", - ), - TargetClass( - tags=[TagInfo(value="Plecotus auritus")], - name="pleaur", - ), - TargetClass( - tags=[TagInfo(value="Nyctalus noctula")], - name="nycnoc", - ), - TargetClass( - tags=[TagInfo(value="Plecotus austriacus")], - name="pleaus", - ), - ], - generic_class=[TagInfo(value="Bat")], - ), + classes=DEFAULT_CLASSES_CONFIG, ) diff --git a/batdetect2/train/train.py b/batdetect2/train/train.py index c0c29a9..04d0b23 100644 --- a/batdetect2/train/train.py +++ b/batdetect2/train/train.py @@ -38,6 +38,8 @@ def train( config: Optional[TrainingConfig] = None, callbacks: Optional[List[Callback]] = None, model_path: Optional[data.PathLike] = None, + train_workers: int = 0, + val_workers: int = 0, **trainer_kwargs, ) -> None: config = config or TrainingConfig() @@ -85,6 +87,7 @@ def train( train_dataset, batch_size=config.batch_size, shuffle=True, + num_workers=train_workers, ) val_dataloader = None @@ -97,6 +100,7 @@ def train( val_dataset, batch_size=config.batch_size, shuffle=False, + num_workers=val_workers, ) trainer.fit( diff --git a/example_conf.yaml b/config.yaml similarity index 100% rename from example_conf.yaml rename to config.yaml