mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Added batch_size and num_workers to API
This commit is contained in:
parent
8366410332
commit
110432bd40
@ -236,6 +236,8 @@ class BatDetect2API:
|
||||
def process_clips(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
batch_size: Optional[int] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
@ -244,6 +246,8 @@ class BatDetect2API:
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
batch_size=batch_size,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
def save_predictions(
|
||||
|
||||
@ -25,7 +25,10 @@ class HasTag:
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tag in sound_event_annotation.tags
|
||||
return any(
|
||||
self.tag.term.name == tag.term.name and self.tag.value == tag.value
|
||||
for tag in sound_event_annotation.tags
|
||||
)
|
||||
|
||||
@conditions.register(HasTagConfig)
|
||||
@staticmethod
|
||||
@ -43,12 +46,14 @@ class HasAllTags:
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = set(tags)
|
||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tags.issubset(sound_event_annotation.tags)
|
||||
return self.tags.issubset(
|
||||
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
|
||||
)
|
||||
|
||||
@conditions.register(HasAllTagsConfig)
|
||||
@staticmethod
|
||||
@ -66,12 +71,19 @@ class HasAnyTag:
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = set(tags)
|
||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
||||
return bool(
|
||||
self.tags.intersection(
|
||||
{
|
||||
(tag.term.name, tag.value)
|
||||
for tag in sound_event_annotation.tags
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@conditions.register(HasAnyTagConfig)
|
||||
@staticmethod
|
||||
|
||||
@ -29,6 +29,7 @@ def run_batch_inference(
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
@ -48,6 +49,7 @@ def run_batch_inference(
|
||||
preprocessor=preprocessor,
|
||||
config=config.inference.loader,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
)
|
||||
|
||||
module = InferenceModule(model)
|
||||
|
||||
@ -70,6 +70,7 @@ def build_inference_loader(
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[InferenceLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
batch_size: Optional[int] = None,
|
||||
) -> DataLoader[DatasetItem]:
|
||||
logger.info("Building inference data loader...")
|
||||
config = config or InferenceLoaderConfig()
|
||||
@ -80,10 +81,12 @@ def build_inference_loader(
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
batch_size = batch_size or config.batch_size
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=config.batch_size,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user