From 981e37c346301a76a80d7a53e706cad292b42a9f Mon Sep 17 00:00:00 2001 From: mbsantiago Date: Tue, 30 Sep 2025 13:22:03 +0100 Subject: [PATCH] Writing batch inference code --- src/batdetect2/{api/__init__.py => api.py} | 0 src/batdetect2/{api/base.py => api_v2.py} | 105 +++++++++++++++--- src/batdetect2/audio/clips.py | 73 +++++++++++-- src/batdetect2/config.py | 2 + src/batdetect2/evaluate/evaluate.py | 5 +- src/batdetect2/evaluate/lightning.py | 8 +- src/batdetect2/inference/__init__.py | 10 ++ src/batdetect2/inference/batch.py | 88 +++++++++++++++ src/batdetect2/inference/clips.py | 75 +++++++++++++ src/batdetect2/inference/config.py | 21 ++++ src/batdetect2/inference/dataset.py | 120 +++++++++++++++++++++ src/batdetect2/inference/lightning.py | 52 +++++++++ src/batdetect2/train/__init__.py | 10 +- src/batdetect2/typing/postprocess.py | 4 +- src/batdetect2/typing/train.py | 4 +- 15 files changed, 546 insertions(+), 31 deletions(-) rename src/batdetect2/{api/__init__.py => api.py} (100%) rename src/batdetect2/{api/base.py => api_v2.py} (63%) create mode 100644 src/batdetect2/inference/batch.py create mode 100644 src/batdetect2/inference/clips.py create mode 100644 src/batdetect2/inference/config.py create mode 100644 src/batdetect2/inference/dataset.py create mode 100644 src/batdetect2/inference/lightning.py diff --git a/src/batdetect2/api/__init__.py b/src/batdetect2/api.py similarity index 100% rename from src/batdetect2/api/__init__.py rename to src/batdetect2/api.py diff --git a/src/batdetect2/api/base.py b/src/batdetect2/api_v2.py similarity index 63% rename from src/batdetect2/api/base.py rename to src/batdetect2/api_v2.py index 11992a8..3b31bb7 100644 --- a/src/batdetect2/api/base.py +++ b/src/batdetect2/api_v2.py @@ -1,27 +1,29 @@ from pathlib import Path from typing import List, Optional, Sequence +import numpy as np import torch from soundevent import data +from soundevent.audio.files import get_audio_files from batdetect2.audio import build_audio_loader from batdetect2.config import BatDetect2Config from batdetect2.evaluate import build_evaluator, evaluate +from batdetect2.inference import process_file_list, run_batch_inference from batdetect2.models import Model, build_model -from batdetect2.postprocess import build_postprocessor -from batdetect2.postprocess.decoding import to_raw_predictions +from batdetect2.postprocess import build_postprocessor, to_raw_predictions from batdetect2.preprocess import build_preprocessor -from batdetect2.targets.targets import build_targets -from batdetect2.train import train -from batdetect2.train.lightning import load_model_from_checkpoint +from batdetect2.targets import build_targets +from batdetect2.train import load_model_from_checkpoint, train from batdetect2.typing import ( AudioLoader, + BatDetect2Prediction, EvaluatorProtocol, PostprocessorProtocol, PreprocessorProtocol, + RawPrediction, TargetProtocol, ) -from batdetect2.typing.postprocess import RawPrediction class BatDetect2API: @@ -95,17 +97,94 @@ class BatDetect2API: run_name=run_name, ) + def load_audio(self, path: data.PathLike) -> np.ndarray: + return self.audio_loader.load_file(path) + + def load_clip(self, clip: data.Clip) -> np.ndarray: + return self.audio_loader.load_clip(clip) + + def generate_spectrogram( + self, + audio: np.ndarray, + ) -> torch.Tensor: + tensor = torch.tensor(audio).unsqueeze(0) + return self.preprocessor(tensor) + + def process_file(self, audio_file: str) -> BatDetect2Prediction: + recording = data.Recording.from_file(audio_file, compute_hash=False) + wav = self.audio_loader.load_recording(recording) + detections = self.process_audio(wav) + return BatDetect2Prediction( + clip=data.Clip( + uuid=recording.uuid, + recording=recording, + start_time=0, + end_time=recording.duration, + ), + predictions=detections, + ) + + def process_audio( + self, + audio: np.ndarray, + ) -> List[RawPrediction]: + spec = self.generate_spectrogram(audio) + return self.process_spectrogram(spec) + def process_spectrogram( self, spec: torch.Tensor, - start_times: Optional[Sequence[float]] = None, - ) -> List[List[RawPrediction]]: + start_time: float = 0, + ) -> List[RawPrediction]: + if spec.ndim == 4 and spec.shape[0] > 1: + raise ValueError("Batched spectrograms not supported.") + + if spec.ndim == 3: + spec = spec.unsqueeze(0) + outputs = self.model.detector(spec) - clip_detections = self.postprocessor(outputs, start_times=start_times) - return [ - to_raw_predictions(clip_dets.numpy(), self.targets) - for clip_dets in clip_detections - ] + + detections = self.model.postprocessor( + outputs, + start_times=[start_time], + )[0] + + return to_raw_predictions(detections.numpy(), targets=self.targets) + + def process_directory( + self, + audio_dir: data.PathLike, + ) -> List[BatDetect2Prediction]: + files = list(get_audio_files(audio_dir)) + return self.process_files(files) + + def process_files( + self, + audio_files: Sequence[data.PathLike], + num_workers: Optional[int] = None, + ) -> List[BatDetect2Prediction]: + return process_file_list( + self.model, + audio_files, + config=self.config, + targets=self.targets, + audio_loader=self.audio_loader, + preprocessor=self.preprocessor, + num_workers=num_workers, + ) + + def process_clips( + self, + clips: Sequence[data.Clip], + ) -> List[BatDetect2Prediction]: + return run_batch_inference( + self.model, + clips, + targets=self.targets, + audio_loader=self.audio_loader, + preprocessor=self.preprocessor, + config=self.config, + ) @classmethod def from_config(cls, config: BatDetect2Config): diff --git a/src/batdetect2/audio/clips.py b/src/batdetect2/audio/clips.py index 86ddf18..1a2a41e 100644 --- a/src/batdetect2/audio/clips.py +++ b/src/batdetect2/audio/clips.py @@ -48,12 +48,25 @@ class RandomClip: self, clip_annotation: data.ClipAnnotation, ) -> data.ClipAnnotation: - return get_subclip_annotation( + subclip = self.get_subclip(clip_annotation.clip) + sound_events = select_sound_event_annotations( clip_annotation, + subclip, + min_overlap=self.min_sound_event_overlap, + ) + return clip_annotation.model_copy( + update=dict( + clip=subclip, + sound_events=sound_events, + ) + ) + + def get_subclip(self, clip: data.Clip) -> data.Clip: + return select_random_subclip( + clip, random=self.random, duration=self.duration, max_empty=self.max_empty, - min_sound_event_overlap=self.min_sound_event_overlap, ) @clipper_registry.register(RandomClipConfig) @@ -75,7 +88,7 @@ def get_subclip_annotation( ) -> data.ClipAnnotation: clip = clip_annotation.clip - subclip = select_subclip( + subclip = select_random_subclip( clip, random=random, duration=duration, @@ -96,7 +109,7 @@ def get_subclip_annotation( ) -def select_subclip( +def select_random_subclip( clip: data.Clip, random: bool = True, duration: float = 0.5, @@ -170,6 +183,10 @@ class PaddedClip: clip_annotation: data.ClipAnnotation, ) -> data.ClipAnnotation: clip = clip_annotation.clip + clip = self.get_subclip(clip) + return clip_annotation.model_copy(update=dict(clip=clip)) + + def get_subclip(self, clip: data.Clip) -> data.Clip: duration = clip.duration target_duration = float( @@ -180,7 +197,7 @@ class PaddedClip: end_time=clip.start_time + target_duration, ) ) - return clip_annotation.model_copy(update=dict(clip=clip)) + return clip @clipper_registry.register(PaddedClipConfig) @staticmethod @@ -188,8 +205,52 @@ class PaddedClip: return PaddedClip(chunk_size=config.chunk_size) +class FixedDurationClipConfig(BaseConfig): + name: Literal["fixed_duration"] = "fixed_duration" + duration: float = DEFAULT_TRAIN_CLIP_DURATION + + +class FixedDurationClip: + def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION): + self.duration = duration + + def __call__( + self, + clip_annotation: data.ClipAnnotation, + ) -> data.ClipAnnotation: + clip = self.get_subclip(clip_annotation.clip) + sound_events = select_sound_event_annotations( + clip_annotation, + clip, + min_overlap=0, + ) + return clip_annotation.model_copy( + update=dict( + clip=clip, + sound_events=sound_events, + ) + ) + + def get_subclip(self, clip: data.Clip) -> data.Clip: + return clip.model_copy( + update=dict( + end_time=clip.start_time + self.duration, + ) + ) + + @clipper_registry.register(FixedDurationClipConfig) + @staticmethod + def from_config(config: FixedDurationClipConfig): + return FixedDurationClip(duration=config.duration) + + ClipConfig = Annotated[ - Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") + Union[ + RandomClipConfig, + PaddedClipConfig, + FixedDurationClipConfig, + ], + Field(discriminator="name"), ] diff --git a/src/batdetect2/config.py b/src/batdetect2/config.py index bffd563..8fef59a 100644 --- a/src/batdetect2/config.py +++ b/src/batdetect2/config.py @@ -7,6 +7,7 @@ from batdetect2.audio import AudioConfig from batdetect2.core import BaseConfig from batdetect2.core.configs import load_config from batdetect2.evaluate.config import EvaluationConfig +from batdetect2.inference.config import InferenceConfig from batdetect2.models.config import BackboneConfig from batdetect2.postprocess.config import PostprocessConfig from batdetect2.preprocess.config import PreprocessingConfig @@ -31,6 +32,7 @@ class BatDetect2Config(BaseConfig): postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) audio: AudioConfig = Field(default_factory=AudioConfig) targets: TargetConfig = Field(default_factory=TargetConfig) + inference: InferenceConfig = Field(default_factory=InferenceConfig) def load_full_config( diff --git a/src/batdetect2/evaluate/evaluate.py b/src/batdetect2/evaluate/evaluate.py index 2fd723f..cbc0cd1 100644 --- a/src/batdetect2/evaluate/evaluate.py +++ b/src/batdetect2/evaluate/evaluate.py @@ -40,13 +40,14 @@ def evaluate( config = config or BatDetect2Config() - audio_loader = audio_loader or build_audio_loader() + audio_loader = audio_loader or build_audio_loader(config=config.audio) preprocessor = preprocessor or build_preprocessor( + config=config.preprocess, input_samplerate=audio_loader.samplerate, ) - targets = targets or build_targets() + targets = targets or build_targets(config=config.targets) loader = build_test_loader( test_annotations, diff --git a/src/batdetect2/evaluate/lightning.py b/src/batdetect2/evaluate/lightning.py index ccca917..6a02d5f 100644 --- a/src/batdetect2/evaluate/lightning.py +++ b/src/batdetect2/evaluate/lightning.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Any from lightning import LightningModule from torch.utils.data import DataLoader @@ -7,7 +7,7 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample from batdetect2.logging import get_image_logger from batdetect2.models import Model from batdetect2.postprocess import to_raw_predictions -from batdetect2.typing import ClipMatches, EvaluatorProtocol +from batdetect2.typing import EvaluatorProtocol class EvaluationModule(LightningModule): @@ -54,7 +54,7 @@ class EvaluationModule(LightningModule): self.log_metrics(self.clip_evaluations) self.plot_examples(self.clip_evaluations) - def plot_examples(self, evaluated_clips: Sequence[ClipMatches]): + def plot_examples(self, evaluated_clips: Any): plotter = get_image_logger(self.logger) # type: ignore if plotter is None: @@ -63,7 +63,7 @@ class EvaluationModule(LightningModule): for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): plotter(figure_name, fig, self.global_step) - def log_metrics(self, evaluated_clips: Sequence[ClipMatches]): + def log_metrics(self, evaluated_clips: Any): metrics = self.evaluator.compute_metrics(evaluated_clips) self.log_dict(metrics) diff --git a/src/batdetect2/inference/__init__.py b/src/batdetect2/inference/__init__.py index e69de29..5cd37bc 100644 --- a/src/batdetect2/inference/__init__.py +++ b/src/batdetect2/inference/__init__.py @@ -0,0 +1,10 @@ +from batdetect2.inference.batch import process_file_list, run_batch_inference +from batdetect2.inference.clips import get_clips_from_files +from batdetect2.inference.config import InferenceConfig + +__all__ = [ + "process_file_list", + "run_batch_inference", + "InferenceConfig", + "get_clips_from_files", +] diff --git a/src/batdetect2/inference/batch.py b/src/batdetect2/inference/batch.py new file mode 100644 index 0000000..b0d878a --- /dev/null +++ b/src/batdetect2/inference/batch.py @@ -0,0 +1,88 @@ +from typing import TYPE_CHECKING, List, Optional, Sequence + +from lightning import Trainer +from soundevent import data + +from batdetect2.audio.loader import build_audio_loader +from batdetect2.inference.clips import get_clips_from_files +from batdetect2.inference.dataset import build_inference_loader +from batdetect2.inference.lightning import InferenceModule +from batdetect2.models import Model +from batdetect2.preprocess.preprocessor import build_preprocessor +from batdetect2.targets.targets import build_targets +from batdetect2.typing.postprocess import BatDetect2Prediction + +if TYPE_CHECKING: + from batdetect2.config import BatDetect2Config + from batdetect2.typing import ( + AudioLoader, + PreprocessorProtocol, + TargetProtocol, + ) + + +def run_batch_inference( + model, + clips: Sequence[data.Clip], + targets: Optional["TargetProtocol"] = None, + audio_loader: Optional["AudioLoader"] = None, + preprocessor: Optional["PreprocessorProtocol"] = None, + config: Optional["BatDetect2Config"] = None, + num_workers: Optional[int] = None, +) -> List[BatDetect2Prediction]: + from batdetect2.config import BatDetect2Config + + config = config or BatDetect2Config() + + audio_loader = audio_loader or build_audio_loader() + + preprocessor = preprocessor or build_preprocessor( + input_samplerate=audio_loader.samplerate, + ) + + targets = targets or build_targets() + + loader = build_inference_loader( + clips, + audio_loader=audio_loader, + preprocessor=preprocessor, + config=config.inference.loader, + num_workers=num_workers, + ) + + module = InferenceModule(model) + trainer = Trainer(enable_checkpointing=False) + outputs = trainer.predict(module, loader) + return [ + clip_prediction + for clip_predictions in outputs # type: ignore + for clip_prediction in clip_predictions + ] + + +def process_file_list( + model: Model, + paths: Sequence[data.PathLike], + config: "BatDetect2Config", + targets: Optional["TargetProtocol"] = None, + audio_loader: Optional["AudioLoader"] = None, + preprocessor: Optional["PreprocessorProtocol"] = None, + num_workers: Optional[int] = None, +) -> List[BatDetect2Prediction]: + clip_config = config.inference.clipping + clips = get_clips_from_files( + paths, + duration=clip_config.duration, + overlap=clip_config.overlap, + max_empty=clip_config.max_empty, + discard_empty=clip_config.discard_empty, + ) + return run_batch_inference( + model, + clips, + targets=targets, + audio_loader=audio_loader, + preprocessor=preprocessor, + config=config, + num_workers=num_workers, + ) diff --git a/src/batdetect2/inference/clips.py b/src/batdetect2/inference/clips.py new file mode 100644 index 0000000..b69e066 --- /dev/null +++ b/src/batdetect2/inference/clips.py @@ -0,0 +1,75 @@ +from typing import List, Sequence +from uuid import uuid5 + +import numpy as np +from soundevent import data + + +def get_clips_from_files( + paths: Sequence[data.PathLike], + duration: float, + overlap: float = 0.0, + max_empty: float = 0.0, + discard_empty: bool = True, + compute_hash: bool = False, +) -> List[data.Clip]: + clips: List[data.Clip] = [] + + for path in paths: + recording = data.Recording.from_file(path, compute_hash=compute_hash) + clips.extend( + get_recording_clips( + recording, + duration, + overlap=overlap, + max_empty=max_empty, + discard_empty=discard_empty, + ) + ) + + return clips + + +def get_recording_clips( + recording: data.Recording, + duration: float, + overlap: float = 0.0, + max_empty: float = 0.0, + discard_empty: bool = True, +) -> Sequence[data.Clip]: + start_time = 0 + duration = recording.duration + hop = duration * (1 - overlap) + + num_clips = int(np.ceil(duration / hop)) + + if num_clips == 0: + # This should only happen if the clip's duration is zero, + # which should never happen in practice, but just in case... + return [] + + clips = [] + for i in range(num_clips): + start = start_time + i * hop + end = start + duration + + if end > duration: + empty_duration = end - duration + + if empty_duration > max_empty and discard_empty: + # Discard clips that contain too much empty space + continue + + clips.append( + data.Clip( + uuid=uuid5(recording.uuid, f"{start}_{end}"), + recording=recording, + start_time=start, + end_time=end, + ) + ) + + if discard_empty: + clips = [clip for clip in clips if clip.duration > max_empty] + + return clips diff --git a/src/batdetect2/inference/config.py b/src/batdetect2/inference/config.py new file mode 100644 index 0000000..1db715d --- /dev/null +++ b/src/batdetect2/inference/config.py @@ -0,0 +1,21 @@ +from pydantic import Field + +from batdetect2.core.configs import BaseConfig +from batdetect2.inference.dataset import InferenceLoaderConfig + +__all__ = ["InferenceConfig"] + + +class ClipingConfig(BaseConfig): + enabled: bool = True + duration: float = 0.5 + overlap: float = 0.0 + max_empty: float = 0.0 + discard_empty: bool = True + + +class InferenceConfig(BaseConfig): + loader: InferenceLoaderConfig = Field( + default_factory=InferenceLoaderConfig + ) + clipping: ClipingConfig = Field(default_factory=ClipingConfig) diff --git a/src/batdetect2/inference/dataset.py b/src/batdetect2/inference/dataset.py new file mode 100644 index 0000000..76d868d --- /dev/null +++ b/src/batdetect2/inference/dataset.py @@ -0,0 +1,120 @@ +from typing import List, NamedTuple, Optional, Sequence + +import torch +from loguru import logger +from soundevent import data +from torch.utils.data import DataLoader, Dataset + +from batdetect2.audio import build_audio_loader +from batdetect2.core import BaseConfig +from batdetect2.core.arrays import adjust_width +from batdetect2.preprocess import build_preprocessor +from batdetect2.typing import AudioLoader, PreprocessorProtocol + +__all__ = [ + "InferenceDataset", + "build_inference_dataset", + "build_inference_loader", +] + + +DEFAULT_INFERENCE_CLIP_DURATION = 0.512 + + +class DatasetItem(NamedTuple): + spec: torch.Tensor + idx: torch.Tensor + start_time: torch.Tensor + end_time: torch.Tensor + + +class InferenceDataset(Dataset[DatasetItem]): + clips: List[data.Clip] + + def __init__( + self, + clips: Sequence[data.Clip], + audio_loader: AudioLoader, + preprocessor: PreprocessorProtocol, + audio_dir: Optional[data.PathLike] = None, + ): + self.clips = list(clips) + self.preprocessor = preprocessor + self.audio_loader = audio_loader + self.audio_dir = audio_dir + + def __len__(self): + return len(self.clips) + + def __getitem__(self, idx: int) -> DatasetItem: + clip = self.clips[idx] + wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir) + wav_tensor = torch.tensor(wav).unsqueeze(0) + spectrogram = self.preprocessor(wav_tensor) + return DatasetItem( + spec=spectrogram, + idx=torch.tensor(idx), + start_time=torch.tensor(clip.start_time), + end_time=torch.tensor(clip.end_time), + ) + + +class InferenceLoaderConfig(BaseConfig): + num_workers: int = 0 + batch_size: int = 8 + + +def build_inference_loader( + clips: Sequence[data.Clip], + audio_loader: Optional[AudioLoader] = None, + preprocessor: Optional[PreprocessorProtocol] = None, + config: Optional[InferenceLoaderConfig] = None, + num_workers: Optional[int] = None, +) -> DataLoader[DatasetItem]: + logger.info("Building inference data loader...") + config = config or InferenceLoaderConfig() + + inference_dataset = build_inference_dataset( + clips, + audio_loader=audio_loader, + preprocessor=preprocessor, + ) + + num_workers = num_workers or config.num_workers + return DataLoader( + inference_dataset, + batch_size=config.batch_size, + shuffle=False, + num_workers=config.num_workers, + collate_fn=_collate_fn, + ) + + +def build_inference_dataset( + clips: Sequence[data.Clip], + audio_loader: Optional[AudioLoader] = None, + preprocessor: Optional[PreprocessorProtocol] = None, +) -> InferenceDataset: + if audio_loader is None: + audio_loader = build_audio_loader() + + if preprocessor is None: + preprocessor = build_preprocessor() + + return InferenceDataset( + clips, + audio_loader=audio_loader, + preprocessor=preprocessor, + ) + + +def _collate_fn(batch: List[DatasetItem]) -> DatasetItem: + max_width = max(item.spec.shape[-1] for item in batch) + return DatasetItem( + spec=torch.stack( + [adjust_width(item.spec, max_width) for item in batch] + ), + idx=torch.stack([item.idx for item in batch]), + start_time=torch.stack([item.start_time for item in batch]), + end_time=torch.stack([item.end_time for item in batch]), + ) diff --git a/src/batdetect2/inference/lightning.py b/src/batdetect2/inference/lightning.py new file mode 100644 index 0000000..d6ff5fb --- /dev/null +++ b/src/batdetect2/inference/lightning.py @@ -0,0 +1,52 @@ +from typing import Sequence + +from lightning import LightningModule +from torch.utils.data import DataLoader + +from batdetect2.inference.dataset import DatasetItem, InferenceDataset +from batdetect2.models import Model +from batdetect2.postprocess import to_raw_predictions +from batdetect2.typing.postprocess import BatDetect2Prediction + + +class InferenceModule(LightningModule): + def __init__(self, model: Model): + super().__init__() + self.model = model + + def predict_step( + self, + batch: DatasetItem, + batch_idx: int, + dataloader_idx: int = 0, + ) -> Sequence[BatDetect2Prediction]: + dataset = self.get_dataset() + + clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx] + + outputs = self.model.detector(batch.spec) + + clip_detections = self.model.postprocessor( + outputs, + start_times=[clip.start_time for clip in clips], + ) + + predictions = [ + BatDetect2Prediction( + clip=clip, + predictions=to_raw_predictions( + clip_dets.numpy(), + targets=self.model.targets, + ), + ) + for clip, clip_dets in zip(clips, clip_detections) + ] + + return predictions + + def get_dataset(self) -> InferenceDataset: + dataloaders = self.trainer.predict_dataloaders + assert isinstance(dataloaders, DataLoader) + dataset = dataloaders.dataset + assert isinstance(dataset, InferenceDataset) + return dataset diff --git a/src/batdetect2/train/__init__.py b/src/batdetect2/train/__init__.py index 226dfe8..5581e56 100644 --- a/src/batdetect2/train/__init__.py +++ b/src/batdetect2/train/__init__.py @@ -1,9 +1,9 @@ from batdetect2.train.augmentations import ( - AugmentationsConfig, AddEchoConfig, + AugmentationsConfig, MaskFrequencyConfig, - RandomAudioSource, MaskTimeConfig, + RandomAudioSource, ScaleVolumeConfig, WarpConfig, add_echo, @@ -28,7 +28,10 @@ from batdetect2.train.dataset import ( build_val_loader, ) from batdetect2.train.labels import build_clip_labeler, load_label_config -from batdetect2.train.lightning import TrainingModule +from batdetect2.train.lightning import ( + TrainingModule, + load_model_from_checkpoint, +) from batdetect2.train.losses import ( ClassificationLossConfig, DetectionLossConfig, @@ -74,4 +77,5 @@ __all__ = [ "scale_volume", "train", "warp_spectrogram", + "load_model_from_checkpoint", ] diff --git a/src/batdetect2/typing/postprocess.py b/src/batdetect2/typing/postprocess.py index df45759..ece0f15 100644 --- a/src/batdetect2/typing/postprocess.py +++ b/src/batdetect2/typing/postprocess.py @@ -83,8 +83,8 @@ class ClipDetectionsTensor(NamedTuple): @dataclass class BatDetect2Prediction: - raw: RawPrediction - sound_event_prediction: data.SoundEventPrediction + clip: data.Clip + predictions: List[RawPrediction] class PostprocessorProtocol(Protocol): diff --git a/src/batdetect2/typing/train.py b/src/batdetect2/typing/train.py index 7edd401..fc3f0e9 100644 --- a/src/batdetect2/typing/train.py +++ b/src/batdetect2/typing/train.py @@ -1,4 +1,4 @@ -from typing import Callable, NamedTuple, Protocol, Tuple +from typing import Callable, List, NamedTuple, Protocol, Tuple import torch from soundevent import data @@ -104,3 +104,5 @@ class ClipperProtocol(Protocol): self, clip_annotation: data.ClipAnnotation, ) -> data.ClipAnnotation: ... + + def get_subclip(self, clip: data.Clip) -> data.Clip: ...