From acda71ea45e92841cd7164b9378c88801a04fca1 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Fri, 4 Apr 2025 10:58:54 +0100 Subject: [PATCH] Formatting --- batdetect2/evaluate/evaluate.py | 2 +- batdetect2/modules.py | 4 ++-- batdetect2/train/__init__.py | 17 +++++++++++++++-- batdetect2/train/dataset.py | 6 ++++-- batdetect2/train/preprocess.py | 4 ++-- batdetect2/train/targets.py | 4 ++-- 6 files changed, 26 insertions(+), 11 deletions(-) diff --git a/batdetect2/evaluate/evaluate.py b/batdetect2/evaluate/evaluate.py index dd0cbfe..5e403b8 100755 --- a/batdetect2/evaluate/evaluate.py +++ b/batdetect2/evaluate/evaluate.py @@ -6,7 +6,7 @@ from sklearn.metrics import auc, roc_curve from soundevent import data from soundevent.evaluation import match_geometries -from batdetect2.train.targets import build_encoder, get_class_names +from batdetect2.train.targets import build_target_encoder, get_class_names def match_predictions_and_annotations( diff --git a/batdetect2/modules.py b/batdetect2/modules.py index ae9fd68..d97a8ef 100644 --- a/batdetect2/modules.py +++ b/batdetect2/modules.py @@ -29,7 +29,7 @@ from batdetect2.train.losses import compute_loss from batdetect2.train.targets import ( TargetConfig, build_decoder, - build_encoder, + build_target_encoder, get_class_names, ) @@ -78,7 +78,7 @@ class DetectorModel(L.LightningModule): # Training targets self.class_names = get_class_names(self.config.targets.classes) - self.encoder = build_encoder( + self.encoder = build_target_encoder( self.config.targets.classes, replacement_rules=self.config.targets.replace, ) diff --git a/batdetect2/train/__init__.py b/batdetect2/train/__init__.py index f364c30..4be16c7 100644 --- a/batdetect2/train/__init__.py +++ b/batdetect2/train/__init__.py @@ -15,10 +15,19 @@ from batdetect2.train.dataset import ( LabeledDataset, SubclipConfig, TrainExample, + get_preprocessed_files, ) from batdetect2.train.labels import LabelConfig, load_label_config -from batdetect2.train.preprocess import preprocess_annotations -from batdetect2.train.targets import TargetConfig, load_target_config +from batdetect2.train.preprocess import ( + generate_train_example, + preprocess_annotations, +) +from batdetect2.train.targets import ( + TagInfo, + TargetConfig, + build_target_encoder, + load_target_config, +) from batdetect2.train.train import TrainerConfig, load_trainer_config, train __all__ = [ @@ -26,12 +35,16 @@ __all__ = [ "LabelConfig", "LabeledDataset", "SubclipConfig", + "TagInfo", "TargetConfig", "TrainExample", "TrainerConfig", "TrainingConfig", "add_echo", "augment_example", + "build_target_encoder", + "generate_train_example", + "get_preprocessed_files", "load_agumentation_config", "load_label_config", "load_target_config", diff --git a/batdetect2/train/dataset.py b/batdetect2/train/dataset.py index a5c754d..49205d9 100644 --- a/batdetect2/train/dataset.py +++ b/batdetect2/train/dataset.py @@ -101,7 +101,7 @@ class LabeledDataset(Dataset): preprocessing: Optional[PreprocessingConfig] = None, ): return cls( - get_files(directory, extension), + get_preprocessed_files(directory, extension), subclip=subclip, augmentation=augmentation, preprocessing=preprocessing, @@ -143,5 +143,7 @@ class LabeledDataset(Dataset): return adjust_width(tensor, width) -def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]: +def get_preprocessed_files( + directory: PathLike, extension: str = ".nc" +) -> Sequence[Path]: return list(Path(directory).glob(f"*{extension}")) diff --git a/batdetect2/train/preprocess.py b/batdetect2/train/preprocess.py index 92ae3ca..8b9da80 100644 --- a/batdetect2/train/preprocess.py +++ b/batdetect2/train/preprocess.py @@ -20,7 +20,7 @@ from batdetect2.preprocess import ( from batdetect2.train.labels import LabelConfig, generate_heatmaps from batdetect2.train.targets import ( TargetConfig, - build_encoder, + build_target_encoder, build_sound_event_filter, get_class_names, ) @@ -76,7 +76,7 @@ def generate_train_example( event for event in clip_annotation.sound_events if filter_fn(event) ] - encoder = build_encoder( + encoder = build_target_encoder( config.target.classes, replacement_rules=config.target.replace, ) diff --git a/batdetect2/train/targets.py b/batdetect2/train/targets.py index d1d1143..0b622fd 100644 --- a/batdetect2/train/targets.py +++ b/batdetect2/train/targets.py @@ -12,7 +12,7 @@ from batdetect2.terms import TagInfo, get_tag_from_info __all__ = [ "TargetConfig", "load_target_config", - "build_encoder", + "build_target_encoder", "build_decoder", "filter_sound_event", ] @@ -91,7 +91,7 @@ def build_replacer( return replacer -def build_encoder( +def build_target_encoder( classes: List[TagInfo], replacement_rules: Optional[List[ReplaceConfig]] = None, ) -> Callable[[Iterable[data.Tag]], Optional[str]]: