Writing batch inference code

This commit is contained in:
mbsantiago 2025-09-30 13:22:03 +01:00
parent 30159d64a9
commit 981e37c346
15 changed files with 546 additions and 31 deletions

View File

@ -1,27 +1,29 @@
from pathlib import Path from pathlib import Path
from typing import List, Optional, Sequence from typing import List, Optional, Sequence
import numpy as np
import torch import torch
from soundevent import data from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import build_audio_loader from batdetect2.audio import build_audio_loader
from batdetect2.config import BatDetect2Config from batdetect2.config import BatDetect2Config
from batdetect2.evaluate import build_evaluator, evaluate from batdetect2.evaluate import build_evaluator, evaluate
from batdetect2.inference import process_file_list, run_batch_inference
from batdetect2.models import Model, build_model from batdetect2.models import Model, build_model
from batdetect2.postprocess import build_postprocessor from batdetect2.postprocess import build_postprocessor, to_raw_predictions
from batdetect2.postprocess.decoding import to_raw_predictions
from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess import build_preprocessor
from batdetect2.targets.targets import build_targets from batdetect2.targets import build_targets
from batdetect2.train import train from batdetect2.train import load_model_from_checkpoint, train
from batdetect2.train.lightning import load_model_from_checkpoint
from batdetect2.typing import ( from batdetect2.typing import (
AudioLoader, AudioLoader,
BatDetect2Prediction,
EvaluatorProtocol, EvaluatorProtocol,
PostprocessorProtocol, PostprocessorProtocol,
PreprocessorProtocol, PreprocessorProtocol,
RawPrediction,
TargetProtocol, TargetProtocol,
) )
from batdetect2.typing.postprocess import RawPrediction
class BatDetect2API: class BatDetect2API:
@ -95,17 +97,94 @@ class BatDetect2API:
run_name=run_name, run_name=run_name,
) )
def load_audio(self, path: data.PathLike) -> np.ndarray:
return self.audio_loader.load_file(path)
def load_clip(self, clip: data.Clip) -> np.ndarray:
return self.audio_loader.load_clip(clip)
def generate_spectrogram(
self,
audio: np.ndarray,
) -> torch.Tensor:
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
def process_file(self, audio_file: str) -> BatDetect2Prediction:
recording = data.Recording.from_file(audio_file, compute_hash=False)
wav = self.audio_loader.load_recording(recording)
detections = self.process_audio(wav)
return BatDetect2Prediction(
clip=data.Clip(
uuid=recording.uuid,
recording=recording,
start_time=0,
end_time=recording.duration,
),
predictions=detections,
)
def process_audio(
self,
audio: np.ndarray,
) -> List[RawPrediction]:
spec = self.generate_spectrogram(audio)
return self.process_spectrogram(spec)
def process_spectrogram( def process_spectrogram(
self, self,
spec: torch.Tensor, spec: torch.Tensor,
start_times: Optional[Sequence[float]] = None, start_time: float = 0,
) -> List[List[RawPrediction]]: ) -> List[RawPrediction]:
if spec.ndim == 4 and spec.shape[0] > 1:
raise ValueError("Batched spectrograms not supported.")
if spec.ndim == 3:
spec = spec.unsqueeze(0)
outputs = self.model.detector(spec) outputs = self.model.detector(spec)
clip_detections = self.postprocessor(outputs, start_times=start_times)
return [ detections = self.model.postprocessor(
to_raw_predictions(clip_dets.numpy(), self.targets) outputs,
for clip_dets in clip_detections start_times=[start_time],
] )[0]
return to_raw_predictions(detections.numpy(), targets=self.targets)
def process_directory(
self,
audio_dir: data.PathLike,
) -> List[BatDetect2Prediction]:
files = list(get_audio_files(audio_dir))
return self.process_files(files)
def process_files(
self,
audio_files: Sequence[data.PathLike],
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
return process_file_list(
self.model,
audio_files,
config=self.config,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
num_workers=num_workers,
)
def process_clips(
self,
clips: Sequence[data.Clip],
) -> List[BatDetect2Prediction]:
return run_batch_inference(
self.model,
clips,
targets=self.targets,
audio_loader=self.audio_loader,
preprocessor=self.preprocessor,
config=self.config,
)
@classmethod @classmethod
def from_config(cls, config: BatDetect2Config): def from_config(cls, config: BatDetect2Config):

View File

@ -48,12 +48,25 @@ class RandomClip:
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ) -> data.ClipAnnotation:
return get_subclip_annotation( subclip = self.get_subclip(clip_annotation.clip)
sound_events = select_sound_event_annotations(
clip_annotation, clip_annotation,
subclip,
min_overlap=self.min_sound_event_overlap,
)
return clip_annotation.model_copy(
update=dict(
clip=subclip,
sound_events=sound_events,
)
)
def get_subclip(self, clip: data.Clip) -> data.Clip:
return select_random_subclip(
clip,
random=self.random, random=self.random,
duration=self.duration, duration=self.duration,
max_empty=self.max_empty, max_empty=self.max_empty,
min_sound_event_overlap=self.min_sound_event_overlap,
) )
@clipper_registry.register(RandomClipConfig) @clipper_registry.register(RandomClipConfig)
@ -75,7 +88,7 @@ def get_subclip_annotation(
) -> data.ClipAnnotation: ) -> data.ClipAnnotation:
clip = clip_annotation.clip clip = clip_annotation.clip
subclip = select_subclip( subclip = select_random_subclip(
clip, clip,
random=random, random=random,
duration=duration, duration=duration,
@ -96,7 +109,7 @@ def get_subclip_annotation(
) )
def select_subclip( def select_random_subclip(
clip: data.Clip, clip: data.Clip,
random: bool = True, random: bool = True,
duration: float = 0.5, duration: float = 0.5,
@ -170,6 +183,10 @@ class PaddedClip:
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ) -> data.ClipAnnotation:
clip = clip_annotation.clip clip = clip_annotation.clip
clip = self.get_subclip(clip)
return clip_annotation.model_copy(update=dict(clip=clip))
def get_subclip(self, clip: data.Clip) -> data.Clip:
duration = clip.duration duration = clip.duration
target_duration = float( target_duration = float(
@ -180,7 +197,7 @@ class PaddedClip:
end_time=clip.start_time + target_duration, end_time=clip.start_time + target_duration,
) )
) )
return clip_annotation.model_copy(update=dict(clip=clip)) return clip
@clipper_registry.register(PaddedClipConfig) @clipper_registry.register(PaddedClipConfig)
@staticmethod @staticmethod
@ -188,8 +205,52 @@ class PaddedClip:
return PaddedClip(chunk_size=config.chunk_size) return PaddedClip(chunk_size=config.chunk_size)
class FixedDurationClipConfig(BaseConfig):
name: Literal["fixed_duration"] = "fixed_duration"
duration: float = DEFAULT_TRAIN_CLIP_DURATION
class FixedDurationClip:
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
self.duration = duration
def __call__(
self,
clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation:
clip = self.get_subclip(clip_annotation.clip)
sound_events = select_sound_event_annotations(
clip_annotation,
clip,
min_overlap=0,
)
return clip_annotation.model_copy(
update=dict(
clip=clip,
sound_events=sound_events,
)
)
def get_subclip(self, clip: data.Clip) -> data.Clip:
return clip.model_copy(
update=dict(
end_time=clip.start_time + self.duration,
)
)
@clipper_registry.register(FixedDurationClipConfig)
@staticmethod
def from_config(config: FixedDurationClipConfig):
return FixedDurationClip(duration=config.duration)
ClipConfig = Annotated[ ClipConfig = Annotated[
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name") Union[
RandomClipConfig,
PaddedClipConfig,
FixedDurationClipConfig,
],
Field(discriminator="name"),
] ]

View File

@ -7,6 +7,7 @@ from batdetect2.audio import AudioConfig
from batdetect2.core import BaseConfig from batdetect2.core import BaseConfig
from batdetect2.core.configs import load_config from batdetect2.core.configs import load_config
from batdetect2.evaluate.config import EvaluationConfig from batdetect2.evaluate.config import EvaluationConfig
from batdetect2.inference.config import InferenceConfig
from batdetect2.models.config import BackboneConfig from batdetect2.models.config import BackboneConfig
from batdetect2.postprocess.config import PostprocessConfig from batdetect2.postprocess.config import PostprocessConfig
from batdetect2.preprocess.config import PreprocessingConfig from batdetect2.preprocess.config import PreprocessingConfig
@ -31,6 +32,7 @@ class BatDetect2Config(BaseConfig):
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig) postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
audio: AudioConfig = Field(default_factory=AudioConfig) audio: AudioConfig = Field(default_factory=AudioConfig)
targets: TargetConfig = Field(default_factory=TargetConfig) targets: TargetConfig = Field(default_factory=TargetConfig)
inference: InferenceConfig = Field(default_factory=InferenceConfig)
def load_full_config( def load_full_config(

View File

@ -40,13 +40,14 @@ def evaluate(
config = config or BatDetect2Config() config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader() audio_loader = audio_loader or build_audio_loader(config=config.audio)
preprocessor = preprocessor or build_preprocessor( preprocessor = preprocessor or build_preprocessor(
config=config.preprocess,
input_samplerate=audio_loader.samplerate, input_samplerate=audio_loader.samplerate,
) )
targets = targets or build_targets() targets = targets or build_targets(config=config.targets)
loader = build_test_loader( loader = build_test_loader(
test_annotations, test_annotations,

View File

@ -1,4 +1,4 @@
from typing import Sequence from typing import Any
from lightning import LightningModule from lightning import LightningModule
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
@ -7,7 +7,7 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
from batdetect2.logging import get_image_logger from batdetect2.logging import get_image_logger
from batdetect2.models import Model from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing import ClipMatches, EvaluatorProtocol from batdetect2.typing import EvaluatorProtocol
class EvaluationModule(LightningModule): class EvaluationModule(LightningModule):
@ -54,7 +54,7 @@ class EvaluationModule(LightningModule):
self.log_metrics(self.clip_evaluations) self.log_metrics(self.clip_evaluations)
self.plot_examples(self.clip_evaluations) self.plot_examples(self.clip_evaluations)
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]): def plot_examples(self, evaluated_clips: Any):
plotter = get_image_logger(self.logger) # type: ignore plotter = get_image_logger(self.logger) # type: ignore
if plotter is None: if plotter is None:
@ -63,7 +63,7 @@ class EvaluationModule(LightningModule):
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips): for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
plotter(figure_name, fig, self.global_step) plotter(figure_name, fig, self.global_step)
def log_metrics(self, evaluated_clips: Sequence[ClipMatches]): def log_metrics(self, evaluated_clips: Any):
metrics = self.evaluator.compute_metrics(evaluated_clips) metrics = self.evaluator.compute_metrics(evaluated_clips)
self.log_dict(metrics) self.log_dict(metrics)

View File

@ -0,0 +1,10 @@
from batdetect2.inference.batch import process_file_list, run_batch_inference
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.config import InferenceConfig
__all__ = [
"process_file_list",
"run_batch_inference",
"InferenceConfig",
"get_clips_from_files",
]

View File

@ -0,0 +1,88 @@
from typing import TYPE_CHECKING, List, Optional, Sequence
from lightning import Trainer
from soundevent import data
from batdetect2.audio.loader import build_audio_loader
from batdetect2.inference.clips import get_clips_from_files
from batdetect2.inference.dataset import build_inference_loader
from batdetect2.inference.lightning import InferenceModule
from batdetect2.models import Model
from batdetect2.preprocess.preprocessor import build_preprocessor
from batdetect2.targets.targets import build_targets
from batdetect2.typing.postprocess import BatDetect2Prediction
if TYPE_CHECKING:
from batdetect2.config import BatDetect2Config
from batdetect2.typing import (
AudioLoader,
PreprocessorProtocol,
TargetProtocol,
)
def run_batch_inference(
model,
clips: Sequence[data.Clip],
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
config: Optional["BatDetect2Config"] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
from batdetect2.config import BatDetect2Config
config = config or BatDetect2Config()
audio_loader = audio_loader or build_audio_loader()
preprocessor = preprocessor or build_preprocessor(
input_samplerate=audio_loader.samplerate,
)
targets = targets or build_targets()
loader = build_inference_loader(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config.inference.loader,
num_workers=num_workers,
)
module = InferenceModule(model)
trainer = Trainer(enable_checkpointing=False)
outputs = trainer.predict(module, loader)
return [
clip_prediction
for clip_predictions in outputs # type: ignore
for clip_prediction in clip_predictions
]
def process_file_list(
model: Model,
paths: Sequence[data.PathLike],
config: "BatDetect2Config",
targets: Optional["TargetProtocol"] = None,
audio_loader: Optional["AudioLoader"] = None,
preprocessor: Optional["PreprocessorProtocol"] = None,
num_workers: Optional[int] = None,
) -> List[BatDetect2Prediction]:
clip_config = config.inference.clipping
clips = get_clips_from_files(
paths,
duration=clip_config.duration,
overlap=clip_config.overlap,
max_empty=clip_config.max_empty,
discard_empty=clip_config.discard_empty,
)
return run_batch_inference(
model,
clips,
targets=targets,
audio_loader=audio_loader,
preprocessor=preprocessor,
config=config,
num_workers=num_workers,
)

View File

@ -0,0 +1,75 @@
from typing import List, Sequence
from uuid import uuid5
import numpy as np
from soundevent import data
def get_clips_from_files(
paths: Sequence[data.PathLike],
duration: float,
overlap: float = 0.0,
max_empty: float = 0.0,
discard_empty: bool = True,
compute_hash: bool = False,
) -> List[data.Clip]:
clips: List[data.Clip] = []
for path in paths:
recording = data.Recording.from_file(path, compute_hash=compute_hash)
clips.extend(
get_recording_clips(
recording,
duration,
overlap=overlap,
max_empty=max_empty,
discard_empty=discard_empty,
)
)
return clips
def get_recording_clips(
recording: data.Recording,
duration: float,
overlap: float = 0.0,
max_empty: float = 0.0,
discard_empty: bool = True,
) -> Sequence[data.Clip]:
start_time = 0
duration = recording.duration
hop = duration * (1 - overlap)
num_clips = int(np.ceil(duration / hop))
if num_clips == 0:
# This should only happen if the clip's duration is zero,
# which should never happen in practice, but just in case...
return []
clips = []
for i in range(num_clips):
start = start_time + i * hop
end = start + duration
if end > duration:
empty_duration = end - duration
if empty_duration > max_empty and discard_empty:
# Discard clips that contain too much empty space
continue
clips.append(
data.Clip(
uuid=uuid5(recording.uuid, f"{start}_{end}"),
recording=recording,
start_time=start,
end_time=end,
)
)
if discard_empty:
clips = [clip for clip in clips if clip.duration > max_empty]
return clips

View File

@ -0,0 +1,21 @@
from pydantic import Field
from batdetect2.core.configs import BaseConfig
from batdetect2.inference.dataset import InferenceLoaderConfig
__all__ = ["InferenceConfig"]
class ClipingConfig(BaseConfig):
enabled: bool = True
duration: float = 0.5
overlap: float = 0.0
max_empty: float = 0.0
discard_empty: bool = True
class InferenceConfig(BaseConfig):
loader: InferenceLoaderConfig = Field(
default_factory=InferenceLoaderConfig
)
clipping: ClipingConfig = Field(default_factory=ClipingConfig)

View File

@ -0,0 +1,120 @@
from typing import List, NamedTuple, Optional, Sequence
import torch
from loguru import logger
from soundevent import data
from torch.utils.data import DataLoader, Dataset
from batdetect2.audio import build_audio_loader
from batdetect2.core import BaseConfig
from batdetect2.core.arrays import adjust_width
from batdetect2.preprocess import build_preprocessor
from batdetect2.typing import AudioLoader, PreprocessorProtocol
__all__ = [
"InferenceDataset",
"build_inference_dataset",
"build_inference_loader",
]
DEFAULT_INFERENCE_CLIP_DURATION = 0.512
class DatasetItem(NamedTuple):
spec: torch.Tensor
idx: torch.Tensor
start_time: torch.Tensor
end_time: torch.Tensor
class InferenceDataset(Dataset[DatasetItem]):
clips: List[data.Clip]
def __init__(
self,
clips: Sequence[data.Clip],
audio_loader: AudioLoader,
preprocessor: PreprocessorProtocol,
audio_dir: Optional[data.PathLike] = None,
):
self.clips = list(clips)
self.preprocessor = preprocessor
self.audio_loader = audio_loader
self.audio_dir = audio_dir
def __len__(self):
return len(self.clips)
def __getitem__(self, idx: int) -> DatasetItem:
clip = self.clips[idx]
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
wav_tensor = torch.tensor(wav).unsqueeze(0)
spectrogram = self.preprocessor(wav_tensor)
return DatasetItem(
spec=spectrogram,
idx=torch.tensor(idx),
start_time=torch.tensor(clip.start_time),
end_time=torch.tensor(clip.end_time),
)
class InferenceLoaderConfig(BaseConfig):
num_workers: int = 0
batch_size: int = 8
def build_inference_loader(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
config: Optional[InferenceLoaderConfig] = None,
num_workers: Optional[int] = None,
) -> DataLoader[DatasetItem]:
logger.info("Building inference data loader...")
config = config or InferenceLoaderConfig()
inference_dataset = build_inference_dataset(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
)
num_workers = num_workers or config.num_workers
return DataLoader(
inference_dataset,
batch_size=config.batch_size,
shuffle=False,
num_workers=config.num_workers,
collate_fn=_collate_fn,
)
def build_inference_dataset(
clips: Sequence[data.Clip],
audio_loader: Optional[AudioLoader] = None,
preprocessor: Optional[PreprocessorProtocol] = None,
) -> InferenceDataset:
if audio_loader is None:
audio_loader = build_audio_loader()
if preprocessor is None:
preprocessor = build_preprocessor()
return InferenceDataset(
clips,
audio_loader=audio_loader,
preprocessor=preprocessor,
)
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
max_width = max(item.spec.shape[-1] for item in batch)
return DatasetItem(
spec=torch.stack(
[adjust_width(item.spec, max_width) for item in batch]
),
idx=torch.stack([item.idx for item in batch]),
start_time=torch.stack([item.start_time for item in batch]),
end_time=torch.stack([item.end_time for item in batch]),
)

View File

@ -0,0 +1,52 @@
from typing import Sequence
from lightning import LightningModule
from torch.utils.data import DataLoader
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
from batdetect2.models import Model
from batdetect2.postprocess import to_raw_predictions
from batdetect2.typing.postprocess import BatDetect2Prediction
class InferenceModule(LightningModule):
def __init__(self, model: Model):
super().__init__()
self.model = model
def predict_step(
self,
batch: DatasetItem,
batch_idx: int,
dataloader_idx: int = 0,
) -> Sequence[BatDetect2Prediction]:
dataset = self.get_dataset()
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
outputs = self.model.detector(batch.spec)
clip_detections = self.model.postprocessor(
outputs,
start_times=[clip.start_time for clip in clips],
)
predictions = [
BatDetect2Prediction(
clip=clip,
predictions=to_raw_predictions(
clip_dets.numpy(),
targets=self.model.targets,
),
)
for clip, clip_dets in zip(clips, clip_detections)
]
return predictions
def get_dataset(self) -> InferenceDataset:
dataloaders = self.trainer.predict_dataloaders
assert isinstance(dataloaders, DataLoader)
dataset = dataloaders.dataset
assert isinstance(dataset, InferenceDataset)
return dataset

View File

@ -1,9 +1,9 @@
from batdetect2.train.augmentations import ( from batdetect2.train.augmentations import (
AugmentationsConfig,
AddEchoConfig, AddEchoConfig,
AugmentationsConfig,
MaskFrequencyConfig, MaskFrequencyConfig,
RandomAudioSource,
MaskTimeConfig, MaskTimeConfig,
RandomAudioSource,
ScaleVolumeConfig, ScaleVolumeConfig,
WarpConfig, WarpConfig,
add_echo, add_echo,
@ -28,7 +28,10 @@ from batdetect2.train.dataset import (
build_val_loader, build_val_loader,
) )
from batdetect2.train.labels import build_clip_labeler, load_label_config from batdetect2.train.labels import build_clip_labeler, load_label_config
from batdetect2.train.lightning import TrainingModule from batdetect2.train.lightning import (
TrainingModule,
load_model_from_checkpoint,
)
from batdetect2.train.losses import ( from batdetect2.train.losses import (
ClassificationLossConfig, ClassificationLossConfig,
DetectionLossConfig, DetectionLossConfig,
@ -74,4 +77,5 @@ __all__ = [
"scale_volume", "scale_volume",
"train", "train",
"warp_spectrogram", "warp_spectrogram",
"load_model_from_checkpoint",
] ]

View File

@ -83,8 +83,8 @@ class ClipDetectionsTensor(NamedTuple):
@dataclass @dataclass
class BatDetect2Prediction: class BatDetect2Prediction:
raw: RawPrediction clip: data.Clip
sound_event_prediction: data.SoundEventPrediction predictions: List[RawPrediction]
class PostprocessorProtocol(Protocol): class PostprocessorProtocol(Protocol):

View File

@ -1,4 +1,4 @@
from typing import Callable, NamedTuple, Protocol, Tuple from typing import Callable, List, NamedTuple, Protocol, Tuple
import torch import torch
from soundevent import data from soundevent import data
@ -104,3 +104,5 @@ class ClipperProtocol(Protocol):
self, self,
clip_annotation: data.ClipAnnotation, clip_annotation: data.ClipAnnotation,
) -> data.ClipAnnotation: ... ) -> data.ClipAnnotation: ...
def get_subclip(self, clip: data.Clip) -> data.Clip: ...