mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Added batch_size and num_workers to API
This commit is contained in:
parent
8366410332
commit
110432bd40
@ -236,6 +236,8 @@ class BatDetect2API:
|
|||||||
def process_clips(
|
def process_clips(
|
||||||
self,
|
self,
|
||||||
clips: Sequence[data.Clip],
|
clips: Sequence[data.Clip],
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[BatDetect2Prediction]:
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
@ -244,6 +246,8 @@ class BatDetect2API:
|
|||||||
audio_loader=self.audio_loader,
|
audio_loader=self.audio_loader,
|
||||||
preprocessor=self.preprocessor,
|
preprocessor=self.preprocessor,
|
||||||
config=self.config,
|
config=self.config,
|
||||||
|
batch_size=batch_size,
|
||||||
|
num_workers=num_workers,
|
||||||
)
|
)
|
||||||
|
|
||||||
def save_predictions(
|
def save_predictions(
|
||||||
|
|||||||
@ -25,7 +25,10 @@ class HasTag:
|
|||||||
def __call__(
|
def __call__(
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return self.tag in sound_event_annotation.tags
|
return any(
|
||||||
|
self.tag.term.name == tag.term.name and self.tag.value == tag.value
|
||||||
|
for tag in sound_event_annotation.tags
|
||||||
|
)
|
||||||
|
|
||||||
@conditions.register(HasTagConfig)
|
@conditions.register(HasTagConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -43,12 +46,14 @@ class HasAllTags:
|
|||||||
if not tags:
|
if not tags:
|
||||||
raise ValueError("Need to specify at least one tag")
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
self.tags = set(tags)
|
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return self.tags.issubset(sound_event_annotation.tags)
|
return self.tags.issubset(
|
||||||
|
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
|
||||||
|
)
|
||||||
|
|
||||||
@conditions.register(HasAllTagsConfig)
|
@conditions.register(HasAllTagsConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -66,12 +71,19 @@ class HasAnyTag:
|
|||||||
if not tags:
|
if not tags:
|
||||||
raise ValueError("Need to specify at least one tag")
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
self.tags = set(tags)
|
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, sound_event_annotation: data.SoundEventAnnotation
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return bool(self.tags.intersection(sound_event_annotation.tags))
|
return bool(
|
||||||
|
self.tags.intersection(
|
||||||
|
{
|
||||||
|
(tag.term.name, tag.value)
|
||||||
|
for tag in sound_event_annotation.tags
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
@conditions.register(HasAnyTagConfig)
|
@conditions.register(HasAnyTagConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|||||||
@ -29,6 +29,7 @@ def run_batch_inference(
|
|||||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
config: Optional["BatDetect2Config"] = None,
|
config: Optional["BatDetect2Config"] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> List[BatDetect2Prediction]:
|
) -> List[BatDetect2Prediction]:
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
@ -48,6 +49,7 @@ def run_batch_inference(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
config=config.inference.loader,
|
config=config.inference.loader,
|
||||||
num_workers=num_workers,
|
num_workers=num_workers,
|
||||||
|
batch_size=batch_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
module = InferenceModule(model)
|
module = InferenceModule(model)
|
||||||
|
|||||||
@ -70,6 +70,7 @@ def build_inference_loader(
|
|||||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
config: Optional[InferenceLoaderConfig] = None,
|
config: Optional[InferenceLoaderConfig] = None,
|
||||||
num_workers: Optional[int] = None,
|
num_workers: Optional[int] = None,
|
||||||
|
batch_size: Optional[int] = None,
|
||||||
) -> DataLoader[DatasetItem]:
|
) -> DataLoader[DatasetItem]:
|
||||||
logger.info("Building inference data loader...")
|
logger.info("Building inference data loader...")
|
||||||
config = config or InferenceLoaderConfig()
|
config = config or InferenceLoaderConfig()
|
||||||
@ -80,10 +81,12 @@ def build_inference_loader(
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
batch_size = batch_size or config.batch_size
|
||||||
|
|
||||||
num_workers = num_workers or config.num_workers
|
num_workers = num_workers or config.num_workers
|
||||||
return DataLoader(
|
return DataLoader(
|
||||||
inference_dataset,
|
inference_dataset,
|
||||||
batch_size=config.batch_size,
|
batch_size=batch_size,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=config.num_workers,
|
num_workers=config.num_workers,
|
||||||
collate_fn=_collate_fn,
|
collate_fn=_collate_fn,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user