Writing batch inference code

This commit is contained in:
mbsantiago 2025-09-30 13:22:03 +01:00
parent 30159d64a9
commit 981e37c346
15 changed files with 546 additions and 31 deletions

View File

@ -1,27 +1,29 @@
from pathlib import Path
from typing import List, Optional, Sequence
import numpy as np
import torch
from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor
from batdetect2.postprocess.decoding import to_raw_predictions
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.train import train
from batdetect2.train.lightning import load_model_from_checkpoint
from batdetect2.targets import build_targets
from batdetect2.train import load_model_from_checkpoint, train
from batdetect2.typing import (
AudioLoader,
BatDetect2Prediction,
EvaluatorProtocol,
PostprocessorProtocol,
PreprocessorProtocol,
RawPrediction,
TargetProtocol,
)
from batdetect2.typing.postprocess import RawPrediction
class BatDetect2API:
@ -95,17 +97,94 @@ class BatDetect2API:
run_name=run_name,
)
def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path)
def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip)
def generate_spectrogram(
self,
audio: np.ndarray,
) -> torch.Tensor:
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> BatDetect2Prediction:
recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav)
return BatDetect2Prediction(
clip=data.Clip(
uuid=recording.uuid,
recording=recording,
start_time=0,
end_time=recording.duration,
),
predictions=detections,
)
def process_audio(
self,
audio: np.ndarray,
) -> List[RawPrediction]:
spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec)
def process_spectrogram(
self,
spec: torch.Tensor,
start_times: Optional[Sequence[float]] = None,
) -> List[List[RawPrediction]]:
start_time: float = 0,
) -> List[RawPrediction]:
if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.")
if spec.ndim == 3:
spec = spec.unsqueeze(0)
outputs = self.model.detector(spec)
clip_detections = self.postprocessor(outputs, start_times=start_times)
return [
to_raw_predictions(clip_dets.numpy(), self.targets)
for clip_dets in clip_detections
]
detections = self.model.postprocessor(
outputs,
start_times=[start_time],
)[0]
return to_raw_predictions(detections.numpy(), targets=self.targets)
def process_directory(
self,
audio_dir: data.PathLike,
) -> List[BatDetect2Prediction]:
files = list(get_audio_files(audio_dir))
return self.process_files(files)
def process_files(
self,
audio_files: Sequence[data.PathLike],
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
return process_file_list(
self.model,
audio_files,
config=self.config,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
num_workers=num_workers,
)
def process_clips(
self,
clips: Sequence[data.Clip],
) -> List[BatDetect2Prediction]:
return run_batch_inference(
self.model,
clips,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
)
@classmethod
def from_config(cls, config: BatDetect2Config):

View File

@ -48,12 +48,25 @@ class RandomClip:
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
return get_subclip_annotation(
subclip = self.get_subclip(clip_annotation.clip)
sound_events = select_sound_event_annotations(
clip_annotation,
subclip,
min_overlap=self.min_sound_event_overlap,
)
return clip_annotation.model_copy(
update=dict(
clip=subclip,
sound_events=sound_events,
)
)
def get_subclip(self, clip: data.Clip) -> data.Clip:
return select_random_subclip(
clip,
random=self.random,
duration=self.duration,
max_empty=self.max_empty,
min_sound_event_overlap=self.min_sound_event_overlap,
)
@clipper_registry.register(RandomClipConfig)
@ -75,7 +88,7 @@ def get_subclip_annotation(
) -> data.ClipAnnotation:
clip = clip_annotation.clip
subclip = select_subclip(
subclip = select_random_subclip(
clip,
random=random,
duration=duration,
@ -96,7 +109,7 @@ def get_subclip_annotation(
)
def select_subclip(
def select_random_subclip(
clip: data.Clip,
random: bool = True,
duration: float = 0.5,
@ -170,6 +183,10 @@ class PaddedClip:
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = clip_annotation.clip
clip = self.get_subclip(clip)
return clip_annotation.model_copy(update=dict(clip=clip))
def get_subclip(self, clip: data.Clip) -> data.Clip:
duration = clip.duration
target_duration = float(
@ -180,7 +197,7 @@ class PaddedClip:
end_time=clip.start_time + target_duration,
)
)
return clip_annotation.model_copy(update=dict(clip=clip))
return clip
@clipper_registry.register(PaddedClipConfig)
@staticmethod
@ -188,8 +205,52 @@ class PaddedClip:
return PaddedClip(chunk_size=config.chunk_size)
class FixedDurationClipConfig(BaseConfig):
name: Literal["fixed_duration"] = "fixed_duration"
duration: float = DEFAULT_TRAIN_CLIP_DURATION
class FixedDurationClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = self.get_subclip(clip_annotation.clip)
sound_events = select_sound_event_annotations(
clip_annotation,
clip,
min_overlap=0,
)
return clip_annotation.model_copy(
update=dict(
clip=clip,
sound_events=sound_events,
)
)
def get_subclip(self, clip: data.Clip) -> data.Clip:
return clip.model_copy(
update=dict(
end_time=clip.start_time + self.duration,
)
)
@clipper_registry.register(FixedDurationClipConfig)
@staticmethod
def from_config(config: FixedDurationClipConfig):
return FixedDurationClip(duration=config.duration)
ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
Union[
RandomClipConfig,
PaddedClipConfig,
FixedDurationClipConfig,
],
Field(discriminator="name"),
]

View File

@ -7,6 +7,7 @@ from batdetect2.audio import AudioConfig
from batdetect2.core import BaseConfig
from batdetect2.core.configs import load_config
from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.inference.config import InferenceConfig
from batdetect2.models.config import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig
@ -31,6 +32,7 @@ class BatDetect2Config(BaseConfig):
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
def load_full_config(

View File

@ -40,13 +40,14 @@ def evaluate(
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
audio_loader = audio_loader or build_audio_loader(config=config.audio)
preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets()
targets = targets or build_targets(config=config.targets)
loader = build_test_loader(
test_annotations,

View File

@ -1,4 +1,4 @@
from typing import Sequence
from typing import Any
from lightning import LightningModule
from torch.utils.data import DataLoader
@ -7,7 +7,7 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.logging import get_image_logger
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipMatches, EvaluatorProtocol
from batdetect2.typing import EvaluatorProtocol
class EvaluationModule(LightningModule):
@ -54,7 +54,7 @@ class EvaluationModule(LightningModule):
self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations)
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
def plot_examples(self, evaluated_clips: Any):
plotter = get_image_logger(self.logger) # type: ignore
if plotter is None:
@ -63,7 +63,7 @@ class EvaluationModule(LightningModule):
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
def log_metrics(self, evaluated_clips: Any):
metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics)

View File

@ -0,0 +1,10 @@
from batdetect2.inference.batch import process_file_list, run_batch_inference
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
__all__ = [
"process_file_list",
"run_batch_inference",
"InferenceConfig",
"get_clips_from_files",
]

View File

@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, List, Optional, Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio.loader import build_audio_loader
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import BatDetect2Prediction
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
def run_batch_inference(
model,
clips: Sequence[data.Clip],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets()
loader = build_inference_loader(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config.inference.loader,
num_workers=num_workers,
)
module = InferenceModule(model)
trainer = Trainer(enable_checkpointing=False)
outputs = trainer.predict(module, loader)
return [
clip_prediction
for clip_predictions in outputs # type: ignore
for clip_prediction in clip_predictions
]
def process_file_list(
model: Model,
paths: Sequence[data.PathLike],
config: "BatDetect2Config",
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
clip_config = config.inference.clipping
clips = get_clips_from_files(
paths,
duration=clip_config.duration,
overlap=clip_config.overlap,
max_empty=clip_config.max_empty,
discard_empty=clip_config.discard_empty,
)
return run_batch_inference(
model,
clips,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
num_workers=num_workers,
)

View File

@ -0,0 +1,75 @@
from typing import List, Sequence
from uuid import uuid5
import numpy as np
from soundevent import data
def get_clips_from_files(
paths: Sequence[data.PathLike],
duration: float,
overlap: float = 0.0,
max_empty: float = 0.0,
discard_empty: bool = True,
compute_hash: bool = False,
) -> List[data.Clip]:
clips: List[data.Clip] = []
for path in paths:
recording = data.Recording.from_file(path, compute_hash=compute_hash)
clips.extend(
get_recording_clips(
recording,
duration,
overlap=overlap,
max_empty=max_empty,
discard_empty=discard_empty,
)
)
return clips
def get_recording_clips(
recording: data.Recording,
duration: float,
overlap: float = 0.0,
max_empty: float = 0.0,
discard_empty: bool = True,
) -> Sequence[data.Clip]:
start_time = 0
duration = recording.duration
hop = duration * (1 - overlap)
num_clips = int(np.ceil(duration / hop))
if num_clips == 0:
# This should only happen if the clip's duration is zero,
# which should never happen in practice, but just in case...
return []
clips = []
for i in range(num_clips):
start = start_time + i * hop
end = start + duration
if end > duration:
empty_duration = end - duration
if empty_duration > max_empty and discard_empty:
# Discard clips that contain too much empty space
continue
clips.append(
data.Clip(
uuid=uuid5(recording.uuid, f"{start}_{end}"),
recording=recording,
start_time=start,
end_time=end,
)
)
if discard_empty:
clips = [clip for clip in clips if clip.duration > max_empty]
return clips

View File

@ -0,0 +1,21 @@
from pydantic import Field
from batdetect2.core.configs import BaseConfig
from batdetect2.inference.dataset import InferenceLoaderConfig
__all__ = ["InferenceConfig"]
class ClipingConfig(BaseConfig):
enabled: bool = True
duration: float = 0.5
overlap: float = 0.0
max_empty: float = 0.0
discard_empty: bool = True
class InferenceConfig(BaseConfig):
loader: InferenceLoaderConfig = Field(
default_factory=InferenceLoaderConfig
)
clipping: ClipingConfig = Field(default_factory=ClipingConfig)

View File

@ -0,0 +1,120 @@
from typing import List, NamedTuple, Optional, Sequence
import torch
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [
"InferenceDataset",
"build_inference_dataset",
"build_inference_loader",
]
DEFAULT_INFERENCE_CLIP_DURATION = 0.512
class DatasetItem(NamedTuple):
spec: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class InferenceDataset(Dataset[DatasetItem]):
clips: List[data.Clip]
def __init__(
self,
clips: Sequence[data.Clip],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
audio_dir: Optional[data.PathLike] = None,
):
self.clips = list(clips)
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clips)
def __getitem__(self, idx: int) -> DatasetItem:
clip = self.clips[idx]
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return DatasetItem(
spec=spectrogram,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class InferenceLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
def build_inference_loader(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[InferenceLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig()
inference_dataset = build_inference_dataset(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
)
num_workers = num_workers or config.num_workers
return DataLoader(
inference_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
collate_fn=_collate_fn,
)
def build_inference_dataset(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
) -> InferenceDataset:
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
return InferenceDataset(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
)
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
max_width = max(item.spec.shape[-1] for item in batch)
return DatasetItem(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -0,0 +1,52 @@
from typing import Sequence
from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import BatDetect2Prediction
class InferenceModule(LightningModule):
def __init__(self, model: Model):
super().__init__()
self.model = model
def predict_step(
self,
batch: DatasetItem,
batch_idx: int,
dataloader_idx: int = 0,
) -> Sequence[BatDetect2Prediction]:
dataset = self.get_dataset()
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[clip.start_time for clip in clips],
)
predictions = [
BatDetect2Prediction(
clip=clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),
)
for clip, clip_dets in zip(clips, clip_detections)
]
return predictions
def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, InferenceDataset)
return dataset

View File

@ -1,9 +1,9 @@
from batdetect2.train.augmentations import (
AugmentationsConfig,
AddEchoConfig,
AugmentationsConfig,
MaskFrequencyConfig,
RandomAudioSource,
MaskTimeConfig,
RandomAudioSource,
ScaleVolumeConfig,
WarpConfig,
add_echo,
@ -28,7 +28,10 @@ from batdetect2.train.dataset import (
build_val_loader,
)
from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule
from batdetect2.train.lightning import (
TrainingModule,
load_model_from_checkpoint,
)
from batdetect2.train.losses import (
ClassificationLossConfig,
DetectionLossConfig,
@ -74,4 +77,5 @@ __all__ = [
"scale_volume",
"train",
"warp_spectrogram",
"load_model_from_checkpoint",
]

View File

@ -83,8 +83,8 @@ class ClipDetectionsTensor(NamedTuple):
@dataclass
class BatDetect2Prediction:
raw: RawPrediction
sound_event_prediction: data.SoundEventPrediction
clip: data.Clip
predictions: List[RawPrediction]
class PostprocessorProtocol(Protocol):

View File

@ -1,4 +1,4 @@
from typing import Callable, NamedTuple, Protocol, Tuple
from typing import Callable, List, NamedTuple, Protocol, Tuple
import torch
from soundevent import data
@ -104,3 +104,5 @@ class ClipperProtocol(Protocol):
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ...
def get_subclip(self, clip: data.Clip) -> data.Clip: ...