Added batch_size and num_workers to API

This commit is contained in:
mbsantiago 2025-10-29 12:13:16 +00:00
parent 8366410332
commit 110432bd40
4 changed files with 27 additions and 6 deletions

View File

@ -236,6 +236,8 @@ class BatDetect2API:
def process_clips(
self,
clips: Sequence[data.Clip],
batch_size: Optional[int] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
return run_batch_inference(
self.model,
@ -244,6 +246,8 @@ class BatDetect2API:
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
batch_size=batch_size,
num_workers=num_workers,
)
def save_predictions(

View File

@ -25,7 +25,10 @@ class HasTag:
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tag in sound_event_annotation.tags
return any(
self.tag.term.name == tag.term.name and self.tag.value == tag.value
for tag in sound_event_annotation.tags
)
@conditions.register(HasTagConfig)
@staticmethod
@ -43,12 +46,14 @@ class HasAllTags:
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
self.tags = {(tag.term.name, tag.value) for tag in tags}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tags.issubset(sound_event_annotation.tags)
return self.tags.issubset(
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
)
@conditions.register(HasAllTagsConfig)
@staticmethod
@ -66,12 +71,19 @@ class HasAnyTag:
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = set(tags)
self.tags = {(tag.term.name, tag.value) for tag in tags}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return bool(self.tags.intersection(sound_event_annotation.tags))
return bool(
self.tags.intersection(
{
(tag.term.name, tag.value)
for tag in sound_event_annotation.tags
}
)
)
@conditions.register(HasAnyTagConfig)
@staticmethod

View File

@ -29,6 +29,7 @@ def run_batch_inference(
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
batch_size: Optional[int] = None,
) -> List[BatDetect2Prediction]:
from batdetect2.config import BatDetect2Config
@ -48,6 +49,7 @@ def run_batch_inference(
preprocessor=preprocessor,
config=config.inference.loader,
num_workers=num_workers,
batch_size=batch_size,
)
module = InferenceModule(model)

View File

@ -70,6 +70,7 @@ def build_inference_loader(
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[InferenceLoaderConfig] = None,
num_workers: Optional[int] = None,
batch_size: Optional[int] = None,
) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig()
@ -80,10 +81,12 @@ def build_inference_loader(
preprocessor=preprocessor,
)
batch_size = batch_size or config.batch_size
num_workers = num_workers or config.num_workers
return DataLoader(
inference_dataset,
batch_size=config.batch_size,
batch_size=batch_size,
shuffle=False,
num_workers=config.num_workers,
collate_fn=_collate_fn,