mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-09 16:59:33 +01:00
Writing batch inference code
This commit is contained in:
parent
30159d64a9
commit
981e37c346
@ -1,27 +1,29 @@
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.evaluate import build_evaluator, evaluate
|
||||
from batdetect2.inference import process_file_list, run_batch_inference
|
||||
from batdetect2.models import Model, build_model
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.postprocess.decoding import to_raw_predictions
|
||||
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.train import train
|
||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
||||
from batdetect2.targets import build_targets
|
||||
from batdetect2.train import load_model_from_checkpoint, train
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
BatDetect2Prediction,
|
||||
EvaluatorProtocol,
|
||||
PostprocessorProtocol,
|
||||
PreprocessorProtocol,
|
||||
RawPrediction,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.typing.postprocess import RawPrediction
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
@ -95,17 +97,94 @@ class BatDetect2API:
|
||||
run_name=run_name,
|
||||
)
|
||||
|
||||
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||
return self.audio_loader.load_file(path)
|
||||
|
||||
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||
return self.audio_loader.load_clip(clip)
|
||||
|
||||
def generate_spectrogram(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
wav = self.audio_loader.load_recording(recording)
|
||||
detections = self.process_audio(wav)
|
||||
return BatDetect2Prediction(
|
||||
clip=data.Clip(
|
||||
uuid=recording.uuid,
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
),
|
||||
predictions=detections,
|
||||
)
|
||||
|
||||
def process_audio(
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> List[RawPrediction]:
|
||||
spec = self.generate_spectrogram(audio)
|
||||
return self.process_spectrogram(spec)
|
||||
|
||||
def process_spectrogram(
|
||||
self,
|
||||
spec: torch.Tensor,
|
||||
start_times: Optional[Sequence[float]] = None,
|
||||
) -> List[List[RawPrediction]]:
|
||||
start_time: float = 0,
|
||||
) -> List[RawPrediction]:
|
||||
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||
raise ValueError("Batched spectrograms not supported.")
|
||||
|
||||
if spec.ndim == 3:
|
||||
spec = spec.unsqueeze(0)
|
||||
|
||||
outputs = self.model.detector(spec)
|
||||
clip_detections = self.postprocessor(outputs, start_times=start_times)
|
||||
return [
|
||||
to_raw_predictions(clip_dets.numpy(), self.targets)
|
||||
for clip_dets in clip_detections
|
||||
]
|
||||
|
||||
detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[start_time],
|
||||
)[0]
|
||||
|
||||
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
||||
|
||||
def process_directory(
|
||||
self,
|
||||
audio_dir: data.PathLike,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(files)
|
||||
|
||||
def process_files(
|
||||
self,
|
||||
audio_files: Sequence[data.PathLike],
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
config=self.config,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
def process_clips(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
) -> List[BatDetect2Prediction]:
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
targets=self.targets,
|
||||
audio_loader=self.audio_loader,
|
||||
preprocessor=self.preprocessor,
|
||||
config=self.config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: BatDetect2Config):
|
||||
@ -48,12 +48,25 @@ class RandomClip:
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
return get_subclip_annotation(
|
||||
subclip = self.get_subclip(clip_annotation.clip)
|
||||
sound_events = select_sound_event_annotations(
|
||||
clip_annotation,
|
||||
subclip,
|
||||
min_overlap=self.min_sound_event_overlap,
|
||||
)
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
clip=subclip,
|
||||
sound_events=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||
return select_random_subclip(
|
||||
clip,
|
||||
random=self.random,
|
||||
duration=self.duration,
|
||||
max_empty=self.max_empty,
|
||||
min_sound_event_overlap=self.min_sound_event_overlap,
|
||||
)
|
||||
|
||||
@clipper_registry.register(RandomClipConfig)
|
||||
@ -75,7 +88,7 @@ def get_subclip_annotation(
|
||||
) -> data.ClipAnnotation:
|
||||
clip = clip_annotation.clip
|
||||
|
||||
subclip = select_subclip(
|
||||
subclip = select_random_subclip(
|
||||
clip,
|
||||
random=random,
|
||||
duration=duration,
|
||||
@ -96,7 +109,7 @@ def get_subclip_annotation(
|
||||
)
|
||||
|
||||
|
||||
def select_subclip(
|
||||
def select_random_subclip(
|
||||
clip: data.Clip,
|
||||
random: bool = True,
|
||||
duration: float = 0.5,
|
||||
@ -170,6 +183,10 @@ class PaddedClip:
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
clip = clip_annotation.clip
|
||||
clip = self.get_subclip(clip)
|
||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||
duration = clip.duration
|
||||
|
||||
target_duration = float(
|
||||
@ -180,7 +197,7 @@ class PaddedClip:
|
||||
end_time=clip.start_time + target_duration,
|
||||
)
|
||||
)
|
||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
||||
return clip
|
||||
|
||||
@clipper_registry.register(PaddedClipConfig)
|
||||
@staticmethod
|
||||
@ -188,8 +205,52 @@ class PaddedClip:
|
||||
return PaddedClip(chunk_size=config.chunk_size)
|
||||
|
||||
|
||||
class FixedDurationClipConfig(BaseConfig):
|
||||
name: Literal["fixed_duration"] = "fixed_duration"
|
||||
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||
|
||||
|
||||
class FixedDurationClip:
|
||||
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||
self.duration = duration
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation:
|
||||
clip = self.get_subclip(clip_annotation.clip)
|
||||
sound_events = select_sound_event_annotations(
|
||||
clip_annotation,
|
||||
clip,
|
||||
min_overlap=0,
|
||||
)
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
clip=clip,
|
||||
sound_events=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||
return clip.model_copy(
|
||||
update=dict(
|
||||
end_time=clip.start_time + self.duration,
|
||||
)
|
||||
)
|
||||
|
||||
@clipper_registry.register(FixedDurationClipConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FixedDurationClipConfig):
|
||||
return FixedDurationClip(duration=config.duration)
|
||||
|
||||
|
||||
ClipConfig = Annotated[
|
||||
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
|
||||
Union[
|
||||
RandomClipConfig,
|
||||
PaddedClipConfig,
|
||||
FixedDurationClipConfig,
|
||||
],
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from batdetect2.audio import AudioConfig
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.configs import load_config
|
||||
from batdetect2.evaluate.config import EvaluationConfig
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
from batdetect2.models.config import BackboneConfig
|
||||
from batdetect2.postprocess.config import PostprocessConfig
|
||||
from batdetect2.preprocess.config import PreprocessingConfig
|
||||
@ -31,6 +32,7 @@ class BatDetect2Config(BaseConfig):
|
||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||
|
||||
|
||||
def load_full_config(
|
||||
|
||||
@ -40,13 +40,14 @@ def evaluate(
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
config=config.preprocess,
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
targets = targets or build_targets()
|
||||
targets = targets or build_targets(config=config.targets)
|
||||
|
||||
loader = build_test_loader(
|
||||
test_annotations,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Sequence
|
||||
from typing import Any
|
||||
|
||||
from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
@ -7,7 +7,7 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
|
||||
from batdetect2.logging import get_image_logger
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing import ClipMatches, EvaluatorProtocol
|
||||
from batdetect2.typing import EvaluatorProtocol
|
||||
|
||||
|
||||
class EvaluationModule(LightningModule):
|
||||
@ -54,7 +54,7 @@ class EvaluationModule(LightningModule):
|
||||
self.log_metrics(self.clip_evaluations)
|
||||
self.plot_examples(self.clip_evaluations)
|
||||
|
||||
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
|
||||
def plot_examples(self, evaluated_clips: Any):
|
||||
plotter = get_image_logger(self.logger) # type: ignore
|
||||
|
||||
if plotter is None:
|
||||
@ -63,7 +63,7 @@ class EvaluationModule(LightningModule):
|
||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||
plotter(figure_name, fig, self.global_step)
|
||||
|
||||
def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
|
||||
def log_metrics(self, evaluated_clips: Any):
|
||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||
self.log_dict(metrics)
|
||||
|
||||
|
||||
@ -0,0 +1,10 @@
|
||||
from batdetect2.inference.batch import process_file_list, run_batch_inference
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.config import InferenceConfig
|
||||
|
||||
__all__ = [
|
||||
"process_file_list",
|
||||
"run_batch_inference",
|
||||
"InferenceConfig",
|
||||
"get_clips_from_files",
|
||||
]
|
||||
88
src/batdetect2/inference/batch.py
Normal file
88
src/batdetect2/inference/batch.py
Normal file
@ -0,0 +1,88 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||
|
||||
from lightning import Trainer
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.audio.loader import build_audio_loader
|
||||
from batdetect2.inference.clips import get_clips_from_files
|
||||
from batdetect2.inference.dataset import build_inference_loader
|
||||
from batdetect2.inference.lightning import InferenceModule
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||
from batdetect2.targets.targets import build_targets
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.typing import (
|
||||
AudioLoader,
|
||||
PreprocessorProtocol,
|
||||
TargetProtocol,
|
||||
)
|
||||
|
||||
|
||||
def run_batch_inference(
|
||||
model,
|
||||
clips: Sequence[data.Clip],
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
config: Optional["BatDetect2Config"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
from batdetect2.config import BatDetect2Config
|
||||
|
||||
config = config or BatDetect2Config()
|
||||
|
||||
audio_loader = audio_loader or build_audio_loader()
|
||||
|
||||
preprocessor = preprocessor or build_preprocessor(
|
||||
input_samplerate=audio_loader.samplerate,
|
||||
)
|
||||
|
||||
targets = targets or build_targets()
|
||||
|
||||
loader = build_inference_loader(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config.inference.loader,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
|
||||
module = InferenceModule(model)
|
||||
trainer = Trainer(enable_checkpointing=False)
|
||||
outputs = trainer.predict(module, loader)
|
||||
return [
|
||||
clip_prediction
|
||||
for clip_predictions in outputs # type: ignore
|
||||
for clip_prediction in clip_predictions
|
||||
]
|
||||
|
||||
|
||||
def process_file_list(
|
||||
model: Model,
|
||||
paths: Sequence[data.PathLike],
|
||||
config: "BatDetect2Config",
|
||||
targets: Optional["TargetProtocol"] = None,
|
||||
audio_loader: Optional["AudioLoader"] = None,
|
||||
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> List[BatDetect2Prediction]:
|
||||
clip_config = config.inference.clipping
|
||||
clips = get_clips_from_files(
|
||||
paths,
|
||||
duration=clip_config.duration,
|
||||
overlap=clip_config.overlap,
|
||||
max_empty=clip_config.max_empty,
|
||||
discard_empty=clip_config.discard_empty,
|
||||
)
|
||||
return run_batch_inference(
|
||||
model,
|
||||
clips,
|
||||
targets=targets,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
config=config,
|
||||
num_workers=num_workers,
|
||||
)
|
||||
75
src/batdetect2/inference/clips.py
Normal file
75
src/batdetect2/inference/clips.py
Normal file
@ -0,0 +1,75 @@
|
||||
from typing import List, Sequence
|
||||
from uuid import uuid5
|
||||
|
||||
import numpy as np
|
||||
from soundevent import data
|
||||
|
||||
|
||||
def get_clips_from_files(
|
||||
paths: Sequence[data.PathLike],
|
||||
duration: float,
|
||||
overlap: float = 0.0,
|
||||
max_empty: float = 0.0,
|
||||
discard_empty: bool = True,
|
||||
compute_hash: bool = False,
|
||||
) -> List[data.Clip]:
|
||||
clips: List[data.Clip] = []
|
||||
|
||||
for path in paths:
|
||||
recording = data.Recording.from_file(path, compute_hash=compute_hash)
|
||||
clips.extend(
|
||||
get_recording_clips(
|
||||
recording,
|
||||
duration,
|
||||
overlap=overlap,
|
||||
max_empty=max_empty,
|
||||
discard_empty=discard_empty,
|
||||
)
|
||||
)
|
||||
|
||||
return clips
|
||||
|
||||
|
||||
def get_recording_clips(
|
||||
recording: data.Recording,
|
||||
duration: float,
|
||||
overlap: float = 0.0,
|
||||
max_empty: float = 0.0,
|
||||
discard_empty: bool = True,
|
||||
) -> Sequence[data.Clip]:
|
||||
start_time = 0
|
||||
duration = recording.duration
|
||||
hop = duration * (1 - overlap)
|
||||
|
||||
num_clips = int(np.ceil(duration / hop))
|
||||
|
||||
if num_clips == 0:
|
||||
# This should only happen if the clip's duration is zero,
|
||||
# which should never happen in practice, but just in case...
|
||||
return []
|
||||
|
||||
clips = []
|
||||
for i in range(num_clips):
|
||||
start = start_time + i * hop
|
||||
end = start + duration
|
||||
|
||||
if end > duration:
|
||||
empty_duration = end - duration
|
||||
|
||||
if empty_duration > max_empty and discard_empty:
|
||||
# Discard clips that contain too much empty space
|
||||
continue
|
||||
|
||||
clips.append(
|
||||
data.Clip(
|
||||
uuid=uuid5(recording.uuid, f"{start}_{end}"),
|
||||
recording=recording,
|
||||
start_time=start,
|
||||
end_time=end,
|
||||
)
|
||||
)
|
||||
|
||||
if discard_empty:
|
||||
clips = [clip for clip in clips if clip.duration > max_empty]
|
||||
|
||||
return clips
|
||||
21
src/batdetect2/inference/config.py
Normal file
21
src/batdetect2/inference/config.py
Normal file
@ -0,0 +1,21 @@
|
||||
from pydantic import Field
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.inference.dataset import InferenceLoaderConfig
|
||||
|
||||
__all__ = ["InferenceConfig"]
|
||||
|
||||
|
||||
class ClipingConfig(BaseConfig):
|
||||
enabled: bool = True
|
||||
duration: float = 0.5
|
||||
overlap: float = 0.0
|
||||
max_empty: float = 0.0
|
||||
discard_empty: bool = True
|
||||
|
||||
|
||||
class InferenceConfig(BaseConfig):
|
||||
loader: InferenceLoaderConfig = Field(
|
||||
default_factory=InferenceLoaderConfig
|
||||
)
|
||||
clipping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||
120
src/batdetect2/inference/dataset.py
Normal file
120
src/batdetect2/inference/dataset.py
Normal file
@ -0,0 +1,120 @@
|
||||
from typing import List, NamedTuple, Optional, Sequence
|
||||
|
||||
import torch
|
||||
from loguru import logger
|
||||
from soundevent import data
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.core import BaseConfig
|
||||
from batdetect2.core.arrays import adjust_width
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||
|
||||
__all__ = [
|
||||
"InferenceDataset",
|
||||
"build_inference_dataset",
|
||||
"build_inference_loader",
|
||||
]
|
||||
|
||||
|
||||
DEFAULT_INFERENCE_CLIP_DURATION = 0.512
|
||||
|
||||
|
||||
class DatasetItem(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
start_time: torch.Tensor
|
||||
end_time: torch.Tensor
|
||||
|
||||
|
||||
class InferenceDataset(Dataset[DatasetItem]):
|
||||
clips: List[data.Clip]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: AudioLoader,
|
||||
preprocessor: PreprocessorProtocol,
|
||||
audio_dir: Optional[data.PathLike] = None,
|
||||
):
|
||||
self.clips = list(clips)
|
||||
self.preprocessor = preprocessor
|
||||
self.audio_loader = audio_loader
|
||||
self.audio_dir = audio_dir
|
||||
|
||||
def __len__(self):
|
||||
return len(self.clips)
|
||||
|
||||
def __getitem__(self, idx: int) -> DatasetItem:
|
||||
clip = self.clips[idx]
|
||||
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||
spectrogram = self.preprocessor(wav_tensor)
|
||||
return DatasetItem(
|
||||
spec=spectrogram,
|
||||
idx=torch.tensor(idx),
|
||||
start_time=torch.tensor(clip.start_time),
|
||||
end_time=torch.tensor(clip.end_time),
|
||||
)
|
||||
|
||||
|
||||
class InferenceLoaderConfig(BaseConfig):
|
||||
num_workers: int = 0
|
||||
batch_size: int = 8
|
||||
|
||||
|
||||
def build_inference_loader(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
config: Optional[InferenceLoaderConfig] = None,
|
||||
num_workers: Optional[int] = None,
|
||||
) -> DataLoader[DatasetItem]:
|
||||
logger.info("Building inference data loader...")
|
||||
config = config or InferenceLoaderConfig()
|
||||
|
||||
inference_dataset = build_inference_dataset(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
num_workers = num_workers or config.num_workers
|
||||
return DataLoader(
|
||||
inference_dataset,
|
||||
batch_size=config.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=config.num_workers,
|
||||
collate_fn=_collate_fn,
|
||||
)
|
||||
|
||||
|
||||
def build_inference_dataset(
|
||||
clips: Sequence[data.Clip],
|
||||
audio_loader: Optional[AudioLoader] = None,
|
||||
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||
) -> InferenceDataset:
|
||||
if audio_loader is None:
|
||||
audio_loader = build_audio_loader()
|
||||
|
||||
if preprocessor is None:
|
||||
preprocessor = build_preprocessor()
|
||||
|
||||
return InferenceDataset(
|
||||
clips,
|
||||
audio_loader=audio_loader,
|
||||
preprocessor=preprocessor,
|
||||
)
|
||||
|
||||
|
||||
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
||||
max_width = max(item.spec.shape[-1] for item in batch)
|
||||
return DatasetItem(
|
||||
spec=torch.stack(
|
||||
[adjust_width(item.spec, max_width) for item in batch]
|
||||
),
|
||||
idx=torch.stack([item.idx for item in batch]),
|
||||
start_time=torch.stack([item.start_time for item in batch]),
|
||||
end_time=torch.stack([item.end_time for item in batch]),
|
||||
)
|
||||
52
src/batdetect2/inference/lightning.py
Normal file
52
src/batdetect2/inference/lightning.py
Normal file
@ -0,0 +1,52 @@
|
||||
from typing import Sequence
|
||||
|
||||
from lightning import LightningModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||
from batdetect2.models import Model
|
||||
from batdetect2.postprocess import to_raw_predictions
|
||||
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||
|
||||
|
||||
class InferenceModule(LightningModule):
|
||||
def __init__(self, model: Model):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
|
||||
def predict_step(
|
||||
self,
|
||||
batch: DatasetItem,
|
||||
batch_idx: int,
|
||||
dataloader_idx: int = 0,
|
||||
) -> Sequence[BatDetect2Prediction]:
|
||||
dataset = self.get_dataset()
|
||||
|
||||
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
||||
|
||||
outputs = self.model.detector(batch.spec)
|
||||
|
||||
clip_detections = self.model.postprocessor(
|
||||
outputs,
|
||||
start_times=[clip.start_time for clip in clips],
|
||||
)
|
||||
|
||||
predictions = [
|
||||
BatDetect2Prediction(
|
||||
clip=clip,
|
||||
predictions=to_raw_predictions(
|
||||
clip_dets.numpy(),
|
||||
targets=self.model.targets,
|
||||
),
|
||||
)
|
||||
for clip, clip_dets in zip(clips, clip_detections)
|
||||
]
|
||||
|
||||
return predictions
|
||||
|
||||
def get_dataset(self) -> InferenceDataset:
|
||||
dataloaders = self.trainer.predict_dataloaders
|
||||
assert isinstance(dataloaders, DataLoader)
|
||||
dataset = dataloaders.dataset
|
||||
assert isinstance(dataset, InferenceDataset)
|
||||
return dataset
|
||||
@ -1,9 +1,9 @@
|
||||
from batdetect2.train.augmentations import (
|
||||
AugmentationsConfig,
|
||||
AddEchoConfig,
|
||||
AugmentationsConfig,
|
||||
MaskFrequencyConfig,
|
||||
RandomAudioSource,
|
||||
MaskTimeConfig,
|
||||
RandomAudioSource,
|
||||
ScaleVolumeConfig,
|
||||
WarpConfig,
|
||||
add_echo,
|
||||
@ -28,7 +28,10 @@ from batdetect2.train.dataset import (
|
||||
build_val_loader,
|
||||
)
|
||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||
from batdetect2.train.lightning import TrainingModule
|
||||
from batdetect2.train.lightning import (
|
||||
TrainingModule,
|
||||
load_model_from_checkpoint,
|
||||
)
|
||||
from batdetect2.train.losses import (
|
||||
ClassificationLossConfig,
|
||||
DetectionLossConfig,
|
||||
@ -74,4 +77,5 @@ __all__ = [
|
||||
"scale_volume",
|
||||
"train",
|
||||
"warp_spectrogram",
|
||||
"load_model_from_checkpoint",
|
||||
]
|
||||
|
||||
@ -83,8 +83,8 @@ class ClipDetectionsTensor(NamedTuple):
|
||||
|
||||
@dataclass
|
||||
class BatDetect2Prediction:
|
||||
raw: RawPrediction
|
||||
sound_event_prediction: data.SoundEventPrediction
|
||||
clip: data.Clip
|
||||
predictions: List[RawPrediction]
|
||||
|
||||
|
||||
class PostprocessorProtocol(Protocol):
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
||||
from typing import Callable, List, NamedTuple, Protocol, Tuple
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
@ -104,3 +104,5 @@ class ClipperProtocol(Protocol):
|
||||
self,
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
) -> data.ClipAnnotation: ...
|
||||
|
||||
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
||||
|
||||
Loading…
Reference in New Issue
Block a user