From 110432bd408625540e75e8e895fe4eb1b8729141 Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Wed, 29 Oct 2025 12:13:16 +0000 Subject: [PATCH] Added batch_size and num_workers to API --- src/batdetect2/api_v2.py | 4 ++++ src/batdetect2/data/conditions.py | 22 +++++++++++++++++----- src/batdetect2/inference/batch.py | 2 ++ src/batdetect2/inference/dataset.py | 5 ++++- 4 files changed, 27 insertions(+), 6 deletions(-) diff --git a/src/batdetect2/api_v2.py b/src/batdetect2/api_v2.py index 20f6234..7c9a55a 100644 --- a/src/batdetect2/api_v2.py +++ b/src/batdetect2/api_v2.py @@ -236,6 +236,8 @@ class BatDetect2API: def process_clips( self, clips: Sequence[data.Clip], + batch_size: Optional[int] = None, + num_workers: Optional[int] = None, ) -> List[BatDetect2Prediction]: return run_batch_inference( self.model, @@ -244,6 +246,8 @@ class BatDetect2API: audio_loader=self.audio_loader, preprocessor=self.preprocessor, config=self.config, + batch_size=batch_size, + num_workers=num_workers, ) def save_predictions( diff --git a/src/batdetect2/data/conditions.py b/src/batdetect2/data/conditions.py index 4ee8590..f9bca3d 100644 --- a/src/batdetect2/data/conditions.py +++ b/src/batdetect2/data/conditions.py @@ -25,7 +25,10 @@ class HasTag: def __call__( self, sound_event_annotation: data.SoundEventAnnotation ) -> bool: - return self.tag in sound_event_annotation.tags + return any( + self.tag.term.name == tag.term.name and self.tag.value == tag.value + for tag in sound_event_annotation.tags + ) @conditions.register(HasTagConfig) @staticmethod @@ -43,12 +46,14 @@ class HasAllTags: if not tags: raise ValueError("Need to specify at least one tag") - self.tags = set(tags) + self.tags = {(tag.term.name, tag.value) for tag in tags} def __call__( self, sound_event_annotation: data.SoundEventAnnotation ) -> bool: - return self.tags.issubset(sound_event_annotation.tags) + return self.tags.issubset( + {(tag.term.name, tag.value) for tag in sound_event_annotation.tags} + ) @conditions.register(HasAllTagsConfig) @staticmethod @@ -66,12 +71,19 @@ class HasAnyTag: if not tags: raise ValueError("Need to specify at least one tag") - self.tags = set(tags) + self.tags = {(tag.term.name, tag.value) for tag in tags} def __call__( self, sound_event_annotation: data.SoundEventAnnotation ) -> bool: - return bool(self.tags.intersection(sound_event_annotation.tags)) + return bool( + self.tags.intersection( + { + (tag.term.name, tag.value) + for tag in sound_event_annotation.tags + } + ) + ) @conditions.register(HasAnyTagConfig) @staticmethod diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py index 06e67b4..e3aee99 100644 --- a/src/batdetect2/inference/batch.py +++ b/src/batdetect2/inference/batch.py @@ -29,6 +29,7 @@ def run_batch_inference( preprocessor: Optional["PreprocessorProtocol"] = None, config: Optional["BatDetect2Config"] = None, num_workers: Optional[int] = None, + batch_size: Optional[int] = None, ) -> List[BatDetect2Prediction]: from batdetect2.config import BatDetect2Config @@ -48,6 +49,7 @@ def run_batch_inference( preprocessor=preprocessor, config=config.inference.loader, num_workers=num_workers, + batch_size=batch_size, ) module = InferenceModule(model) diff --git a/src/batdetect2/inference/dataset.py b/src/batdetect2/inference/dataset.py index 76d868d..b2fba56 100644 --- a/src/batdetect2/inference/dataset.py +++ b/src/batdetect2/inference/dataset.py @@ -70,6 +70,7 @@ def build_inference_loader( preprocessor: Optional[PreprocessorProtocol] = None, config: Optional[InferenceLoaderConfig] = None, num_workers: Optional[int] = None, + batch_size: Optional[int] = None, ) -> DataLoader[DatasetItem]: logger.info("Building inference data loader...") config = config or InferenceLoaderConfig() @@ -80,10 +81,12 @@ def build_inference_loader( preprocessor=preprocessor, ) + batch_size = batch_size or config.batch_size + num_workers = num_workers or config.num_workers return DataLoader( inference_dataset, - batch_size=config.batch_size, + batch_size=batch_size, shuffle=False, num_workers=config.num_workers, collate_fn=_collate_fn,