diff --git a/batdetect2/data/preprocessing.py b/batdetect2/data/preprocessing.py index 211b191..f114df7 100644 --- a/batdetect2/data/preprocessing.py +++ b/batdetect2/data/preprocessing.py @@ -1,6 +1,7 @@ """Module containing functions for preprocessing audio clips.""" -from typing import Optional +from typing import Optional, Union +from pathlib import Path import librosa import librosa.core.spectrum @@ -60,6 +61,46 @@ class PreprocessingConfig(BaseModel): spec_time_period: float = SPEC_TIME_PERIOD + @classmethod + def from_file( + cls, + path: Union[str, Path], + ) -> "PreprocessingConfig": + """Load configuration from a file. + + Parameters + ---------- + path + Path to the configuration file. + + Returns + ------- + PreprocessingConfig + The configuration loaded from the file. + + Raises + ------ + FileNotFoundError + If the configuration file does not exist. + pydantic.ValidationError + If the configuration file is invalid. + """ + path = Path(path) + + if not path.is_file(): + raise FileNotFoundError(f"Config file not found: {path}") + + return cls.model_validate_json(path.read_text()) + + def to_file(self, path: Union[str, Path]) -> None: + """Save configuration to a file.""" + path = Path(path) + + if not path.parent.exists(): + path.parent.mkdir(parents=True) + + path.write_text(self.model_dump_json()) + def preprocess_audio_clip( clip: data.Clip, @@ -105,6 +146,7 @@ def preprocess_audio_clip( spec, time=int(np.ceil(duration / config.spec_time_period)), frequency=config.spec_height, + dtype=np.float32, )