WIP updating training code

This commit is contained in:
mbsantiago 2024-04-24 10:06:04 -06:00
parent 343bc5f87c
commit c66d14b7c7
45 changed files with 4484 additions and 1567 deletions

View File

@ -0,0 +1,11 @@
from batdetect2.cli.base import cli
from batdetect2.cli.compat import detect
__all__ = [
"cli",
"detect",
]
if __name__ == "__main__":
cli()

27
batdetect2/cli/base.py Normal file
View File

@ -0,0 +1,27 @@
"""BatDetect2 command line interface."""
import os
import click
__all__ = [
"cli",
]
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
INFO_STR = """
BatDetect2 - Detection and Classification
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
@click.group()
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)

View File

@ -1,5 +1,4 @@
"""BatDetect2 command line interface."""
import os
import click
@ -8,21 +7,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
from batdetect2.types import ProcessingConfiguration
from batdetect2.utils.detector_utils import save_results_to_file
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
INFO_STR = """
BatDetect2 - Detection and Classification
Assumes audio files are mono, not stereo.
Spaces in the input paths will throw an error. Wrap in quotes.
Input files should be short in duration e.g. < 30 seconds.
"""
@click.group()
def cli():
"""BatDetect2 - Bat Call Detection and Classification."""
click.echo(INFO_STR)
from batdetect2.cli.base import cli
@cli.command()
@ -147,7 +132,3 @@ def print_config(config: ProcessingConfiguration):
click.echo("\nProcessing Configuration:")
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
if __name__ == "__main__":
cli()

View File

View File

@ -0,0 +1,304 @@
from functools import wraps
from typing import Callable, List, Optional, Tuple
import numpy as np
import xarray as xr
from soundevent import data
from soundevent.geometry import compute_bounds
ClipAugmentation = Callable[[data.ClipAnnotation], data.ClipAnnotation]
AudioAugmentation = Callable[
[xr.DataArray, data.ClipAnnotation],
Tuple[xr.DataArray, data.ClipAnnotation],
]
SpecAugmentation = Callable[
[xr.DataArray, data.ClipAnnotation],
Tuple[xr.DataArray, data.ClipAnnotation],
]
ClipProvider = Callable[
[data.ClipAnnotation], Tuple[xr.DataArray, data.ClipAnnotation]
]
"""A function that provides some clip and its annotation.
Usually this function loads a random clip from a dataset. Takes
as input a clip annotation that can be used to filter the clips
to load (in case you want to avoid loading the same clip multiple times).
"""
AUGMENTATION_PROBABILITY = 0.2
MAX_DELAY = 0.005
STRETCH_SQUEEZE_DELTA = 0.04
MASK_MAX_TIME_PERC: float = 0.05
MASK_MAX_FREQ_PERC: float = 0.10
def maybe_apply(
augmentation: Callable,
prob: float = AUGMENTATION_PROBABILITY,
) -> Callable:
"""Apply an augmentation with a given probability."""
@wraps(augmentation)
def _augmentation(x):
if np.random.rand() > prob:
return x
return augmentation(x)
return _augmentation
def select_random_subclip(
clip_annotation: data.ClipAnnotation,
duration: Optional[float] = None,
proportion: float = 0.9,
) -> data.ClipAnnotation:
"""Select a random subclip from a clip."""
clip = clip_annotation.clip
if duration is None:
clip_duration = clip.end_time - clip.start_time
duration = clip_duration * proportion
start_time = np.random.uniform(clip.start_time, clip.end_time - duration)
return clip_annotation.model_copy(
update=dict(
clip=clip.model_copy(
update=dict(
start_time=start_time,
end_time=start_time + duration,
)
)
)
)
def combine_audio(
audio1: xr.DataArray,
audio2: xr.DataArray,
alpha: Optional[float] = None,
min_alpha: float = 0.3,
max_alpha: float = 0.7,
) -> xr.DataArray:
"""Combine two audio clips."""
if alpha is None:
alpha = np.random.uniform(min_alpha, max_alpha)
return alpha * audio1 + (1 - alpha) * audio2.data
def random_mix(
audio: xr.DataArray,
clip: data.ClipAnnotation,
provider: Optional[ClipProvider] = None,
alpha: Optional[float] = None,
min_alpha: float = 0.3,
max_alpha: float = 0.7,
join_annotations: bool = True,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Mix two audio clips."""
if provider is None:
raise ValueError("No audio provider given.")
try:
other_audio, other_clip = provider(clip)
except (StopIteration, ValueError):
raise ValueError("No more audio sources available.")
new_audio = combine_audio(
audio,
other_audio,
alpha=alpha,
min_alpha=min_alpha,
max_alpha=max_alpha,
)
if join_annotations:
clip = clip.model_copy(
update=dict(
sound_events=clip.sound_events + other_clip.sound_events,
)
)
return new_audio, clip
def add_echo(
audio: xr.DataArray,
clip: data.ClipAnnotation,
delay: Optional[float] = None,
alpha: Optional[float] = None,
min_alpha: float = 0.0,
max_alpha: float = 1.0,
max_delay: float = MAX_DELAY,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Add a delay to the audio."""
if delay is None:
delay = np.random.uniform(0, max_delay)
if alpha is None:
alpha = np.random.uniform(min_alpha, max_alpha)
samplerate = audio.attrs["samplerate"]
offset = int(delay * samplerate)
# NOTE: We use the copy method to avoid modifying the original audio
# data.
new_audio = audio.copy()
new_audio[offset:] += alpha * audio.data[:-offset]
return new_audio, clip
def scale_volume(
spec: xr.DataArray,
clip: data.ClipAnnotation,
factor: Optional[float] = None,
max_scaling: float = 2,
min_scaling: float = 0,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Scale the volume of a spectrogram."""
if factor is None:
factor = np.random.uniform(min_scaling, max_scaling)
return spec * factor, clip
def scale_sound_event_annotation(
sound_event_annotation: data.SoundEventAnnotation,
time_factor: float = 1,
frequency_factor: float = 1,
) -> data.SoundEventAnnotation:
sound_event = sound_event_annotation.sound_event
geometry = sound_event.geometry
if geometry is None:
return sound_event_annotation
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
new_geometry = data.BoundingBox(
coordinates=[
start_time * time_factor,
low_freq * frequency_factor,
end_time * time_factor,
high_freq * frequency_factor,
]
)
return sound_event_annotation.model_copy(
update=dict(
sound_event=sound_event.model_copy(
update=dict(
geometry=new_geometry,
)
)
)
)
def warp_spectrogram(
spec: xr.DataArray,
clip: data.ClipAnnotation,
factor: Optional[float] = None,
delta: float = STRETCH_SQUEEZE_DELTA,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Warp a spectrogram."""
if factor is None:
factor = np.random.uniform(1 - delta, 1 + delta)
start_time = clip.clip.start_time
end_time = clip.clip.end_time
duration = end_time - start_time
new_time = np.linspace(
start_time,
start_time + duration * factor,
spec.time.size,
)
scaled_spec = spec.interp(
time=new_time,
method="linear",
kwargs={"fill_value": 0},
)
scaled_spec.coords["time"] = spec.time
scaled_clip = clip.model_copy(
update=dict(
sound_events=[
scale_sound_event_annotation(
sound_event_annotation,
time_factor=1 / factor,
)
for sound_event_annotation in clip.sound_events
]
)
)
return scaled_spec, scaled_clip
def mask_axis(
array: xr.DataArray,
axis: str,
start: float,
end: float,
mask_value: float = 0,
) -> xr.DataArray:
if axis not in array.dims:
raise ValueError(f"Axis {axis} not found in array")
coord = array[axis]
return array.where((coord < start) | (coord > end), mask_value)
def mask_time(
spec: xr.DataArray,
clip: data.ClipAnnotation,
max_time_mask: float = MASK_MAX_TIME_PERC,
max_num_masks: int = 3,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Mask a random section of the time axis."""
num_masks = np.random.randint(1, max_num_masks + 1)
for _ in range(num_masks):
mask_size = np.random.uniform(0, max_time_mask)
start = np.random.uniform(0, spec.time[-1] - mask_size)
end = start + mask_size
spec = mask_axis(spec, "time", start, end)
return spec, clip
def mask_frequency(
spec: xr.DataArray,
clip: data.ClipAnnotation,
max_freq_mask: float = MASK_MAX_FREQ_PERC,
max_num_masks: int = 3,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
"""Mask a random section of the frequency axis."""
num_masks = np.random.randint(1, max_num_masks + 1)
for _ in range(num_masks):
mask_size = np.random.uniform(0, max_freq_mask)
start = np.random.uniform(0, spec.frequency[-1] - mask_size)
end = start + mask_size
spec = mask_axis(spec, "frequency", start, end)
return spec, clip
CLIP_AUGMENTATIONS: List[ClipAugmentation] = [
select_random_subclip,
]
AUDIO_AUGMENTATIONS: List[AudioAugmentation] = [
add_echo,
random_mix,
]
SPEC_AUGMENTATIONS: List[SpecAugmentation] = [
scale_volume,
warp_spectrogram,
mask_time,
mask_frequency,
]

332
batdetect2/data/compat.py Normal file
View File

@ -0,0 +1,332 @@
"""Compatibility functions between old and new data structures."""
import os
import uuid
from pathlib import Path
from typing import Callable, List, Optional, Union
import numpy as np
from pydantic import BaseModel, Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2 import types
from batdetect2.data.labels import LabelFn
PathLike = Union[Path, str, os.PathLike]
__all__ = [
"convert_to_annotation_group",
"load_annotation_project",
]
SPECIES_TAG_KEY = "species"
ECHOLOCATION_EVENT = "Echolocation"
UNKNOWN_CLASS = "__UNKNOWN__"
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
ClassFn = Callable[[data.Recording], int]
IndividualFn = Callable[[data.SoundEventAnnotation], int]
def get_recording_class_name(recording: data.Recording) -> str:
"""Get the class name for a recording."""
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
if tag is None:
return UNKNOWN_CLASS
return tag.value
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
"""Get the notes for a ClipAnnotation."""
all_notes = [
*annotation.notes,
*annotation.clip.recording.notes,
]
messages = [note.message for note in all_notes if note.message is not None]
return "\n".join(messages)
def convert_to_annotation_group(
annotation: data.ClipAnnotation,
label_fn: LabelFn = lambda _: None,
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
class_fn: ClassFn = lambda _: 0,
individual_fn: IndividualFn = lambda _: 0,
) -> types.AudioLoaderAnnotationGroup:
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
recording = annotation.clip.recording
start_times = []
end_times = []
low_freqs = []
high_freqs = []
class_ids = []
x_inds = []
y_inds = []
individual_ids = []
annotations: List[types.Annotation] = []
class_id_file = class_fn(recording)
for sound_event in annotation.sound_events:
geometry = sound_event.sound_event.geometry
if geometry is None:
continue
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
class_id = label_fn(sound_event) or -1
event = event_fn(sound_event)
individual_id = individual_fn(sound_event) or -1
start_times.append(start_time)
end_times.append(end_time)
low_freqs.append(low_freq)
high_freqs.append(high_freq)
class_ids.append(class_id)
individual_ids.append(individual_id)
# NOTE: This will be computed later so we just put a placeholder
# here for now.
x_inds.append(0)
y_inds.append(0)
annotations.append(
{
"start_time": start_time,
"end_time": end_time,
"low_freq": low_freq,
"high_freq": high_freq,
"class_prob": 1.0,
"det_prob": 1.0,
"individual": "0",
"event": event,
"class_id": class_id, # type: ignore
}
)
return {
"id": str(recording.path),
"duration": recording.duration,
"issues": False,
"file_path": str(recording.path),
"time_exp": recording.time_expansion,
"class_name": get_recording_class_name(recording),
"notes": get_annotation_notes(annotation),
"annotated": True,
"start_times": np.array(start_times),
"end_times": np.array(end_times),
"low_freqs": np.array(low_freqs),
"high_freqs": np.array(high_freqs),
"class_ids": np.array(class_ids),
"x_inds": np.array(x_inds),
"y_inds": np.array(y_inds),
"individual_ids": np.array(individual_ids),
"annotation": annotations,
"class_id_file": class_id_file,
}
class Annotation(BaseModel):
"""Annotation class to hold batdetect annotations."""
label: str = Field(alias="class")
event: str
individual: int = 0
start_time: float
end_time: float
low_freq: float
high_freq: float
class FileAnnotation(BaseModel):
"""FileAnnotation class to hold batdetect annotations for a file."""
id: str
duration: float
time_exp: float = 1
label: str = Field(alias="class_name")
annotation: List[Annotation]
annotated: bool = False
issues: bool = False
notes: str = ""
def load_file_annotation(path: PathLike) -> FileAnnotation:
"""Load annotation from batdetect format."""
path = Path(path)
return FileAnnotation.model_validate_json(path.read_text())
def annotation_to_sound_event(
annotation: Annotation,
recording: data.Recording,
label_key: str = "class",
event_key: str = "event",
individual_key: str = "individual",
) -> data.SoundEventAnnotation:
"""Convert annotation to sound event annotation."""
sound_event = data.SoundEvent(
uuid=uuid.uuid5(
NAMESPACE,
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
),
recording=recording,
geometry=data.BoundingBox(
coordinates=[
annotation.start_time,
annotation.low_freq,
annotation.end_time,
annotation.high_freq,
],
),
)
return data.SoundEventAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
sound_event=sound_event,
tags=[
data.Tag(key=label_key, value=annotation.label),
data.Tag(key=event_key, value=annotation.event),
data.Tag(key=individual_key, value=str(annotation.individual)),
],
)
def file_annotation_to_clip(
file_annotation: FileAnnotation,
audio_dir: PathLike = Path.cwd(),
) -> data.Clip:
"""Convert file annotation to recording."""
full_path = Path(audio_dir) / file_annotation.id
if not full_path.exists():
raise FileNotFoundError(f"File {full_path} not found.")
recording = data.Recording.from_file(
full_path,
time_expansion=file_annotation.time_exp,
)
return data.Clip(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
recording=recording,
start_time=0,
end_time=recording.duration,
)
def file_annotation_to_clip_annotation(
file_annotation: FileAnnotation,
clip: data.Clip,
label_key: str = "class",
event_key: str = "event",
individual_key: str = "individual",
) -> data.ClipAnnotation:
"""Convert file annotation to clip annotation."""
notes = []
if file_annotation.notes:
notes.append(data.Note(message=file_annotation.notes))
return data.ClipAnnotation(
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
clip=clip,
notes=notes,
tags=[data.Tag(key=label_key, value=file_annotation.label)],
sound_events=[
annotation_to_sound_event(
annotation,
clip.recording,
label_key=label_key,
event_key=event_key,
individual_key=individual_key,
)
for annotation in file_annotation.annotation
],
)
def file_annotation_to_annotation_task(
file_annotation: FileAnnotation,
clip: data.Clip,
) -> data.AnnotationTask:
status_badges = []
if file_annotation.issues:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.rejected)
)
elif file_annotation.annotated:
status_badges.append(
data.StatusBadge(state=data.AnnotationState.completed)
)
return data.AnnotationTask(
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
clip=clip,
status_badges=status_badges,
)
def list_file_annotations(path: PathLike) -> List[Path]:
"""List all annotations in a directory."""
path = Path(path)
return [file for file in path.glob("*.json")]
def load_annotation_project(
path: PathLike,
name: Optional[str] = None,
audio_dir: PathLike = Path.cwd(),
) -> data.AnnotationProject:
"""Convert annotations to annotation project."""
paths = list_file_annotations(path)
if name is None:
name = str(path)
annotations = []
tasks = []
for p in paths:
try:
file_annotation = load_file_annotation(p)
except FileNotFoundError:
continue
try:
clip = file_annotation_to_clip(
file_annotation,
audio_dir=audio_dir,
)
except FileNotFoundError:
continue
annotations.append(
file_annotation_to_clip_annotation(
file_annotation,
clip,
)
)
tasks.append(
file_annotation_to_annotation_task(
file_annotation,
clip,
)
)
return data.AnnotationProject(
name=name,
clip_annotations=annotations,
tasks=tasks,
)

View File

@ -0,0 +1,58 @@
from typing import Callable, Generic, Iterable, List, TypeVar
from soundevent import data
from torch.utils.data import Dataset
__all__ = [
"ClipAnnotationDataset",
"ClipDataset",
]
E = TypeVar("E")
class ClipAnnotationDataset(Dataset, Generic[E]):
clip_annotations: List[data.ClipAnnotation]
transform: Callable[[data.ClipAnnotation], E]
def __init__(
self,
clip_annotations: Iterable[data.ClipAnnotation],
transform: Callable[[data.ClipAnnotation], E],
name: str = "ClipAnnotationDataset",
):
self.clip_annotations = list(clip_annotations)
self.transform = transform
self.name = name
def __len__(self) -> int:
return len(self.clip_annotations)
def __getitem__(self, idx: int) -> E:
return self.transform(self.clip_annotations[idx])
class ClipDataset(Dataset, Generic[E]):
clips: List[data.Clip]
transform: Callable[[data.Clip], E]
def __init__(
self,
clips: Iterable[data.Clip],
transform: Callable[[data.Clip], E],
name: str = "ClipDataset",
):
self.clips = list(clips)
self.transform = transform
self.name = name
def __len__(self) -> int:
return len(self.clips)
def __getitem__(self, idx: int) -> E:
return self.transform(self.clips[idx])

231
batdetect2/data/labels.py Normal file
View File

@ -0,0 +1,231 @@
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import xarray as xr
from scipy.ndimage import gaussian_filter
from soundevent import data, geometry
__all__ = [
"generate_heatmaps",
]
PositionFn = Callable[[data.SoundEvent], Tuple[float, float]]
"""Convert a sound event to a single position in time-frequency space."""
SizeFn = Callable[[data.SoundEvent, float, float], np.ndarray]
"""Compute the size of a sound event in time-frequency space.
The time and frequency scales are provided as arguments to allow
modifying the size of the sound event based on the spectrogram
parameters.
"""
LabelFn = Callable[[data.SoundEventAnnotation], Optional[str]]
"""Convert a sound event annotation to a label.
When the label is None, this indicates that the sound event does not
belong to any of the classes of interest.
"""
TARGET_SIGMA = 3.0
GENERIC_LABEL = "__UNKNOWN__"
def get_lower_left_position(
sound_event: data.SoundEvent,
) -> Tuple[float, float]:
if sound_event.geometry is None:
raise ValueError("Sound event has no geometry.")
start_time, low_freq, _, _ = geometry.compute_bounds(sound_event.geometry)
return start_time, low_freq
def get_bbox_size(
sound_event: data.SoundEvent,
time_scale: float = 1.0,
frequency_scale: float = 1.0,
) -> np.ndarray:
if sound_event.geometry is None:
raise ValueError("Sound event has no geometry.")
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
sound_event.geometry
)
return np.array(
[
time_scale * (end_time - start_time),
frequency_scale * (high_freq - low_freq),
]
)
def _tag_key(tag: data.Tag) -> Tuple[str, str]:
return (tag.key, tag.value)
def set_value_at_position(
array: xr.DataArray,
value: Any,
**query,
) -> xr.DataArray:
dims = {dim: n for n, dim in enumerate(array.dims)}
indexer: List[Union[slice, int]] = [slice(None) for _ in range(array.ndim)]
for key, coord in query.items():
if key not in dims:
raise ValueError(f"Dimension {key} not found in array.")
coordinates = array.indexes[key]
indexer[dims[key]] = coordinates.get_loc(coordinates.asof(coord))
if isinstance(value, (tuple, list)):
value = np.array(value)
array.data[tuple(indexer)] = value
return array
def generate_heatmaps(
clip_annotation: data.ClipAnnotation,
spec: xr.DataArray,
num_classes: int = 1,
label_fn: LabelFn = lambda _: None,
target_sigma: float = TARGET_SIGMA,
size_fn: SizeFn = get_bbox_size,
position_fn: PositionFn = get_lower_left_position,
class_labels: Optional[List[str]] = None,
dtype=np.float32,
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
if class_labels is None:
class_labels = [str(i) for i in range(num_classes)]
if len(class_labels) != num_classes:
raise ValueError(
"Number of class labels must match the number of classes."
)
shape = dict(zip(spec.dims, spec.shape))
if "time" not in shape or "frequency" not in shape:
raise ValueError(
"Spectrogram must have time and frequency dimensions."
)
time_duration = spec.time.attrs["max"] - spec.time.attrs["min"]
freq_bandwidth = spec.frequency.attrs["max"] - spec.frequency.attrs["min"]
# Compute the size factors
time_scale = 1 / time_duration
frequency_scale = 1 / freq_bandwidth
# Initialize heatmaps
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
class_heatmap = xr.DataArray(
data=np.zeros((num_classes, *spec.shape), dtype=dtype),
dims=["category", *spec.dims],
coords={
"category": class_labels,
**spec.coords,
},
)
size_heatmap = xr.DataArray(
data=np.zeros((2, *spec.shape), dtype=dtype),
dims=["dimension", *spec.dims],
coords={
"dimension": ["width", "height"],
**spec.coords,
},
)
for sound_event_annotation in clip_annotation.sound_events:
# Get the position of the sound event
time, frequency = position_fn(sound_event_annotation.sound_event)
# Set 1.0 at the position of the sound event in the detection heatmap
detection_heatmap = set_value_at_position(
detection_heatmap,
1.0,
time=time,
frequency=frequency,
)
# Set the size of the sound event at the position in the size heatmap
size = size_fn(
sound_event_annotation.sound_event,
time_scale,
frequency_scale,
)
size_heatmap = set_value_at_position(
size_heatmap,
size,
time=time,
frequency=frequency,
)
# Get the label id for the sound event
label = label_fn(sound_event_annotation)
if label is None or label not in class_labels:
# If the label is None or not in the class labels, we skip the
# sound event
continue
# Set 1.0 at the position and category of the sound event in the class
# heatmap
class_heatmap = set_value_at_position(
class_heatmap,
1.0,
time=time,
frequency=frequency,
category=label,
)
# Apply gaussian filters
detection_heatmap = xr.apply_ufunc(
gaussian_filter,
detection_heatmap,
target_sigma,
)
class_heatmap = class_heatmap.groupby("category").map(
gaussian_filter, # type: ignore
args=(target_sigma,),
)
# Normalize heatmaps
detection_heatmap = (
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
class_heatmap = (
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
).fillna(0.0)
return detection_heatmap, class_heatmap, size_heatmap
class Labeler:
def __init__(self, tags: List[data.Tag]):
"""Create a labeler from a list of tags.
Each tag is assigned a unique label. The labeler can then be used
to convert sound event annotations to labels.
"""
self.tags = tags
self._label_map = {_tag_key(tag): i for i, tag in enumerate(tags)}
self._inverse_label_map = {v: k for k, v in self._label_map.items()}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> Optional[int]:
for tag in sound_event_annotation.tags:
key = _tag_key(tag)
if key in self._label_map:
return self._label_map[key]
return None

View File

@ -0,0 +1,586 @@
"""Module containing functions for preprocessing audio clips."""
import random
from typing import List, Optional, Tuple
import librosa
import librosa.core.spectrum
import numpy as np
import xarray as xr
from numpy.typing import DTypeLike
from scipy.signal import resample_poly
from soundevent import audio, data
TARGET_SAMPLERATE_HZ = 256000
SCALE_RAW_AUDIO = False
FFT_WIN_LENGTH_S = 512 / 256000.0
FFT_OVERLAP = 0.75
MAX_FREQ_HZ = 120000
MIN_FREQ_HZ = 10000
DEFAULT_DURATION = 1
SPEC_HEIGHT = 128
SPEC_WIDTH = 256
SPEC_SCALE = "pcen"
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
DENOISE_SPEC_AVG = True
MAX_SCALE_SPEC = False
def preprocess_audio_clip(
clip: data.Clip,
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
scale_audio: bool = SCALE_RAW_AUDIO,
fft_win_length: float = FFT_WIN_LENGTH_S,
fft_overlap: float = FFT_OVERLAP,
max_freq: int = MAX_FREQ_HZ,
min_freq: int = MIN_FREQ_HZ,
spec_scale: str = SPEC_SCALE,
denoise_spec_avg: bool = True,
max_scale_spec: bool = False,
duration: Optional[float] = DEFAULT_DURATION,
spec_height: int = SPEC_HEIGHT,
spec_time_period: float = SPEC_TIME_PERIOD,
) -> xr.DataArray:
"""Preprocesses audio clip to generate spectrogram.
Parameters
----------
clip
The audio clip to preprocess.
target_sampling_rate
Target sampling rate for the audio. If the audio has a different
sampling rate, it will be resampled to this rate.
scale_audio
Whether to scale the audio amplitudes to a range of [-1, 1].
By default, the audio is not scaled.
fft_win_length
Length of the FFT window in seconds.
fft_overlap
Amount of overlap between FFT windows as a fraction of the window
length.
max_freq
Maximum frequency for spectrogram. Anything above this frequency will
be cropped.
min_freq
Minimum frequency for spectrogram. Anything below this frequency will
be cropped.
spec_scale
Scaling method for the spectrogram. Can be "pcen", "log" or
"amplitude".
denoise_spec_avg
Whether to denoise the spectrogram. Denoising is done by subtracting
the average of the spectrogram from the spectrogram and clipping
negative values to 0.
max_scale_spec
Whether to max scale the spectrogram. Max scaling is done by dividing
the spectrogram by its maximum value thus scaling values to [0, 1].
duration
Duration of the spectrogram in seconds. If the clip duration is
different from this value, the spectrogram will be cropped or extended
to match this duration. If None, the spectrogram will have the same
duration as the clip.
spec_height
Number of frequency bins for the spectrogram. This is the height of
the final spectrogram.
spec_time_period
Time period for each spectrogram bin in seconds. The spectrogram array
will be resized (using bilinear interpolation) to have this time
period.
Returns
-------
xr.DataArray
Preprocessed spectrogram.
"""
wav = load_clip_audio(
clip,
target_sampling_rate=target_sampling_rate,
scale=scale_audio,
)
wav = wav.assign_attrs(
recording_id=str(wav.attrs["recording_id"]),
clip_id=str(wav.attrs["clip_id"]),
path=str(wav.attrs["path"]),
)
spec = compute_spectrogram(
wav,
fft_win_length=fft_win_length,
fft_overlap=fft_overlap,
max_freq=max_freq,
min_freq=min_freq,
spec_scale=spec_scale,
denoise_spec_avg=denoise_spec_avg,
max_scale_spec=max_scale_spec,
)
if duration is not None:
spec = adjust_spec_duration(clip, spec, duration)
duration = get_dim_width(spec, dim="time")
return resize_spectrogram(
spec,
time_bins=int(np.ceil(duration / spec_time_period)),
freq_bins=spec_height,
)
def adjust_spec_duration(
clip: data.Clip,
spec: xr.DataArray,
duration: float,
) -> xr.DataArray:
current_duration = clip.end_time - clip.start_time
if current_duration == duration:
return spec
if current_duration > duration:
return crop_axis(
spec,
dim="time",
start=clip.start_time,
end=clip.start_time + duration,
)
return extend_axis(
spec,
dim="time",
start=clip.start_time,
end=clip.start_time + duration,
)
def load_clip_audio(
clip: data.Clip,
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
scale: bool = SCALE_RAW_AUDIO,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
wav = audio.load_clip(clip).sel(channel=0)
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
if scale:
wav = scale_audio(wav)
wav.coords["time"] = wav.time.assign_attrs(
unit="s",
long_name="Seconds since start of recording",
min=clip.start_time,
max=clip.end_time,
)
return wav
def resample_audio(
wav: xr.DataArray,
target_samplerate: int = TARGET_SAMPLERATE_HZ,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
if "samplerate" not in wav.attrs:
raise ValueError("Audio must have a 'samplerate' attribute")
if "time" not in wav.dims:
raise ValueError("Audio must have a time dimension")
time_axis: int = wav.get_axis_num("time") # type: ignore
original_samplerate = wav.attrs["samplerate"]
if original_samplerate == target_samplerate:
return wav.astype(dtype)
gcd = np.gcd(original_samplerate, target_samplerate)
resampled = resample_poly(
wav.values,
target_samplerate // gcd,
original_samplerate // gcd,
axis=time_axis,
)
resampled_times = np.linspace(
wav.time[0],
wav.time[-1],
len(resampled),
endpoint=False,
dtype=dtype,
)
return xr.DataArray(
data=resampled.astype(dtype),
dims=wav.dims,
coords={
**wav.coords,
"time": resampled_times,
},
attrs={
**wav.attrs,
"samplerate": target_samplerate,
},
)
def scale_audio(
audio: xr.DataArray,
eps: float = 10e-6,
) -> xr.DataArray:
audio = audio - audio.mean()
return audio / np.add(np.abs(audio).max(), eps, dtype=audio.dtype)
def compute_spectrogram(
wav: xr.DataArray,
fft_win_length: float = FFT_WIN_LENGTH_S,
fft_overlap: float = FFT_OVERLAP,
max_freq: int = MAX_FREQ_HZ,
min_freq: int = MIN_FREQ_HZ,
spec_scale: str = SPEC_SCALE,
denoise_spec_avg: bool = True,
max_scale_spec: bool = False,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
spec = gen_mag_spectrogram(
wav,
window_len=fft_win_length,
overlap_perc=fft_overlap,
dtype=dtype,
)
spec = crop_axis(
spec,
dim="frequency",
start=min_freq,
end=max_freq,
)
spec = scale_spectrogram(spec, scale=spec_scale)
if denoise_spec_avg:
spec = denoise_spectrogram(spec)
if max_scale_spec:
spec = max_scale_spectrogram(spec)
return spec
def crop_axis(
arr: xr.DataArray,
dim: str,
start: float,
end: float,
right_closed: bool = False,
left_closed: bool = True,
eps: float = 10e-6,
) -> xr.DataArray:
coord = arr.coords[dim]
if not all(attr in coord.attrs for attr in ["min", "max"]):
raise ValueError(
f"Coordinate '{dim}' must have 'min' and 'max' attributes"
)
current_min = coord.attrs["min"]
current_max = coord.attrs["max"]
if start < current_min or end > current_max:
raise ValueError(
f"Cannot select axis '{dim}' from {start} to {end}. "
f"Axis range is {current_min} to {current_max}"
)
slice_end = end
if not right_closed:
slice_end = end - eps
slice_start = start
if not left_closed:
slice_start = start + eps
arr = arr.sel({dim: slice(slice_start, slice_end)})
arr.coords[dim].attrs.update(
min=start,
max=end,
)
return arr
def extend_axis(
arr: xr.DataArray,
dim: str,
start: float,
end: float,
fill_value: float = 0,
) -> xr.DataArray:
coord = arr.coords[dim]
if not all(attr in coord.attrs for attr in ["min", "max", "period"]):
raise ValueError(
f"Coordinate '{dim}' must have 'min', 'max' and 'period' attributes"
" to extend axis"
)
current_min = coord.attrs["min"]
current_max = coord.attrs["max"]
period = coord.attrs["period"]
coords = coord.data
if start < current_min:
new_coords = np.arange(
current_min,
start,
-period,
dtype=coord.dtype,
)[1:][::-1]
coords = np.concatenate([new_coords, coords])
if end > current_max:
new_coords = np.arange(
current_max,
end,
period,
dtype=coord.dtype,
)[1:]
coords = np.concatenate([coords, new_coords])
arr = arr.reindex(
{dim: coords},
fill_value=fill_value, # type: ignore
)
arr.coords[dim].attrs.update(
min=start,
max=end,
)
return arr
def gen_mag_spectrogram(
audio: xr.DataArray,
window_len: float,
overlap_perc: float,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
sampling_rate = audio.attrs["samplerate"]
hop_len = window_len * (1 - overlap_perc)
nfft = int(window_len * sampling_rate)
noverlap = int(overlap_perc * nfft)
start_time = audio.time.attrs["min"]
end_time = audio.time.attrs["max"]
# compute spec
spec, _ = librosa.core.spectrum._spectrogram(
y=audio.data,
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
center=False,
)
spec = xr.DataArray(
data=spec.astype(dtype),
dims=["frequency", "time"],
coords={
"frequency": np.linspace(
0,
sampling_rate / 2,
spec.shape[0],
endpoint=False,
dtype=dtype,
),
"time": np.linspace(
start_time,
end_time - (window_len - hop_len),
spec.shape[1],
endpoint=False,
dtype=dtype,
),
},
attrs={
**audio.attrs,
"nfft": nfft,
"noverlap": noverlap,
},
)
# Add metadata to coordinates
spec.coords["time"].attrs.update(
unit="s",
long_name="Time",
min=start_time,
max=end_time - (window_len - hop_len),
period=(nfft - noverlap) / sampling_rate,
)
spec.coords["frequency"].attrs.update(
unit="Hz",
long_name="Frequency",
period=(sampling_rate / nfft),
min=0,
max=sampling_rate / 2,
)
return spec
def denoise_spectrogram(
spec: xr.DataArray,
) -> xr.DataArray:
return xr.DataArray(
data=(spec - spec.mean("time")).clip(0),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"denoised": 1,
},
)
def scale_spectrogram(
spec: xr.DataArray,
scale: str = SPEC_SCALE,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
if scale == "pcen":
return pcen(spec, dtype=dtype)
if scale == "log":
return log_scale(spec, dtype=dtype)
return spec
def log_scale(
spec: xr.DataArray,
dtype: DTypeLike = np.float32,
) -> xr.DataArray:
nfft = spec.attrs["nfft"]
sampling_rate = spec.attrs["samplerate"]
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
)
return xr.DataArray(
data=np.log1p(log_scaling * spec).astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"scale": "log",
},
)
def pcen(spec: xr.DataArray, dtype: DTypeLike = np.float32) -> xr.DataArray:
sampling_rate = spec.attrs["samplerate"]
data = librosa.pcen(
spec.data * (2**31),
sr=sampling_rate / 10,
)
return xr.DataArray(
data=data.astype(dtype),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"scale": "pcen",
},
)
def max_scale_spectrogram(spec: xr.DataArray, eps=10e-6) -> xr.DataArray:
return xr.DataArray(
data=spec / np.add(spec.max(), eps, dtype=spec.dtype),
dims=spec.dims,
coords=spec.coords,
attrs={
**spec.attrs,
"max_scaled": 1,
},
)
def resize_spectrogram(
spec: xr.DataArray,
time_bins: int,
freq_bins: int,
) -> xr.DataArray:
new_times = np.linspace(
spec.time[0],
spec.time[-1],
time_bins,
dtype=spec.time.dtype,
endpoint=True,
)
new_frequencies = np.linspace(
spec.frequency[0],
spec.frequency[-1],
freq_bins,
dtype=spec.frequency.dtype,
endpoint=True,
)
return spec.interp(
coords=dict(
time=new_times,
frequency=new_frequencies,
),
method="linear",
)
def get_dim_width(arr: xr.DataArray, dim: str) -> float:
coord = arr.coords[dim]
attrs = coord.attrs
if "min" in attrs and "max" in attrs:
return attrs["max"] - attrs["min"]
coord_min = coord.min()
coord_max = coord.max()
return float(coord_max - coord_min)
class RandomClipProvider:
def __init__(
self,
clip_annotations: List[data.ClipAnnotation],
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
scale_audio: bool = SCALE_RAW_AUDIO,
):
self.target_sampling_rate = target_sampling_rate
self.scale_audio = scale_audio
self.clip_annotations = clip_annotations
def get_next_clip(self, clip: data.ClipAnnotation) -> data.ClipAnnotation:
tries = 0
while True:
random_clip = random.choice(self.clip_annotations)
if random_clip.clip != clip.clip:
return random_clip
tries += 1
if tries > 4:
raise ValueError("Could not find a different clip")
def __call__(
self,
clip: data.ClipAnnotation,
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
random_clip = self.get_next_clip(clip)
wav = load_clip_audio(
random_clip.clip,
target_sampling_rate=self.target_sampling_rate,
scale=self.scale_audio,
)
return wav, random_clip

View File

@ -53,7 +53,13 @@ class SelfAttention(nn.Module):
class ConvBlockDownCoordF(nn.Module):
def __init__(
self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
self,
in_chn,
out_chn,
ip_height,
k_size=3,
pad_size=1,
stride=1,
):
super(ConvBlockDownCoordF, self).__init__()
self.coords = nn.Parameter(

View File

@ -1,5 +1,4 @@
import torch
import torch.fft
import torch.nn.functional as F
from torch import nn
@ -207,7 +206,7 @@ class Net2DFastNoAttn(nn.Module):
num_filts // 4, 2, kernel_size=1, padding=0
)
self.conv_classes_op = nn.Conv2d(
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
)
if self.emb_dim > 0:

View File

@ -28,6 +28,7 @@ MAX_SCALE_SPEC = False
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)

View File

@ -68,6 +68,7 @@ def run_nms(
params["fft_win_length"],
params["fft_overlap"],
)
print("duration", duration)
top_k = int(duration * params["nms_top_k_per_sec"])
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)

View File

@ -7,16 +7,16 @@ import copy
import json
import os
import torch
import numpy as np
import pandas as pd
import torch
from sklearn.ensemble import RandomForestClassifier
from batdetect2.detector import parameters
import batdetect2.train.evaluate as evl
import batdetect2.train.train_utils as tu
import batdetect2.utils.detector_utils as du
import batdetect2.utils.plot_utils as pu
from batdetect2.detector import parameters
def get_blank_annotation(ip_str):
@ -330,7 +330,8 @@ def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
for dd in datasets:
print("\n" + dd["dataset_name"])
gt_dataset = tu.load_set_of_anns(
[dd], events_of_interest=events_of_interest, verbose=True
[dd],
events_of_interest=events_of_interest,
)
gt_dataset = [
parse_data(gg, class_names, classes_to_ignore, False)
@ -553,7 +554,9 @@ if __name__ == "__main__":
test_dict["dataset_name"] = args["test_file"].replace(".json", "")
test_dict["is_test"] = True
test_dict["is_binary"] = True
test_dict["ann_path"] = os.path.join(args["ann_dir"], args["test_file"])
test_dict["ann_path"] = os.path.join(
args["ann_dir"], args["test_file"]
)
test_dict["wav_path"] = args["data_dir"]
test_sets = [test_dict]

View File

@ -0,0 +1,91 @@
import os
from typing import Tuple, Union
import torch
from batdetect2.models.encoders import (
Net2DFast,
Net2DFastNoAttn,
Net2DFastNoCoordConv,
)
from batdetect2.models.typing import DetectionModel
__all__ = [
"load_model",
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
]
DEFAULT_MODEL_PATH = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"models",
"checkpoints",
"Net2DFast_UK_same.pth.tar",
)
def load_model(
model_path: str = DEFAULT_MODEL_PATH,
load_weights: bool = True,
device: Union[torch.device, str, None] = None,
) -> Tuple[DetectionModel, dict]:
"""Load model from file.
Args:
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
load_weights (bool, optional): Load weights. Defaults to True.
Returns:
model, params: Model and parameters.
Raises:
FileNotFoundError: Model file not found.
ValueError: Unknown model name.
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not os.path.isfile(model_path):
raise FileNotFoundError("Model file not found.")
net_params = torch.load(model_path, map_location=device)
params = net_params["params"]
model: DetectionModel
if params["model_name"] == "Net2DFast":
model = Net2DFast(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
elif params["model_name"] == "Net2DFastNoAttn":
model = Net2DFastNoAttn(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
elif params["model_name"] == "Net2DFastNoCoordConv":
model = Net2DFastNoCoordConv(
params["num_filters"],
num_classes=len(params["class_names"]),
emb_dim=params["emb_dim"],
ip_height=params["ip_height"],
resize_factor=params["resize_factor"],
)
else:
raise ValueError("Unknown model.")
if load_weights:
model.load_state_dict(net_params["state_dict"])
model = model.to(device)
model.eval()
return model, params

219
batdetect2/models/blocks.py Normal file
View File

@ -0,0 +1,219 @@
"""Module containing custom NN blocks.
All these classes are subclasses of `torch.nn.Module` and can be used to build
complex neural network architectures.
"""
from typing import Tuple
import torch
import torch.nn.functional as F
from torch import nn
__all__ = [
"SelfAttention",
"ConvBlockDownCoordF",
"ConvBlockDownStandard",
"ConvBlockUpF",
"ConvBlockUpStandard",
]
class SelfAttention(nn.Module):
"""Self-Attention module.
This module implements self-attention mechanism.
"""
def __init__(self, ip_dim: int, att_dim: int):
super().__init__()
# Note, does not encode position information (absolute or realtive)
self.temperature = 1.0
self.att_dim = att_dim
self.key_fun = nn.Linear(ip_dim, att_dim)
self.val_fun = nn.Linear(ip_dim, att_dim)
self.que_fun = nn.Linear(ip_dim, att_dim)
self.pro_fun = nn.Linear(att_dim, ip_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.squeeze(2).permute(0, 2, 1)
key = torch.matmul(
x, self.key_fun.weight.T
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
query = torch.matmul(
x, self.que_fun.weight.T
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
value = torch.matmul(
x, self.val_fun.weight.T
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
self.temperature * self.att_dim
)
att_weights = F.softmax(kk_qq, 1)
att = torch.bmm(value.permute(0, 2, 1), att_weights)
op = torch.matmul(
att.permute(0, 2, 1), self.pro_fun.weight.T
) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
op = op.permute(0, 2, 1).unsqueeze(2)
return op
class ConvBlockDownCoordF(nn.Module):
"""Convolutional Block with Downsampling and Coord Feature.
This block performs convolution followed by downsampling
and concatenates with coordinate information.
"""
def __init__(
self,
in_chn: int,
out_chn: int,
ip_height: int,
k_size: int = 3,
pad_size: int = 1,
stride: int = 1,
):
super().__init__()
self.coords = nn.Parameter(
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
requires_grad=False,
)
self.conv = nn.Conv2d(
in_chn + 1,
out_chn,
kernel_size=k_size,
padding=pad_size,
stride=stride,
)
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
x = torch.cat((x, freq_info), 1)
x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True)
return x
class ConvBlockDownStandard(nn.Module):
"""Convolutional Block with Downsampling.
This block performs convolution followed by downsampling.
"""
def __init__(
self,
in_chn,
out_chn,
k_size=3,
pad_size=1,
stride=1,
):
super(ConvBlockDownStandard, self).__init__()
self.conv = nn.Conv2d(
in_chn,
out_chn,
kernel_size=k_size,
padding=pad_size,
stride=stride,
)
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x):
x = F.max_pool2d(self.conv(x), 2, 2)
x = F.relu(self.conv_bn(x), inplace=True)
return x
class ConvBlockUpF(nn.Module):
"""Convolutional Block with Upsampling and Coord Feature.
This block performs convolution followed by upsampling
and concatenates with coordinate information.
"""
def __init__(
self,
in_chn: int,
out_chn: int,
ip_height: int,
k_size: int = 3,
pad_size: int = 1,
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
):
super().__init__()
self.up_scale = up_scale
self.up_mode = up_mode
self.coords = nn.Parameter(
torch.linspace(-1, 1, ip_height * up_scale[0])[
None, None, ..., None
],
requires_grad=False,
)
self.conv = nn.Conv2d(
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
)
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
op = F.interpolate(
x,
size=(
x.shape[-2] * self.up_scale[0],
x.shape[-1] * self.up_scale[1],
),
mode=self.up_mode,
align_corners=False,
)
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
op = torch.cat((op, freq_info), 1)
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
return op
class ConvBlockUpStandard(nn.Module):
"""Convolutional Block with Upsampling.
This block performs convolution followed by upsampling.
"""
def __init__(
self,
in_chn: int,
out_chn: int,
k_size: int = 3,
pad_size: int = 1,
up_mode: str = "bilinear",
up_scale: Tuple[int, int] = (2, 2),
):
super(ConvBlockUpStandard, self).__init__()
self.up_scale = up_scale
self.up_mode = up_mode
self.conv = nn.Conv2d(
in_chn, out_chn, kernel_size=k_size, padding=pad_size
)
self.conv_bn = nn.BatchNorm2d(out_chn)
def forward(self, x: torch.Tensor) -> torch.Tensor:
op = F.interpolate(
x,
size=(
x.shape[-2] * self.up_scale[0],
x.shape[-1] * self.up_scale[1],
),
mode=self.up_mode,
align_corners=False,
)
op = self.conv(op)
op = F.relu(self.conv_bn(op), inplace=True)
return op

View File

@ -0,0 +1,148 @@
import pytorch_lightning as L
import torch
import xarray as xr
from soundevent import data
from torch import nn, optim
from batdetect2.data.preprocessing import preprocess_audio_clip
from batdetect2.models.typing import EncoderModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
from batdetect2.models.post_process import (
PostprocessConfig,
postprocess_model_outputs,
)
from batdetect2.train.preprocess import PreprocessingConfig
class DetectorModel(L.LightningModule):
def __init__(
self,
encoder: EncoderModel,
num_classes: int,
learning_rate: float = 1e-3,
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
postprocessing_config: PostprocessConfig = PostprocessConfig(),
):
super().__init__()
self.preprocessing_config = preprocessing_config
self.postprocessing_config = postprocessing_config
self.num_classes = num_classes
self.learning_rate = learning_rate
self.encoder = encoder
self.classifier = nn.Conv2d(
self.encoder.num_filts // 4,
self.num_classes + 1,
kernel_size=1,
padding=0,
)
self.bbox = nn.Conv2d(
self.encoder.num_filts // 4,
2,
kernel_size=1,
padding=0,
)
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
features = self.encoder(spec)
classification_logits = self.classifier(features)
classification_probs = torch.softmax(classification_logits, dim=1)
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
return ModelOutput(
detection_probs=detection_probs,
size_preds=self.bbox(features),
class_probs=classification_probs,
features=features,
)
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
config = self.preprocessing_config
return preprocess_audio_clip(
clip,
target_sampling_rate=config.target_samplerate,
scale_audio=config.scale_audio,
fft_win_length=config.fft_win_length,
fft_overlap=config.fft_overlap,
max_freq=config.max_freq,
min_freq=config.min_freq,
spec_scale=config.spec_scale,
denoise_spec_avg=config.denoise_spec_avg,
max_scale_spec=config.max_scale_spec,
)
def process_clip(self, clip: data.Clip):
spectrogram = self.compute_spectrogram(clip)
spec_tensor = (
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
)
outputs = self(spec_tensor)
config = self.postprocessing_config
return postprocess_model_outputs(
outputs,
[clip],
nms_kernel_size=config.nms_kernel_size,
detection_threshold=config.detection_threshold,
min_freq=config.min_freq,
max_freq=config.max_freq,
top_k_per_sec=config.top_k_per_sec,
)
def compute_loss(
self,
outputs: ModelOutput,
batch: TrainExample,
) -> torch.Tensor:
detection_loss = losses.focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
)
size_loss = losses.bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = losses.focal_loss(
outputs.class_probs,
batch.class_heatmap,
valid_mask=valid_mask,
)
return detection_loss + size_loss + classification_loss
def training_step( # type: ignore
self,
batch: TrainExample,
):
features = self.encoder(batch.spec)
classification_logits = self.classifier(features)
classification_probs = torch.softmax(classification_logits, dim=1)
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
loss = self.compute_loss(
ModelOutput(
detection_probs=detection_probs,
size_preds=self.bbox(features),
class_probs=classification_probs,
features=features,
),
batch,
)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
return [optimizer], [scheduler]

View File

@ -0,0 +1,319 @@
from typing import Tuple
import torch
import torch.fft
import torch.nn.functional as F
from torch import nn
from batdetect2.models.typing import EncoderModel
from batdetect2.models.blocks import (
ConvBlockDownCoordF,
ConvBlockDownStandard,
ConvBlockUpF,
ConvBlockUpStandard,
SelfAttention,
)
__all__ = [
"Net2DFast",
"Net2DFastNoAttn",
"Net2DFastNoCoordConv",
]
class Net2DFast(EncoderModel):
def __init__(
self,
num_filts: int,
input_height: int = 128,
):
super().__init__()
self.num_filts = num_filts
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
# encoder
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_filts // 4,
self.input_height,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_filts // 4,
self.num_filts // 2,
self.input_height // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_filts // 2,
self.num_filts,
self.input_height // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_filts,
self.num_filts * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
# bottleneck
self.conv_1d = nn.Conv2d(
self.num_filts * 2,
self.num_filts * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
# decoder
self.conv_up_2 = ConvBlockUpF(
self.num_filts * 2,
self.num_filts // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_filts // 2,
self.num_filts // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_filts // 4,
self.num_filts // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_filts // 4,
self.num_filts // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
h, w = spec.shape[2:]
h_pad = (32 - h % 32) % 32
w_pad = (32 - w % 32) % 32
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
def forward(self, spec: torch.Tensor) -> torch.Tensor:
# encoder
spec, h_pad, w_pad = self.pad_adjust(spec)
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
# bottleneck
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x)
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
# decoder
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
# Restore original size
if h_pad > 0:
x = x[:, :, :-h_pad, :]
if w_pad > 0:
x = x[:, :, :, :-w_pad]
return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoAttn(EncoderModel):
def __init__(
self,
num_filts: int,
input_height: int = 128,
):
super().__init__()
self.num_filts = num_filts
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownCoordF(
1,
self.num_filts // 4,
self.input_height,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownCoordF(
self.num_filts // 4,
self.num_filts // 2,
self.input_height // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownCoordF(
self.num_filts // 2,
self.num_filts,
self.input_height // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_filts,
self.num_filts * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_1d = nn.Conv2d(
self.num_filts * 2,
self.num_filts * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_up_2 = ConvBlockUpF(
self.num_filts * 2,
self.num_filts // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpF(
self.num_filts // 2,
self.num_filts // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpF(
self.num_filts // 4,
self.num_filts // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_filts // 4,
self.num_filts // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
return F.relu_(self.conv_op_bn(self.conv_op(x)))
class Net2DFastNoCoordConv(EncoderModel):
def __init__(
self,
num_filts: int,
input_height: int = 128,
):
super().__init__()
self.num_filts = num_filts
self.input_height = input_height
self.bottleneck_height = self.input_height // 32
self.conv_dn_0 = ConvBlockDownStandard(
1,
self.num_filts // 4,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_1 = ConvBlockDownStandard(
self.num_filts // 4,
self.num_filts // 2,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_2 = ConvBlockDownStandard(
self.num_filts // 2,
self.num_filts,
k_size=3,
pad_size=1,
stride=1,
)
self.conv_dn_3 = nn.Conv2d(
self.num_filts,
self.num_filts * 2,
3,
padding=1,
)
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
self.conv_1d = nn.Conv2d(
self.num_filts * 2,
self.num_filts * 2,
(self.input_height // 8, 1),
padding=0,
)
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
self.conv_up_2 = ConvBlockUpStandard(
self.num_filts * 2,
self.num_filts // 2,
self.input_height // 8,
)
self.conv_up_3 = ConvBlockUpStandard(
self.num_filts // 2,
self.num_filts // 4,
self.input_height // 4,
)
self.conv_up_4 = ConvBlockUpStandard(
self.num_filts // 4,
self.num_filts // 4,
self.input_height // 2,
)
self.conv_op = nn.Conv2d(
self.num_filts // 4,
self.num_filts // 4,
kernel_size=3,
padding=1,
)
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
def forward(self, spec: torch.Tensor) -> torch.Tensor:
x1 = self.conv_dn_0(spec)
x2 = self.conv_dn_1(x1)
x3 = self.conv_dn_2(x2)
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
x = self.att(x)
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
x = self.conv_up_2(x + x3)
x = self.conv_up_3(x + x2)
x = self.conv_up_4(x + x1)
return F.relu_(self.conv_op_bn(self.conv_op(x)))

View File

@ -0,0 +1,310 @@
"""Module for postprocessing model outputs."""
from typing import Callable, List, Tuple, Union
from pydantic import BaseModel, Field
import numpy as np
import torch
from soundevent import data
from torch import nn
from batdetect2.models.typing import ModelOutput
__all__ = [
"postprocess_model_outputs",
"PostprocessConfig",
]
NMS_KERNEL_SIZE = 9
DETECTION_THRESHOLD = 0.01
TOP_K_PER_SEC = 200
class PostprocessConfig(BaseModel):
"""Configuration for postprocessing model outputs."""
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
min_freq: int = Field(default=10000, gt=0)
max_freq: int = Field(default=120000, gt=0)
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
TagFunction = Callable[[int], List[data.Tag]]
def postprocess_model_outputs(
outputs: ModelOutput,
clips: List[data.Clip],
nms_kernel_size: int = NMS_KERNEL_SIZE,
detection_threshold: float = DETECTION_THRESHOLD,
min_freq: int = 10000,
max_freq: int = 120000,
top_k_per_sec: int = TOP_K_PER_SEC,
) -> List[data.ClipPrediction]:
"""Postprocesses model outputs to generate clip predictions.
This function takes the output from the model, applies non-maximum suppression,
selects the top-k scores, computes sound events from the outputs, and returns
clip predictions based on these processed outputs.
Parameters
----------
outputs
Output from the model containing detection probabilities, size
predictions, class logits, and features. All tensors are expected
to have a batch dimension.
clips
List of clips for which predictions are made. The number of clips
must match the batch dimension of the model outputs.
nms_kernel_size
Size of the non-maximum suppression kernel. Default is 9.
detection_threshold
Detection threshold. Default is 0.01.
min_freq
Minimum frequency. Default is 10000.
max_freq
Maximum frequency. Default is 120000.
top_k_per_sec
Top k per second. Default is 200.
Returns
-------
predictions: List[data.ClipPrediction]
List of clip predictions containing predicted sound events.
Raises
------
ValueError
If the number of predictions does not match the number of clips.
"""
num_predictions = len(outputs.detection_probs)
if num_predictions == 0:
return []
if num_predictions != len(clips):
raise ValueError(
"Number of predictions must match the number of clips."
)
detection_probs = non_max_suppression(
outputs.detection_probs,
kernel_size=nms_kernel_size,
)
duration = clips[0].end_time - clips[0].start_time
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
detection_probs,
int(top_k_per_sec * duration / 2),
)
predictions: List[data.ClipPrediction] = []
for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip(
scores_batch,
y_pos_batch,
x_pos_batch,
outputs.size_preds,
outputs.class_probs,
outputs.features,
clips,
):
sound_events = compute_sound_events_from_outputs(
clip,
scores,
y_pos,
x_pos,
size_preds,
class_probs,
features,
min_freq=min_freq,
max_freq=max_freq,
detection_threshold=detection_threshold,
)
predictions.append(
data.ClipPrediction(
clip=clip,
sound_events=sound_events,
)
)
return predictions
def compute_sound_events_from_outputs(
clip: data.Clip,
scores: torch.Tensor,
y_pos: torch.Tensor,
x_pos: torch.Tensor,
size_preds: torch.Tensor,
class_probs: torch.Tensor,
features: torch.Tensor,
tag_fn: TagFunction = lambda _: [],
min_freq: int = 10000,
max_freq: int = 120000,
detection_threshold: float = DETECTION_THRESHOLD,
) -> List[data.SoundEventPrediction]:
_, freq_bins, time_bins = size_preds.shape
sorted_indices = torch.argsort(x_pos)
valid_indices = sorted_indices[
scores[sorted_indices] > detection_threshold
]
scores = scores[valid_indices]
x_pos = x_pos[valid_indices]
y_pos = y_pos[valid_indices]
predictions: List[data.SoundEventPrediction] = []
for score, x, y in zip(scores, x_pos, y_pos):
width, height = size_preds[:, y, x]
print(width, height)
class_prob = class_probs[:, y, x]
feature = features[:, y, x]
start_time = np.interp(
x.item(),
[0, time_bins],
[clip.start_time, clip.end_time],
)
end_time = np.interp(
x.item() + width.item(),
[0, time_bins],
[clip.start_time, clip.end_time],
)
low_freq = np.interp(
y.item(),
[0, freq_bins],
[max_freq, min_freq],
)
high_freq = np.interp(
y.item() - height.item(),
[0, freq_bins],
[max_freq, min_freq],
)
predicted_tags: List[data.PredictedTag] = []
for label_id, class_score in enumerate(class_prob):
corresponding_tags = tag_fn(label_id)
predicted_tags.extend(
[
data.PredictedTag(
tag=tag,
score=class_score.item(),
)
for tag in corresponding_tags
]
)
start_time, end_time = sorted([float(start_time), float(end_time)])
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
sound_event = data.SoundEvent(
recording=clip.recording,
geometry=data.BoundingBox(
coordinates=[
start_time,
low_freq,
end_time,
high_freq,
]
),
features=[
data.Feature(
name=f"batdetect2_{i}",
value=value.item(),
)
for i, value in enumerate(feature)
],
)
predictions.append(
data.SoundEventPrediction(
sound_event=sound_event,
score=score.item(),
tags=predicted_tags,
)
)
return predictions
def non_max_suppression(
tensor: torch.Tensor,
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
) -> torch.Tensor:
"""Run non-maximum suppression on a tensor.
This function removes values from the input tensor that are not local
maxima in the neighborhood of the given kernel size.
All non-maximum values are set to zero.
Parameters
----------
tensor : torch.Tensor
Input tensor.
kernel_size : Union[int, Tuple[int, int]], optional
Size of the neighborhood to consider for non-maximum suppression.
If an integer is given, the neighborhood will be a square of the
given size. If a tuple is given, the neighborhood will be a
rectangle with the given height and width.
Returns
-------
torch.Tensor
Tensor with non-maximum suppressed values.
"""
if isinstance(kernel_size, int):
kernel_size_h = kernel_size
kernel_size_w = kernel_size
else:
kernel_size_h, kernel_size_w = kernel_size
pad_h = (kernel_size_h - 1) // 2
pad_w = (kernel_size_w - 1) // 2
hmax = nn.functional.max_pool2d(
tensor,
(kernel_size_h, kernel_size_w),
stride=1,
padding=(pad_h, pad_w),
)
keep = (hmax == tensor).float()
return tensor * keep
def get_topk_scores(
scores: torch.Tensor,
K: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Get the top-k scores and their indices.
Parameters
----------
scores : torch.Tensor
Tensor with scores. Expects input of size: `batch x 1 x height x width`.
K : int
Number of top scores to return.
Returns
-------
scores : torch.Tensor
Top-k scores.
ys : torch.Tensor
Y coordinates of the top-k scores.
xs : torch.Tensor
X coordinates of the top-k scores.
"""
batch, _, height, width = scores.size()
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
topk_inds = topk_inds % (height * width)
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
topk_xs = (topk_inds % width).long()
return topk_scores, topk_ys, topk_xs

View File

@ -0,0 +1,55 @@
from abc import ABC, abstractmethod
from typing import NamedTuple
import torch
import torch.nn as nn
class ModelOutput(NamedTuple):
"""Output of the detection model.
Each of the tensors has a shape of
`(batch_size, num_channels, spec_height, spec_width)`.
Where `spec_height` and `spec_width` are the height and width of the
input spectrograms.
They contain localised information of:
1. The probability of a bounding box detection at the given location.
2. The predicted size of the bounding box at the given location.
3. The probabilities of each class at the given location before softmax.
4. Features used to make the predictions at the given location.
"""
detection_probs: torch.Tensor
"""Tensor with predict detection probabilities."""
size_preds: torch.Tensor
"""Tensor with predicted bounding box sizes."""
class_probs: torch.Tensor
"""Tensor with predicted class probabilities."""
features: torch.Tensor
"""Tensor with intermediate features."""
class EncoderModel(ABC, nn.Module):
input_height: int
"""Height of the input spectrogram."""
num_filts: int
"""Dimension of the feature tensor."""
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Forward pass of the encoder model."""
class DetectionModel(ABC, nn.Module):
@abstractmethod
def forward(self, spec: torch.Tensor) -> torch.Tensor:
"""Forward pass of the detection model."""

View File

@ -102,6 +102,7 @@ def spectrogram(
return ax
def spectrogram_with_detections(
spec: Union[torch.Tensor, np.ndarray],
dets: List[Annotation],

View File

View File

@ -0,0 +1,22 @@
"""General plotting utilities."""
from typing import Optional, Tuple
import matplotlib.pyplot as plt
from matplotlib import axes
__all__ = [
"create_ax",
]
def create_ax(
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
**kwargs,
) -> axes.Axes:
"""Create a new axis if none is provided"""
if ax is None:
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
return ax # type: ignore

View File

@ -0,0 +1,27 @@
"""Plot heatmaps"""
from typing import Optional, Tuple
import matplotlib.pyplot as plt
import xarray as xr
from matplotlib import axes
from batdetect2.plotting.common import create_ax
def plot_heatmap(
heatmap: xr.DataArray,
ax: Optional[axes.Axes] = None,
figsize: Tuple[int, int] = (10, 10),
) -> axes.Axes:
ax = create_ax(ax, figsize=figsize)
ax.pcolormesh(
heatmap.time,
heatmap.frequency,
heatmap,
vmax=1,
vmin=0,
)
return ax

View File

@ -1,4 +1,5 @@
"""Functions and dataloaders for training and testing the model."""
import copy
from typing import List, Optional, Tuple
@ -199,8 +200,7 @@ def draw_gaussian(
x0 = y0 = size // 2
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
g = np.exp(
-((x - x0) ** 2) / (2 * sigmax**2)
- ((y - y0) ** 2) / (2 * sigmay**2)
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
)
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
@ -399,6 +399,8 @@ def echo_aug(
sample_offset = (
int(echo_max_delay * np.random.random() * sampling_rate) + 1
)
# NOTE: This seems to be wrong, as the echo should be added to the
# end of the audio, not the beginning.
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
return audio
@ -820,9 +822,10 @@ class AudioLoader(torch.utils.data.Dataset):
# )
# create spectrogram
spec = au.generate_spectrogram(
spec, _ = au.generate_spectrogram(
audio,
sampling_rate,
params=dict(
fft_win_length=self.params["fft_win_length"],
fft_overlap=self.params["fft_overlap"],
max_freq=self.params["max_freq"],
@ -830,6 +833,7 @@ class AudioLoader(torch.utils.data.Dataset):
spec_scale=self.params["spec_scale"],
denoise_spec_avg=self.params["denoise_spec_avg"],
max_scale_spec=self.params["max_scale_spec"],
),
)
rsf = self.params["resize_factor"]
spec_op_shape = (

View File

@ -0,0 +1,86 @@
import os
from typing import NamedTuple
from pathlib import Path
from typing import Sequence, Union, Dict
from soundevent import data
from torch.utils.data import Dataset
import torch
import xarray as xr
from batdetect2.train.preprocess import PreprocessingConfig
__all__ = [
"TrainExample",
"LabeledDataset",
]
PathLike = Union[Path, str, os.PathLike]
class TrainExample(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
idx: torch.Tensor
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
return list(Path(directory).glob(f"*{extension}"))
class LabeledDataset(Dataset):
def __init__(self, filenames: Sequence[PathLike]):
self.filenames = filenames
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx) -> TrainExample:
data = self.load(self.filenames[idx])
return TrainExample(
spec=data["spectrogram"],
detection_heatmap=data["detection"],
class_heatmap=data["class"],
size_heatmap=data["size"],
idx=torch.tensor(idx),
)
@classmethod
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
return cls(get_files(directory, extension))
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
dataset = xr.open_dataset(filename)
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
return {
"spectrogram": spectrogram,
"detection": torch.tensor(dataset["detection"].values),
"class": torch.tensor(dataset["class"].values),
"size": torch.tensor(dataset["size"].values),
}
def get_spectrogram(self, idx):
return xr.open_dataset(self.filenames[idx])["spectrogram"]
def get_detection_mask(self, idx):
return xr.open_dataset(self.filenames[idx])["detection"]
def get_class_mask(self, idx):
return xr.open_dataset(self.filenames[idx])["class"]
def get_size_mask(self, idx):
return xr.open_dataset(self.filenames[idx])["size"]
def get_clip_annotation(self, idx):
filename = self.filenames[idx]
dataset = xr.open_dataset(filename)
clip_annotation = dataset.attrs["clip_annotation"]
return data.ClipAnnotation.model_validate_json(clip_annotation)
def get_preprocessing_configuration(self, idx):
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
return PreprocessingConfig.model_validate_json(config)

56
batdetect2/train/light.py Normal file
View File

@ -0,0 +1,56 @@
import pytorch_lightning as L
from torch import Tensor, optim
from batdetect2.models.typing import DetectionModel, ModelOutput
from batdetect2.train import losses
from batdetect2.train.dataset import TrainExample
__all__ = [
"LitDetectorModel",
]
class LitDetectorModel(L.LightningModule):
model: DetectionModel
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
super().__init__()
self.model = model
self.learning_rate = learning_rate
def compute_loss(
self,
outputs: ModelOutput,
batch: TrainExample,
) -> Tensor:
detection_loss = losses.focal_loss(
outputs.detection_probs,
batch.detection_heatmap,
)
size_loss = losses.bbox_size_loss(
outputs.size_preds,
batch.size_heatmap,
)
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
classification_loss = losses.focal_loss(
outputs.class_probs,
batch.class_heatmap,
valid_mask=valid_mask,
)
return detection_loss + size_loss + classification_loss
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
outputs: ModelOutput = self.model(batch.spec)
loss = self.compute_loss(outputs, batch)
self.log("train_loss", loss)
return loss
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
return [optimizer], [scheduler]

View File

@ -22,15 +22,15 @@ def focal_loss(
gt: torch.Tensor,
weights: Optional[torch.Tensor] = None,
valid_mask: Optional[torch.Tensor] = None,
eps: float = 1e-5,
beta: float = 4,
alpha: float = 2,
) -> torch.Tensor:
"""
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
pred (batch x c x h x w)
gt (batch x c x h x w)
"""
eps = 1e-5
beta = 4
alpha = 2
pos_inds = gt.eq(1).float()
neg_inds = gt.lt(1).float()

View File

@ -0,0 +1,209 @@
"""Module for preprocessing data for training."""
import os
import warnings
from functools import partial
from pathlib import Path
from typing import Callable, Optional, Sequence, Union
from tqdm.auto import tqdm
from multiprocessing import Pool
import xarray as xr
from pydantic import BaseModel, Field
from soundevent import data
from batdetect2.data.labels import TARGET_SIGMA, LabelFn, generate_heatmaps
from batdetect2.data.preprocessing import (
DENOISE_SPEC_AVG,
FFT_OVERLAP,
FFT_WIN_LENGTH_S,
MAX_FREQ_HZ,
MAX_SCALE_SPEC,
MIN_FREQ_HZ,
SCALE_RAW_AUDIO,
SPEC_SCALE,
TARGET_SAMPLERATE_HZ,
preprocess_audio_clip,
)
PathLike = Union[Path, str, os.PathLike]
FilenameFn = Callable[[data.ClipAnnotation], str]
__all__ = [
"preprocess_annotations",
]
class PreprocessingConfig(BaseModel):
"""Configuration for preprocessing data."""
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
spec_scale: str = Field(default=SPEC_SCALE)
denoise_spec_avg: bool = DENOISE_SPEC_AVG
max_scale_spec: bool = MAX_SCALE_SPEC
target_sigma: float = Field(default=TARGET_SIGMA, gt=0)
class_labels: Sequence[str] = ["bat"]
def generate_train_example(
clip_annotation: data.ClipAnnotation,
label_fn: LabelFn = lambda _: None,
config: Optional[PreprocessingConfig] = None,
) -> xr.Dataset:
"""Generate a training example."""
if config is None:
config = PreprocessingConfig()
spectrogram = preprocess_audio_clip(
clip_annotation.clip,
target_sampling_rate=config.target_samplerate,
scale_audio=config.scale_audio,
fft_win_length=config.fft_win_length,
fft_overlap=config.fft_overlap,
max_freq=config.max_freq,
min_freq=config.min_freq,
spec_scale=config.spec_scale,
denoise_spec_avg=config.denoise_spec_avg,
max_scale_spec=config.max_scale_spec,
)
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
clip_annotation,
spectrogram,
target_sigma=config.target_sigma,
num_classes=len(config.class_labels),
class_labels=list(config.class_labels),
label_fn=label_fn,
)
dataset = xr.Dataset(
{
"spectrogram": spectrogram,
"detection": detection_heatmap,
"class": class_heatmap,
"size": size_heatmap,
}
)
return dataset.assign_attrs(
title=f"Training example for {clip_annotation.uuid}",
configuration=config.model_dump_json(),
clip_annotation=clip_annotation.model_dump_json(),
)
def save_to_file(
dataset: xr.Dataset,
path: PathLike,
) -> None:
dataset.to_netcdf(
path,
encoding={
"spectrogram": {"zlib": True},
"size": {"zlib": True},
"class": {"zlib": True},
"detection": {"zlib": True},
},
)
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
"""Load configuration from file."""
path = Path(path)
if not path.is_file():
warnings.warn(f"Config file not found: {path}. Using default config.")
return PreprocessingConfig(**kwargs)
try:
return PreprocessingConfig.model_validate_json(path.read_text())
except ValueError as e:
warnings.warn(
f"Failed to load config file: {e}. Using default config."
)
return PreprocessingConfig(**kwargs)
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
return f"{clip_annotation.uuid}.nc"
def preprocess_single_annotation(
clip_annotation: data.ClipAnnotation,
output_dir: PathLike,
config: PreprocessingConfig,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
label_fn: LabelFn = lambda _: None,
) -> None:
output_dir = Path(output_dir)
filename = filename_fn(clip_annotation)
path = output_dir / filename
if path.is_file() and not replace:
return
sample = generate_train_example(
clip_annotation,
label_fn=label_fn,
config=config,
)
save_to_file(sample, path)
def preprocess_annotations(
clip_annotations: Sequence[data.ClipAnnotation],
output_dir: PathLike,
filename_fn: FilenameFn = _get_filename,
replace: bool = False,
config_file: Optional[PathLike] = None,
label_fn: LabelFn = lambda _: None,
max_workers: Optional[int] = None,
**kwargs,
) -> None:
"""Preprocess annotations and save to disk."""
output_dir = Path(output_dir)
if not output_dir.is_dir():
output_dir.mkdir(parents=True)
if config_file is not None:
config = load_config(config_file, **kwargs)
else:
config = PreprocessingConfig(**kwargs)
with Pool(max_workers) as pool:
list(
tqdm(
pool.imap_unordered(
partial(
preprocess_single_annotation,
output_dir=output_dir,
config=config,
filename_fn=filename_fn,
replace=replace,
label_fn=label_fn,
),
clip_annotations,
),
total=len(clip_annotations),
)
)

82
batdetect2/train/train.py Normal file
View File

@ -0,0 +1,82 @@
from typing import Callable, NamedTuple, Optional
import torch
from soundevent import data
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from batdetect2.data.datasets import ClipAnnotationDataset
from batdetect2.models.typing import DetectionModel
class TrainInputs(NamedTuple):
spec: torch.Tensor
detection_heatmap: torch.Tensor
class_heatmap: torch.Tensor
size_heatmap: torch.Tensor
def train_loop(
model: DetectionModel,
train_dataset: ClipAnnotationDataset[TrainInputs],
validation_dataset: ClipAnnotationDataset[TrainInputs],
device: Optional[torch.device] = None,
num_epochs: int = 100,
learning_rate: float = 1e-4,
):
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32)
model.to(device)
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = CosineAnnealingLR(
optimizer,
num_epochs * len(train_loader),
)
for epoch in range(num_epochs):
train_loss = train_single_epoch(
model,
train_loader,
optimizer,
device,
scheduler,
)
def train_single_epoch(
model: DetectionModel,
train_loader: DataLoader,
optimizer: Adam,
device: torch.device,
scheduler: CosineAnnealingLR,
):
model.train()
train_loss = tu.AverageMeter()
for batch in train_loader:
optimizer.zero_grad()
spec = batch.spec.to(device)
detection_heatmap = batch.detection_heatmap.to(device)
class_heatmap = batch.class_heatmap.to(device)
size_heatmap = batch.size_heatmap.to(device)
outputs = model(spec)
loss = loss_fun(
outputs,
gt_det,
gt_size,
gt_class,
det_criterion,
params,
class_inv_freq,
)
train_loss.update(loss.item(), data.shape[0])
loss.backward()
optimizer.step()
scheduler.step()

View File

@ -1,3 +1,4 @@
import sys
import json
from collections import Counter
from pathlib import Path
@ -7,6 +8,11 @@ import numpy as np
from batdetect2 import types
if sys.version_info >= (3, 9):
StringCounter = Counter[str]
else:
from typing import Counter as StringCounter
def write_notes_file(file_name: str, text: str):
with open(file_name, "a") as da:
@ -148,7 +154,7 @@ def format_annotation(
def get_class_names(
data: List[types.FileAnnotation],
classes_to_ignore: Optional[List[str]] = None,
) -> Tuple[Counter[str], List[float]]:
) -> Tuple[StringCounter, List[float]]:
"""Extracts class names and their inverse frequencies.
Parameters
@ -182,7 +188,7 @@ def get_class_names(
return counts, [mean_counts / counts[cc] for cc in class_names_list]
def report_class_counts(class_names: Counter[str]):
def report_class_counts(class_names: StringCounter):
print("Class count:")
str_len = np.max([len(cc) for cc in class_names]) + 5
for index, (class_name, count) in enumerate(class_names.most_common()):

View File

@ -1,11 +1,13 @@
import warnings
from typing import Optional, Tuple, Union, overload
from typing import Optional, Tuple
import librosa
import librosa.core.spectrum
import numpy as np
import torch
from . import wavfile
__all__ = [
"load_audio",
"generate_spectrogram",
@ -13,171 +15,113 @@ __all__ = [
]
@overload
def time_to_x_coords(
time_in_file: np.ndarray,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> np.ndarray:
...
@overload
def time_to_x_coords(
time_in_file: float,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> float:
...
def time_to_x_coords(
time_in_file: Union[float, np.ndarray],
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
) -> Union[float, np.ndarray]:
nfft = np.floor(fft_win_length * sampling_rate)
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
noverlap = np.floor(fft_overlap * nfft)
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
# NOTE this is also defined in post_process
def x_coords_to_time(
x_pos: float,
sampling_rate: int,
fft_win_length: float,
fft_overlap: float,
) -> float:
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
nfft = np.floor(fft_win_length * sampling_rate)
noverlap = np.floor(fft_overlap * nfft)
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for
# center of temporal window
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
def generate_spectrogram(
audio: np.ndarray,
sampling_rate: float,
fft_win_length: float,
fft_overlap: float,
max_freq: float,
min_freq: float,
spec_scale: str,
denoise_spec_avg: bool = False,
max_scale_spec: bool = False,
) -> np.ndarray:
audio,
sampling_rate,
params,
return_spec_for_viz=False,
check_spec_size=True,
):
# generate spectrogram
spec = gen_mag_spectrogram(
audio,
sampling_rate,
window_len=fft_win_length,
overlap_perc=fft_overlap,
)
spec = crop_spectrogram(
spec,
fft_win_length=fft_win_length,
max_freq=max_freq,
min_freq=min_freq,
)
spec = scale_spectrogram(
spec,
sampling_rate,
spec_scale=spec_scale,
fft_win_length=fft_win_length,
params["fft_win_length"],
params["fft_overlap"],
)
if denoise_spec_avg:
spec = denoise_spectrogram(spec)
if max_scale_spec:
spec = max_scale_spectrogram(spec)
return spec
def crop_spectrogram(
spec: np.ndarray,
fft_win_length: float,
max_freq: float,
min_freq: float,
) -> np.ndarray:
# crop to min/max freq
max_freq = round(max_freq * fft_win_length)
min_freq = round(min_freq * fft_win_length)
max_freq = round(params["max_freq"] * params["fft_win_length"])
min_freq = round(params["min_freq"] * params["fft_win_length"])
if spec.shape[0] < max_freq:
freq_pad = max_freq - spec.shape[0]
spec = np.vstack(
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
)
return spec[-max_freq : spec.shape[0] - min_freq, :]
spec = spec[-max_freq : spec.shape[0] - min_freq, :]
def denoise_spectrogram(spec: np.ndarray) -> np.ndarray:
spec = spec - np.mean(spec, 1)[:, np.newaxis]
return spec.clip(min=0)
def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray:
return spec / (spec.max() + 10e-6)
def log_scale(
spec: np.ndarray,
sampling_rate: float,
fft_win_length: float,
) -> np.ndarray:
if params["spec_scale"] == "log":
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
).sum()
)
)
return np.log1p(log_scaling * spec)
# log_scaling = (1.0 / sampling_rate)*0.1
# log_scaling = (1.0 / sampling_rate)*10e4
spec = np.log1p(log_scaling * spec)
elif params["spec_scale"] == "pcen":
spec = pcen(spec , sampling_rate)
elif params["spec_scale"] == "none":
pass
def scale_spectrogram(
spec: np.ndarray,
sampling_rate: float,
spec_scale: str,
fft_win_length: float,
) -> np.ndarray:
if spec_scale == "log":
return log_scale(spec, sampling_rate, fft_win_length)
if params["denoise_spec_avg"]:
spec = spec - np.mean(spec, 1)[:, np.newaxis]
spec.clip(min=0, out=spec)
if spec_scale == "pcen":
return pcen(spec, sampling_rate)
if params["max_scale_spec"]:
spec = spec / (spec.max() + 10e-6)
return spec
# needs to be divisible by specific factor - if not it should have been padded
# if check_spec_size:
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
def prepare_spec_for_viz(
spec: np.ndarray,
sampling_rate: int,
fft_win_length: float,
) -> np.ndarray:
# for visualization purposes - use log scaled spectrogram
return log_scale(
spec,
sampling_rate,
fft_win_length=fft_win_length,
).astype(np.float32)
if return_spec_for_viz:
log_scaling = (
2.0
* (1.0 / sampling_rate)
* (
1.0
/ (
np.abs(
np.hanning(
int(params["fft_win_length"] * sampling_rate)
)
)
** 2
).sum()
)
)
spec_for_viz = np.log1p(log_scaling * spec).astype(np.float32)
else:
spec_for_viz = None
return spec, spec_for_viz
def load_audio(
audio_file: str,
time_exp_fact: float,
target_sampling_rate: int,
target_samp_rate: int,
scale: bool = False,
max_duration: Optional[float] = None,
) -> Tuple[float, np.ndarray]:
) -> Tuple[int, np.ndarray]:
"""Load an audio file and resample it to the target sampling rate.
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
@ -208,82 +152,63 @@ def load_audio(
"""
with warnings.catch_warnings():
audio, sampling_rate = librosa.load(
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
# sampling_rate, audio_raw = wavfile.read(audio_file)
audio_raw, sampling_rate = librosa.load(
audio_file,
sr=None,
dtype=np.float32,
)
if len(audio.shape) > 1:
if len(audio_raw.shape) > 1:
raise ValueError("Currently does not handle stereo files")
sampling_rate = sampling_rate * time_exp_fact
# resample - need to do this after correcting for time expansion
audio = resample_audio(audio, sampling_rate, target_sampling_rate)
if max_duration is not None:
audio = clip_audio(audio, target_sampling_rate, max_duration)
# scale to [-1, 1]
if scale:
audio = scale_audio(audio)
return target_sampling_rate, audio
def resample_audio(
audio: np.ndarray,
sr_orig: float,
sr_target: float,
) -> np.ndarray:
if sr_orig != sr_target:
return librosa.resample(
audio,
orig_sr=sr_orig,
target_sr=sr_target,
sampling_rate_old = sampling_rate
sampling_rate = target_samp_rate
if sampling_rate_old != sampling_rate:
audio_raw = librosa.resample(
audio_raw,
orig_sr=sampling_rate_old,
target_sr=sampling_rate,
res_type="polyphase",
)
return audio
def clip_audio(
audio: np.ndarray,
sampling_rate: float,
max_duration: float,
) -> np.ndarray:
# clipping maximum duration
if max_duration is not None:
max_duration = int(
np.minimum(
int(sampling_rate * max_duration),
audio.shape[0],
audio_raw.shape[0],
)
)
return audio[:max_duration]
audio_raw = audio_raw[:max_duration]
# scale to [-1, 1]
if scale:
audio_raw = audio_raw - audio_raw.mean()
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
def scale_audio(
audio: np.ndarray,
eps: float = 10e-6,
) -> np.ndarray:
return (audio - audio.mean()) / (np.abs(audio).max() + eps)
return sampling_rate, audio_raw
def pad_audio(
audio_raw: np.ndarray,
sampling_rate: float,
window_len: float,
overlap_perc: float,
resize_factor: float,
divide_factor: float,
fixed_width: Optional[int] = None,
) -> np.ndarray:
audio_raw,
fs,
ms,
overlap_perc,
resize_factor,
divide_factor,
fixed_width=None,
):
# Adds zeros to the end of the raw data so that the generated sepctrogram
# will be evenly divisible by `divide_factor`
# Also deals with very short audio clips and fixed_width during training
# This code could be clearer, clean up
nfft = int(window_len * sampling_rate)
nfft = int(ms * fs)
noverlap = int(overlap_perc * nfft)
step = nfft - noverlap
min_size = int(divide_factor * (1.0 / resize_factor))
@ -320,23 +245,22 @@ def pad_audio(
return audio_raw
def gen_mag_spectrogram(
audio: np.ndarray,
sampling_rate: float,
window_len: float,
overlap_perc: float,
) -> np.ndarray:
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
# Computes magnitude spectrogram by specifying time.
audio = audio.astype(np.float32)
nfft = int(window_len * sampling_rate)
x = x.astype(np.float32)
nfft = int(ms * fs)
noverlap = int(overlap_perc * nfft)
# window data
step = nfft - noverlap
# compute spec
spec, _ = librosa.core.spectrum._spectrogram(
y=audio,
y=x,
power=1,
n_fft=nfft,
hop_length=nfft - noverlap,
hop_length=step,
center=False,
)
@ -346,25 +270,24 @@ def gen_mag_spectrogram(
return spec.astype(np.float32)
def gen_mag_spectrogram_pt(
audio: torch.Tensor,
sampling_rate: float,
window_len: float,
overlap_perc: float,
) -> torch.Tensor:
nfft = int(window_len * sampling_rate)
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
nfft = int(ms * fs)
nstep = round((1.0 - overlap_perc) * nfft)
han_win = torch.hann_window(nfft, periodic=False).to(audio.device)
complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False)
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
spec = complex_spec.pow(2.0).sum(-1)
# remove DC component and flip vertically
return torch.flipud(spec[0, 1:, :])
spec = torch.flipud(spec[0, 1:, :])
return spec
def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray:
def pcen(spec_cropped, sampling_rate):
# TODO should be passing hop_length too i.e. step
return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype(
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
np.float32
)
return spec

View File

@ -437,7 +437,7 @@ def compute_spectrogram(
)
# generate spectrogram
spec = au.generate_spectrogram(audio, sampling_rate, params)
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
# convert to pytorch
spec = torch.from_numpy(spec).to(device)
@ -746,7 +746,7 @@ def process_file(
sampling_rate, audio_full = au.load_audio(
audio_file,
time_exp_fact=config.get("time_expansion", 1) or 1,
target_sampling_rate=config["target_samp_rate"],
target_samp_rate=config["target_samp_rate"],
scale=config["scale_raw_audio"],
max_duration=config.get("max_duration"),
)

View File

@ -1,17 +0,0 @@
name: batdetect2
channels:
- defaults
- conda-forge
- pytorch
- nvidia
dependencies:
- python==3.10
- matplotlib
- pandas
- scikit-learn
- numpy
- pytorch
- scipy
- torchvision
- librosa
- torchaudio

1272
pdm.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -1,3 +1,9 @@
[tool]
rye = { dev-dependencies = [
"ipykernel>=6.29.4",
"setuptools>=69.5.1",
"pytest>=8.1.1",
] }
[tool.pdm]
[tool.pdm.dev-dependencies]
dev = [
@ -22,12 +28,17 @@ dependencies = [
"torch>=1.13.1",
"torchaudio",
"torchvision",
"click",
"soundevent>=1.3.5",
"click",
"soundevent[audio,geometry,plot]>=1.3.5",
"click>=8.1.7",
"netcdf4>=1.6.5",
"tqdm>=4.66.2",
"pytorch-lightning>=2.2.2",
"cf-xarray>=0.9.0",
"onnx>=1.16.0",
"lightning[extra]>=2.2.2",
"tensorboard>=2.16.2",
]
requires-python = ">=3.8,<3.12"
requires-python = ">=3.9"
readme = "README.md"
license = { text = "CC-by-nc-4" }
classifiers = [
@ -65,6 +76,9 @@ line-length = 79
profile = "black"
line_length = 79
[tool.ruff]
line-length = 79
[[tool.mypy.overrides]]
module = [
"librosa",

518
requirements-dev.lock Normal file
View File

@ -0,0 +1,518 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
# with-sources: false
-e file:.
absl-py==2.1.0
# via tensorboard
aiobotocore==2.12.3
# via s3fs
aiohttp==3.9.5
# via aiobotocore
# via fsspec
# via lightning
# via s3fs
aioitertools==0.11.0
# via aiobotocore
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via hydra-core
# via omegaconf
anyio==4.3.0
# via starlette
arrow==1.3.0
# via lightning
asttokens==2.4.1
# via stack-data
async-timeout==4.0.3
# via aiohttp
# via redis
attrs==23.2.0
# via aiohttp
audioread==3.0.1
# via librosa
backcall==0.2.0
# via ipython
backoff==2.2.1
# via lightning
beautifulsoup4==4.12.3
# via lightning
bitsandbytes==0.41.0
# via lightning
blessed==1.20.0
# via inquirer
boto3==1.34.69
# via lightning-cloud
botocore==1.34.69
# via aiobotocore
# via boto3
# via s3transfer
certifi==2024.2.2
# via netcdf4
# via requests
cf-xarray==0.9.0
# via batdetect2
cffi==1.16.0
# via soundfile
cftime==1.6.3
# via netcdf4
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via batdetect2
# via lightning
# via lightning-cloud
# via uvicorn
comm==0.2.2
# via ipykernel
contourpy==1.1.1
# via matplotlib
croniter==1.4.1
# via lightning
cycler==0.12.1
# via matplotlib
cython==3.0.10
# via soundevent
dateutils==0.6.12
# via lightning
debugpy==1.8.1
# via ipykernel
decorator==5.1.1
# via ipython
# via librosa
deepdiff==6.7.1
# via lightning
dnspython==2.6.1
# via email-validator
docker==6.1.3
# via lightning
docstring-parser==0.16
# via jsonargparse
editor==1.6.6
# via inquirer
email-validator==2.1.1
# via soundevent
exceptiongroup==1.2.1
# via anyio
# via pytest
executing==2.0.1
# via stack-data
fastapi==0.110.2
# via lightning
# via lightning-cloud
filelock==3.13.4
# via torch
# via triton
fonttools==4.51.0
# via matplotlib
frozenlist==1.4.1
# via aiohttp
# via aiosignal
fsspec==2023.12.2
# via lightning
# via lightning-fabric
# via pytorch-lightning
# via s3fs
# via torch
grpcio==1.62.2
# via tensorboard
h11==0.14.0
# via uvicorn
hydra-core==1.3.2
# via lightning
idna==3.7
# via anyio
# via email-validator
# via requests
# via yarl
importlib-metadata==7.1.0
# via jupyter-client
# via markdown
importlib-resources==6.4.0
# via matplotlib
# via typeshed-client
iniconfig==2.0.0
# via pytest
inquirer==3.2.4
# via lightning
ipykernel==6.29.4
ipython==8.12.3
# via ipykernel
jedi==0.19.1
# via ipython
jinja2==3.1.3
# via lightning
# via torch
jmespath==1.0.1
# via boto3
# via botocore
joblib==1.4.0
# via librosa
# via scikit-learn
jsonargparse==4.28.0
# via lightning
jupyter-client==8.6.1
# via ipykernel
jupyter-core==5.7.2
# via ipykernel
# via jupyter-client
kiwisolver==1.4.5
# via matplotlib
lazy-loader==0.4
# via librosa
librosa==0.10.1
# via batdetect2
lightning==2.2.2
# via batdetect2
lightning-api-access==0.0.5
# via lightning
lightning-cloud==0.5.65
# via lightning
lightning-fabric==2.2.2
# via lightning
lightning-utilities==0.11.2
# via lightning
# via lightning-fabric
# via pytorch-lightning
# via torchmetrics
llvmlite==0.41.1
# via numba
markdown==3.6
# via tensorboard
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
# via jinja2
# via werkzeug
matplotlib==3.7.5
# via batdetect2
# via lightning
# via soundevent
matplotlib-inline==0.1.7
# via ipykernel
# via ipython
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
msgpack==1.0.8
# via librosa
multidict==6.0.5
# via aiohttp
# via yarl
nest-asyncio==1.6.0
# via ipykernel
netcdf4==1.6.5
# via batdetect2
networkx==3.1
# via torch
numba==0.58.1
# via librosa
numpy==1.24.4
# via batdetect2
# via cftime
# via contourpy
# via librosa
# via lightning
# via lightning-fabric
# via matplotlib
# via netcdf4
# via numba
# via onnx
# via pandas
# via pytorch-lightning
# via scikit-learn
# via scipy
# via shapely
# via soxr
# via tensorboard
# via tensorboardx
# via torchmetrics
# via torchvision
# via xarray
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.19.3
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
omegaconf==2.3.0
# via hydra-core
# via lightning
onnx==1.16.0
# via batdetect2
ordered-set==4.1.0
# via deepdiff
packaging==24.0
# via docker
# via hydra-core
# via ipykernel
# via lazy-loader
# via lightning
# via lightning-fabric
# via lightning-utilities
# via matplotlib
# via pooch
# via pytest
# via pytorch-lightning
# via tensorboardx
# via torchmetrics
# via xarray
pandas==2.0.3
# via batdetect2
# via xarray
parso==0.8.4
# via jedi
pexpect==4.9.0
# via ipython
pickleshare==0.7.5
# via ipython
pillow==10.3.0
# via matplotlib
# via torchvision
platformdirs==4.2.0
# via jupyter-core
# via pooch
pluggy==1.4.0
# via pytest
pooch==1.8.1
# via librosa
prompt-toolkit==3.0.43
# via ipython
protobuf==5.26.1
# via onnx
# via tensorboard
# via tensorboardx
psutil==5.9.8
# via ipykernel
# via lightning
ptyprocess==0.7.0
# via pexpect
pure-eval==0.2.2
# via stack-data
pycparser==2.22
# via cffi
pydantic==2.7.0
# via fastapi
# via lightning
# via soundevent
pydantic-core==2.18.1
# via pydantic
pygments==2.17.2
# via ipython
# via rich
pyjwt==2.8.0
# via lightning-cloud
pyparsing==3.1.2
# via matplotlib
pytest==8.1.1
python-dateutil==2.9.0.post0
# via arrow
# via botocore
# via croniter
# via dateutils
# via jupyter-client
# via matplotlib
# via pandas
python-multipart==0.0.9
# via lightning
# via lightning-cloud
pytorch-lightning==2.2.2
# via batdetect2
# via lightning
pytz==2024.1
# via dateutils
# via pandas
pyyaml==6.0.1
# via jsonargparse
# via lightning
# via omegaconf
# via pytorch-lightning
pyzmq==26.0.0
# via ipykernel
# via jupyter-client
readchar==4.0.6
# via inquirer
redis==5.0.4
# via lightning
requests==2.31.0
# via docker
# via fsspec
# via lightning
# via lightning-cloud
# via pooch
rich==13.7.1
# via lightning
# via lightning-cloud
runs==1.2.2
# via editor
s3fs==2023.12.2
# via lightning
s3transfer==0.10.1
# via boto3
scikit-learn==1.3.2
# via batdetect2
# via librosa
scipy==1.10.1
# via batdetect2
# via librosa
# via scikit-learn
# via soundevent
setuptools==69.5.1
# via lightning-utilities
# via readchar
# via tensorboard
shapely==2.0.3
# via soundevent
six==1.16.0
# via asttokens
# via blessed
# via lightning-cloud
# via python-dateutil
# via tensorboard
sniffio==1.3.1
# via anyio
soundevent==1.3.5
# via batdetect2
soundfile==0.12.1
# via librosa
# via soundevent
soupsieve==2.5
# via beautifulsoup4
soxr==0.3.7
# via librosa
stack-data==0.6.3
# via ipython
starlette==0.37.2
# via fastapi
# via lightning
sympy==1.12
# via torch
tensorboard==2.16.2
# via batdetect2
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.2.2
# via lightning
threadpoolctl==3.4.0
# via scikit-learn
tomli==2.0.1
# via pytest
torch==2.2.2
# via batdetect2
# via lightning
# via lightning-fabric
# via pytorch-lightning
# via torchaudio
# via torchmetrics
# via torchvision
torchaudio==2.2.2
# via batdetect2
torchmetrics==1.3.2
# via lightning
# via pytorch-lightning
torchvision==0.17.2
# via batdetect2
tornado==6.4
# via ipykernel
# via jupyter-client
tqdm==4.66.2
# via batdetect2
# via lightning
# via pytorch-lightning
traitlets==5.14.2
# via comm
# via ipykernel
# via ipython
# via jupyter-client
# via jupyter-core
# via lightning
# via matplotlib-inline
triton==2.2.0
# via torch
types-python-dateutil==2.9.0.20240316
# via arrow
typeshed-client==2.5.1
# via jsonargparse
typing-extensions==4.11.0
# via aioitertools
# via anyio
# via fastapi
# via ipython
# via jsonargparse
# via librosa
# via lightning
# via lightning-fabric
# via lightning-utilities
# via pydantic
# via pydantic-core
# via pytorch-lightning
# via starlette
# via torch
# via typeshed-client
# via uvicorn
tzdata==2024.1
# via pandas
urllib3==1.26.18
# via botocore
# via docker
# via lightning
# via lightning-cloud
# via requests
uvicorn==0.29.0
# via lightning
# via lightning-cloud
wcwidth==0.2.13
# via blessed
# via prompt-toolkit
websocket-client==1.7.0
# via docker
# via lightning
# via lightning-cloud
websockets==11.0.3
# via lightning
werkzeug==3.0.2
# via tensorboard
wrapt==1.16.0
# via aiobotocore
xarray==2023.1.0
# via cf-xarray
# via soundevent
xmod==1.8.1
# via editor
# via runs
yarl==1.9.4
# via aiohttp
zipp==3.18.1
# via importlib-metadata
# via importlib-resources

448
requirements.lock Normal file
View File

@ -0,0 +1,448 @@
# generated by rye
# use `rye lock` or `rye sync` to update this lockfile
#
# last locked with the following flags:
# pre: false
# features: []
# all-features: false
# with-sources: false
-e file:.
absl-py==2.1.0
# via tensorboard
aiobotocore==2.12.3
# via s3fs
aiohttp==3.9.5
# via aiobotocore
# via fsspec
# via lightning
# via s3fs
aioitertools==0.11.0
# via aiobotocore
aiosignal==1.3.1
# via aiohttp
annotated-types==0.6.0
# via pydantic
antlr4-python3-runtime==4.9.3
# via hydra-core
# via omegaconf
anyio==4.3.0
# via starlette
arrow==1.3.0
# via lightning
async-timeout==4.0.3
# via aiohttp
# via redis
attrs==23.2.0
# via aiohttp
audioread==3.0.1
# via librosa
backoff==2.2.1
# via lightning
beautifulsoup4==4.12.3
# via lightning
bitsandbytes==0.41.0
# via lightning
blessed==1.20.0
# via inquirer
boto3==1.34.69
# via lightning-cloud
botocore==1.34.69
# via aiobotocore
# via boto3
# via s3transfer
certifi==2024.2.2
# via netcdf4
# via requests
cf-xarray==0.9.0
# via batdetect2
cffi==1.16.0
# via soundfile
cftime==1.6.3
# via netcdf4
charset-normalizer==3.3.2
# via requests
click==8.1.7
# via batdetect2
# via lightning
# via lightning-cloud
# via uvicorn
contourpy==1.1.1
# via matplotlib
croniter==1.4.1
# via lightning
cycler==0.12.1
# via matplotlib
cython==3.0.10
# via soundevent
dateutils==0.6.12
# via lightning
decorator==5.1.1
# via librosa
deepdiff==6.7.1
# via lightning
dnspython==2.6.1
# via email-validator
docker==6.1.3
# via lightning
docstring-parser==0.16
# via jsonargparse
editor==1.6.6
# via inquirer
email-validator==2.1.1
# via soundevent
exceptiongroup==1.2.1
# via anyio
fastapi==0.110.2
# via lightning
# via lightning-cloud
filelock==3.13.4
# via torch
# via triton
fonttools==4.51.0
# via matplotlib
frozenlist==1.4.1
# via aiohttp
# via aiosignal
fsspec==2023.12.2
# via lightning
# via lightning-fabric
# via pytorch-lightning
# via s3fs
# via torch
grpcio==1.62.2
# via tensorboard
h11==0.14.0
# via uvicorn
hydra-core==1.3.2
# via lightning
idna==3.7
# via anyio
# via email-validator
# via requests
# via yarl
importlib-metadata==7.1.0
# via markdown
importlib-resources==6.4.0
# via matplotlib
# via typeshed-client
inquirer==3.2.4
# via lightning
jinja2==3.1.3
# via lightning
# via torch
jmespath==1.0.1
# via boto3
# via botocore
joblib==1.4.0
# via librosa
# via scikit-learn
jsonargparse==4.28.0
# via lightning
kiwisolver==1.4.5
# via matplotlib
lazy-loader==0.4
# via librosa
librosa==0.10.1
# via batdetect2
lightning==2.2.2
# via batdetect2
lightning-api-access==0.0.5
# via lightning
lightning-cloud==0.5.65
# via lightning
lightning-fabric==2.2.2
# via lightning
lightning-utilities==0.11.2
# via lightning
# via lightning-fabric
# via pytorch-lightning
# via torchmetrics
llvmlite==0.41.1
# via numba
markdown==3.6
# via tensorboard
markdown-it-py==3.0.0
# via rich
markupsafe==2.1.5
# via jinja2
# via werkzeug
matplotlib==3.7.5
# via batdetect2
# via lightning
# via soundevent
mdurl==0.1.2
# via markdown-it-py
mpmath==1.3.0
# via sympy
msgpack==1.0.8
# via librosa
multidict==6.0.5
# via aiohttp
# via yarl
netcdf4==1.6.5
# via batdetect2
networkx==3.1
# via torch
numba==0.58.1
# via librosa
numpy==1.24.4
# via batdetect2
# via cftime
# via contourpy
# via librosa
# via lightning
# via lightning-fabric
# via matplotlib
# via netcdf4
# via numba
# via onnx
# via pandas
# via pytorch-lightning
# via scikit-learn
# via scipy
# via shapely
# via soxr
# via tensorboard
# via tensorboardx
# via torchmetrics
# via torchvision
# via xarray
nvidia-cublas-cu12==12.1.3.1
# via nvidia-cudnn-cu12
# via nvidia-cusolver-cu12
# via torch
nvidia-cuda-cupti-cu12==12.1.105
# via torch
nvidia-cuda-nvrtc-cu12==12.1.105
# via torch
nvidia-cuda-runtime-cu12==12.1.105
# via torch
nvidia-cudnn-cu12==8.9.2.26
# via torch
nvidia-cufft-cu12==11.0.2.54
# via torch
nvidia-curand-cu12==10.3.2.106
# via torch
nvidia-cusolver-cu12==11.4.5.107
# via torch
nvidia-cusparse-cu12==12.1.0.106
# via nvidia-cusolver-cu12
# via torch
nvidia-nccl-cu12==2.19.3
# via torch
nvidia-nvjitlink-cu12==12.4.127
# via nvidia-cusolver-cu12
# via nvidia-cusparse-cu12
nvidia-nvtx-cu12==12.1.105
# via torch
omegaconf==2.3.0
# via hydra-core
# via lightning
onnx==1.16.0
# via batdetect2
ordered-set==4.1.0
# via deepdiff
packaging==24.0
# via docker
# via hydra-core
# via lazy-loader
# via lightning
# via lightning-fabric
# via lightning-utilities
# via matplotlib
# via pooch
# via pytorch-lightning
# via tensorboardx
# via torchmetrics
# via xarray
pandas==2.0.3
# via batdetect2
# via xarray
pillow==10.3.0
# via matplotlib
# via torchvision
platformdirs==4.2.0
# via pooch
pooch==1.8.1
# via librosa
protobuf==5.26.1
# via onnx
# via tensorboard
# via tensorboardx
psutil==5.9.8
# via lightning
pycparser==2.22
# via cffi
pydantic==2.7.0
# via fastapi
# via lightning
# via soundevent
pydantic-core==2.18.1
# via pydantic
pygments==2.17.2
# via rich
pyjwt==2.8.0
# via lightning-cloud
pyparsing==3.1.2
# via matplotlib
python-dateutil==2.9.0.post0
# via arrow
# via botocore
# via croniter
# via dateutils
# via matplotlib
# via pandas
python-multipart==0.0.9
# via lightning
# via lightning-cloud
pytorch-lightning==2.2.2
# via batdetect2
# via lightning
pytz==2024.1
# via dateutils
# via pandas
pyyaml==6.0.1
# via jsonargparse
# via lightning
# via omegaconf
# via pytorch-lightning
readchar==4.0.6
# via inquirer
redis==5.0.4
# via lightning
requests==2.31.0
# via docker
# via fsspec
# via lightning
# via lightning-cloud
# via pooch
rich==13.7.1
# via lightning
# via lightning-cloud
runs==1.2.2
# via editor
s3fs==2023.12.2
# via lightning
s3transfer==0.10.1
# via boto3
scikit-learn==1.3.2
# via batdetect2
# via librosa
scipy==1.10.1
# via batdetect2
# via librosa
# via scikit-learn
# via soundevent
setuptools==69.5.1
# via lightning-utilities
# via readchar
# via tensorboard
shapely==2.0.3
# via soundevent
six==1.16.0
# via blessed
# via lightning-cloud
# via python-dateutil
# via tensorboard
sniffio==1.3.1
# via anyio
soundevent==1.3.5
# via batdetect2
soundfile==0.12.1
# via librosa
# via soundevent
soupsieve==2.5
# via beautifulsoup4
soxr==0.3.7
# via librosa
starlette==0.37.2
# via fastapi
# via lightning
sympy==1.12
# via torch
tensorboard==2.16.2
# via batdetect2
tensorboard-data-server==0.7.2
# via tensorboard
tensorboardx==2.6.2.2
# via lightning
threadpoolctl==3.4.0
# via scikit-learn
torch==2.2.2
# via batdetect2
# via lightning
# via lightning-fabric
# via pytorch-lightning
# via torchaudio
# via torchmetrics
# via torchvision
torchaudio==2.2.2
# via batdetect2
torchmetrics==1.3.2
# via lightning
# via pytorch-lightning
torchvision==0.17.2
# via batdetect2
tqdm==4.66.2
# via batdetect2
# via lightning
# via pytorch-lightning
traitlets==5.14.3
# via lightning
triton==2.2.0
# via torch
types-python-dateutil==2.9.0.20240316
# via arrow
typeshed-client==2.5.1
# via jsonargparse
typing-extensions==4.11.0
# via aioitertools
# via anyio
# via fastapi
# via jsonargparse
# via librosa
# via lightning
# via lightning-fabric
# via lightning-utilities
# via pydantic
# via pydantic-core
# via pytorch-lightning
# via starlette
# via torch
# via typeshed-client
# via uvicorn
tzdata==2024.1
# via pandas
urllib3==1.26.18
# via botocore
# via docker
# via lightning
# via lightning-cloud
# via requests
uvicorn==0.29.0
# via lightning
# via lightning-cloud
wcwidth==0.2.13
# via blessed
websocket-client==1.7.0
# via docker
# via lightning
# via lightning-cloud
websockets==11.0.3
# via lightning
werkzeug==3.0.2
# via tensorboard
wrapt==1.16.0
# via aiobotocore
xarray==2023.1.0
# via cf-xarray
# via soundevent
xmod==1.8.1
# via editor
# via runs
yarl==1.9.4
# via aiohttp
zipp==3.18.1
# via importlib-metadata
# via importlib-resources

View File

@ -1,10 +0,0 @@
librosa==0.9.2
matplotlib==3.6.2
numpy==1.23.4
pandas==1.5.2
scikit_learn==1.2.0
scipy==1.9.3
torch==1.13.0
torchaudio==0.13.0
torchvision==0.14.0
click

View File

View File

@ -0,0 +1,19 @@
"""Test suite for loading batdetect annotations."""
from pathlib import Path
from soundevent import data
from batdetect2.data.compat import load_annotation_project
ROOT_DIR = Path(__file__).parent.parent.parent
def test_load_example_annotation_project():
path = ROOT_DIR / "example_data" / "anns"
audio_dir = ROOT_DIR / "example_data" / "audio"
project = load_annotation_project(path, audio_dir=audio_dir)
assert isinstance(project, data.AnnotationProject)
assert project.name == str(path)
assert len(project.clip_annotations) == 3
assert len(project.tasks) == 3

View File

@ -133,17 +133,16 @@ def test_compute_max_power_bb(max_power: int):
audio = np.zeros((int(duration * samplerate),))
# Add a signal during the time and frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] = 0.5 * librosa.tone(
audio[int(start_time * samplerate) : int(end_time * samplerate)] = (
0.5
* librosa.tone(
max_power, sr=samplerate, duration=end_time - start_time
)
)
# Add a more powerful signal outside frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] += 2 * librosa.tone(
80_000, sr=samplerate, duration=end_time - start_time
audio[int(start_time * samplerate) : int(end_time * samplerate)] += (
2 * librosa.tone(80_000, sr=samplerate, duration=end_time - start_time)
)
params = api.get_config(
@ -152,7 +151,7 @@ def test_compute_max_power_bb(max_power: int):
target_samp_rate=samplerate,
)
spec = au.generate_spectrogram(
spec, _ = au.generate_spectrogram(
audio,
samplerate,
params,
@ -221,18 +220,18 @@ def test_compute_max_power():
audio = np.zeros((int(duration * samplerate),))
# Add a signal during the time and frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] = 0.5 * librosa.tone(
3_500, sr=samplerate, duration=end_time - start_time
audio[int(start_time * samplerate) : int(end_time * samplerate)] = (
0.5
* librosa.tone(3_500, sr=samplerate, duration=end_time - start_time)
)
# Add a more powerful signal outside frequency range of interest
audio[
int(start_time * samplerate) : int(end_time * samplerate)
] += 2 * librosa.tone(
audio[int(start_time * samplerate) : int(end_time * samplerate)] += (
2
* librosa.tone(
max_power, sr=samplerate, duration=end_time - start_time
)
)
params = api.get_config(
min_freq=min_freq,
@ -240,7 +239,7 @@ def test_compute_max_power():
target_samp_rate=samplerate,
)
spec = au.generate_spectrogram(
spec, _ = au.generate_spectrogram(
audio,
samplerate,
params,

View File

View File

@ -0,0 +1,120 @@
from pathlib import Path
import numpy as np
import pytest
from soundevent import data
from batdetect2.data import preprocessing
from batdetect2.utils import audio_utils
ROOT_DIR = Path(__file__).parent.parent.parent
EXAMPLE_AUDIO = ROOT_DIR / "example_data" / "audio"
TEST_AUDIO = ROOT_DIR / "tests" / "data"
TEST_FILES = [
EXAMPLE_AUDIO / "20170701_213954-MYOMYS-LR_0_0.5.wav",
EXAMPLE_AUDIO / "20180530_213516-EPTSER-LR_0_0.5.wav",
EXAMPLE_AUDIO / "20180627_215323-RHIFER-LR_0_0.5.wav",
TEST_AUDIO / "20230322_172000_selec2.wav",
]
@pytest.mark.parametrize("audio_file", TEST_FILES)
@pytest.mark.parametrize("scale", [True, False])
def test_audio_loading_hasnt_changed(
audio_file,
scale,
):
time_expansion = 1
target_sampling_rate = 256_000
recording = data.Recording.from_file(
audio_file,
time_expansion=time_expansion,
)
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
_, audio_original = audio_utils.load_audio(
audio_file,
time_expansion,
target_samp_rate=target_sampling_rate,
scale=scale,
)
audio_new = preprocessing.load_clip_audio(
clip,
target_sampling_rate=target_sampling_rate,
scale=scale,
dtype=np.float32,
)
assert audio_original.shape == audio_new.shape
assert audio_original.dtype == audio_new.dtype
assert np.isclose(audio_original, audio_new.data).all()
@pytest.mark.parametrize("audio_file", TEST_FILES)
@pytest.mark.parametrize("spec_scale", ["log", "pcen", "amplitude"])
@pytest.mark.parametrize("denoise_spec_avg", [True, False])
@pytest.mark.parametrize("max_scale_spec", [True, False])
@pytest.mark.parametrize("fft_win_length", [512 / 256_000, 1024 / 256_000])
def test_spectrogram_generation_hasnt_changed(
audio_file,
spec_scale,
denoise_spec_avg,
max_scale_spec,
fft_win_length,
):
time_expansion = 1
target_sampling_rate = 256_000
min_freq = 10_000
max_freq = 120_000
fft_overlap = 0.75
recording = data.Recording.from_file(
audio_file,
time_expansion=time_expansion,
)
clip = data.Clip(
recording=recording,
start_time=0,
end_time=recording.duration,
)
audio = preprocessing.load_clip_audio(
clip,
target_sampling_rate=target_sampling_rate,
)
spec_original, _ = audio_utils.generate_spectrogram(
audio.data,
sampling_rate=target_sampling_rate,
params=dict(
fft_win_length=fft_win_length,
fft_overlap=fft_overlap,
max_freq=max_freq,
min_freq=min_freq,
spec_scale=spec_scale,
denoise_spec_avg=denoise_spec_avg,
max_scale_spec=max_scale_spec,
),
)
new_spec = preprocessing.compute_spectrogram(
audio,
fft_win_length=fft_win_length,
fft_overlap=fft_overlap,
max_freq=max_freq,
min_freq=min_freq,
spec_scale=spec_scale,
denoise_spec_avg=denoise_spec_avg,
max_scale_spec=max_scale_spec,
dtype=np.float32,
)
assert spec_original.shape == new_spec.shape
assert spec_original.dtype == new_spec.dtype
# NOTE: The original spectrogram is flipped vertically
assert np.isclose(spec_original, np.flipud(new_spec.data)).all()