mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-01-10 17:19:34 +01:00
Writing batch inference code
This commit is contained in:
parent
30159d64a9
commit
981e37c346
@ -1,27 +1,29 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Sequence
|
from typing import List, Optional, Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
from batdetect2.audio import build_audio_loader
|
from batdetect2.audio import build_audio_loader
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.evaluate import build_evaluator, evaluate
|
from batdetect2.evaluate import build_evaluator, evaluate
|
||||||
|
from batdetect2.inference import process_file_list, run_batch_inference
|
||||||
from batdetect2.models import Model, build_model
|
from batdetect2.models import Model, build_model
|
||||||
from batdetect2.postprocess import build_postprocessor
|
from batdetect2.postprocess import build_postprocessor, to_raw_predictions
|
||||||
from batdetect2.postprocess.decoding import to_raw_predictions
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
from batdetect2.preprocess import build_preprocessor
|
||||||
from batdetect2.targets.targets import build_targets
|
from batdetect2.targets import build_targets
|
||||||
from batdetect2.train import train
|
from batdetect2.train import load_model_from_checkpoint, train
|
||||||
from batdetect2.train.lightning import load_model_from_checkpoint
|
|
||||||
from batdetect2.typing import (
|
from batdetect2.typing import (
|
||||||
AudioLoader,
|
AudioLoader,
|
||||||
|
BatDetect2Prediction,
|
||||||
EvaluatorProtocol,
|
EvaluatorProtocol,
|
||||||
PostprocessorProtocol,
|
PostprocessorProtocol,
|
||||||
PreprocessorProtocol,
|
PreprocessorProtocol,
|
||||||
|
RawPrediction,
|
||||||
TargetProtocol,
|
TargetProtocol,
|
||||||
)
|
)
|
||||||
from batdetect2.typing.postprocess import RawPrediction
|
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
class BatDetect2API:
|
||||||
@ -95,17 +97,94 @@ class BatDetect2API:
|
|||||||
run_name=run_name,
|
run_name=run_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def load_audio(self, path: data.PathLike) -> np.ndarray:
|
||||||
|
return self.audio_loader.load_file(path)
|
||||||
|
|
||||||
|
def load_clip(self, clip: data.Clip) -> np.ndarray:
|
||||||
|
return self.audio_loader.load_clip(clip)
|
||||||
|
|
||||||
|
def generate_spectrogram(
|
||||||
|
self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tensor = torch.tensor(audio).unsqueeze(0)
|
||||||
|
return self.preprocessor(tensor)
|
||||||
|
|
||||||
|
def process_file(self, audio_file: str) -> BatDetect2Prediction:
|
||||||
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
|
wav = self.audio_loader.load_recording(recording)
|
||||||
|
detections = self.process_audio(wav)
|
||||||
|
return BatDetect2Prediction(
|
||||||
|
clip=data.Clip(
|
||||||
|
uuid=recording.uuid,
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
),
|
||||||
|
predictions=detections,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_audio(
|
||||||
|
self,
|
||||||
|
audio: np.ndarray,
|
||||||
|
) -> List[RawPrediction]:
|
||||||
|
spec = self.generate_spectrogram(audio)
|
||||||
|
return self.process_spectrogram(spec)
|
||||||
|
|
||||||
def process_spectrogram(
|
def process_spectrogram(
|
||||||
self,
|
self,
|
||||||
spec: torch.Tensor,
|
spec: torch.Tensor,
|
||||||
start_times: Optional[Sequence[float]] = None,
|
start_time: float = 0,
|
||||||
) -> List[List[RawPrediction]]:
|
) -> List[RawPrediction]:
|
||||||
|
if spec.ndim == 4 and spec.shape[0] > 1:
|
||||||
|
raise ValueError("Batched spectrograms not supported.")
|
||||||
|
|
||||||
|
if spec.ndim == 3:
|
||||||
|
spec = spec.unsqueeze(0)
|
||||||
|
|
||||||
outputs = self.model.detector(spec)
|
outputs = self.model.detector(spec)
|
||||||
clip_detections = self.postprocessor(outputs, start_times=start_times)
|
|
||||||
return [
|
detections = self.model.postprocessor(
|
||||||
to_raw_predictions(clip_dets.numpy(), self.targets)
|
outputs,
|
||||||
for clip_dets in clip_detections
|
start_times=[start_time],
|
||||||
]
|
)[0]
|
||||||
|
|
||||||
|
return to_raw_predictions(detections.numpy(), targets=self.targets)
|
||||||
|
|
||||||
|
def process_directory(
|
||||||
|
self,
|
||||||
|
audio_dir: data.PathLike,
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
files = list(get_audio_files(audio_dir))
|
||||||
|
return self.process_files(files)
|
||||||
|
|
||||||
|
def process_files(
|
||||||
|
self,
|
||||||
|
audio_files: Sequence[data.PathLike],
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
return process_file_list(
|
||||||
|
self.model,
|
||||||
|
audio_files,
|
||||||
|
config=self.config,
|
||||||
|
targets=self.targets,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_clips(
|
||||||
|
self,
|
||||||
|
clips: Sequence[data.Clip],
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
return run_batch_inference(
|
||||||
|
self.model,
|
||||||
|
clips,
|
||||||
|
targets=self.targets,
|
||||||
|
audio_loader=self.audio_loader,
|
||||||
|
preprocessor=self.preprocessor,
|
||||||
|
config=self.config,
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: BatDetect2Config):
|
def from_config(cls, config: BatDetect2Config):
|
||||||
@ -48,12 +48,25 @@ class RandomClip:
|
|||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
) -> data.ClipAnnotation:
|
) -> data.ClipAnnotation:
|
||||||
return get_subclip_annotation(
|
subclip = self.get_subclip(clip_annotation.clip)
|
||||||
|
sound_events = select_sound_event_annotations(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
|
subclip,
|
||||||
|
min_overlap=self.min_sound_event_overlap,
|
||||||
|
)
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
clip=subclip,
|
||||||
|
sound_events=sound_events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||||
|
return select_random_subclip(
|
||||||
|
clip,
|
||||||
random=self.random,
|
random=self.random,
|
||||||
duration=self.duration,
|
duration=self.duration,
|
||||||
max_empty=self.max_empty,
|
max_empty=self.max_empty,
|
||||||
min_sound_event_overlap=self.min_sound_event_overlap,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@clipper_registry.register(RandomClipConfig)
|
@clipper_registry.register(RandomClipConfig)
|
||||||
@ -75,7 +88,7 @@ def get_subclip_annotation(
|
|||||||
) -> data.ClipAnnotation:
|
) -> data.ClipAnnotation:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
subclip = select_subclip(
|
subclip = select_random_subclip(
|
||||||
clip,
|
clip,
|
||||||
random=random,
|
random=random,
|
||||||
duration=duration,
|
duration=duration,
|
||||||
@ -96,7 +109,7 @@ def get_subclip_annotation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def select_subclip(
|
def select_random_subclip(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
random: bool = True,
|
random: bool = True,
|
||||||
duration: float = 0.5,
|
duration: float = 0.5,
|
||||||
@ -170,6 +183,10 @@ class PaddedClip:
|
|||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
) -> data.ClipAnnotation:
|
) -> data.ClipAnnotation:
|
||||||
clip = clip_annotation.clip
|
clip = clip_annotation.clip
|
||||||
|
clip = self.get_subclip(clip)
|
||||||
|
return clip_annotation.model_copy(update=dict(clip=clip))
|
||||||
|
|
||||||
|
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||||
duration = clip.duration
|
duration = clip.duration
|
||||||
|
|
||||||
target_duration = float(
|
target_duration = float(
|
||||||
@ -180,7 +197,7 @@ class PaddedClip:
|
|||||||
end_time=clip.start_time + target_duration,
|
end_time=clip.start_time + target_duration,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return clip_annotation.model_copy(update=dict(clip=clip))
|
return clip
|
||||||
|
|
||||||
@clipper_registry.register(PaddedClipConfig)
|
@clipper_registry.register(PaddedClipConfig)
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -188,8 +205,52 @@ class PaddedClip:
|
|||||||
return PaddedClip(chunk_size=config.chunk_size)
|
return PaddedClip(chunk_size=config.chunk_size)
|
||||||
|
|
||||||
|
|
||||||
|
class FixedDurationClipConfig(BaseConfig):
|
||||||
|
name: Literal["fixed_duration"] = "fixed_duration"
|
||||||
|
duration: float = DEFAULT_TRAIN_CLIP_DURATION
|
||||||
|
|
||||||
|
|
||||||
|
class FixedDurationClip:
|
||||||
|
def __init__(self, duration: float = DEFAULT_TRAIN_CLIP_DURATION):
|
||||||
|
self.duration = duration
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
clip = self.get_subclip(clip_annotation.clip)
|
||||||
|
sound_events = select_sound_event_annotations(
|
||||||
|
clip_annotation,
|
||||||
|
clip,
|
||||||
|
min_overlap=0,
|
||||||
|
)
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=sound_events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_subclip(self, clip: data.Clip) -> data.Clip:
|
||||||
|
return clip.model_copy(
|
||||||
|
update=dict(
|
||||||
|
end_time=clip.start_time + self.duration,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@clipper_registry.register(FixedDurationClipConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: FixedDurationClipConfig):
|
||||||
|
return FixedDurationClip(duration=config.duration)
|
||||||
|
|
||||||
|
|
||||||
ClipConfig = Annotated[
|
ClipConfig = Annotated[
|
||||||
Union[RandomClipConfig, PaddedClipConfig], Field(discriminator="name")
|
Union[
|
||||||
|
RandomClipConfig,
|
||||||
|
PaddedClipConfig,
|
||||||
|
FixedDurationClipConfig,
|
||||||
|
],
|
||||||
|
Field(discriminator="name"),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from batdetect2.audio import AudioConfig
|
|||||||
from batdetect2.core import BaseConfig
|
from batdetect2.core import BaseConfig
|
||||||
from batdetect2.core.configs import load_config
|
from batdetect2.core.configs import load_config
|
||||||
from batdetect2.evaluate.config import EvaluationConfig
|
from batdetect2.evaluate.config import EvaluationConfig
|
||||||
|
from batdetect2.inference.config import InferenceConfig
|
||||||
from batdetect2.models.config import BackboneConfig
|
from batdetect2.models.config import BackboneConfig
|
||||||
from batdetect2.postprocess.config import PostprocessConfig
|
from batdetect2.postprocess.config import PostprocessConfig
|
||||||
from batdetect2.preprocess.config import PreprocessingConfig
|
from batdetect2.preprocess.config import PreprocessingConfig
|
||||||
@ -31,6 +32,7 @@ class BatDetect2Config(BaseConfig):
|
|||||||
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
postprocess: PostprocessConfig = Field(default_factory=PostprocessConfig)
|
||||||
audio: AudioConfig = Field(default_factory=AudioConfig)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
targets: TargetConfig = Field(default_factory=TargetConfig)
|
targets: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
|
inference: InferenceConfig = Field(default_factory=InferenceConfig)
|
||||||
|
|
||||||
|
|
||||||
def load_full_config(
|
def load_full_config(
|
||||||
|
|||||||
@ -40,13 +40,14 @@ def evaluate(
|
|||||||
|
|
||||||
config = config or BatDetect2Config()
|
config = config or BatDetect2Config()
|
||||||
|
|
||||||
audio_loader = audio_loader or build_audio_loader()
|
audio_loader = audio_loader or build_audio_loader(config=config.audio)
|
||||||
|
|
||||||
preprocessor = preprocessor or build_preprocessor(
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
|
config=config.preprocess,
|
||||||
input_samplerate=audio_loader.samplerate,
|
input_samplerate=audio_loader.samplerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
targets = targets or build_targets()
|
targets = targets or build_targets(config=config.targets)
|
||||||
|
|
||||||
loader = build_test_loader(
|
loader = build_test_loader(
|
||||||
test_annotations,
|
test_annotations,
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Sequence
|
from typing import Any
|
||||||
|
|
||||||
from lightning import LightningModule
|
from lightning import LightningModule
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
@ -7,7 +7,7 @@ from batdetect2.evaluate.dataset import TestDataset, TestExample
|
|||||||
from batdetect2.logging import get_image_logger
|
from batdetect2.logging import get_image_logger
|
||||||
from batdetect2.models import Model
|
from batdetect2.models import Model
|
||||||
from batdetect2.postprocess import to_raw_predictions
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
from batdetect2.typing import ClipMatches, EvaluatorProtocol
|
from batdetect2.typing import EvaluatorProtocol
|
||||||
|
|
||||||
|
|
||||||
class EvaluationModule(LightningModule):
|
class EvaluationModule(LightningModule):
|
||||||
@ -54,7 +54,7 @@ class EvaluationModule(LightningModule):
|
|||||||
self.log_metrics(self.clip_evaluations)
|
self.log_metrics(self.clip_evaluations)
|
||||||
self.plot_examples(self.clip_evaluations)
|
self.plot_examples(self.clip_evaluations)
|
||||||
|
|
||||||
def plot_examples(self, evaluated_clips: Sequence[ClipMatches]):
|
def plot_examples(self, evaluated_clips: Any):
|
||||||
plotter = get_image_logger(self.logger) # type: ignore
|
plotter = get_image_logger(self.logger) # type: ignore
|
||||||
|
|
||||||
if plotter is None:
|
if plotter is None:
|
||||||
@ -63,7 +63,7 @@ class EvaluationModule(LightningModule):
|
|||||||
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
for figure_name, fig in self.evaluator.generate_plots(evaluated_clips):
|
||||||
plotter(figure_name, fig, self.global_step)
|
plotter(figure_name, fig, self.global_step)
|
||||||
|
|
||||||
def log_metrics(self, evaluated_clips: Sequence[ClipMatches]):
|
def log_metrics(self, evaluated_clips: Any):
|
||||||
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
metrics = self.evaluator.compute_metrics(evaluated_clips)
|
||||||
self.log_dict(metrics)
|
self.log_dict(metrics)
|
||||||
|
|
||||||
|
|||||||
@ -0,0 +1,10 @@
|
|||||||
|
from batdetect2.inference.batch import process_file_list, run_batch_inference
|
||||||
|
from batdetect2.inference.clips import get_clips_from_files
|
||||||
|
from batdetect2.inference.config import InferenceConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"process_file_list",
|
||||||
|
"run_batch_inference",
|
||||||
|
"InferenceConfig",
|
||||||
|
"get_clips_from_files",
|
||||||
|
]
|
||||||
88
src/batdetect2/inference/batch.py
Normal file
88
src/batdetect2/inference/batch.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
from typing import TYPE_CHECKING, List, Optional, Sequence
|
||||||
|
|
||||||
|
from lightning import Trainer
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.audio.loader import build_audio_loader
|
||||||
|
from batdetect2.inference.clips import get_clips_from_files
|
||||||
|
from batdetect2.inference.dataset import build_inference_loader
|
||||||
|
from batdetect2.inference.lightning import InferenceModule
|
||||||
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.preprocess.preprocessor import build_preprocessor
|
||||||
|
from batdetect2.targets.targets import build_targets
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.typing import (
|
||||||
|
AudioLoader,
|
||||||
|
PreprocessorProtocol,
|
||||||
|
TargetProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def run_batch_inference(
|
||||||
|
model,
|
||||||
|
clips: Sequence[data.Clip],
|
||||||
|
targets: Optional["TargetProtocol"] = None,
|
||||||
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
|
config: Optional["BatDetect2Config"] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
from batdetect2.config import BatDetect2Config
|
||||||
|
|
||||||
|
config = config or BatDetect2Config()
|
||||||
|
|
||||||
|
audio_loader = audio_loader or build_audio_loader()
|
||||||
|
|
||||||
|
preprocessor = preprocessor or build_preprocessor(
|
||||||
|
input_samplerate=audio_loader.samplerate,
|
||||||
|
)
|
||||||
|
|
||||||
|
targets = targets or build_targets()
|
||||||
|
|
||||||
|
loader = build_inference_loader(
|
||||||
|
clips,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config.inference.loader,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
module = InferenceModule(model)
|
||||||
|
trainer = Trainer(enable_checkpointing=False)
|
||||||
|
outputs = trainer.predict(module, loader)
|
||||||
|
return [
|
||||||
|
clip_prediction
|
||||||
|
for clip_predictions in outputs # type: ignore
|
||||||
|
for clip_prediction in clip_predictions
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_file_list(
|
||||||
|
model: Model,
|
||||||
|
paths: Sequence[data.PathLike],
|
||||||
|
config: "BatDetect2Config",
|
||||||
|
targets: Optional["TargetProtocol"] = None,
|
||||||
|
audio_loader: Optional["AudioLoader"] = None,
|
||||||
|
preprocessor: Optional["PreprocessorProtocol"] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> List[BatDetect2Prediction]:
|
||||||
|
clip_config = config.inference.clipping
|
||||||
|
clips = get_clips_from_files(
|
||||||
|
paths,
|
||||||
|
duration=clip_config.duration,
|
||||||
|
overlap=clip_config.overlap,
|
||||||
|
max_empty=clip_config.max_empty,
|
||||||
|
discard_empty=clip_config.discard_empty,
|
||||||
|
)
|
||||||
|
return run_batch_inference(
|
||||||
|
model,
|
||||||
|
clips,
|
||||||
|
targets=targets,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
config=config,
|
||||||
|
num_workers=num_workers,
|
||||||
|
)
|
||||||
75
src/batdetect2/inference/clips.py
Normal file
75
src/batdetect2/inference/clips.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
from typing import List, Sequence
|
||||||
|
from uuid import uuid5
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
|
||||||
|
def get_clips_from_files(
|
||||||
|
paths: Sequence[data.PathLike],
|
||||||
|
duration: float,
|
||||||
|
overlap: float = 0.0,
|
||||||
|
max_empty: float = 0.0,
|
||||||
|
discard_empty: bool = True,
|
||||||
|
compute_hash: bool = False,
|
||||||
|
) -> List[data.Clip]:
|
||||||
|
clips: List[data.Clip] = []
|
||||||
|
|
||||||
|
for path in paths:
|
||||||
|
recording = data.Recording.from_file(path, compute_hash=compute_hash)
|
||||||
|
clips.extend(
|
||||||
|
get_recording_clips(
|
||||||
|
recording,
|
||||||
|
duration,
|
||||||
|
overlap=overlap,
|
||||||
|
max_empty=max_empty,
|
||||||
|
discard_empty=discard_empty,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return clips
|
||||||
|
|
||||||
|
|
||||||
|
def get_recording_clips(
|
||||||
|
recording: data.Recording,
|
||||||
|
duration: float,
|
||||||
|
overlap: float = 0.0,
|
||||||
|
max_empty: float = 0.0,
|
||||||
|
discard_empty: bool = True,
|
||||||
|
) -> Sequence[data.Clip]:
|
||||||
|
start_time = 0
|
||||||
|
duration = recording.duration
|
||||||
|
hop = duration * (1 - overlap)
|
||||||
|
|
||||||
|
num_clips = int(np.ceil(duration / hop))
|
||||||
|
|
||||||
|
if num_clips == 0:
|
||||||
|
# This should only happen if the clip's duration is zero,
|
||||||
|
# which should never happen in practice, but just in case...
|
||||||
|
return []
|
||||||
|
|
||||||
|
clips = []
|
||||||
|
for i in range(num_clips):
|
||||||
|
start = start_time + i * hop
|
||||||
|
end = start + duration
|
||||||
|
|
||||||
|
if end > duration:
|
||||||
|
empty_duration = end - duration
|
||||||
|
|
||||||
|
if empty_duration > max_empty and discard_empty:
|
||||||
|
# Discard clips that contain too much empty space
|
||||||
|
continue
|
||||||
|
|
||||||
|
clips.append(
|
||||||
|
data.Clip(
|
||||||
|
uuid=uuid5(recording.uuid, f"{start}_{end}"),
|
||||||
|
recording=recording,
|
||||||
|
start_time=start,
|
||||||
|
end_time=end,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if discard_empty:
|
||||||
|
clips = [clip for clip in clips if clip.duration > max_empty]
|
||||||
|
|
||||||
|
return clips
|
||||||
21
src/batdetect2/inference/config.py
Normal file
21
src/batdetect2/inference/config.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.inference.dataset import InferenceLoaderConfig
|
||||||
|
|
||||||
|
__all__ = ["InferenceConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
class ClipingConfig(BaseConfig):
|
||||||
|
enabled: bool = True
|
||||||
|
duration: float = 0.5
|
||||||
|
overlap: float = 0.0
|
||||||
|
max_empty: float = 0.0
|
||||||
|
discard_empty: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceConfig(BaseConfig):
|
||||||
|
loader: InferenceLoaderConfig = Field(
|
||||||
|
default_factory=InferenceLoaderConfig
|
||||||
|
)
|
||||||
|
clipping: ClipingConfig = Field(default_factory=ClipingConfig)
|
||||||
120
src/batdetect2/inference/dataset.py
Normal file
120
src/batdetect2/inference/dataset.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from typing import List, NamedTuple, Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from loguru import logger
|
||||||
|
from soundevent import data
|
||||||
|
from torch.utils.data import DataLoader, Dataset
|
||||||
|
|
||||||
|
from batdetect2.audio import build_audio_loader
|
||||||
|
from batdetect2.core import BaseConfig
|
||||||
|
from batdetect2.core.arrays import adjust_width
|
||||||
|
from batdetect2.preprocess import build_preprocessor
|
||||||
|
from batdetect2.typing import AudioLoader, PreprocessorProtocol
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"InferenceDataset",
|
||||||
|
"build_inference_dataset",
|
||||||
|
"build_inference_loader",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_INFERENCE_CLIP_DURATION = 0.512
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetItem(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
idx: torch.Tensor
|
||||||
|
start_time: torch.Tensor
|
||||||
|
end_time: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceDataset(Dataset[DatasetItem]):
|
||||||
|
clips: List[data.Clip]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clips: Sequence[data.Clip],
|
||||||
|
audio_loader: AudioLoader,
|
||||||
|
preprocessor: PreprocessorProtocol,
|
||||||
|
audio_dir: Optional[data.PathLike] = None,
|
||||||
|
):
|
||||||
|
self.clips = list(clips)
|
||||||
|
self.preprocessor = preprocessor
|
||||||
|
self.audio_loader = audio_loader
|
||||||
|
self.audio_dir = audio_dir
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.clips)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> DatasetItem:
|
||||||
|
clip = self.clips[idx]
|
||||||
|
wav = self.audio_loader.load_clip(clip, audio_dir=self.audio_dir)
|
||||||
|
wav_tensor = torch.tensor(wav).unsqueeze(0)
|
||||||
|
spectrogram = self.preprocessor(wav_tensor)
|
||||||
|
return DatasetItem(
|
||||||
|
spec=spectrogram,
|
||||||
|
idx=torch.tensor(idx),
|
||||||
|
start_time=torch.tensor(clip.start_time),
|
||||||
|
end_time=torch.tensor(clip.end_time),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceLoaderConfig(BaseConfig):
|
||||||
|
num_workers: int = 0
|
||||||
|
batch_size: int = 8
|
||||||
|
|
||||||
|
|
||||||
|
def build_inference_loader(
|
||||||
|
clips: Sequence[data.Clip],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
config: Optional[InferenceLoaderConfig] = None,
|
||||||
|
num_workers: Optional[int] = None,
|
||||||
|
) -> DataLoader[DatasetItem]:
|
||||||
|
logger.info("Building inference data loader...")
|
||||||
|
config = config or InferenceLoaderConfig()
|
||||||
|
|
||||||
|
inference_dataset = build_inference_dataset(
|
||||||
|
clips,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
num_workers = num_workers or config.num_workers
|
||||||
|
return DataLoader(
|
||||||
|
inference_dataset,
|
||||||
|
batch_size=config.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=config.num_workers,
|
||||||
|
collate_fn=_collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_inference_dataset(
|
||||||
|
clips: Sequence[data.Clip],
|
||||||
|
audio_loader: Optional[AudioLoader] = None,
|
||||||
|
preprocessor: Optional[PreprocessorProtocol] = None,
|
||||||
|
) -> InferenceDataset:
|
||||||
|
if audio_loader is None:
|
||||||
|
audio_loader = build_audio_loader()
|
||||||
|
|
||||||
|
if preprocessor is None:
|
||||||
|
preprocessor = build_preprocessor()
|
||||||
|
|
||||||
|
return InferenceDataset(
|
||||||
|
clips,
|
||||||
|
audio_loader=audio_loader,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _collate_fn(batch: List[DatasetItem]) -> DatasetItem:
|
||||||
|
max_width = max(item.spec.shape[-1] for item in batch)
|
||||||
|
return DatasetItem(
|
||||||
|
spec=torch.stack(
|
||||||
|
[adjust_width(item.spec, max_width) for item in batch]
|
||||||
|
),
|
||||||
|
idx=torch.stack([item.idx for item in batch]),
|
||||||
|
start_time=torch.stack([item.start_time for item in batch]),
|
||||||
|
end_time=torch.stack([item.end_time for item in batch]),
|
||||||
|
)
|
||||||
52
src/batdetect2/inference/lightning.py
Normal file
52
src/batdetect2/inference/lightning.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
from lightning import LightningModule
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.inference.dataset import DatasetItem, InferenceDataset
|
||||||
|
from batdetect2.models import Model
|
||||||
|
from batdetect2.postprocess import to_raw_predictions
|
||||||
|
from batdetect2.typing.postprocess import BatDetect2Prediction
|
||||||
|
|
||||||
|
|
||||||
|
class InferenceModule(LightningModule):
|
||||||
|
def __init__(self, model: Model):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
def predict_step(
|
||||||
|
self,
|
||||||
|
batch: DatasetItem,
|
||||||
|
batch_idx: int,
|
||||||
|
dataloader_idx: int = 0,
|
||||||
|
) -> Sequence[BatDetect2Prediction]:
|
||||||
|
dataset = self.get_dataset()
|
||||||
|
|
||||||
|
clips = [dataset.clips[int(example_idx)] for example_idx in batch.idx]
|
||||||
|
|
||||||
|
outputs = self.model.detector(batch.spec)
|
||||||
|
|
||||||
|
clip_detections = self.model.postprocessor(
|
||||||
|
outputs,
|
||||||
|
start_times=[clip.start_time for clip in clips],
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions = [
|
||||||
|
BatDetect2Prediction(
|
||||||
|
clip=clip,
|
||||||
|
predictions=to_raw_predictions(
|
||||||
|
clip_dets.numpy(),
|
||||||
|
targets=self.model.targets,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
for clip, clip_dets in zip(clips, clip_detections)
|
||||||
|
]
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
def get_dataset(self) -> InferenceDataset:
|
||||||
|
dataloaders = self.trainer.predict_dataloaders
|
||||||
|
assert isinstance(dataloaders, DataLoader)
|
||||||
|
dataset = dataloaders.dataset
|
||||||
|
assert isinstance(dataset, InferenceDataset)
|
||||||
|
return dataset
|
||||||
@ -1,9 +1,9 @@
|
|||||||
from batdetect2.train.augmentations import (
|
from batdetect2.train.augmentations import (
|
||||||
AugmentationsConfig,
|
|
||||||
AddEchoConfig,
|
AddEchoConfig,
|
||||||
|
AugmentationsConfig,
|
||||||
MaskFrequencyConfig,
|
MaskFrequencyConfig,
|
||||||
RandomAudioSource,
|
|
||||||
MaskTimeConfig,
|
MaskTimeConfig,
|
||||||
|
RandomAudioSource,
|
||||||
ScaleVolumeConfig,
|
ScaleVolumeConfig,
|
||||||
WarpConfig,
|
WarpConfig,
|
||||||
add_echo,
|
add_echo,
|
||||||
@ -28,7 +28,10 @@ from batdetect2.train.dataset import (
|
|||||||
build_val_loader,
|
build_val_loader,
|
||||||
)
|
)
|
||||||
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
from batdetect2.train.labels import build_clip_labeler, load_label_config
|
||||||
from batdetect2.train.lightning import TrainingModule
|
from batdetect2.train.lightning import (
|
||||||
|
TrainingModule,
|
||||||
|
load_model_from_checkpoint,
|
||||||
|
)
|
||||||
from batdetect2.train.losses import (
|
from batdetect2.train.losses import (
|
||||||
ClassificationLossConfig,
|
ClassificationLossConfig,
|
||||||
DetectionLossConfig,
|
DetectionLossConfig,
|
||||||
@ -74,4 +77,5 @@ __all__ = [
|
|||||||
"scale_volume",
|
"scale_volume",
|
||||||
"train",
|
"train",
|
||||||
"warp_spectrogram",
|
"warp_spectrogram",
|
||||||
|
"load_model_from_checkpoint",
|
||||||
]
|
]
|
||||||
|
|||||||
@ -83,8 +83,8 @@ class ClipDetectionsTensor(NamedTuple):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BatDetect2Prediction:
|
class BatDetect2Prediction:
|
||||||
raw: RawPrediction
|
clip: data.Clip
|
||||||
sound_event_prediction: data.SoundEventPrediction
|
predictions: List[RawPrediction]
|
||||||
|
|
||||||
|
|
||||||
class PostprocessorProtocol(Protocol):
|
class PostprocessorProtocol(Protocol):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
from typing import Callable, NamedTuple, Protocol, Tuple
|
from typing import Callable, List, NamedTuple, Protocol, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
@ -104,3 +104,5 @@ class ClipperProtocol(Protocol):
|
|||||||
self,
|
self,
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
) -> data.ClipAnnotation: ...
|
) -> data.ClipAnnotation: ...
|
||||||
|
|
||||||
|
def get_subclip(self, clip: data.Clip) -> data.Clip: ...
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user