mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
WIP updating training code
This commit is contained in:
parent
343bc5f87c
commit
c66d14b7c7
11
batdetect2/cli/__init__.py
Normal file
11
batdetect2/cli/__init__.py
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
from batdetect2.cli.base import cli
|
||||||
|
from batdetect2.cli.compat import detect
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"cli",
|
||||||
|
"detect",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
cli()
|
27
batdetect2/cli/base.py
Normal file
27
batdetect2/cli/base.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""BatDetect2 command line interface."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"cli",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
|
||||||
|
|
||||||
|
INFO_STR = """
|
||||||
|
BatDetect2 - Detection and Classification
|
||||||
|
Assumes audio files are mono, not stereo.
|
||||||
|
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||||
|
Input files should be short in duration e.g. < 30 seconds.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@click.group()
|
||||||
|
def cli():
|
||||||
|
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||||
|
click.echo(INFO_STR)
|
@ -1,5 +1,4 @@
|
|||||||
"""BatDetect2 command line interface."""
|
"""BatDetect2 command line interface."""
|
||||||
import os
|
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
@ -8,21 +7,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
|||||||
from batdetect2.types import ProcessingConfiguration
|
from batdetect2.types import ProcessingConfiguration
|
||||||
from batdetect2.utils.detector_utils import save_results_to_file
|
from batdetect2.utils.detector_utils import save_results_to_file
|
||||||
|
|
||||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
from batdetect2.cli.base import cli
|
||||||
|
|
||||||
|
|
||||||
INFO_STR = """
|
|
||||||
BatDetect2 - Detection and Classification
|
|
||||||
Assumes audio files are mono, not stereo.
|
|
||||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
|
||||||
Input files should be short in duration e.g. < 30 seconds.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
@click.group()
|
|
||||||
def cli():
|
|
||||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
|
||||||
click.echo(INFO_STR)
|
|
||||||
|
|
||||||
|
|
||||||
@cli.command()
|
@cli.command()
|
||||||
@ -147,7 +132,3 @@ def print_config(config: ProcessingConfiguration):
|
|||||||
click.echo("\nProcessing Configuration:")
|
click.echo("\nProcessing Configuration:")
|
||||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||||
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
cli()
|
|
0
batdetect2/data/__init__.py
Normal file
0
batdetect2/data/__init__.py
Normal file
304
batdetect2/data/augmentations.py
Normal file
304
batdetect2/data/augmentations.py
Normal file
@ -0,0 +1,304 @@
|
|||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
ClipAugmentation = Callable[[data.ClipAnnotation], data.ClipAnnotation]
|
||||||
|
AudioAugmentation = Callable[
|
||||||
|
[xr.DataArray, data.ClipAnnotation],
|
||||||
|
Tuple[xr.DataArray, data.ClipAnnotation],
|
||||||
|
]
|
||||||
|
SpecAugmentation = Callable[
|
||||||
|
[xr.DataArray, data.ClipAnnotation],
|
||||||
|
Tuple[xr.DataArray, data.ClipAnnotation],
|
||||||
|
]
|
||||||
|
|
||||||
|
ClipProvider = Callable[
|
||||||
|
[data.ClipAnnotation], Tuple[xr.DataArray, data.ClipAnnotation]
|
||||||
|
]
|
||||||
|
"""A function that provides some clip and its annotation.
|
||||||
|
|
||||||
|
Usually this function loads a random clip from a dataset. Takes
|
||||||
|
as input a clip annotation that can be used to filter the clips
|
||||||
|
to load (in case you want to avoid loading the same clip multiple times).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
AUGMENTATION_PROBABILITY = 0.2
|
||||||
|
MAX_DELAY = 0.005
|
||||||
|
STRETCH_SQUEEZE_DELTA = 0.04
|
||||||
|
MASK_MAX_TIME_PERC: float = 0.05
|
||||||
|
MASK_MAX_FREQ_PERC: float = 0.10
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_apply(
|
||||||
|
augmentation: Callable,
|
||||||
|
prob: float = AUGMENTATION_PROBABILITY,
|
||||||
|
) -> Callable:
|
||||||
|
"""Apply an augmentation with a given probability."""
|
||||||
|
|
||||||
|
@wraps(augmentation)
|
||||||
|
def _augmentation(x):
|
||||||
|
if np.random.rand() > prob:
|
||||||
|
return x
|
||||||
|
return augmentation(x)
|
||||||
|
|
||||||
|
return _augmentation
|
||||||
|
|
||||||
|
|
||||||
|
def select_random_subclip(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
duration: Optional[float] = None,
|
||||||
|
proportion: float = 0.9,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
"""Select a random subclip from a clip."""
|
||||||
|
clip = clip_annotation.clip
|
||||||
|
|
||||||
|
if duration is None:
|
||||||
|
clip_duration = clip.end_time - clip.start_time
|
||||||
|
duration = clip_duration * proportion
|
||||||
|
|
||||||
|
start_time = np.random.uniform(clip.start_time, clip.end_time - duration)
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
clip=clip.model_copy(
|
||||||
|
update=dict(
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=start_time + duration,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def combine_audio(
|
||||||
|
audio1: xr.DataArray,
|
||||||
|
audio2: xr.DataArray,
|
||||||
|
alpha: Optional[float] = None,
|
||||||
|
min_alpha: float = 0.3,
|
||||||
|
max_alpha: float = 0.7,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Combine two audio clips."""
|
||||||
|
|
||||||
|
if alpha is None:
|
||||||
|
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||||
|
|
||||||
|
return alpha * audio1 + (1 - alpha) * audio2.data
|
||||||
|
|
||||||
|
|
||||||
|
def random_mix(
|
||||||
|
audio: xr.DataArray,
|
||||||
|
clip: data.ClipAnnotation,
|
||||||
|
provider: Optional[ClipProvider] = None,
|
||||||
|
alpha: Optional[float] = None,
|
||||||
|
min_alpha: float = 0.3,
|
||||||
|
max_alpha: float = 0.7,
|
||||||
|
join_annotations: bool = True,
|
||||||
|
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
"""Mix two audio clips."""
|
||||||
|
if provider is None:
|
||||||
|
raise ValueError("No audio provider given.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
other_audio, other_clip = provider(clip)
|
||||||
|
except (StopIteration, ValueError):
|
||||||
|
raise ValueError("No more audio sources available.")
|
||||||
|
|
||||||
|
new_audio = combine_audio(
|
||||||
|
audio,
|
||||||
|
other_audio,
|
||||||
|
alpha=alpha,
|
||||||
|
min_alpha=min_alpha,
|
||||||
|
max_alpha=max_alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
if join_annotations:
|
||||||
|
clip = clip.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=clip.sound_events + other_clip.sound_events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return new_audio, clip
|
||||||
|
|
||||||
|
|
||||||
|
def add_echo(
|
||||||
|
audio: xr.DataArray,
|
||||||
|
clip: data.ClipAnnotation,
|
||||||
|
delay: Optional[float] = None,
|
||||||
|
alpha: Optional[float] = None,
|
||||||
|
min_alpha: float = 0.0,
|
||||||
|
max_alpha: float = 1.0,
|
||||||
|
max_delay: float = MAX_DELAY,
|
||||||
|
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
"""Add a delay to the audio."""
|
||||||
|
if delay is None:
|
||||||
|
delay = np.random.uniform(0, max_delay)
|
||||||
|
|
||||||
|
if alpha is None:
|
||||||
|
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||||
|
|
||||||
|
samplerate = audio.attrs["samplerate"]
|
||||||
|
offset = int(delay * samplerate)
|
||||||
|
|
||||||
|
# NOTE: We use the copy method to avoid modifying the original audio
|
||||||
|
# data.
|
||||||
|
new_audio = audio.copy()
|
||||||
|
new_audio[offset:] += alpha * audio.data[:-offset]
|
||||||
|
return new_audio, clip
|
||||||
|
|
||||||
|
|
||||||
|
def scale_volume(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
clip: data.ClipAnnotation,
|
||||||
|
factor: Optional[float] = None,
|
||||||
|
max_scaling: float = 2,
|
||||||
|
min_scaling: float = 0,
|
||||||
|
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
"""Scale the volume of a spectrogram."""
|
||||||
|
if factor is None:
|
||||||
|
factor = np.random.uniform(min_scaling, max_scaling)
|
||||||
|
|
||||||
|
return spec * factor, clip
|
||||||
|
|
||||||
|
|
||||||
|
def scale_sound_event_annotation(
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
time_factor: float = 1,
|
||||||
|
frequency_factor: float = 1,
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
sound_event = sound_event_annotation.sound_event
|
||||||
|
geometry = sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return sound_event_annotation
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
new_geometry = data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
start_time * time_factor,
|
||||||
|
low_freq * frequency_factor,
|
||||||
|
end_time * time_factor,
|
||||||
|
high_freq * frequency_factor,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
return sound_event_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_event=sound_event.model_copy(
|
||||||
|
update=dict(
|
||||||
|
geometry=new_geometry,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def warp_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
clip: data.ClipAnnotation,
|
||||||
|
factor: Optional[float] = None,
|
||||||
|
delta: float = STRETCH_SQUEEZE_DELTA,
|
||||||
|
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
"""Warp a spectrogram."""
|
||||||
|
if factor is None:
|
||||||
|
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||||
|
|
||||||
|
start_time = clip.clip.start_time
|
||||||
|
end_time = clip.clip.end_time
|
||||||
|
duration = end_time - start_time
|
||||||
|
new_time = np.linspace(
|
||||||
|
start_time,
|
||||||
|
start_time + duration * factor,
|
||||||
|
spec.time.size,
|
||||||
|
)
|
||||||
|
|
||||||
|
scaled_spec = spec.interp(
|
||||||
|
time=new_time,
|
||||||
|
method="linear",
|
||||||
|
kwargs={"fill_value": 0},
|
||||||
|
)
|
||||||
|
scaled_spec.coords["time"] = spec.time
|
||||||
|
|
||||||
|
scaled_clip = clip.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
scale_sound_event_annotation(
|
||||||
|
sound_event_annotation,
|
||||||
|
time_factor=1 / factor,
|
||||||
|
)
|
||||||
|
for sound_event_annotation in clip.sound_events
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return scaled_spec, scaled_clip
|
||||||
|
|
||||||
|
|
||||||
|
def mask_axis(
|
||||||
|
array: xr.DataArray,
|
||||||
|
axis: str,
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
mask_value: float = 0,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if axis not in array.dims:
|
||||||
|
raise ValueError(f"Axis {axis} not found in array")
|
||||||
|
|
||||||
|
coord = array[axis]
|
||||||
|
return array.where((coord < start) | (coord > end), mask_value)
|
||||||
|
|
||||||
|
|
||||||
|
def mask_time(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
clip: data.ClipAnnotation,
|
||||||
|
max_time_mask: float = MASK_MAX_TIME_PERC,
|
||||||
|
max_num_masks: int = 3,
|
||||||
|
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
"""Mask a random section of the time axis."""
|
||||||
|
|
||||||
|
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||||
|
for _ in range(num_masks):
|
||||||
|
mask_size = np.random.uniform(0, max_time_mask)
|
||||||
|
start = np.random.uniform(0, spec.time[-1] - mask_size)
|
||||||
|
end = start + mask_size
|
||||||
|
spec = mask_axis(spec, "time", start, end)
|
||||||
|
|
||||||
|
return spec, clip
|
||||||
|
|
||||||
|
|
||||||
|
def mask_frequency(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
clip: data.ClipAnnotation,
|
||||||
|
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
||||||
|
max_num_masks: int = 3,
|
||||||
|
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
"""Mask a random section of the frequency axis."""
|
||||||
|
|
||||||
|
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||||
|
for _ in range(num_masks):
|
||||||
|
mask_size = np.random.uniform(0, max_freq_mask)
|
||||||
|
start = np.random.uniform(0, spec.frequency[-1] - mask_size)
|
||||||
|
end = start + mask_size
|
||||||
|
spec = mask_axis(spec, "frequency", start, end)
|
||||||
|
|
||||||
|
return spec, clip
|
||||||
|
|
||||||
|
|
||||||
|
CLIP_AUGMENTATIONS: List[ClipAugmentation] = [
|
||||||
|
select_random_subclip,
|
||||||
|
]
|
||||||
|
|
||||||
|
AUDIO_AUGMENTATIONS: List[AudioAugmentation] = [
|
||||||
|
add_echo,
|
||||||
|
random_mix,
|
||||||
|
]
|
||||||
|
|
||||||
|
SPEC_AUGMENTATIONS: List[SpecAugmentation] = [
|
||||||
|
scale_volume,
|
||||||
|
warp_spectrogram,
|
||||||
|
mask_time,
|
||||||
|
mask_frequency,
|
||||||
|
]
|
332
batdetect2/data/compat.py
Normal file
332
batdetect2/data/compat.py
Normal file
@ -0,0 +1,332 @@
|
|||||||
|
"""Compatibility functions between old and new data structures."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2 import types
|
||||||
|
from batdetect2.data.labels import LabelFn
|
||||||
|
|
||||||
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"convert_to_annotation_group",
|
||||||
|
"load_annotation_project",
|
||||||
|
]
|
||||||
|
|
||||||
|
SPECIES_TAG_KEY = "species"
|
||||||
|
ECHOLOCATION_EVENT = "Echolocation"
|
||||||
|
UNKNOWN_CLASS = "__UNKNOWN__"
|
||||||
|
|
||||||
|
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||||
|
|
||||||
|
|
||||||
|
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
|
||||||
|
ClassFn = Callable[[data.Recording], int]
|
||||||
|
|
||||||
|
IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
||||||
|
|
||||||
|
|
||||||
|
def get_recording_class_name(recording: data.Recording) -> str:
|
||||||
|
"""Get the class name for a recording."""
|
||||||
|
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
|
||||||
|
if tag is None:
|
||||||
|
return UNKNOWN_CLASS
|
||||||
|
return tag.value
|
||||||
|
|
||||||
|
|
||||||
|
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
|
||||||
|
"""Get the notes for a ClipAnnotation."""
|
||||||
|
all_notes = [
|
||||||
|
*annotation.notes,
|
||||||
|
*annotation.clip.recording.notes,
|
||||||
|
]
|
||||||
|
messages = [note.message for note in all_notes if note.message is not None]
|
||||||
|
return "\n".join(messages)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_annotation_group(
|
||||||
|
annotation: data.ClipAnnotation,
|
||||||
|
label_fn: LabelFn = lambda _: None,
|
||||||
|
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||||
|
class_fn: ClassFn = lambda _: 0,
|
||||||
|
individual_fn: IndividualFn = lambda _: 0,
|
||||||
|
) -> types.AudioLoaderAnnotationGroup:
|
||||||
|
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
|
||||||
|
recording = annotation.clip.recording
|
||||||
|
|
||||||
|
start_times = []
|
||||||
|
end_times = []
|
||||||
|
low_freqs = []
|
||||||
|
high_freqs = []
|
||||||
|
class_ids = []
|
||||||
|
x_inds = []
|
||||||
|
y_inds = []
|
||||||
|
individual_ids = []
|
||||||
|
annotations: List[types.Annotation] = []
|
||||||
|
class_id_file = class_fn(recording)
|
||||||
|
|
||||||
|
for sound_event in annotation.sound_events:
|
||||||
|
geometry = sound_event.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||||
|
class_id = label_fn(sound_event) or -1
|
||||||
|
event = event_fn(sound_event)
|
||||||
|
individual_id = individual_fn(sound_event) or -1
|
||||||
|
|
||||||
|
start_times.append(start_time)
|
||||||
|
end_times.append(end_time)
|
||||||
|
low_freqs.append(low_freq)
|
||||||
|
high_freqs.append(high_freq)
|
||||||
|
class_ids.append(class_id)
|
||||||
|
individual_ids.append(individual_id)
|
||||||
|
|
||||||
|
# NOTE: This will be computed later so we just put a placeholder
|
||||||
|
# here for now.
|
||||||
|
x_inds.append(0)
|
||||||
|
y_inds.append(0)
|
||||||
|
|
||||||
|
annotations.append(
|
||||||
|
{
|
||||||
|
"start_time": start_time,
|
||||||
|
"end_time": end_time,
|
||||||
|
"low_freq": low_freq,
|
||||||
|
"high_freq": high_freq,
|
||||||
|
"class_prob": 1.0,
|
||||||
|
"det_prob": 1.0,
|
||||||
|
"individual": "0",
|
||||||
|
"event": event,
|
||||||
|
"class_id": class_id, # type: ignore
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": str(recording.path),
|
||||||
|
"duration": recording.duration,
|
||||||
|
"issues": False,
|
||||||
|
"file_path": str(recording.path),
|
||||||
|
"time_exp": recording.time_expansion,
|
||||||
|
"class_name": get_recording_class_name(recording),
|
||||||
|
"notes": get_annotation_notes(annotation),
|
||||||
|
"annotated": True,
|
||||||
|
"start_times": np.array(start_times),
|
||||||
|
"end_times": np.array(end_times),
|
||||||
|
"low_freqs": np.array(low_freqs),
|
||||||
|
"high_freqs": np.array(high_freqs),
|
||||||
|
"class_ids": np.array(class_ids),
|
||||||
|
"x_inds": np.array(x_inds),
|
||||||
|
"y_inds": np.array(y_inds),
|
||||||
|
"individual_ids": np.array(individual_ids),
|
||||||
|
"annotation": annotations,
|
||||||
|
"class_id_file": class_id_file,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Annotation(BaseModel):
|
||||||
|
"""Annotation class to hold batdetect annotations."""
|
||||||
|
|
||||||
|
label: str = Field(alias="class")
|
||||||
|
event: str
|
||||||
|
individual: int = 0
|
||||||
|
|
||||||
|
start_time: float
|
||||||
|
end_time: float
|
||||||
|
low_freq: float
|
||||||
|
high_freq: float
|
||||||
|
|
||||||
|
|
||||||
|
class FileAnnotation(BaseModel):
|
||||||
|
"""FileAnnotation class to hold batdetect annotations for a file."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
duration: float
|
||||||
|
time_exp: float = 1
|
||||||
|
|
||||||
|
label: str = Field(alias="class_name")
|
||||||
|
|
||||||
|
annotation: List[Annotation]
|
||||||
|
|
||||||
|
annotated: bool = False
|
||||||
|
issues: bool = False
|
||||||
|
notes: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_annotation(path: PathLike) -> FileAnnotation:
|
||||||
|
"""Load annotation from batdetect format."""
|
||||||
|
path = Path(path)
|
||||||
|
return FileAnnotation.model_validate_json(path.read_text())
|
||||||
|
|
||||||
|
|
||||||
|
def annotation_to_sound_event(
|
||||||
|
annotation: Annotation,
|
||||||
|
recording: data.Recording,
|
||||||
|
label_key: str = "class",
|
||||||
|
event_key: str = "event",
|
||||||
|
individual_key: str = "individual",
|
||||||
|
) -> data.SoundEventAnnotation:
|
||||||
|
"""Convert annotation to sound event annotation."""
|
||||||
|
sound_event = data.SoundEvent(
|
||||||
|
uuid=uuid.uuid5(
|
||||||
|
NAMESPACE,
|
||||||
|
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
||||||
|
),
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
annotation.start_time,
|
||||||
|
annotation.low_freq,
|
||||||
|
annotation.end_time,
|
||||||
|
annotation.high_freq,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.SoundEventAnnotation(
|
||||||
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key=label_key, value=annotation.label),
|
||||||
|
data.Tag(key=event_key, value=annotation.event),
|
||||||
|
data.Tag(key=individual_key, value=str(annotation.individual)),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def file_annotation_to_clip(
|
||||||
|
file_annotation: FileAnnotation,
|
||||||
|
audio_dir: PathLike = Path.cwd(),
|
||||||
|
) -> data.Clip:
|
||||||
|
"""Convert file annotation to recording."""
|
||||||
|
full_path = Path(audio_dir) / file_annotation.id
|
||||||
|
|
||||||
|
if not full_path.exists():
|
||||||
|
raise FileNotFoundError(f"File {full_path} not found.")
|
||||||
|
|
||||||
|
recording = data.Recording.from_file(
|
||||||
|
full_path,
|
||||||
|
time_expansion=file_annotation.time_exp,
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.Clip(
|
||||||
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def file_annotation_to_clip_annotation(
|
||||||
|
file_annotation: FileAnnotation,
|
||||||
|
clip: data.Clip,
|
||||||
|
label_key: str = "class",
|
||||||
|
event_key: str = "event",
|
||||||
|
individual_key: str = "individual",
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
"""Convert file annotation to clip annotation."""
|
||||||
|
notes = []
|
||||||
|
if file_annotation.notes:
|
||||||
|
notes.append(data.Note(message=file_annotation.notes))
|
||||||
|
|
||||||
|
return data.ClipAnnotation(
|
||||||
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||||
|
clip=clip,
|
||||||
|
notes=notes,
|
||||||
|
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||||
|
sound_events=[
|
||||||
|
annotation_to_sound_event(
|
||||||
|
annotation,
|
||||||
|
clip.recording,
|
||||||
|
label_key=label_key,
|
||||||
|
event_key=event_key,
|
||||||
|
individual_key=individual_key,
|
||||||
|
)
|
||||||
|
for annotation in file_annotation.annotation
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def file_annotation_to_annotation_task(
|
||||||
|
file_annotation: FileAnnotation,
|
||||||
|
clip: data.Clip,
|
||||||
|
) -> data.AnnotationTask:
|
||||||
|
status_badges = []
|
||||||
|
|
||||||
|
if file_annotation.issues:
|
||||||
|
status_badges.append(
|
||||||
|
data.StatusBadge(state=data.AnnotationState.rejected)
|
||||||
|
)
|
||||||
|
elif file_annotation.annotated:
|
||||||
|
status_badges.append(
|
||||||
|
data.StatusBadge(state=data.AnnotationState.completed)
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.AnnotationTask(
|
||||||
|
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
|
||||||
|
clip=clip,
|
||||||
|
status_badges=status_badges,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_file_annotations(path: PathLike) -> List[Path]:
|
||||||
|
"""List all annotations in a directory."""
|
||||||
|
path = Path(path)
|
||||||
|
return [file for file in path.glob("*.json")]
|
||||||
|
|
||||||
|
|
||||||
|
def load_annotation_project(
|
||||||
|
path: PathLike,
|
||||||
|
name: Optional[str] = None,
|
||||||
|
audio_dir: PathLike = Path.cwd(),
|
||||||
|
) -> data.AnnotationProject:
|
||||||
|
"""Convert annotations to annotation project."""
|
||||||
|
paths = list_file_annotations(path)
|
||||||
|
|
||||||
|
if name is None:
|
||||||
|
name = str(path)
|
||||||
|
|
||||||
|
annotations = []
|
||||||
|
tasks = []
|
||||||
|
|
||||||
|
for p in paths:
|
||||||
|
try:
|
||||||
|
file_annotation = load_file_annotation(p)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
clip = file_annotation_to_clip(
|
||||||
|
file_annotation,
|
||||||
|
audio_dir=audio_dir,
|
||||||
|
)
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotations.append(
|
||||||
|
file_annotation_to_clip_annotation(
|
||||||
|
file_annotation,
|
||||||
|
clip,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
tasks.append(
|
||||||
|
file_annotation_to_annotation_task(
|
||||||
|
file_annotation,
|
||||||
|
clip,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return data.AnnotationProject(
|
||||||
|
name=name,
|
||||||
|
clip_annotations=annotations,
|
||||||
|
tasks=tasks,
|
||||||
|
)
|
58
batdetect2/data/datasets.py
Normal file
58
batdetect2/data/datasets.py
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
from typing import Callable, Generic, Iterable, List, TypeVar
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ClipAnnotationDataset",
|
||||||
|
"ClipDataset",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
E = TypeVar("E")
|
||||||
|
|
||||||
|
|
||||||
|
class ClipAnnotationDataset(Dataset, Generic[E]):
|
||||||
|
|
||||||
|
clip_annotations: List[data.ClipAnnotation]
|
||||||
|
|
||||||
|
transform: Callable[[data.ClipAnnotation], E]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_annotations: Iterable[data.ClipAnnotation],
|
||||||
|
transform: Callable[[data.ClipAnnotation], E],
|
||||||
|
name: str = "ClipAnnotationDataset",
|
||||||
|
):
|
||||||
|
self.clip_annotations = list(clip_annotations)
|
||||||
|
self.transform = transform
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.clip_annotations)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> E:
|
||||||
|
return self.transform(self.clip_annotations[idx])
|
||||||
|
|
||||||
|
|
||||||
|
class ClipDataset(Dataset, Generic[E]):
|
||||||
|
|
||||||
|
clips: List[data.Clip]
|
||||||
|
|
||||||
|
transform: Callable[[data.Clip], E]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clips: Iterable[data.Clip],
|
||||||
|
transform: Callable[[data.Clip], E],
|
||||||
|
name: str = "ClipDataset",
|
||||||
|
):
|
||||||
|
self.clips = list(clips)
|
||||||
|
self.transform = transform
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
return len(self.clips)
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int) -> E:
|
||||||
|
return self.transform(self.clips[idx])
|
231
batdetect2/data/labels.py
Normal file
231
batdetect2/data/labels.py
Normal file
@ -0,0 +1,231 @@
|
|||||||
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from scipy.ndimage import gaussian_filter
|
||||||
|
from soundevent import data, geometry
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"generate_heatmaps",
|
||||||
|
]
|
||||||
|
|
||||||
|
PositionFn = Callable[[data.SoundEvent], Tuple[float, float]]
|
||||||
|
"""Convert a sound event to a single position in time-frequency space."""
|
||||||
|
|
||||||
|
SizeFn = Callable[[data.SoundEvent, float, float], np.ndarray]
|
||||||
|
"""Compute the size of a sound event in time-frequency space.
|
||||||
|
|
||||||
|
The time and frequency scales are provided as arguments to allow
|
||||||
|
modifying the size of the sound event based on the spectrogram
|
||||||
|
parameters.
|
||||||
|
"""
|
||||||
|
|
||||||
|
LabelFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||||
|
"""Convert a sound event annotation to a label.
|
||||||
|
|
||||||
|
When the label is None, this indicates that the sound event does not
|
||||||
|
belong to any of the classes of interest.
|
||||||
|
"""
|
||||||
|
|
||||||
|
TARGET_SIGMA = 3.0
|
||||||
|
|
||||||
|
|
||||||
|
GENERIC_LABEL = "__UNKNOWN__"
|
||||||
|
|
||||||
|
|
||||||
|
def get_lower_left_position(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
) -> Tuple[float, float]:
|
||||||
|
if sound_event.geometry is None:
|
||||||
|
raise ValueError("Sound event has no geometry.")
|
||||||
|
|
||||||
|
start_time, low_freq, _, _ = geometry.compute_bounds(sound_event.geometry)
|
||||||
|
return start_time, low_freq
|
||||||
|
|
||||||
|
|
||||||
|
def get_bbox_size(
|
||||||
|
sound_event: data.SoundEvent,
|
||||||
|
time_scale: float = 1.0,
|
||||||
|
frequency_scale: float = 1.0,
|
||||||
|
) -> np.ndarray:
|
||||||
|
if sound_event.geometry is None:
|
||||||
|
raise ValueError("Sound event has no geometry.")
|
||||||
|
|
||||||
|
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||||
|
sound_event.geometry
|
||||||
|
)
|
||||||
|
|
||||||
|
return np.array(
|
||||||
|
[
|
||||||
|
time_scale * (end_time - start_time),
|
||||||
|
frequency_scale * (high_freq - low_freq),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _tag_key(tag: data.Tag) -> Tuple[str, str]:
|
||||||
|
return (tag.key, tag.value)
|
||||||
|
|
||||||
|
|
||||||
|
def set_value_at_position(
|
||||||
|
array: xr.DataArray,
|
||||||
|
value: Any,
|
||||||
|
**query,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
dims = {dim: n for n, dim in enumerate(array.dims)}
|
||||||
|
indexer: List[Union[slice, int]] = [slice(None) for _ in range(array.ndim)]
|
||||||
|
|
||||||
|
for key, coord in query.items():
|
||||||
|
if key not in dims:
|
||||||
|
raise ValueError(f"Dimension {key} not found in array.")
|
||||||
|
|
||||||
|
coordinates = array.indexes[key]
|
||||||
|
indexer[dims[key]] = coordinates.get_loc(coordinates.asof(coord))
|
||||||
|
|
||||||
|
if isinstance(value, (tuple, list)):
|
||||||
|
value = np.array(value)
|
||||||
|
|
||||||
|
array.data[tuple(indexer)] = value
|
||||||
|
return array
|
||||||
|
|
||||||
|
|
||||||
|
def generate_heatmaps(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
spec: xr.DataArray,
|
||||||
|
num_classes: int = 1,
|
||||||
|
label_fn: LabelFn = lambda _: None,
|
||||||
|
target_sigma: float = TARGET_SIGMA,
|
||||||
|
size_fn: SizeFn = get_bbox_size,
|
||||||
|
position_fn: PositionFn = get_lower_left_position,
|
||||||
|
class_labels: Optional[List[str]] = None,
|
||||||
|
dtype=np.float32,
|
||||||
|
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||||
|
if class_labels is None:
|
||||||
|
class_labels = [str(i) for i in range(num_classes)]
|
||||||
|
|
||||||
|
if len(class_labels) != num_classes:
|
||||||
|
raise ValueError(
|
||||||
|
"Number of class labels must match the number of classes."
|
||||||
|
)
|
||||||
|
|
||||||
|
shape = dict(zip(spec.dims, spec.shape))
|
||||||
|
|
||||||
|
if "time" not in shape or "frequency" not in shape:
|
||||||
|
raise ValueError(
|
||||||
|
"Spectrogram must have time and frequency dimensions."
|
||||||
|
)
|
||||||
|
|
||||||
|
time_duration = spec.time.attrs["max"] - spec.time.attrs["min"]
|
||||||
|
freq_bandwidth = spec.frequency.attrs["max"] - spec.frequency.attrs["min"]
|
||||||
|
|
||||||
|
# Compute the size factors
|
||||||
|
time_scale = 1 / time_duration
|
||||||
|
frequency_scale = 1 / freq_bandwidth
|
||||||
|
|
||||||
|
# Initialize heatmaps
|
||||||
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||||
|
class_heatmap = xr.DataArray(
|
||||||
|
data=np.zeros((num_classes, *spec.shape), dtype=dtype),
|
||||||
|
dims=["category", *spec.dims],
|
||||||
|
coords={
|
||||||
|
"category": class_labels,
|
||||||
|
**spec.coords,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
size_heatmap = xr.DataArray(
|
||||||
|
data=np.zeros((2, *spec.shape), dtype=dtype),
|
||||||
|
dims=["dimension", *spec.dims],
|
||||||
|
coords={
|
||||||
|
"dimension": ["width", "height"],
|
||||||
|
**spec.coords,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
for sound_event_annotation in clip_annotation.sound_events:
|
||||||
|
# Get the position of the sound event
|
||||||
|
time, frequency = position_fn(sound_event_annotation.sound_event)
|
||||||
|
|
||||||
|
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||||
|
detection_heatmap = set_value_at_position(
|
||||||
|
detection_heatmap,
|
||||||
|
1.0,
|
||||||
|
time=time,
|
||||||
|
frequency=frequency,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set the size of the sound event at the position in the size heatmap
|
||||||
|
size = size_fn(
|
||||||
|
sound_event_annotation.sound_event,
|
||||||
|
time_scale,
|
||||||
|
frequency_scale,
|
||||||
|
|
||||||
|
)
|
||||||
|
size_heatmap = set_value_at_position(
|
||||||
|
size_heatmap,
|
||||||
|
size,
|
||||||
|
time=time,
|
||||||
|
frequency=frequency,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the label id for the sound event
|
||||||
|
label = label_fn(sound_event_annotation)
|
||||||
|
|
||||||
|
if label is None or label not in class_labels:
|
||||||
|
# If the label is None or not in the class labels, we skip the
|
||||||
|
# sound event
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Set 1.0 at the position and category of the sound event in the class
|
||||||
|
# heatmap
|
||||||
|
class_heatmap = set_value_at_position(
|
||||||
|
class_heatmap,
|
||||||
|
1.0,
|
||||||
|
time=time,
|
||||||
|
frequency=frequency,
|
||||||
|
category=label,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Apply gaussian filters
|
||||||
|
detection_heatmap = xr.apply_ufunc(
|
||||||
|
gaussian_filter,
|
||||||
|
detection_heatmap,
|
||||||
|
target_sigma,
|
||||||
|
)
|
||||||
|
|
||||||
|
class_heatmap = class_heatmap.groupby("category").map(
|
||||||
|
gaussian_filter, # type: ignore
|
||||||
|
args=(target_sigma,),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Normalize heatmaps
|
||||||
|
detection_heatmap = (
|
||||||
|
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
|
||||||
|
).fillna(0.0)
|
||||||
|
|
||||||
|
class_heatmap = (
|
||||||
|
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
||||||
|
).fillna(0.0)
|
||||||
|
|
||||||
|
return detection_heatmap, class_heatmap, size_heatmap
|
||||||
|
|
||||||
|
|
||||||
|
class Labeler:
|
||||||
|
def __init__(self, tags: List[data.Tag]):
|
||||||
|
"""Create a labeler from a list of tags.
|
||||||
|
|
||||||
|
Each tag is assigned a unique label. The labeler can then be used
|
||||||
|
to convert sound event annotations to labels.
|
||||||
|
"""
|
||||||
|
self.tags = tags
|
||||||
|
self._label_map = {_tag_key(tag): i for i, tag in enumerate(tags)}
|
||||||
|
self._inverse_label_map = {v: k for k, v in self._label_map.items()}
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> Optional[int]:
|
||||||
|
for tag in sound_event_annotation.tags:
|
||||||
|
key = _tag_key(tag)
|
||||||
|
if key in self._label_map:
|
||||||
|
return self._label_map[key]
|
||||||
|
|
||||||
|
return None
|
586
batdetect2/data/preprocessing.py
Normal file
586
batdetect2/data/preprocessing.py
Normal file
@ -0,0 +1,586 @@
|
|||||||
|
"""Module containing functions for preprocessing audio clips."""
|
||||||
|
|
||||||
|
import random
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
import librosa
|
||||||
|
import librosa.core.spectrum
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from numpy.typing import DTypeLike
|
||||||
|
from scipy.signal import resample_poly
|
||||||
|
from soundevent import audio, data
|
||||||
|
|
||||||
|
TARGET_SAMPLERATE_HZ = 256000
|
||||||
|
SCALE_RAW_AUDIO = False
|
||||||
|
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||||
|
FFT_OVERLAP = 0.75
|
||||||
|
MAX_FREQ_HZ = 120000
|
||||||
|
MIN_FREQ_HZ = 10000
|
||||||
|
DEFAULT_DURATION = 1
|
||||||
|
SPEC_HEIGHT = 128
|
||||||
|
SPEC_WIDTH = 256
|
||||||
|
SPEC_SCALE = "pcen"
|
||||||
|
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||||
|
DENOISE_SPEC_AVG = True
|
||||||
|
MAX_SCALE_SPEC = False
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_audio_clip(
|
||||||
|
clip: data.Clip,
|
||||||
|
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
scale_audio: bool = SCALE_RAW_AUDIO,
|
||||||
|
fft_win_length: float = FFT_WIN_LENGTH_S,
|
||||||
|
fft_overlap: float = FFT_OVERLAP,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
spec_scale: str = SPEC_SCALE,
|
||||||
|
denoise_spec_avg: bool = True,
|
||||||
|
max_scale_spec: bool = False,
|
||||||
|
duration: Optional[float] = DEFAULT_DURATION,
|
||||||
|
spec_height: int = SPEC_HEIGHT,
|
||||||
|
spec_time_period: float = SPEC_TIME_PERIOD,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
"""Preprocesses audio clip to generate spectrogram.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
clip
|
||||||
|
The audio clip to preprocess.
|
||||||
|
target_sampling_rate
|
||||||
|
Target sampling rate for the audio. If the audio has a different
|
||||||
|
sampling rate, it will be resampled to this rate.
|
||||||
|
scale_audio
|
||||||
|
Whether to scale the audio amplitudes to a range of [-1, 1].
|
||||||
|
By default, the audio is not scaled.
|
||||||
|
fft_win_length
|
||||||
|
Length of the FFT window in seconds.
|
||||||
|
fft_overlap
|
||||||
|
Amount of overlap between FFT windows as a fraction of the window
|
||||||
|
length.
|
||||||
|
max_freq
|
||||||
|
Maximum frequency for spectrogram. Anything above this frequency will
|
||||||
|
be cropped.
|
||||||
|
min_freq
|
||||||
|
Minimum frequency for spectrogram. Anything below this frequency will
|
||||||
|
be cropped.
|
||||||
|
spec_scale
|
||||||
|
Scaling method for the spectrogram. Can be "pcen", "log" or
|
||||||
|
"amplitude".
|
||||||
|
denoise_spec_avg
|
||||||
|
Whether to denoise the spectrogram. Denoising is done by subtracting
|
||||||
|
the average of the spectrogram from the spectrogram and clipping
|
||||||
|
negative values to 0.
|
||||||
|
max_scale_spec
|
||||||
|
Whether to max scale the spectrogram. Max scaling is done by dividing
|
||||||
|
the spectrogram by its maximum value thus scaling values to [0, 1].
|
||||||
|
duration
|
||||||
|
Duration of the spectrogram in seconds. If the clip duration is
|
||||||
|
different from this value, the spectrogram will be cropped or extended
|
||||||
|
to match this duration. If None, the spectrogram will have the same
|
||||||
|
duration as the clip.
|
||||||
|
spec_height
|
||||||
|
Number of frequency bins for the spectrogram. This is the height of
|
||||||
|
the final spectrogram.
|
||||||
|
spec_time_period
|
||||||
|
Time period for each spectrogram bin in seconds. The spectrogram array
|
||||||
|
will be resized (using bilinear interpolation) to have this time
|
||||||
|
period.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
xr.DataArray
|
||||||
|
Preprocessed spectrogram.
|
||||||
|
|
||||||
|
"""
|
||||||
|
wav = load_clip_audio(
|
||||||
|
clip,
|
||||||
|
target_sampling_rate=target_sampling_rate,
|
||||||
|
scale=scale_audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
wav = wav.assign_attrs(
|
||||||
|
recording_id=str(wav.attrs["recording_id"]),
|
||||||
|
clip_id=str(wav.attrs["clip_id"]),
|
||||||
|
path=str(wav.attrs["path"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = compute_spectrogram(
|
||||||
|
wav,
|
||||||
|
fft_win_length=fft_win_length,
|
||||||
|
fft_overlap=fft_overlap,
|
||||||
|
max_freq=max_freq,
|
||||||
|
min_freq=min_freq,
|
||||||
|
spec_scale=spec_scale,
|
||||||
|
denoise_spec_avg=denoise_spec_avg,
|
||||||
|
max_scale_spec=max_scale_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
if duration is not None:
|
||||||
|
spec = adjust_spec_duration(clip, spec, duration)
|
||||||
|
|
||||||
|
duration = get_dim_width(spec, dim="time")
|
||||||
|
return resize_spectrogram(
|
||||||
|
spec,
|
||||||
|
time_bins=int(np.ceil(duration / spec_time_period)),
|
||||||
|
freq_bins=spec_height,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_spec_duration(
|
||||||
|
clip: data.Clip,
|
||||||
|
spec: xr.DataArray,
|
||||||
|
duration: float,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
current_duration = clip.end_time - clip.start_time
|
||||||
|
|
||||||
|
if current_duration == duration:
|
||||||
|
return spec
|
||||||
|
|
||||||
|
if current_duration > duration:
|
||||||
|
return crop_axis(
|
||||||
|
spec,
|
||||||
|
dim="time",
|
||||||
|
start=clip.start_time,
|
||||||
|
end=clip.start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
return extend_axis(
|
||||||
|
spec,
|
||||||
|
dim="time",
|
||||||
|
start=clip.start_time,
|
||||||
|
end=clip.start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_clip_audio(
|
||||||
|
clip: data.Clip,
|
||||||
|
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
scale: bool = SCALE_RAW_AUDIO,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
wav = audio.load_clip(clip).sel(channel=0)
|
||||||
|
|
||||||
|
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
||||||
|
|
||||||
|
if scale:
|
||||||
|
wav = scale_audio(wav)
|
||||||
|
|
||||||
|
wav.coords["time"] = wav.time.assign_attrs(
|
||||||
|
unit="s",
|
||||||
|
long_name="Seconds since start of recording",
|
||||||
|
min=clip.start_time,
|
||||||
|
max=clip.end_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wav
|
||||||
|
|
||||||
|
|
||||||
|
def resample_audio(
|
||||||
|
wav: xr.DataArray,
|
||||||
|
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if "samplerate" not in wav.attrs:
|
||||||
|
raise ValueError("Audio must have a 'samplerate' attribute")
|
||||||
|
|
||||||
|
if "time" not in wav.dims:
|
||||||
|
raise ValueError("Audio must have a time dimension")
|
||||||
|
|
||||||
|
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||||
|
original_samplerate = wav.attrs["samplerate"]
|
||||||
|
|
||||||
|
if original_samplerate == target_samplerate:
|
||||||
|
return wav.astype(dtype)
|
||||||
|
|
||||||
|
gcd = np.gcd(original_samplerate, target_samplerate)
|
||||||
|
resampled = resample_poly(
|
||||||
|
wav.values,
|
||||||
|
target_samplerate // gcd,
|
||||||
|
original_samplerate // gcd,
|
||||||
|
axis=time_axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
resampled_times = np.linspace(
|
||||||
|
wav.time[0],
|
||||||
|
wav.time[-1],
|
||||||
|
len(resampled),
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
return xr.DataArray(
|
||||||
|
data=resampled.astype(dtype),
|
||||||
|
dims=wav.dims,
|
||||||
|
coords={
|
||||||
|
**wav.coords,
|
||||||
|
"time": resampled_times,
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
**wav.attrs,
|
||||||
|
"samplerate": target_samplerate,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_audio(
|
||||||
|
audio: xr.DataArray,
|
||||||
|
eps: float = 10e-6,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
audio = audio - audio.mean()
|
||||||
|
return audio / np.add(np.abs(audio).max(), eps, dtype=audio.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram(
|
||||||
|
wav: xr.DataArray,
|
||||||
|
fft_win_length: float = FFT_WIN_LENGTH_S,
|
||||||
|
fft_overlap: float = FFT_OVERLAP,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
spec_scale: str = SPEC_SCALE,
|
||||||
|
denoise_spec_avg: bool = True,
|
||||||
|
max_scale_spec: bool = False,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
spec = gen_mag_spectrogram(
|
||||||
|
wav,
|
||||||
|
window_len=fft_win_length,
|
||||||
|
overlap_perc=fft_overlap,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = crop_axis(
|
||||||
|
spec,
|
||||||
|
dim="frequency",
|
||||||
|
start=min_freq,
|
||||||
|
end=max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = scale_spectrogram(spec, scale=spec_scale)
|
||||||
|
|
||||||
|
if denoise_spec_avg:
|
||||||
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
|
if max_scale_spec:
|
||||||
|
spec = max_scale_spectrogram(spec)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def crop_axis(
|
||||||
|
arr: xr.DataArray,
|
||||||
|
dim: str,
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
right_closed: bool = False,
|
||||||
|
left_closed: bool = True,
|
||||||
|
eps: float = 10e-6,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
coord = arr.coords[dim]
|
||||||
|
|
||||||
|
if not all(attr in coord.attrs for attr in ["min", "max"]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Coordinate '{dim}' must have 'min' and 'max' attributes"
|
||||||
|
)
|
||||||
|
|
||||||
|
current_min = coord.attrs["min"]
|
||||||
|
current_max = coord.attrs["max"]
|
||||||
|
|
||||||
|
if start < current_min or end > current_max:
|
||||||
|
raise ValueError(
|
||||||
|
f"Cannot select axis '{dim}' from {start} to {end}. "
|
||||||
|
f"Axis range is {current_min} to {current_max}"
|
||||||
|
)
|
||||||
|
|
||||||
|
slice_end = end
|
||||||
|
if not right_closed:
|
||||||
|
slice_end = end - eps
|
||||||
|
|
||||||
|
slice_start = start
|
||||||
|
if not left_closed:
|
||||||
|
slice_start = start + eps
|
||||||
|
|
||||||
|
arr = arr.sel({dim: slice(slice_start, slice_end)})
|
||||||
|
|
||||||
|
arr.coords[dim].attrs.update(
|
||||||
|
min=start,
|
||||||
|
max=end,
|
||||||
|
)
|
||||||
|
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def extend_axis(
|
||||||
|
arr: xr.DataArray,
|
||||||
|
dim: str,
|
||||||
|
start: float,
|
||||||
|
end: float,
|
||||||
|
fill_value: float = 0,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
coord = arr.coords[dim]
|
||||||
|
|
||||||
|
if not all(attr in coord.attrs for attr in ["min", "max", "period"]):
|
||||||
|
raise ValueError(
|
||||||
|
f"Coordinate '{dim}' must have 'min', 'max' and 'period' attributes"
|
||||||
|
" to extend axis"
|
||||||
|
)
|
||||||
|
|
||||||
|
current_min = coord.attrs["min"]
|
||||||
|
current_max = coord.attrs["max"]
|
||||||
|
period = coord.attrs["period"]
|
||||||
|
|
||||||
|
coords = coord.data
|
||||||
|
|
||||||
|
if start < current_min:
|
||||||
|
new_coords = np.arange(
|
||||||
|
current_min,
|
||||||
|
start,
|
||||||
|
-period,
|
||||||
|
dtype=coord.dtype,
|
||||||
|
)[1:][::-1]
|
||||||
|
coords = np.concatenate([new_coords, coords])
|
||||||
|
|
||||||
|
if end > current_max:
|
||||||
|
new_coords = np.arange(
|
||||||
|
current_max,
|
||||||
|
end,
|
||||||
|
period,
|
||||||
|
dtype=coord.dtype,
|
||||||
|
)[1:]
|
||||||
|
coords = np.concatenate([coords, new_coords])
|
||||||
|
|
||||||
|
arr = arr.reindex(
|
||||||
|
{dim: coords},
|
||||||
|
fill_value=fill_value, # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
arr.coords[dim].attrs.update(
|
||||||
|
min=start,
|
||||||
|
max=end,
|
||||||
|
)
|
||||||
|
|
||||||
|
return arr
|
||||||
|
|
||||||
|
|
||||||
|
def gen_mag_spectrogram(
|
||||||
|
audio: xr.DataArray,
|
||||||
|
window_len: float,
|
||||||
|
overlap_perc: float,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
sampling_rate = audio.attrs["samplerate"]
|
||||||
|
hop_len = window_len * (1 - overlap_perc)
|
||||||
|
nfft = int(window_len * sampling_rate)
|
||||||
|
noverlap = int(overlap_perc * nfft)
|
||||||
|
start_time = audio.time.attrs["min"]
|
||||||
|
end_time = audio.time.attrs["max"]
|
||||||
|
|
||||||
|
# compute spec
|
||||||
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
|
y=audio.data,
|
||||||
|
power=1,
|
||||||
|
n_fft=nfft,
|
||||||
|
hop_length=nfft - noverlap,
|
||||||
|
center=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = xr.DataArray(
|
||||||
|
data=spec.astype(dtype),
|
||||||
|
dims=["frequency", "time"],
|
||||||
|
coords={
|
||||||
|
"frequency": np.linspace(
|
||||||
|
0,
|
||||||
|
sampling_rate / 2,
|
||||||
|
spec.shape[0],
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
"time": np.linspace(
|
||||||
|
start_time,
|
||||||
|
end_time - (window_len - hop_len),
|
||||||
|
spec.shape[1],
|
||||||
|
endpoint=False,
|
||||||
|
dtype=dtype,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
attrs={
|
||||||
|
**audio.attrs,
|
||||||
|
"nfft": nfft,
|
||||||
|
"noverlap": noverlap,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add metadata to coordinates
|
||||||
|
spec.coords["time"].attrs.update(
|
||||||
|
unit="s",
|
||||||
|
long_name="Time",
|
||||||
|
min=start_time,
|
||||||
|
max=end_time - (window_len - hop_len),
|
||||||
|
period=(nfft - noverlap) / sampling_rate,
|
||||||
|
)
|
||||||
|
spec.coords["frequency"].attrs.update(
|
||||||
|
unit="Hz",
|
||||||
|
long_name="Frequency",
|
||||||
|
period=(sampling_rate / nfft),
|
||||||
|
min=0,
|
||||||
|
max=sampling_rate / 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def denoise_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
return xr.DataArray(
|
||||||
|
data=(spec - spec.mean("time")).clip(0),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs={
|
||||||
|
**spec.attrs,
|
||||||
|
"denoised": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
scale: str = SPEC_SCALE,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
if scale == "pcen":
|
||||||
|
return pcen(spec, dtype=dtype)
|
||||||
|
|
||||||
|
if scale == "log":
|
||||||
|
return log_scale(spec, dtype=dtype)
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def log_scale(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
nfft = spec.attrs["nfft"]
|
||||||
|
sampling_rate = spec.attrs["samplerate"]
|
||||||
|
log_scaling = (
|
||||||
|
2.0
|
||||||
|
* (1.0 / sampling_rate)
|
||||||
|
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||||
|
)
|
||||||
|
return xr.DataArray(
|
||||||
|
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs={
|
||||||
|
**spec.attrs,
|
||||||
|
"scale": "log",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pcen(spec: xr.DataArray, dtype: DTypeLike = np.float32) -> xr.DataArray:
|
||||||
|
sampling_rate = spec.attrs["samplerate"]
|
||||||
|
data = librosa.pcen(
|
||||||
|
spec.data * (2**31),
|
||||||
|
sr=sampling_rate / 10,
|
||||||
|
)
|
||||||
|
return xr.DataArray(
|
||||||
|
data=data.astype(dtype),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs={
|
||||||
|
**spec.attrs,
|
||||||
|
"scale": "pcen",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def max_scale_spectrogram(spec: xr.DataArray, eps=10e-6) -> xr.DataArray:
|
||||||
|
return xr.DataArray(
|
||||||
|
data=spec / np.add(spec.max(), eps, dtype=spec.dtype),
|
||||||
|
dims=spec.dims,
|
||||||
|
coords=spec.coords,
|
||||||
|
attrs={
|
||||||
|
**spec.attrs,
|
||||||
|
"max_scaled": 1,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resize_spectrogram(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
time_bins: int,
|
||||||
|
freq_bins: int,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
new_times = np.linspace(
|
||||||
|
spec.time[0],
|
||||||
|
spec.time[-1],
|
||||||
|
time_bins,
|
||||||
|
dtype=spec.time.dtype,
|
||||||
|
endpoint=True,
|
||||||
|
)
|
||||||
|
new_frequencies = np.linspace(
|
||||||
|
spec.frequency[0],
|
||||||
|
spec.frequency[-1],
|
||||||
|
freq_bins,
|
||||||
|
dtype=spec.frequency.dtype,
|
||||||
|
endpoint=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
return spec.interp(
|
||||||
|
coords=dict(
|
||||||
|
time=new_times,
|
||||||
|
frequency=new_frequencies,
|
||||||
|
),
|
||||||
|
method="linear",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_dim_width(arr: xr.DataArray, dim: str) -> float:
|
||||||
|
coord = arr.coords[dim]
|
||||||
|
attrs = coord.attrs
|
||||||
|
if "min" in attrs and "max" in attrs:
|
||||||
|
return attrs["max"] - attrs["min"]
|
||||||
|
|
||||||
|
coord_min = coord.min()
|
||||||
|
coord_max = coord.max()
|
||||||
|
return float(coord_max - coord_min)
|
||||||
|
|
||||||
|
|
||||||
|
class RandomClipProvider:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
clip_annotations: List[data.ClipAnnotation],
|
||||||
|
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||||
|
scale_audio: bool = SCALE_RAW_AUDIO,
|
||||||
|
):
|
||||||
|
self.target_sampling_rate = target_sampling_rate
|
||||||
|
self.scale_audio = scale_audio
|
||||||
|
self.clip_annotations = clip_annotations
|
||||||
|
|
||||||
|
def get_next_clip(self, clip: data.ClipAnnotation) -> data.ClipAnnotation:
|
||||||
|
tries = 0
|
||||||
|
while True:
|
||||||
|
random_clip = random.choice(self.clip_annotations)
|
||||||
|
|
||||||
|
if random_clip.clip != clip.clip:
|
||||||
|
return random_clip
|
||||||
|
|
||||||
|
tries += 1
|
||||||
|
if tries > 4:
|
||||||
|
raise ValueError("Could not find a different clip")
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
clip: data.ClipAnnotation,
|
||||||
|
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||||
|
random_clip = self.get_next_clip(clip)
|
||||||
|
|
||||||
|
wav = load_clip_audio(
|
||||||
|
random_clip.clip,
|
||||||
|
target_sampling_rate=self.target_sampling_rate,
|
||||||
|
scale=self.scale_audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
return wav, random_clip
|
@ -53,7 +53,13 @@ class SelfAttention(nn.Module):
|
|||||||
|
|
||||||
class ConvBlockDownCoordF(nn.Module):
|
class ConvBlockDownCoordF(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
|
self,
|
||||||
|
in_chn,
|
||||||
|
out_chn,
|
||||||
|
ip_height,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
):
|
):
|
||||||
super(ConvBlockDownCoordF, self).__init__()
|
super(ConvBlockDownCoordF, self).__init__()
|
||||||
self.coords = nn.Parameter(
|
self.coords = nn.Parameter(
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.fft
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -207,7 +206,7 @@ class Net2DFastNoAttn(nn.Module):
|
|||||||
num_filts // 4, 2, kernel_size=1, padding=0
|
num_filts // 4, 2, kernel_size=1, padding=0
|
||||||
)
|
)
|
||||||
self.conv_classes_op = nn.Conv2d(
|
self.conv_classes_op = nn.Conv2d(
|
||||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.emb_dim > 0:
|
if self.emb_dim > 0:
|
||||||
|
@ -28,6 +28,7 @@ MAX_SCALE_SPEC = False
|
|||||||
DEFAULT_MODEL_PATH = os.path.join(
|
DEFAULT_MODEL_PATH = os.path.join(
|
||||||
os.path.dirname(os.path.dirname(__file__)),
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
"models",
|
"models",
|
||||||
|
"checkpoints",
|
||||||
"Net2DFast_UK_same.pth.tar",
|
"Net2DFast_UK_same.pth.tar",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -68,6 +68,7 @@ def run_nms(
|
|||||||
params["fft_win_length"],
|
params["fft_win_length"],
|
||||||
params["fft_overlap"],
|
params["fft_overlap"],
|
||||||
)
|
)
|
||||||
|
print("duration", duration)
|
||||||
top_k = int(duration * params["nms_top_k_per_sec"])
|
top_k = int(duration * params["nms_top_k_per_sec"])
|
||||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||||
|
|
||||||
|
@ -7,16 +7,16 @@ import copy
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
|
import torch
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
|
||||||
from batdetect2.detector import parameters
|
|
||||||
import batdetect2.train.evaluate as evl
|
import batdetect2.train.evaluate as evl
|
||||||
import batdetect2.train.train_utils as tu
|
import batdetect2.train.train_utils as tu
|
||||||
import batdetect2.utils.detector_utils as du
|
import batdetect2.utils.detector_utils as du
|
||||||
import batdetect2.utils.plot_utils as pu
|
import batdetect2.utils.plot_utils as pu
|
||||||
|
from batdetect2.detector import parameters
|
||||||
|
|
||||||
|
|
||||||
def get_blank_annotation(ip_str):
|
def get_blank_annotation(ip_str):
|
||||||
@ -330,7 +330,8 @@ def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
|
|||||||
for dd in datasets:
|
for dd in datasets:
|
||||||
print("\n" + dd["dataset_name"])
|
print("\n" + dd["dataset_name"])
|
||||||
gt_dataset = tu.load_set_of_anns(
|
gt_dataset = tu.load_set_of_anns(
|
||||||
[dd], events_of_interest=events_of_interest, verbose=True
|
[dd],
|
||||||
|
events_of_interest=events_of_interest,
|
||||||
)
|
)
|
||||||
gt_dataset = [
|
gt_dataset = [
|
||||||
parse_data(gg, class_names, classes_to_ignore, False)
|
parse_data(gg, class_names, classes_to_ignore, False)
|
||||||
@ -553,7 +554,9 @@ if __name__ == "__main__":
|
|||||||
test_dict["dataset_name"] = args["test_file"].replace(".json", "")
|
test_dict["dataset_name"] = args["test_file"].replace(".json", "")
|
||||||
test_dict["is_test"] = True
|
test_dict["is_test"] = True
|
||||||
test_dict["is_binary"] = True
|
test_dict["is_binary"] = True
|
||||||
test_dict["ann_path"] = os.path.join(args["ann_dir"], args["test_file"])
|
test_dict["ann_path"] = os.path.join(
|
||||||
|
args["ann_dir"], args["test_file"]
|
||||||
|
)
|
||||||
test_dict["wav_path"] = args["data_dir"]
|
test_dict["wav_path"] = args["data_dir"]
|
||||||
test_sets = [test_dict]
|
test_sets = [test_dict]
|
||||||
|
|
||||||
|
Binary file not shown.
91
batdetect2/models/__init__.py
Normal file
91
batdetect2/models/__init__.py
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
import os
|
||||||
|
from typing import Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from batdetect2.models.encoders import (
|
||||||
|
Net2DFast,
|
||||||
|
Net2DFastNoAttn,
|
||||||
|
Net2DFastNoCoordConv,
|
||||||
|
)
|
||||||
|
from batdetect2.models.typing import DetectionModel
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_model",
|
||||||
|
"Net2DFast",
|
||||||
|
"Net2DFastNoAttn",
|
||||||
|
"Net2DFastNoCoordConv",
|
||||||
|
]
|
||||||
|
|
||||||
|
DEFAULT_MODEL_PATH = os.path.join(
|
||||||
|
os.path.dirname(os.path.dirname(__file__)),
|
||||||
|
"models",
|
||||||
|
"checkpoints",
|
||||||
|
"Net2DFast_UK_same.pth.tar",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_path: str = DEFAULT_MODEL_PATH,
|
||||||
|
load_weights: bool = True,
|
||||||
|
device: Union[torch.device, str, None] = None,
|
||||||
|
) -> Tuple[DetectionModel, dict]:
|
||||||
|
"""Load model from file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
|
||||||
|
load_weights (bool, optional): Load weights. Defaults to True.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model, params: Model and parameters.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: Model file not found.
|
||||||
|
ValueError: Unknown model name.
|
||||||
|
"""
|
||||||
|
if device is None:
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
if not os.path.isfile(model_path):
|
||||||
|
raise FileNotFoundError("Model file not found.")
|
||||||
|
|
||||||
|
net_params = torch.load(model_path, map_location=device)
|
||||||
|
|
||||||
|
params = net_params["params"]
|
||||||
|
|
||||||
|
model: DetectionModel
|
||||||
|
|
||||||
|
if params["model_name"] == "Net2DFast":
|
||||||
|
model = Net2DFast(
|
||||||
|
params["num_filters"],
|
||||||
|
num_classes=len(params["class_names"]),
|
||||||
|
emb_dim=params["emb_dim"],
|
||||||
|
ip_height=params["ip_height"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
)
|
||||||
|
elif params["model_name"] == "Net2DFastNoAttn":
|
||||||
|
model = Net2DFastNoAttn(
|
||||||
|
params["num_filters"],
|
||||||
|
num_classes=len(params["class_names"]),
|
||||||
|
emb_dim=params["emb_dim"],
|
||||||
|
ip_height=params["ip_height"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
)
|
||||||
|
elif params["model_name"] == "Net2DFastNoCoordConv":
|
||||||
|
model = Net2DFastNoCoordConv(
|
||||||
|
params["num_filters"],
|
||||||
|
num_classes=len(params["class_names"]),
|
||||||
|
emb_dim=params["emb_dim"],
|
||||||
|
ip_height=params["ip_height"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown model.")
|
||||||
|
|
||||||
|
if load_weights:
|
||||||
|
model.load_state_dict(net_params["state_dict"])
|
||||||
|
|
||||||
|
model = model.to(device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
return model, params
|
219
batdetect2/models/blocks.py
Normal file
219
batdetect2/models/blocks.py
Normal file
@ -0,0 +1,219 @@
|
|||||||
|
"""Module containing custom NN blocks.
|
||||||
|
|
||||||
|
All these classes are subclasses of `torch.nn.Module` and can be used to build
|
||||||
|
complex neural network architectures.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"SelfAttention",
|
||||||
|
"ConvBlockDownCoordF",
|
||||||
|
"ConvBlockDownStandard",
|
||||||
|
"ConvBlockUpF",
|
||||||
|
"ConvBlockUpStandard",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SelfAttention(nn.Module):
|
||||||
|
"""Self-Attention module.
|
||||||
|
|
||||||
|
This module implements self-attention mechanism.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, ip_dim: int, att_dim: int):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Note, does not encode position information (absolute or realtive)
|
||||||
|
self.temperature = 1.0
|
||||||
|
self.att_dim = att_dim
|
||||||
|
self.key_fun = nn.Linear(ip_dim, att_dim)
|
||||||
|
self.val_fun = nn.Linear(ip_dim, att_dim)
|
||||||
|
self.que_fun = nn.Linear(ip_dim, att_dim)
|
||||||
|
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.squeeze(2).permute(0, 2, 1)
|
||||||
|
|
||||||
|
key = torch.matmul(
|
||||||
|
x, self.key_fun.weight.T
|
||||||
|
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
query = torch.matmul(
|
||||||
|
x, self.que_fun.weight.T
|
||||||
|
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
value = torch.matmul(
|
||||||
|
x, self.val_fun.weight.T
|
||||||
|
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
|
||||||
|
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||||
|
self.temperature * self.att_dim
|
||||||
|
)
|
||||||
|
att_weights = F.softmax(kk_qq, 1)
|
||||||
|
att = torch.bmm(value.permute(0, 2, 1), att_weights)
|
||||||
|
|
||||||
|
op = torch.matmul(
|
||||||
|
att.permute(0, 2, 1), self.pro_fun.weight.T
|
||||||
|
) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||||
|
op = op.permute(0, 2, 1).unsqueeze(2)
|
||||||
|
|
||||||
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockDownCoordF(nn.Module):
|
||||||
|
"""Convolutional Block with Downsampling and Coord Feature.
|
||||||
|
|
||||||
|
This block performs convolution followed by downsampling
|
||||||
|
and concatenates with coordinate information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chn: int,
|
||||||
|
out_chn: int,
|
||||||
|
ip_height: int,
|
||||||
|
k_size: int = 3,
|
||||||
|
pad_size: int = 1,
|
||||||
|
stride: int = 1,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.coords = nn.Parameter(
|
||||||
|
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_chn + 1,
|
||||||
|
out_chn,
|
||||||
|
kernel_size=k_size,
|
||||||
|
padding=pad_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||||
|
x = torch.cat((x, freq_info), 1)
|
||||||
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
|
x = F.relu(self.conv_bn(x), inplace=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockDownStandard(nn.Module):
|
||||||
|
"""Convolutional Block with Downsampling.
|
||||||
|
|
||||||
|
This block performs convolution followed by downsampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chn,
|
||||||
|
out_chn,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
):
|
||||||
|
super(ConvBlockDownStandard, self).__init__()
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_chn,
|
||||||
|
out_chn,
|
||||||
|
kernel_size=k_size,
|
||||||
|
padding=pad_size,
|
||||||
|
stride=stride,
|
||||||
|
)
|
||||||
|
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||||
|
x = F.relu(self.conv_bn(x), inplace=True)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockUpF(nn.Module):
|
||||||
|
"""Convolutional Block with Upsampling and Coord Feature.
|
||||||
|
|
||||||
|
This block performs convolution followed by upsampling
|
||||||
|
and concatenates with coordinate information.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chn: int,
|
||||||
|
out_chn: int,
|
||||||
|
ip_height: int,
|
||||||
|
k_size: int = 3,
|
||||||
|
pad_size: int = 1,
|
||||||
|
up_mode: str = "bilinear",
|
||||||
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.up_scale = up_scale
|
||||||
|
self.up_mode = up_mode
|
||||||
|
self.coords = nn.Parameter(
|
||||||
|
torch.linspace(-1, 1, ip_height * up_scale[0])[
|
||||||
|
None, None, ..., None
|
||||||
|
],
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
|
||||||
|
)
|
||||||
|
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
op = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=(
|
||||||
|
x.shape[-2] * self.up_scale[0],
|
||||||
|
x.shape[-1] * self.up_scale[1],
|
||||||
|
),
|
||||||
|
mode=self.up_mode,
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
||||||
|
op = torch.cat((op, freq_info), 1)
|
||||||
|
op = self.conv(op)
|
||||||
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
|
return op
|
||||||
|
|
||||||
|
|
||||||
|
class ConvBlockUpStandard(nn.Module):
|
||||||
|
"""Convolutional Block with Upsampling.
|
||||||
|
|
||||||
|
This block performs convolution followed by upsampling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
in_chn: int,
|
||||||
|
out_chn: int,
|
||||||
|
k_size: int = 3,
|
||||||
|
pad_size: int = 1,
|
||||||
|
up_mode: str = "bilinear",
|
||||||
|
up_scale: Tuple[int, int] = (2, 2),
|
||||||
|
):
|
||||||
|
super(ConvBlockUpStandard, self).__init__()
|
||||||
|
self.up_scale = up_scale
|
||||||
|
self.up_mode = up_mode
|
||||||
|
self.conv = nn.Conv2d(
|
||||||
|
in_chn, out_chn, kernel_size=k_size, padding=pad_size
|
||||||
|
)
|
||||||
|
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
op = F.interpolate(
|
||||||
|
x,
|
||||||
|
size=(
|
||||||
|
x.shape[-2] * self.up_scale[0],
|
||||||
|
x.shape[-1] * self.up_scale[1],
|
||||||
|
),
|
||||||
|
mode=self.up_mode,
|
||||||
|
align_corners=False,
|
||||||
|
)
|
||||||
|
op = self.conv(op)
|
||||||
|
op = F.relu(self.conv_bn(op), inplace=True)
|
||||||
|
return op
|
148
batdetect2/models/detectors.py
Normal file
148
batdetect2/models/detectors.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
import pytorch_lightning as L
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
from torch import nn, optim
|
||||||
|
|
||||||
|
from batdetect2.data.preprocessing import preprocess_audio_clip
|
||||||
|
from batdetect2.models.typing import EncoderModel, ModelOutput
|
||||||
|
from batdetect2.train import losses
|
||||||
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
from batdetect2.models.post_process import (
|
||||||
|
PostprocessConfig,
|
||||||
|
postprocess_model_outputs,
|
||||||
|
)
|
||||||
|
from batdetect2.train.preprocess import PreprocessingConfig
|
||||||
|
|
||||||
|
|
||||||
|
class DetectorModel(L.LightningModule):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder: EncoderModel,
|
||||||
|
num_classes: int,
|
||||||
|
learning_rate: float = 1e-3,
|
||||||
|
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||||
|
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.preprocessing_config = preprocessing_config
|
||||||
|
self.postprocessing_config = postprocessing_config
|
||||||
|
self.num_classes = num_classes
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
self.encoder = encoder
|
||||||
|
|
||||||
|
self.classifier = nn.Conv2d(
|
||||||
|
self.encoder.num_filts // 4,
|
||||||
|
self.num_classes + 1,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bbox = nn.Conv2d(
|
||||||
|
self.encoder.num_filts // 4,
|
||||||
|
2,
|
||||||
|
kernel_size=1,
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||||
|
features = self.encoder(spec)
|
||||||
|
|
||||||
|
classification_logits = self.classifier(features)
|
||||||
|
classification_probs = torch.softmax(classification_logits, dim=1)
|
||||||
|
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
return ModelOutput(
|
||||||
|
detection_probs=detection_probs,
|
||||||
|
size_preds=self.bbox(features),
|
||||||
|
class_probs=classification_probs,
|
||||||
|
features=features,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
||||||
|
config = self.preprocessing_config
|
||||||
|
|
||||||
|
return preprocess_audio_clip(
|
||||||
|
clip,
|
||||||
|
target_sampling_rate=config.target_samplerate,
|
||||||
|
scale_audio=config.scale_audio,
|
||||||
|
fft_win_length=config.fft_win_length,
|
||||||
|
fft_overlap=config.fft_overlap,
|
||||||
|
max_freq=config.max_freq,
|
||||||
|
min_freq=config.min_freq,
|
||||||
|
spec_scale=config.spec_scale,
|
||||||
|
denoise_spec_avg=config.denoise_spec_avg,
|
||||||
|
max_scale_spec=config.max_scale_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_clip(self, clip: data.Clip):
|
||||||
|
spectrogram = self.compute_spectrogram(clip)
|
||||||
|
spec_tensor = (
|
||||||
|
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||||
|
)
|
||||||
|
|
||||||
|
outputs = self(spec_tensor)
|
||||||
|
|
||||||
|
config = self.postprocessing_config
|
||||||
|
return postprocess_model_outputs(
|
||||||
|
outputs,
|
||||||
|
[clip],
|
||||||
|
nms_kernel_size=config.nms_kernel_size,
|
||||||
|
detection_threshold=config.detection_threshold,
|
||||||
|
min_freq=config.min_freq,
|
||||||
|
max_freq=config.max_freq,
|
||||||
|
top_k_per_sec=config.top_k_per_sec,
|
||||||
|
)
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
batch: TrainExample,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
detection_loss = losses.focal_loss(
|
||||||
|
outputs.detection_probs,
|
||||||
|
batch.detection_heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
size_loss = losses.bbox_size_loss(
|
||||||
|
outputs.size_preds,
|
||||||
|
batch.size_heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||||
|
classification_loss = losses.focal_loss(
|
||||||
|
outputs.class_probs,
|
||||||
|
batch.class_heatmap,
|
||||||
|
valid_mask=valid_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
return detection_loss + size_loss + classification_loss
|
||||||
|
|
||||||
|
def training_step( # type: ignore
|
||||||
|
self,
|
||||||
|
batch: TrainExample,
|
||||||
|
):
|
||||||
|
features = self.encoder(batch.spec)
|
||||||
|
|
||||||
|
classification_logits = self.classifier(features)
|
||||||
|
classification_probs = torch.softmax(classification_logits, dim=1)
|
||||||
|
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
||||||
|
|
||||||
|
loss = self.compute_loss(
|
||||||
|
ModelOutput(
|
||||||
|
detection_probs=detection_probs,
|
||||||
|
size_preds=self.bbox(features),
|
||||||
|
class_probs=classification_probs,
|
||||||
|
features=features,
|
||||||
|
),
|
||||||
|
batch,
|
||||||
|
)
|
||||||
|
self.log("train_loss", loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||||
|
return [optimizer], [scheduler]
|
319
batdetect2/models/encoders.py
Normal file
319
batdetect2/models/encoders.py
Normal file
@ -0,0 +1,319 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.fft
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.models.typing import EncoderModel
|
||||||
|
from batdetect2.models.blocks import (
|
||||||
|
ConvBlockDownCoordF,
|
||||||
|
ConvBlockDownStandard,
|
||||||
|
ConvBlockUpF,
|
||||||
|
ConvBlockUpStandard,
|
||||||
|
SelfAttention,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Net2DFast",
|
||||||
|
"Net2DFastNoAttn",
|
||||||
|
"Net2DFastNoCoordConv",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFast(EncoderModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_filts: int,
|
||||||
|
input_height: int = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_filts = num_filts
|
||||||
|
self.input_height = input_height
|
||||||
|
self.bottleneck_height = self.input_height // 32
|
||||||
|
|
||||||
|
# encoder
|
||||||
|
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||||
|
1,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.input_height // 2,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.num_filts,
|
||||||
|
self.input_height // 4,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_3 = nn.Conv2d(
|
||||||
|
self.num_filts,
|
||||||
|
self.num_filts * 2,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||||
|
|
||||||
|
# bottleneck
|
||||||
|
self.conv_1d = nn.Conv2d(
|
||||||
|
self.num_filts * 2,
|
||||||
|
self.num_filts * 2,
|
||||||
|
(self.input_height // 8, 1),
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||||
|
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
self.conv_up_2 = ConvBlockUpF(
|
||||||
|
self.num_filts * 2,
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.input_height // 8,
|
||||||
|
)
|
||||||
|
self.conv_up_3 = ConvBlockUpF(
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height // 4,
|
||||||
|
)
|
||||||
|
self.conv_up_4 = ConvBlockUpF(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_op = nn.Conv2d(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 4,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
||||||
|
|
||||||
|
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||||
|
h, w = spec.shape[2:]
|
||||||
|
h_pad = (32 - h % 32) % 32
|
||||||
|
w_pad = (32 - w % 32) % 32
|
||||||
|
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
# encoder
|
||||||
|
spec, h_pad, w_pad = self.pad_adjust(spec)
|
||||||
|
|
||||||
|
x1 = self.conv_dn_0(spec)
|
||||||
|
x2 = self.conv_dn_1(x1)
|
||||||
|
x3 = self.conv_dn_2(x2)
|
||||||
|
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||||
|
|
||||||
|
# bottleneck
|
||||||
|
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||||
|
x = self.att(x)
|
||||||
|
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||||
|
|
||||||
|
# decoder
|
||||||
|
x = self.conv_up_2(x + x3)
|
||||||
|
x = self.conv_up_3(x + x2)
|
||||||
|
x = self.conv_up_4(x + x1)
|
||||||
|
|
||||||
|
# Restore original size
|
||||||
|
if h_pad > 0:
|
||||||
|
x = x[:, :, :-h_pad, :]
|
||||||
|
|
||||||
|
if w_pad > 0:
|
||||||
|
x = x[:, :, :, :-w_pad]
|
||||||
|
|
||||||
|
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFastNoAttn(EncoderModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_filts: int,
|
||||||
|
input_height: int = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_filts = num_filts
|
||||||
|
self.input_height = input_height
|
||||||
|
self.bottleneck_height = self.input_height // 32
|
||||||
|
|
||||||
|
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||||
|
1,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.input_height // 2,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.num_filts,
|
||||||
|
self.input_height // 4,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_3 = nn.Conv2d(
|
||||||
|
self.num_filts,
|
||||||
|
self.num_filts * 2,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||||
|
|
||||||
|
self.conv_1d = nn.Conv2d(
|
||||||
|
self.num_filts * 2,
|
||||||
|
self.num_filts * 2,
|
||||||
|
(self.input_height // 8, 1),
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||||
|
|
||||||
|
self.conv_up_2 = ConvBlockUpF(
|
||||||
|
self.num_filts * 2,
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.input_height // 8,
|
||||||
|
)
|
||||||
|
self.conv_up_3 = ConvBlockUpF(
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height // 4,
|
||||||
|
)
|
||||||
|
self.conv_up_4 = ConvBlockUpF(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_op = nn.Conv2d(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 4,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = self.conv_dn_0(spec)
|
||||||
|
x2 = self.conv_dn_1(x1)
|
||||||
|
x3 = self.conv_dn_2(x2)
|
||||||
|
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||||
|
|
||||||
|
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||||
|
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||||
|
|
||||||
|
x = self.conv_up_2(x + x3)
|
||||||
|
x = self.conv_up_3(x + x2)
|
||||||
|
x = self.conv_up_4(x + x1)
|
||||||
|
|
||||||
|
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class Net2DFastNoCoordConv(EncoderModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_filts: int,
|
||||||
|
input_height: int = 128,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.num_filts = num_filts
|
||||||
|
self.input_height = input_height
|
||||||
|
self.bottleneck_height = self.input_height // 32
|
||||||
|
|
||||||
|
self.conv_dn_0 = ConvBlockDownStandard(
|
||||||
|
1,
|
||||||
|
self.num_filts // 4,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_1 = ConvBlockDownStandard(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 2,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_2 = ConvBlockDownStandard(
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.num_filts,
|
||||||
|
k_size=3,
|
||||||
|
pad_size=1,
|
||||||
|
stride=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_3 = nn.Conv2d(
|
||||||
|
self.num_filts,
|
||||||
|
self.num_filts * 2,
|
||||||
|
3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||||
|
|
||||||
|
self.conv_1d = nn.Conv2d(
|
||||||
|
self.num_filts * 2,
|
||||||
|
self.num_filts * 2,
|
||||||
|
(self.input_height // 8, 1),
|
||||||
|
padding=0,
|
||||||
|
)
|
||||||
|
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||||
|
|
||||||
|
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
|
||||||
|
|
||||||
|
self.conv_up_2 = ConvBlockUpStandard(
|
||||||
|
self.num_filts * 2,
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.input_height // 8,
|
||||||
|
)
|
||||||
|
self.conv_up_3 = ConvBlockUpStandard(
|
||||||
|
self.num_filts // 2,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height // 4,
|
||||||
|
)
|
||||||
|
self.conv_up_4 = ConvBlockUpStandard(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.input_height // 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.conv_op = nn.Conv2d(
|
||||||
|
self.num_filts // 4,
|
||||||
|
self.num_filts // 4,
|
||||||
|
kernel_size=3,
|
||||||
|
padding=1,
|
||||||
|
)
|
||||||
|
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
||||||
|
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
x1 = self.conv_dn_0(spec)
|
||||||
|
x2 = self.conv_dn_1(x1)
|
||||||
|
x3 = self.conv_dn_2(x2)
|
||||||
|
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||||
|
|
||||||
|
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||||
|
x = self.att(x)
|
||||||
|
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||||
|
|
||||||
|
x = self.conv_up_2(x + x3)
|
||||||
|
x = self.conv_up_3(x + x2)
|
||||||
|
x = self.conv_up_4(x + x1)
|
||||||
|
|
||||||
|
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
310
batdetect2/models/post_process.py
Normal file
310
batdetect2/models/post_process.py
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
"""Module for postprocessing model outputs."""
|
||||||
|
|
||||||
|
from typing import Callable, List, Tuple, Union
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from batdetect2.models.typing import ModelOutput
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"postprocess_model_outputs",
|
||||||
|
"PostprocessConfig",
|
||||||
|
]
|
||||||
|
|
||||||
|
NMS_KERNEL_SIZE = 9
|
||||||
|
DETECTION_THRESHOLD = 0.01
|
||||||
|
TOP_K_PER_SEC = 200
|
||||||
|
|
||||||
|
|
||||||
|
class PostprocessConfig(BaseModel):
|
||||||
|
"""Configuration for postprocessing model outputs."""
|
||||||
|
|
||||||
|
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||||
|
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
||||||
|
min_freq: int = Field(default=10000, gt=0)
|
||||||
|
max_freq: int = Field(default=120000, gt=0)
|
||||||
|
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
TagFunction = Callable[[int], List[data.Tag]]
|
||||||
|
|
||||||
|
|
||||||
|
def postprocess_model_outputs(
|
||||||
|
outputs: ModelOutput,
|
||||||
|
clips: List[data.Clip],
|
||||||
|
nms_kernel_size: int = NMS_KERNEL_SIZE,
|
||||||
|
detection_threshold: float = DETECTION_THRESHOLD,
|
||||||
|
min_freq: int = 10000,
|
||||||
|
max_freq: int = 120000,
|
||||||
|
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||||
|
) -> List[data.ClipPrediction]:
|
||||||
|
"""Postprocesses model outputs to generate clip predictions.
|
||||||
|
|
||||||
|
This function takes the output from the model, applies non-maximum suppression,
|
||||||
|
selects the top-k scores, computes sound events from the outputs, and returns
|
||||||
|
clip predictions based on these processed outputs.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
outputs
|
||||||
|
Output from the model containing detection probabilities, size
|
||||||
|
predictions, class logits, and features. All tensors are expected
|
||||||
|
to have a batch dimension.
|
||||||
|
clips
|
||||||
|
List of clips for which predictions are made. The number of clips
|
||||||
|
must match the batch dimension of the model outputs.
|
||||||
|
nms_kernel_size
|
||||||
|
Size of the non-maximum suppression kernel. Default is 9.
|
||||||
|
detection_threshold
|
||||||
|
Detection threshold. Default is 0.01.
|
||||||
|
min_freq
|
||||||
|
Minimum frequency. Default is 10000.
|
||||||
|
max_freq
|
||||||
|
Maximum frequency. Default is 120000.
|
||||||
|
top_k_per_sec
|
||||||
|
Top k per second. Default is 200.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
predictions: List[data.ClipPrediction]
|
||||||
|
List of clip predictions containing predicted sound events.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
ValueError
|
||||||
|
If the number of predictions does not match the number of clips.
|
||||||
|
"""
|
||||||
|
num_predictions = len(outputs.detection_probs)
|
||||||
|
|
||||||
|
if num_predictions == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
if num_predictions != len(clips):
|
||||||
|
raise ValueError(
|
||||||
|
"Number of predictions must match the number of clips."
|
||||||
|
)
|
||||||
|
|
||||||
|
detection_probs = non_max_suppression(
|
||||||
|
outputs.detection_probs,
|
||||||
|
kernel_size=nms_kernel_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
duration = clips[0].end_time - clips[0].start_time
|
||||||
|
|
||||||
|
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
|
||||||
|
detection_probs,
|
||||||
|
int(top_k_per_sec * duration / 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions: List[data.ClipPrediction] = []
|
||||||
|
for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip(
|
||||||
|
scores_batch,
|
||||||
|
y_pos_batch,
|
||||||
|
x_pos_batch,
|
||||||
|
outputs.size_preds,
|
||||||
|
outputs.class_probs,
|
||||||
|
outputs.features,
|
||||||
|
clips,
|
||||||
|
):
|
||||||
|
sound_events = compute_sound_events_from_outputs(
|
||||||
|
clip,
|
||||||
|
scores,
|
||||||
|
y_pos,
|
||||||
|
x_pos,
|
||||||
|
size_preds,
|
||||||
|
class_probs,
|
||||||
|
features,
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
detection_threshold=detection_threshold,
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions.append(
|
||||||
|
data.ClipPrediction(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=sound_events,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def compute_sound_events_from_outputs(
|
||||||
|
clip: data.Clip,
|
||||||
|
scores: torch.Tensor,
|
||||||
|
y_pos: torch.Tensor,
|
||||||
|
x_pos: torch.Tensor,
|
||||||
|
size_preds: torch.Tensor,
|
||||||
|
class_probs: torch.Tensor,
|
||||||
|
features: torch.Tensor,
|
||||||
|
tag_fn: TagFunction = lambda _: [],
|
||||||
|
min_freq: int = 10000,
|
||||||
|
max_freq: int = 120000,
|
||||||
|
detection_threshold: float = DETECTION_THRESHOLD,
|
||||||
|
) -> List[data.SoundEventPrediction]:
|
||||||
|
_, freq_bins, time_bins = size_preds.shape
|
||||||
|
|
||||||
|
sorted_indices = torch.argsort(x_pos)
|
||||||
|
valid_indices = sorted_indices[
|
||||||
|
scores[sorted_indices] > detection_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
scores = scores[valid_indices]
|
||||||
|
x_pos = x_pos[valid_indices]
|
||||||
|
y_pos = y_pos[valid_indices]
|
||||||
|
|
||||||
|
predictions: List[data.SoundEventPrediction] = []
|
||||||
|
for score, x, y in zip(scores, x_pos, y_pos):
|
||||||
|
width, height = size_preds[:, y, x]
|
||||||
|
print(width, height)
|
||||||
|
class_prob = class_probs[:, y, x]
|
||||||
|
feature = features[:, y, x]
|
||||||
|
|
||||||
|
start_time = np.interp(
|
||||||
|
x.item(),
|
||||||
|
[0, time_bins],
|
||||||
|
[clip.start_time, clip.end_time],
|
||||||
|
)
|
||||||
|
|
||||||
|
end_time = np.interp(
|
||||||
|
x.item() + width.item(),
|
||||||
|
[0, time_bins],
|
||||||
|
[clip.start_time, clip.end_time],
|
||||||
|
)
|
||||||
|
|
||||||
|
low_freq = np.interp(
|
||||||
|
y.item(),
|
||||||
|
[0, freq_bins],
|
||||||
|
[max_freq, min_freq],
|
||||||
|
)
|
||||||
|
|
||||||
|
high_freq = np.interp(
|
||||||
|
y.item() - height.item(),
|
||||||
|
[0, freq_bins],
|
||||||
|
[max_freq, min_freq],
|
||||||
|
)
|
||||||
|
|
||||||
|
predicted_tags: List[data.PredictedTag] = []
|
||||||
|
|
||||||
|
for label_id, class_score in enumerate(class_prob):
|
||||||
|
corresponding_tags = tag_fn(label_id)
|
||||||
|
predicted_tags.extend(
|
||||||
|
[
|
||||||
|
data.PredictedTag(
|
||||||
|
tag=tag,
|
||||||
|
score=class_score.item(),
|
||||||
|
)
|
||||||
|
for tag in corresponding_tags
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||||
|
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||||
|
|
||||||
|
sound_event = data.SoundEvent(
|
||||||
|
recording=clip.recording,
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[
|
||||||
|
start_time,
|
||||||
|
low_freq,
|
||||||
|
end_time,
|
||||||
|
high_freq,
|
||||||
|
]
|
||||||
|
),
|
||||||
|
features=[
|
||||||
|
data.Feature(
|
||||||
|
name=f"batdetect2_{i}",
|
||||||
|
value=value.item(),
|
||||||
|
)
|
||||||
|
for i, value in enumerate(feature)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions.append(
|
||||||
|
data.SoundEventPrediction(
|
||||||
|
sound_event=sound_event,
|
||||||
|
score=score.item(),
|
||||||
|
tags=predicted_tags,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return predictions
|
||||||
|
|
||||||
|
|
||||||
|
def non_max_suppression(
|
||||||
|
tensor: torch.Tensor,
|
||||||
|
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Run non-maximum suppression on a tensor.
|
||||||
|
|
||||||
|
This function removes values from the input tensor that are not local
|
||||||
|
maxima in the neighborhood of the given kernel size.
|
||||||
|
|
||||||
|
All non-maximum values are set to zero.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
tensor : torch.Tensor
|
||||||
|
Input tensor.
|
||||||
|
kernel_size : Union[int, Tuple[int, int]], optional
|
||||||
|
Size of the neighborhood to consider for non-maximum suppression.
|
||||||
|
If an integer is given, the neighborhood will be a square of the
|
||||||
|
given size. If a tuple is given, the neighborhood will be a
|
||||||
|
rectangle with the given height and width.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
torch.Tensor
|
||||||
|
Tensor with non-maximum suppressed values.
|
||||||
|
"""
|
||||||
|
if isinstance(kernel_size, int):
|
||||||
|
kernel_size_h = kernel_size
|
||||||
|
kernel_size_w = kernel_size
|
||||||
|
else:
|
||||||
|
kernel_size_h, kernel_size_w = kernel_size
|
||||||
|
|
||||||
|
pad_h = (kernel_size_h - 1) // 2
|
||||||
|
pad_w = (kernel_size_w - 1) // 2
|
||||||
|
|
||||||
|
hmax = nn.functional.max_pool2d(
|
||||||
|
tensor,
|
||||||
|
(kernel_size_h, kernel_size_w),
|
||||||
|
stride=1,
|
||||||
|
padding=(pad_h, pad_w),
|
||||||
|
)
|
||||||
|
keep = (hmax == tensor).float()
|
||||||
|
return tensor * keep
|
||||||
|
|
||||||
|
|
||||||
|
def get_topk_scores(
|
||||||
|
scores: torch.Tensor,
|
||||||
|
K: int,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Get the top-k scores and their indices.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
scores : torch.Tensor
|
||||||
|
Tensor with scores. Expects input of size: `batch x 1 x height x width`.
|
||||||
|
K : int
|
||||||
|
Number of top scores to return.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
scores : torch.Tensor
|
||||||
|
Top-k scores.
|
||||||
|
ys : torch.Tensor
|
||||||
|
Y coordinates of the top-k scores.
|
||||||
|
xs : torch.Tensor
|
||||||
|
X coordinates of the top-k scores.
|
||||||
|
"""
|
||||||
|
batch, _, height, width = scores.size()
|
||||||
|
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
||||||
|
topk_inds = topk_inds % (height * width)
|
||||||
|
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
|
||||||
|
topk_xs = (topk_inds % width).long()
|
||||||
|
return topk_scores, topk_ys, topk_xs
|
55
batdetect2/models/typing.py
Normal file
55
batdetect2/models/typing.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import NamedTuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class ModelOutput(NamedTuple):
|
||||||
|
"""Output of the detection model.
|
||||||
|
|
||||||
|
Each of the tensors has a shape of
|
||||||
|
|
||||||
|
`(batch_size, num_channels, spec_height, spec_width)`.
|
||||||
|
|
||||||
|
Where `spec_height` and `spec_width` are the height and width of the
|
||||||
|
input spectrograms.
|
||||||
|
|
||||||
|
They contain localised information of:
|
||||||
|
|
||||||
|
1. The probability of a bounding box detection at the given location.
|
||||||
|
2. The predicted size of the bounding box at the given location.
|
||||||
|
3. The probabilities of each class at the given location before softmax.
|
||||||
|
4. Features used to make the predictions at the given location.
|
||||||
|
"""
|
||||||
|
|
||||||
|
detection_probs: torch.Tensor
|
||||||
|
"""Tensor with predict detection probabilities."""
|
||||||
|
|
||||||
|
size_preds: torch.Tensor
|
||||||
|
"""Tensor with predicted bounding box sizes."""
|
||||||
|
|
||||||
|
class_probs: torch.Tensor
|
||||||
|
"""Tensor with predicted class probabilities."""
|
||||||
|
|
||||||
|
features: torch.Tensor
|
||||||
|
"""Tensor with intermediate features."""
|
||||||
|
|
||||||
|
|
||||||
|
class EncoderModel(ABC, nn.Module):
|
||||||
|
|
||||||
|
input_height: int
|
||||||
|
"""Height of the input spectrogram."""
|
||||||
|
|
||||||
|
num_filts: int
|
||||||
|
"""Dimension of the feature tensor."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Forward pass of the encoder model."""
|
||||||
|
|
||||||
|
|
||||||
|
class DetectionModel(ABC, nn.Module):
|
||||||
|
@abstractmethod
|
||||||
|
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Forward pass of the detection model."""
|
@ -102,6 +102,7 @@ def spectrogram(
|
|||||||
return ax
|
return ax
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def spectrogram_with_detections(
|
def spectrogram_with_detections(
|
||||||
spec: Union[torch.Tensor, np.ndarray],
|
spec: Union[torch.Tensor, np.ndarray],
|
||||||
dets: List[Annotation],
|
dets: List[Annotation],
|
||||||
|
0
batdetect2/plotting/__init__.py
Normal file
0
batdetect2/plotting/__init__.py
Normal file
22
batdetect2/plotting/common.py
Normal file
22
batdetect2/plotting/common.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
"""General plotting utilities."""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from matplotlib import axes
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"create_ax",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def create_ax(
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
|
**kwargs,
|
||||||
|
) -> axes.Axes:
|
||||||
|
"""Create a new axis if none is provided"""
|
||||||
|
if ax is None:
|
||||||
|
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||||
|
|
||||||
|
return ax # type: ignore
|
27
batdetect2/plotting/heatmaps.py
Normal file
27
batdetect2/plotting/heatmaps.py
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
"""Plot heatmaps"""
|
||||||
|
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import xarray as xr
|
||||||
|
from matplotlib import axes
|
||||||
|
|
||||||
|
from batdetect2.plotting.common import create_ax
|
||||||
|
|
||||||
|
|
||||||
|
def plot_heatmap(
|
||||||
|
heatmap: xr.DataArray,
|
||||||
|
ax: Optional[axes.Axes] = None,
|
||||||
|
figsize: Tuple[int, int] = (10, 10),
|
||||||
|
) -> axes.Axes:
|
||||||
|
ax = create_ax(ax, figsize=figsize)
|
||||||
|
|
||||||
|
ax.pcolormesh(
|
||||||
|
heatmap.time,
|
||||||
|
heatmap.frequency,
|
||||||
|
heatmap,
|
||||||
|
vmax=1,
|
||||||
|
vmin=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ax
|
@ -1,4 +1,5 @@
|
|||||||
"""Functions and dataloaders for training and testing the model."""
|
"""Functions and dataloaders for training and testing the model."""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
@ -199,8 +200,7 @@ def draw_gaussian(
|
|||||||
x0 = y0 = size // 2
|
x0 = y0 = size // 2
|
||||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||||
g = np.exp(
|
g = np.exp(
|
||||||
-((x - x0) ** 2) / (2 * sigmax**2)
|
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
|
||||||
- ((y - y0) ** 2) / (2 * sigmay**2)
|
|
||||||
)
|
)
|
||||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
||||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
||||||
@ -399,6 +399,8 @@ def echo_aug(
|
|||||||
sample_offset = (
|
sample_offset = (
|
||||||
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
||||||
)
|
)
|
||||||
|
# NOTE: This seems to be wrong, as the echo should be added to the
|
||||||
|
# end of the audio, not the beginning.
|
||||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||||
return audio
|
return audio
|
||||||
|
|
||||||
@ -820,16 +822,18 @@ class AudioLoader(torch.utils.data.Dataset):
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
# create spectrogram
|
# create spectrogram
|
||||||
spec = au.generate_spectrogram(
|
spec, _ = au.generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
fft_win_length=self.params["fft_win_length"],
|
params=dict(
|
||||||
fft_overlap=self.params["fft_overlap"],
|
fft_win_length=self.params["fft_win_length"],
|
||||||
max_freq=self.params["max_freq"],
|
fft_overlap=self.params["fft_overlap"],
|
||||||
min_freq=self.params["min_freq"],
|
max_freq=self.params["max_freq"],
|
||||||
spec_scale=self.params["spec_scale"],
|
min_freq=self.params["min_freq"],
|
||||||
denoise_spec_avg=self.params["denoise_spec_avg"],
|
spec_scale=self.params["spec_scale"],
|
||||||
max_scale_spec=self.params["max_scale_spec"],
|
denoise_spec_avg=self.params["denoise_spec_avg"],
|
||||||
|
max_scale_spec=self.params["max_scale_spec"],
|
||||||
|
),
|
||||||
)
|
)
|
||||||
rsf = self.params["resize_factor"]
|
rsf = self.params["resize_factor"]
|
||||||
spec_op_shape = (
|
spec_op_shape = (
|
||||||
|
86
batdetect2/train/dataset.py
Normal file
86
batdetect2/train/dataset.py
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
import os
|
||||||
|
from typing import NamedTuple
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Sequence, Union, Dict
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
import torch
|
||||||
|
import xarray as xr
|
||||||
|
|
||||||
|
from batdetect2.train.preprocess import PreprocessingConfig
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TrainExample",
|
||||||
|
"LabeledDataset",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
|
||||||
|
|
||||||
|
class TrainExample(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
idx: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||||
|
return list(Path(directory).glob(f"*{extension}"))
|
||||||
|
|
||||||
|
|
||||||
|
class LabeledDataset(Dataset):
|
||||||
|
def __init__(self, filenames: Sequence[PathLike]):
|
||||||
|
self.filenames = filenames
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.filenames)
|
||||||
|
|
||||||
|
def __getitem__(self, idx) -> TrainExample:
|
||||||
|
data = self.load(self.filenames[idx])
|
||||||
|
return TrainExample(
|
||||||
|
spec=data["spectrogram"],
|
||||||
|
detection_heatmap=data["detection"],
|
||||||
|
class_heatmap=data["class"],
|
||||||
|
size_heatmap=data["size"],
|
||||||
|
idx=torch.tensor(idx),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
||||||
|
return cls(get_files(directory, extension))
|
||||||
|
|
||||||
|
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
|
||||||
|
dataset = xr.open_dataset(filename)
|
||||||
|
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
|
||||||
|
return {
|
||||||
|
"spectrogram": spectrogram,
|
||||||
|
"detection": torch.tensor(dataset["detection"].values),
|
||||||
|
"class": torch.tensor(dataset["class"].values),
|
||||||
|
"size": torch.tensor(dataset["size"].values),
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_spectrogram(self, idx):
|
||||||
|
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
||||||
|
|
||||||
|
def get_detection_mask(self, idx):
|
||||||
|
return xr.open_dataset(self.filenames[idx])["detection"]
|
||||||
|
|
||||||
|
def get_class_mask(self, idx):
|
||||||
|
return xr.open_dataset(self.filenames[idx])["class"]
|
||||||
|
|
||||||
|
def get_size_mask(self, idx):
|
||||||
|
return xr.open_dataset(self.filenames[idx])["size"]
|
||||||
|
|
||||||
|
def get_clip_annotation(self, idx):
|
||||||
|
filename = self.filenames[idx]
|
||||||
|
dataset = xr.open_dataset(filename)
|
||||||
|
clip_annotation = dataset.attrs["clip_annotation"]
|
||||||
|
return data.ClipAnnotation.model_validate_json(clip_annotation)
|
||||||
|
|
||||||
|
def get_preprocessing_configuration(self, idx):
|
||||||
|
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
|
||||||
|
return PreprocessingConfig.model_validate_json(config)
|
56
batdetect2/train/light.py
Normal file
56
batdetect2/train/light.py
Normal file
@ -0,0 +1,56 @@
|
|||||||
|
import pytorch_lightning as L
|
||||||
|
from torch import Tensor, optim
|
||||||
|
|
||||||
|
from batdetect2.models.typing import DetectionModel, ModelOutput
|
||||||
|
from batdetect2.train import losses
|
||||||
|
|
||||||
|
from batdetect2.train.dataset import TrainExample
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LitDetectorModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class LitDetectorModel(L.LightningModule):
|
||||||
|
model: DetectionModel
|
||||||
|
|
||||||
|
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
|
||||||
|
super().__init__()
|
||||||
|
self.model = model
|
||||||
|
self.learning_rate = learning_rate
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
outputs: ModelOutput,
|
||||||
|
batch: TrainExample,
|
||||||
|
) -> Tensor:
|
||||||
|
detection_loss = losses.focal_loss(
|
||||||
|
outputs.detection_probs,
|
||||||
|
batch.detection_heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
size_loss = losses.bbox_size_loss(
|
||||||
|
outputs.size_preds,
|
||||||
|
batch.size_heatmap,
|
||||||
|
)
|
||||||
|
|
||||||
|
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||||
|
classification_loss = losses.focal_loss(
|
||||||
|
outputs.class_probs,
|
||||||
|
batch.class_heatmap,
|
||||||
|
valid_mask=valid_mask,
|
||||||
|
)
|
||||||
|
|
||||||
|
return detection_loss + size_loss + classification_loss
|
||||||
|
|
||||||
|
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
|
||||||
|
outputs: ModelOutput = self.model(batch.spec)
|
||||||
|
loss = self.compute_loss(outputs, batch)
|
||||||
|
self.log("train_loss", loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
def configure_optimizers(self):
|
||||||
|
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||||
|
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||||
|
return [optimizer], [scheduler]
|
@ -22,15 +22,15 @@ def focal_loss(
|
|||||||
gt: torch.Tensor,
|
gt: torch.Tensor,
|
||||||
weights: Optional[torch.Tensor] = None,
|
weights: Optional[torch.Tensor] = None,
|
||||||
valid_mask: Optional[torch.Tensor] = None,
|
valid_mask: Optional[torch.Tensor] = None,
|
||||||
|
eps: float = 1e-5,
|
||||||
|
beta: float = 4,
|
||||||
|
alpha: float = 2,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
||||||
pred (batch x c x h x w)
|
pred (batch x c x h x w)
|
||||||
gt (batch x c x h x w)
|
gt (batch x c x h x w)
|
||||||
"""
|
"""
|
||||||
eps = 1e-5
|
|
||||||
beta = 4
|
|
||||||
alpha = 2
|
|
||||||
|
|
||||||
pos_inds = gt.eq(1).float()
|
pos_inds = gt.eq(1).float()
|
||||||
neg_inds = gt.lt(1).float()
|
neg_inds = gt.lt(1).float()
|
||||||
|
209
batdetect2/train/preprocess.py
Normal file
209
batdetect2/train/preprocess.py
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
"""Module for preprocessing data for training."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Optional, Sequence, Union
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from multiprocessing import Pool
|
||||||
|
|
||||||
|
import xarray as xr
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.labels import TARGET_SIGMA, LabelFn, generate_heatmaps
|
||||||
|
from batdetect2.data.preprocessing import (
|
||||||
|
DENOISE_SPEC_AVG,
|
||||||
|
FFT_OVERLAP,
|
||||||
|
FFT_WIN_LENGTH_S,
|
||||||
|
MAX_FREQ_HZ,
|
||||||
|
MAX_SCALE_SPEC,
|
||||||
|
MIN_FREQ_HZ,
|
||||||
|
SCALE_RAW_AUDIO,
|
||||||
|
SPEC_SCALE,
|
||||||
|
TARGET_SAMPLERATE_HZ,
|
||||||
|
preprocess_audio_clip,
|
||||||
|
)
|
||||||
|
|
||||||
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
|
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"preprocess_annotations",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class PreprocessingConfig(BaseModel):
|
||||||
|
"""Configuration for preprocessing data."""
|
||||||
|
|
||||||
|
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
|
||||||
|
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
||||||
|
|
||||||
|
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||||
|
|
||||||
|
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||||
|
|
||||||
|
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||||
|
|
||||||
|
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||||
|
|
||||||
|
spec_scale: str = Field(default=SPEC_SCALE)
|
||||||
|
|
||||||
|
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
||||||
|
|
||||||
|
max_scale_spec: bool = MAX_SCALE_SPEC
|
||||||
|
|
||||||
|
target_sigma: float = Field(default=TARGET_SIGMA, gt=0)
|
||||||
|
|
||||||
|
class_labels: Sequence[str] = ["bat"]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_train_example(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
label_fn: LabelFn = lambda _: None,
|
||||||
|
config: Optional[PreprocessingConfig] = None,
|
||||||
|
) -> xr.Dataset:
|
||||||
|
"""Generate a training example."""
|
||||||
|
if config is None:
|
||||||
|
config = PreprocessingConfig()
|
||||||
|
|
||||||
|
spectrogram = preprocess_audio_clip(
|
||||||
|
clip_annotation.clip,
|
||||||
|
target_sampling_rate=config.target_samplerate,
|
||||||
|
scale_audio=config.scale_audio,
|
||||||
|
fft_win_length=config.fft_win_length,
|
||||||
|
fft_overlap=config.fft_overlap,
|
||||||
|
max_freq=config.max_freq,
|
||||||
|
min_freq=config.min_freq,
|
||||||
|
spec_scale=config.spec_scale,
|
||||||
|
denoise_spec_avg=config.denoise_spec_avg,
|
||||||
|
max_scale_spec=config.max_scale_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
|
clip_annotation,
|
||||||
|
spectrogram,
|
||||||
|
target_sigma=config.target_sigma,
|
||||||
|
num_classes=len(config.class_labels),
|
||||||
|
class_labels=list(config.class_labels),
|
||||||
|
label_fn=label_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset = xr.Dataset(
|
||||||
|
{
|
||||||
|
"spectrogram": spectrogram,
|
||||||
|
"detection": detection_heatmap,
|
||||||
|
"class": class_heatmap,
|
||||||
|
"size": size_heatmap,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return dataset.assign_attrs(
|
||||||
|
title=f"Training example for {clip_annotation.uuid}",
|
||||||
|
configuration=config.model_dump_json(),
|
||||||
|
clip_annotation=clip_annotation.model_dump_json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def save_to_file(
|
||||||
|
dataset: xr.Dataset,
|
||||||
|
path: PathLike,
|
||||||
|
) -> None:
|
||||||
|
dataset.to_netcdf(
|
||||||
|
path,
|
||||||
|
encoding={
|
||||||
|
"spectrogram": {"zlib": True},
|
||||||
|
"size": {"zlib": True},
|
||||||
|
"class": {"zlib": True},
|
||||||
|
"detection": {"zlib": True},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
||||||
|
"""Load configuration from file."""
|
||||||
|
|
||||||
|
path = Path(path)
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
warnings.warn(f"Config file not found: {path}. Using default config.")
|
||||||
|
return PreprocessingConfig(**kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return PreprocessingConfig.model_validate_json(path.read_text())
|
||||||
|
except ValueError as e:
|
||||||
|
warnings.warn(
|
||||||
|
f"Failed to load config file: {e}. Using default config."
|
||||||
|
)
|
||||||
|
return PreprocessingConfig(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||||
|
return f"{clip_annotation.uuid}.nc"
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_single_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
output_dir: PathLike,
|
||||||
|
config: PreprocessingConfig,
|
||||||
|
filename_fn: FilenameFn = _get_filename,
|
||||||
|
replace: bool = False,
|
||||||
|
label_fn: LabelFn = lambda _: None,
|
||||||
|
) -> None:
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
filename = filename_fn(clip_annotation)
|
||||||
|
path = output_dir / filename
|
||||||
|
|
||||||
|
if path.is_file() and not replace:
|
||||||
|
return
|
||||||
|
|
||||||
|
sample = generate_train_example(
|
||||||
|
clip_annotation,
|
||||||
|
label_fn=label_fn,
|
||||||
|
config=config,
|
||||||
|
)
|
||||||
|
|
||||||
|
save_to_file(sample, path)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_annotations(
|
||||||
|
clip_annotations: Sequence[data.ClipAnnotation],
|
||||||
|
output_dir: PathLike,
|
||||||
|
filename_fn: FilenameFn = _get_filename,
|
||||||
|
replace: bool = False,
|
||||||
|
config_file: Optional[PathLike] = None,
|
||||||
|
label_fn: LabelFn = lambda _: None,
|
||||||
|
max_workers: Optional[int] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Preprocess annotations and save to disk."""
|
||||||
|
output_dir = Path(output_dir)
|
||||||
|
|
||||||
|
if not output_dir.is_dir():
|
||||||
|
output_dir.mkdir(parents=True)
|
||||||
|
|
||||||
|
if config_file is not None:
|
||||||
|
config = load_config(config_file, **kwargs)
|
||||||
|
else:
|
||||||
|
config = PreprocessingConfig(**kwargs)
|
||||||
|
|
||||||
|
with Pool(max_workers) as pool:
|
||||||
|
list(
|
||||||
|
tqdm(
|
||||||
|
pool.imap_unordered(
|
||||||
|
partial(
|
||||||
|
preprocess_single_annotation,
|
||||||
|
output_dir=output_dir,
|
||||||
|
config=config,
|
||||||
|
filename_fn=filename_fn,
|
||||||
|
replace=replace,
|
||||||
|
label_fn=label_fn,
|
||||||
|
),
|
||||||
|
clip_annotations,
|
||||||
|
),
|
||||||
|
total=len(clip_annotations),
|
||||||
|
)
|
||||||
|
)
|
82
batdetect2/train/train.py
Normal file
82
batdetect2/train/train.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
from typing import Callable, NamedTuple, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from soundevent import data
|
||||||
|
from torch.optim import Adam
|
||||||
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
from batdetect2.data.datasets import ClipAnnotationDataset
|
||||||
|
from batdetect2.models.typing import DetectionModel
|
||||||
|
|
||||||
|
|
||||||
|
class TrainInputs(NamedTuple):
|
||||||
|
spec: torch.Tensor
|
||||||
|
detection_heatmap: torch.Tensor
|
||||||
|
class_heatmap: torch.Tensor
|
||||||
|
size_heatmap: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
|
def train_loop(
|
||||||
|
model: DetectionModel,
|
||||||
|
train_dataset: ClipAnnotationDataset[TrainInputs],
|
||||||
|
validation_dataset: ClipAnnotationDataset[TrainInputs],
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
num_epochs: int = 100,
|
||||||
|
learning_rate: float = 1e-4,
|
||||||
|
):
|
||||||
|
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||||
|
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||||
|
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||||
|
scheduler = CosineAnnealingLR(
|
||||||
|
optimizer,
|
||||||
|
num_epochs * len(train_loader),
|
||||||
|
)
|
||||||
|
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
train_loss = train_single_epoch(
|
||||||
|
model,
|
||||||
|
train_loader,
|
||||||
|
optimizer,
|
||||||
|
device,
|
||||||
|
scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def train_single_epoch(
|
||||||
|
model: DetectionModel,
|
||||||
|
train_loader: DataLoader,
|
||||||
|
optimizer: Adam,
|
||||||
|
device: torch.device,
|
||||||
|
scheduler: CosineAnnealingLR,
|
||||||
|
):
|
||||||
|
model.train()
|
||||||
|
train_loss = tu.AverageMeter()
|
||||||
|
|
||||||
|
for batch in train_loader:
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
|
spec = batch.spec.to(device)
|
||||||
|
detection_heatmap = batch.detection_heatmap.to(device)
|
||||||
|
class_heatmap = batch.class_heatmap.to(device)
|
||||||
|
size_heatmap = batch.size_heatmap.to(device)
|
||||||
|
|
||||||
|
outputs = model(spec)
|
||||||
|
|
||||||
|
loss = loss_fun(
|
||||||
|
outputs,
|
||||||
|
gt_det,
|
||||||
|
gt_size,
|
||||||
|
gt_class,
|
||||||
|
det_criterion,
|
||||||
|
params,
|
||||||
|
class_inv_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loss.update(loss.item(), data.shape[0])
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
@ -1,3 +1,4 @@
|
|||||||
|
import sys
|
||||||
import json
|
import json
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -7,6 +8,11 @@ import numpy as np
|
|||||||
|
|
||||||
from batdetect2 import types
|
from batdetect2 import types
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 9):
|
||||||
|
StringCounter = Counter[str]
|
||||||
|
else:
|
||||||
|
from typing import Counter as StringCounter
|
||||||
|
|
||||||
|
|
||||||
def write_notes_file(file_name: str, text: str):
|
def write_notes_file(file_name: str, text: str):
|
||||||
with open(file_name, "a") as da:
|
with open(file_name, "a") as da:
|
||||||
@ -148,7 +154,7 @@ def format_annotation(
|
|||||||
def get_class_names(
|
def get_class_names(
|
||||||
data: List[types.FileAnnotation],
|
data: List[types.FileAnnotation],
|
||||||
classes_to_ignore: Optional[List[str]] = None,
|
classes_to_ignore: Optional[List[str]] = None,
|
||||||
) -> Tuple[Counter[str], List[float]]:
|
) -> Tuple[StringCounter, List[float]]:
|
||||||
"""Extracts class names and their inverse frequencies.
|
"""Extracts class names and their inverse frequencies.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -182,7 +188,7 @@ def get_class_names(
|
|||||||
return counts, [mean_counts / counts[cc] for cc in class_names_list]
|
return counts, [mean_counts / counts[cc] for cc in class_names_list]
|
||||||
|
|
||||||
|
|
||||||
def report_class_counts(class_names: Counter[str]):
|
def report_class_counts(class_names: StringCounter):
|
||||||
print("Class count:")
|
print("Class count:")
|
||||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||||
for index, (class_name, count) in enumerate(class_names.most_common()):
|
for index, (class_name, count) in enumerate(class_names.most_common()):
|
||||||
|
@ -1,11 +1,13 @@
|
|||||||
import warnings
|
import warnings
|
||||||
from typing import Optional, Tuple, Union, overload
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from . import wavfile
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"load_audio",
|
"load_audio",
|
||||||
"generate_spectrogram",
|
"generate_spectrogram",
|
||||||
@ -13,171 +15,113 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@overload
|
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||||
def time_to_x_coords(
|
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||||
time_in_file: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
|
||||||
def time_to_x_coords(
|
|
||||||
time_in_file: float,
|
|
||||||
sampling_rate: float,
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
) -> float:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def time_to_x_coords(
|
|
||||||
time_in_file: Union[float, np.ndarray],
|
|
||||||
sampling_rate: float,
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
) -> Union[float, np.ndarray]:
|
|
||||||
nfft = np.floor(fft_win_length * sampling_rate)
|
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
noverlap = np.floor(fft_overlap * nfft)
|
||||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||||
|
|
||||||
|
|
||||||
# NOTE this is also defined in post_process
|
# NOTE this is also defined in post_process
|
||||||
def x_coords_to_time(
|
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||||
x_pos: float,
|
|
||||||
sampling_rate: int,
|
|
||||||
fft_win_length: float,
|
|
||||||
fft_overlap: float,
|
|
||||||
) -> float:
|
|
||||||
nfft = np.floor(fft_win_length * sampling_rate)
|
nfft = np.floor(fft_win_length * sampling_rate)
|
||||||
noverlap = np.floor(fft_overlap * nfft)
|
noverlap = np.floor(fft_overlap * nfft)
|
||||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||||
|
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for
|
|
||||||
# center of temporal window
|
|
||||||
|
|
||||||
|
|
||||||
def generate_spectrogram(
|
def generate_spectrogram(
|
||||||
audio: np.ndarray,
|
audio,
|
||||||
sampling_rate: float,
|
sampling_rate,
|
||||||
fft_win_length: float,
|
params,
|
||||||
fft_overlap: float,
|
return_spec_for_viz=False,
|
||||||
max_freq: float,
|
check_spec_size=True,
|
||||||
min_freq: float,
|
):
|
||||||
spec_scale: str,
|
|
||||||
denoise_spec_avg: bool = False,
|
|
||||||
max_scale_spec: bool = False,
|
|
||||||
) -> np.ndarray:
|
|
||||||
# generate spectrogram
|
# generate spectrogram
|
||||||
spec = gen_mag_spectrogram(
|
spec = gen_mag_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
sampling_rate,
|
sampling_rate,
|
||||||
window_len=fft_win_length,
|
params["fft_win_length"],
|
||||||
overlap_perc=fft_overlap,
|
params["fft_overlap"],
|
||||||
)
|
|
||||||
spec = crop_spectrogram(
|
|
||||||
spec,
|
|
||||||
fft_win_length=fft_win_length,
|
|
||||||
max_freq=max_freq,
|
|
||||||
min_freq=min_freq,
|
|
||||||
)
|
|
||||||
spec = scale_spectrogram(
|
|
||||||
spec,
|
|
||||||
sampling_rate,
|
|
||||||
spec_scale=spec_scale,
|
|
||||||
fft_win_length=fft_win_length,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if denoise_spec_avg:
|
|
||||||
spec = denoise_spectrogram(spec)
|
|
||||||
|
|
||||||
if max_scale_spec:
|
|
||||||
spec = max_scale_spectrogram(spec)
|
|
||||||
|
|
||||||
return spec
|
|
||||||
|
|
||||||
|
|
||||||
def crop_spectrogram(
|
|
||||||
spec: np.ndarray,
|
|
||||||
fft_win_length: float,
|
|
||||||
max_freq: float,
|
|
||||||
min_freq: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
# crop to min/max freq
|
# crop to min/max freq
|
||||||
max_freq = round(max_freq * fft_win_length)
|
max_freq = round(params["max_freq"] * params["fft_win_length"])
|
||||||
min_freq = round(min_freq * fft_win_length)
|
min_freq = round(params["min_freq"] * params["fft_win_length"])
|
||||||
if spec.shape[0] < max_freq:
|
if spec.shape[0] < max_freq:
|
||||||
freq_pad = max_freq - spec.shape[0]
|
freq_pad = max_freq - spec.shape[0]
|
||||||
spec = np.vstack(
|
spec = np.vstack(
|
||||||
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
||||||
)
|
)
|
||||||
return spec[-max_freq : spec.shape[0] - min_freq, :]
|
spec = spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||||
|
|
||||||
|
if params["spec_scale"] == "log":
|
||||||
def denoise_spectrogram(spec: np.ndarray) -> np.ndarray:
|
log_scaling = (
|
||||||
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
2.0
|
||||||
return spec.clip(min=0)
|
* (1.0 / sampling_rate)
|
||||||
|
* (
|
||||||
|
1.0
|
||||||
def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray:
|
/ (
|
||||||
return spec / (spec.max() + 10e-6)
|
np.abs(
|
||||||
|
np.hanning(
|
||||||
|
int(params["fft_win_length"] * sampling_rate)
|
||||||
def log_scale(
|
)
|
||||||
spec: np.ndarray,
|
)
|
||||||
sampling_rate: float,
|
** 2
|
||||||
fft_win_length: float,
|
).sum()
|
||||||
) -> np.ndarray:
|
)
|
||||||
log_scaling = (
|
|
||||||
2.0
|
|
||||||
* (1.0 / sampling_rate)
|
|
||||||
* (
|
|
||||||
1.0
|
|
||||||
/ (
|
|
||||||
np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2
|
|
||||||
).sum()
|
|
||||||
)
|
)
|
||||||
)
|
# log_scaling = (1.0 / sampling_rate)*0.1
|
||||||
return np.log1p(log_scaling * spec)
|
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||||
|
spec = np.log1p(log_scaling * spec)
|
||||||
|
elif params["spec_scale"] == "pcen":
|
||||||
|
spec = pcen(spec , sampling_rate)
|
||||||
|
|
||||||
|
elif params["spec_scale"] == "none":
|
||||||
|
pass
|
||||||
|
|
||||||
def scale_spectrogram(
|
if params["denoise_spec_avg"]:
|
||||||
spec: np.ndarray,
|
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
||||||
sampling_rate: float,
|
spec.clip(min=0, out=spec)
|
||||||
spec_scale: str,
|
|
||||||
fft_win_length: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
if spec_scale == "log":
|
|
||||||
return log_scale(spec, sampling_rate, fft_win_length)
|
|
||||||
|
|
||||||
if spec_scale == "pcen":
|
if params["max_scale_spec"]:
|
||||||
return pcen(spec, sampling_rate)
|
spec = spec / (spec.max() + 10e-6)
|
||||||
|
|
||||||
return spec
|
# needs to be divisible by specific factor - if not it should have been padded
|
||||||
|
# if check_spec_size:
|
||||||
|
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||||
|
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||||
|
|
||||||
|
|
||||||
def prepare_spec_for_viz(
|
|
||||||
spec: np.ndarray,
|
|
||||||
sampling_rate: int,
|
|
||||||
fft_win_length: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
# for visualization purposes - use log scaled spectrogram
|
# for visualization purposes - use log scaled spectrogram
|
||||||
return log_scale(
|
if return_spec_for_viz:
|
||||||
spec,
|
log_scaling = (
|
||||||
sampling_rate,
|
2.0
|
||||||
fft_win_length=fft_win_length,
|
* (1.0 / sampling_rate)
|
||||||
).astype(np.float32)
|
* (
|
||||||
|
1.0
|
||||||
|
/ (
|
||||||
|
np.abs(
|
||||||
|
np.hanning(
|
||||||
|
int(params["fft_win_length"] * sampling_rate)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
** 2
|
||||||
|
).sum()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
spec_for_viz = np.log1p(log_scaling * spec).astype(np.float32)
|
||||||
|
else:
|
||||||
|
spec_for_viz = None
|
||||||
|
|
||||||
|
return spec, spec_for_viz
|
||||||
|
|
||||||
|
|
||||||
def load_audio(
|
def load_audio(
|
||||||
audio_file: str,
|
audio_file: str,
|
||||||
time_exp_fact: float,
|
time_exp_fact: float,
|
||||||
target_sampling_rate: int,
|
target_samp_rate: int,
|
||||||
scale: bool = False,
|
scale: bool = False,
|
||||||
max_duration: Optional[float] = None,
|
max_duration: Optional[float] = None,
|
||||||
) -> Tuple[float, np.ndarray]:
|
) -> Tuple[int, np.ndarray]:
|
||||||
"""Load an audio file and resample it to the target sampling rate.
|
"""Load an audio file and resample it to the target sampling rate.
|
||||||
|
|
||||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||||
@ -208,82 +152,63 @@ def load_audio(
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
audio, sampling_rate = librosa.load(
|
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||||
|
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||||
|
audio_raw, sampling_rate = librosa.load(
|
||||||
audio_file,
|
audio_file,
|
||||||
sr=None,
|
sr=None,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(audio.shape) > 1:
|
if len(audio_raw.shape) > 1:
|
||||||
raise ValueError("Currently does not handle stereo files")
|
raise ValueError("Currently does not handle stereo files")
|
||||||
|
|
||||||
sampling_rate = sampling_rate * time_exp_fact
|
sampling_rate = sampling_rate * time_exp_fact
|
||||||
|
|
||||||
# resample - need to do this after correcting for time expansion
|
# resample - need to do this after correcting for time expansion
|
||||||
audio = resample_audio(audio, sampling_rate, target_sampling_rate)
|
sampling_rate_old = sampling_rate
|
||||||
|
sampling_rate = target_samp_rate
|
||||||
if max_duration is not None:
|
if sampling_rate_old != sampling_rate:
|
||||||
audio = clip_audio(audio, target_sampling_rate, max_duration)
|
audio_raw = librosa.resample(
|
||||||
|
audio_raw,
|
||||||
# scale to [-1, 1]
|
orig_sr=sampling_rate_old,
|
||||||
if scale:
|
target_sr=sampling_rate,
|
||||||
audio = scale_audio(audio)
|
|
||||||
|
|
||||||
return target_sampling_rate, audio
|
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
|
||||||
audio: np.ndarray,
|
|
||||||
sr_orig: float,
|
|
||||||
sr_target: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
if sr_orig != sr_target:
|
|
||||||
return librosa.resample(
|
|
||||||
audio,
|
|
||||||
orig_sr=sr_orig,
|
|
||||||
target_sr=sr_target,
|
|
||||||
res_type="polyphase",
|
res_type="polyphase",
|
||||||
)
|
)
|
||||||
|
|
||||||
return audio
|
# clipping maximum duration
|
||||||
|
if max_duration is not None:
|
||||||
|
max_duration = int(
|
||||||
def clip_audio(
|
np.minimum(
|
||||||
audio: np.ndarray,
|
int(sampling_rate * max_duration),
|
||||||
sampling_rate: float,
|
audio_raw.shape[0],
|
||||||
max_duration: float,
|
)
|
||||||
) -> np.ndarray:
|
|
||||||
max_duration = int(
|
|
||||||
np.minimum(
|
|
||||||
int(sampling_rate * max_duration),
|
|
||||||
audio.shape[0],
|
|
||||||
)
|
)
|
||||||
)
|
audio_raw = audio_raw[:max_duration]
|
||||||
return audio[:max_duration]
|
|
||||||
|
|
||||||
|
# scale to [-1, 1]
|
||||||
|
if scale:
|
||||||
|
audio_raw = audio_raw - audio_raw.mean()
|
||||||
|
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||||
|
|
||||||
def scale_audio(
|
return sampling_rate, audio_raw
|
||||||
audio: np.ndarray,
|
|
||||||
eps: float = 10e-6,
|
|
||||||
) -> np.ndarray:
|
|
||||||
return (audio - audio.mean()) / (np.abs(audio).max() + eps)
|
|
||||||
|
|
||||||
|
|
||||||
def pad_audio(
|
def pad_audio(
|
||||||
audio_raw: np.ndarray,
|
audio_raw,
|
||||||
sampling_rate: float,
|
fs,
|
||||||
window_len: float,
|
ms,
|
||||||
overlap_perc: float,
|
overlap_perc,
|
||||||
resize_factor: float,
|
resize_factor,
|
||||||
divide_factor: float,
|
divide_factor,
|
||||||
fixed_width: Optional[int] = None,
|
fixed_width=None,
|
||||||
) -> np.ndarray:
|
):
|
||||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||||
# will be evenly divisible by `divide_factor`
|
# will be evenly divisible by `divide_factor`
|
||||||
# Also deals with very short audio clips and fixed_width during training
|
# Also deals with very short audio clips and fixed_width during training
|
||||||
|
|
||||||
# This code could be clearer, clean up
|
# This code could be clearer, clean up
|
||||||
nfft = int(window_len * sampling_rate)
|
nfft = int(ms * fs)
|
||||||
noverlap = int(overlap_perc * nfft)
|
noverlap = int(overlap_perc * nfft)
|
||||||
step = nfft - noverlap
|
step = nfft - noverlap
|
||||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||||
@ -320,23 +245,22 @@ def pad_audio(
|
|||||||
return audio_raw
|
return audio_raw
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram(
|
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||||
audio: np.ndarray,
|
|
||||||
sampling_rate: float,
|
|
||||||
window_len: float,
|
|
||||||
overlap_perc: float,
|
|
||||||
) -> np.ndarray:
|
|
||||||
# Computes magnitude spectrogram by specifying time.
|
# Computes magnitude spectrogram by specifying time.
|
||||||
audio = audio.astype(np.float32)
|
|
||||||
nfft = int(window_len * sampling_rate)
|
x = x.astype(np.float32)
|
||||||
|
nfft = int(ms * fs)
|
||||||
noverlap = int(overlap_perc * nfft)
|
noverlap = int(overlap_perc * nfft)
|
||||||
|
|
||||||
|
# window data
|
||||||
|
step = nfft - noverlap
|
||||||
|
|
||||||
# compute spec
|
# compute spec
|
||||||
spec, _ = librosa.core.spectrum._spectrogram(
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
y=audio,
|
y=x,
|
||||||
power=1,
|
power=1,
|
||||||
n_fft=nfft,
|
n_fft=nfft,
|
||||||
hop_length=nfft - noverlap,
|
hop_length=step,
|
||||||
center=False,
|
center=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -346,25 +270,24 @@ def gen_mag_spectrogram(
|
|||||||
return spec.astype(np.float32)
|
return spec.astype(np.float32)
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram_pt(
|
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
|
||||||
audio: torch.Tensor,
|
nfft = int(ms * fs)
|
||||||
sampling_rate: float,
|
|
||||||
window_len: float,
|
|
||||||
overlap_perc: float,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
nfft = int(window_len * sampling_rate)
|
|
||||||
nstep = round((1.0 - overlap_perc) * nfft)
|
nstep = round((1.0 - overlap_perc) * nfft)
|
||||||
han_win = torch.hann_window(nfft, periodic=False).to(audio.device)
|
|
||||||
|
|
||||||
complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False)
|
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
|
||||||
|
|
||||||
|
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
|
||||||
spec = complex_spec.pow(2.0).sum(-1)
|
spec = complex_spec.pow(2.0).sum(-1)
|
||||||
|
|
||||||
# remove DC component and flip vertically
|
# remove DC component and flip vertically
|
||||||
return torch.flipud(spec[0, 1:, :])
|
spec = torch.flipud(spec[0, 1:, :])
|
||||||
|
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray:
|
def pcen(spec_cropped, sampling_rate):
|
||||||
# TODO should be passing hop_length too i.e. step
|
# TODO should be passing hop_length too i.e. step
|
||||||
return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype(
|
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
|
||||||
np.float32
|
np.float32
|
||||||
)
|
)
|
||||||
|
return spec
|
||||||
|
@ -437,7 +437,7 @@ def compute_spectrogram(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# generate spectrogram
|
# generate spectrogram
|
||||||
spec = au.generate_spectrogram(audio, sampling_rate, params)
|
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
||||||
|
|
||||||
# convert to pytorch
|
# convert to pytorch
|
||||||
spec = torch.from_numpy(spec).to(device)
|
spec = torch.from_numpy(spec).to(device)
|
||||||
@ -746,7 +746,7 @@ def process_file(
|
|||||||
sampling_rate, audio_full = au.load_audio(
|
sampling_rate, audio_full = au.load_audio(
|
||||||
audio_file,
|
audio_file,
|
||||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||||
target_sampling_rate=config["target_samp_rate"],
|
target_samp_rate=config["target_samp_rate"],
|
||||||
scale=config["scale_raw_audio"],
|
scale=config["scale_raw_audio"],
|
||||||
max_duration=config.get("max_duration"),
|
max_duration=config.get("max_duration"),
|
||||||
)
|
)
|
||||||
|
@ -1,17 +0,0 @@
|
|||||||
name: batdetect2
|
|
||||||
channels:
|
|
||||||
- defaults
|
|
||||||
- conda-forge
|
|
||||||
- pytorch
|
|
||||||
- nvidia
|
|
||||||
dependencies:
|
|
||||||
- python==3.10
|
|
||||||
- matplotlib
|
|
||||||
- pandas
|
|
||||||
- scikit-learn
|
|
||||||
- numpy
|
|
||||||
- pytorch
|
|
||||||
- scipy
|
|
||||||
- torchvision
|
|
||||||
- librosa
|
|
||||||
- torchaudio
|
|
@ -1,3 +1,9 @@
|
|||||||
|
[tool]
|
||||||
|
rye = { dev-dependencies = [
|
||||||
|
"ipykernel>=6.29.4",
|
||||||
|
"setuptools>=69.5.1",
|
||||||
|
"pytest>=8.1.1",
|
||||||
|
] }
|
||||||
[tool.pdm]
|
[tool.pdm]
|
||||||
[tool.pdm.dev-dependencies]
|
[tool.pdm.dev-dependencies]
|
||||||
dev = [
|
dev = [
|
||||||
@ -22,12 +28,17 @@ dependencies = [
|
|||||||
"torch>=1.13.1",
|
"torch>=1.13.1",
|
||||||
"torchaudio",
|
"torchaudio",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
"click",
|
"soundevent[audio,geometry,plot]>=1.3.5",
|
||||||
"soundevent>=1.3.5",
|
|
||||||
"click",
|
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
|
"netcdf4>=1.6.5",
|
||||||
|
"tqdm>=4.66.2",
|
||||||
|
"pytorch-lightning>=2.2.2",
|
||||||
|
"cf-xarray>=0.9.0",
|
||||||
|
"onnx>=1.16.0",
|
||||||
|
"lightning[extra]>=2.2.2",
|
||||||
|
"tensorboard>=2.16.2",
|
||||||
]
|
]
|
||||||
requires-python = ">=3.8,<3.12"
|
requires-python = ">=3.9"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { text = "CC-by-nc-4" }
|
license = { text = "CC-by-nc-4" }
|
||||||
classifiers = [
|
classifiers = [
|
||||||
@ -65,6 +76,9 @@ line-length = 79
|
|||||||
profile = "black"
|
profile = "black"
|
||||||
line_length = 79
|
line_length = 79
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 79
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
[[tool.mypy.overrides]]
|
||||||
module = [
|
module = [
|
||||||
"librosa",
|
"librosa",
|
||||||
|
518
requirements-dev.lock
Normal file
518
requirements-dev.lock
Normal file
@ -0,0 +1,518 @@
|
|||||||
|
# generated by rye
|
||||||
|
# use `rye lock` or `rye sync` to update this lockfile
|
||||||
|
#
|
||||||
|
# last locked with the following flags:
|
||||||
|
# pre: false
|
||||||
|
# features: []
|
||||||
|
# all-features: false
|
||||||
|
# with-sources: false
|
||||||
|
|
||||||
|
-e file:.
|
||||||
|
absl-py==2.1.0
|
||||||
|
# via tensorboard
|
||||||
|
aiobotocore==2.12.3
|
||||||
|
# via s3fs
|
||||||
|
aiohttp==3.9.5
|
||||||
|
# via aiobotocore
|
||||||
|
# via fsspec
|
||||||
|
# via lightning
|
||||||
|
# via s3fs
|
||||||
|
aioitertools==0.11.0
|
||||||
|
# via aiobotocore
|
||||||
|
aiosignal==1.3.1
|
||||||
|
# via aiohttp
|
||||||
|
annotated-types==0.6.0
|
||||||
|
# via pydantic
|
||||||
|
antlr4-python3-runtime==4.9.3
|
||||||
|
# via hydra-core
|
||||||
|
# via omegaconf
|
||||||
|
anyio==4.3.0
|
||||||
|
# via starlette
|
||||||
|
arrow==1.3.0
|
||||||
|
# via lightning
|
||||||
|
asttokens==2.4.1
|
||||||
|
# via stack-data
|
||||||
|
async-timeout==4.0.3
|
||||||
|
# via aiohttp
|
||||||
|
# via redis
|
||||||
|
attrs==23.2.0
|
||||||
|
# via aiohttp
|
||||||
|
audioread==3.0.1
|
||||||
|
# via librosa
|
||||||
|
backcall==0.2.0
|
||||||
|
# via ipython
|
||||||
|
backoff==2.2.1
|
||||||
|
# via lightning
|
||||||
|
beautifulsoup4==4.12.3
|
||||||
|
# via lightning
|
||||||
|
bitsandbytes==0.41.0
|
||||||
|
# via lightning
|
||||||
|
blessed==1.20.0
|
||||||
|
# via inquirer
|
||||||
|
boto3==1.34.69
|
||||||
|
# via lightning-cloud
|
||||||
|
botocore==1.34.69
|
||||||
|
# via aiobotocore
|
||||||
|
# via boto3
|
||||||
|
# via s3transfer
|
||||||
|
certifi==2024.2.2
|
||||||
|
# via netcdf4
|
||||||
|
# via requests
|
||||||
|
cf-xarray==0.9.0
|
||||||
|
# via batdetect2
|
||||||
|
cffi==1.16.0
|
||||||
|
# via soundfile
|
||||||
|
cftime==1.6.3
|
||||||
|
# via netcdf4
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
# via requests
|
||||||
|
click==8.1.7
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
# via uvicorn
|
||||||
|
comm==0.2.2
|
||||||
|
# via ipykernel
|
||||||
|
contourpy==1.1.1
|
||||||
|
# via matplotlib
|
||||||
|
croniter==1.4.1
|
||||||
|
# via lightning
|
||||||
|
cycler==0.12.1
|
||||||
|
# via matplotlib
|
||||||
|
cython==3.0.10
|
||||||
|
# via soundevent
|
||||||
|
dateutils==0.6.12
|
||||||
|
# via lightning
|
||||||
|
debugpy==1.8.1
|
||||||
|
# via ipykernel
|
||||||
|
decorator==5.1.1
|
||||||
|
# via ipython
|
||||||
|
# via librosa
|
||||||
|
deepdiff==6.7.1
|
||||||
|
# via lightning
|
||||||
|
dnspython==2.6.1
|
||||||
|
# via email-validator
|
||||||
|
docker==6.1.3
|
||||||
|
# via lightning
|
||||||
|
docstring-parser==0.16
|
||||||
|
# via jsonargparse
|
||||||
|
editor==1.6.6
|
||||||
|
# via inquirer
|
||||||
|
email-validator==2.1.1
|
||||||
|
# via soundevent
|
||||||
|
exceptiongroup==1.2.1
|
||||||
|
# via anyio
|
||||||
|
# via pytest
|
||||||
|
executing==2.0.1
|
||||||
|
# via stack-data
|
||||||
|
fastapi==0.110.2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
filelock==3.13.4
|
||||||
|
# via torch
|
||||||
|
# via triton
|
||||||
|
fonttools==4.51.0
|
||||||
|
# via matplotlib
|
||||||
|
frozenlist==1.4.1
|
||||||
|
# via aiohttp
|
||||||
|
# via aiosignal
|
||||||
|
fsspec==2023.12.2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via s3fs
|
||||||
|
# via torch
|
||||||
|
grpcio==1.62.2
|
||||||
|
# via tensorboard
|
||||||
|
h11==0.14.0
|
||||||
|
# via uvicorn
|
||||||
|
hydra-core==1.3.2
|
||||||
|
# via lightning
|
||||||
|
idna==3.7
|
||||||
|
# via anyio
|
||||||
|
# via email-validator
|
||||||
|
# via requests
|
||||||
|
# via yarl
|
||||||
|
importlib-metadata==7.1.0
|
||||||
|
# via jupyter-client
|
||||||
|
# via markdown
|
||||||
|
importlib-resources==6.4.0
|
||||||
|
# via matplotlib
|
||||||
|
# via typeshed-client
|
||||||
|
iniconfig==2.0.0
|
||||||
|
# via pytest
|
||||||
|
inquirer==3.2.4
|
||||||
|
# via lightning
|
||||||
|
ipykernel==6.29.4
|
||||||
|
ipython==8.12.3
|
||||||
|
# via ipykernel
|
||||||
|
jedi==0.19.1
|
||||||
|
# via ipython
|
||||||
|
jinja2==3.1.3
|
||||||
|
# via lightning
|
||||||
|
# via torch
|
||||||
|
jmespath==1.0.1
|
||||||
|
# via boto3
|
||||||
|
# via botocore
|
||||||
|
joblib==1.4.0
|
||||||
|
# via librosa
|
||||||
|
# via scikit-learn
|
||||||
|
jsonargparse==4.28.0
|
||||||
|
# via lightning
|
||||||
|
jupyter-client==8.6.1
|
||||||
|
# via ipykernel
|
||||||
|
jupyter-core==5.7.2
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
kiwisolver==1.4.5
|
||||||
|
# via matplotlib
|
||||||
|
lazy-loader==0.4
|
||||||
|
# via librosa
|
||||||
|
librosa==0.10.1
|
||||||
|
# via batdetect2
|
||||||
|
lightning==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
lightning-api-access==0.0.5
|
||||||
|
# via lightning
|
||||||
|
lightning-cloud==0.5.65
|
||||||
|
# via lightning
|
||||||
|
lightning-fabric==2.2.2
|
||||||
|
# via lightning
|
||||||
|
lightning-utilities==0.11.2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torchmetrics
|
||||||
|
llvmlite==0.41.1
|
||||||
|
# via numba
|
||||||
|
markdown==3.6
|
||||||
|
# via tensorboard
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
# via rich
|
||||||
|
markupsafe==2.1.5
|
||||||
|
# via jinja2
|
||||||
|
# via werkzeug
|
||||||
|
matplotlib==3.7.5
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via soundevent
|
||||||
|
matplotlib-inline==0.1.7
|
||||||
|
# via ipykernel
|
||||||
|
# via ipython
|
||||||
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
msgpack==1.0.8
|
||||||
|
# via librosa
|
||||||
|
multidict==6.0.5
|
||||||
|
# via aiohttp
|
||||||
|
# via yarl
|
||||||
|
nest-asyncio==1.6.0
|
||||||
|
# via ipykernel
|
||||||
|
netcdf4==1.6.5
|
||||||
|
# via batdetect2
|
||||||
|
networkx==3.1
|
||||||
|
# via torch
|
||||||
|
numba==0.58.1
|
||||||
|
# via librosa
|
||||||
|
numpy==1.24.4
|
||||||
|
# via batdetect2
|
||||||
|
# via cftime
|
||||||
|
# via contourpy
|
||||||
|
# via librosa
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via matplotlib
|
||||||
|
# via netcdf4
|
||||||
|
# via numba
|
||||||
|
# via onnx
|
||||||
|
# via pandas
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via scikit-learn
|
||||||
|
# via scipy
|
||||||
|
# via shapely
|
||||||
|
# via soxr
|
||||||
|
# via tensorboard
|
||||||
|
# via tensorboardx
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
# via xarray
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
# via nvidia-cudnn-cu12
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
# via torch
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
# via torch
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
# via torch
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
# via torch
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-nccl-cu12==2.19.3
|
||||||
|
# via torch
|
||||||
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via nvidia-cusparse-cu12
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
omegaconf==2.3.0
|
||||||
|
# via hydra-core
|
||||||
|
# via lightning
|
||||||
|
onnx==1.16.0
|
||||||
|
# via batdetect2
|
||||||
|
ordered-set==4.1.0
|
||||||
|
# via deepdiff
|
||||||
|
packaging==24.0
|
||||||
|
# via docker
|
||||||
|
# via hydra-core
|
||||||
|
# via ipykernel
|
||||||
|
# via lazy-loader
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via lightning-utilities
|
||||||
|
# via matplotlib
|
||||||
|
# via pooch
|
||||||
|
# via pytest
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via tensorboardx
|
||||||
|
# via torchmetrics
|
||||||
|
# via xarray
|
||||||
|
pandas==2.0.3
|
||||||
|
# via batdetect2
|
||||||
|
# via xarray
|
||||||
|
parso==0.8.4
|
||||||
|
# via jedi
|
||||||
|
pexpect==4.9.0
|
||||||
|
# via ipython
|
||||||
|
pickleshare==0.7.5
|
||||||
|
# via ipython
|
||||||
|
pillow==10.3.0
|
||||||
|
# via matplotlib
|
||||||
|
# via torchvision
|
||||||
|
platformdirs==4.2.0
|
||||||
|
# via jupyter-core
|
||||||
|
# via pooch
|
||||||
|
pluggy==1.4.0
|
||||||
|
# via pytest
|
||||||
|
pooch==1.8.1
|
||||||
|
# via librosa
|
||||||
|
prompt-toolkit==3.0.43
|
||||||
|
# via ipython
|
||||||
|
protobuf==5.26.1
|
||||||
|
# via onnx
|
||||||
|
# via tensorboard
|
||||||
|
# via tensorboardx
|
||||||
|
psutil==5.9.8
|
||||||
|
# via ipykernel
|
||||||
|
# via lightning
|
||||||
|
ptyprocess==0.7.0
|
||||||
|
# via pexpect
|
||||||
|
pure-eval==0.2.2
|
||||||
|
# via stack-data
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pydantic==2.7.0
|
||||||
|
# via fastapi
|
||||||
|
# via lightning
|
||||||
|
# via soundevent
|
||||||
|
pydantic-core==2.18.1
|
||||||
|
# via pydantic
|
||||||
|
pygments==2.17.2
|
||||||
|
# via ipython
|
||||||
|
# via rich
|
||||||
|
pyjwt==2.8.0
|
||||||
|
# via lightning-cloud
|
||||||
|
pyparsing==3.1.2
|
||||||
|
# via matplotlib
|
||||||
|
pytest==8.1.1
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via arrow
|
||||||
|
# via botocore
|
||||||
|
# via croniter
|
||||||
|
# via dateutils
|
||||||
|
# via jupyter-client
|
||||||
|
# via matplotlib
|
||||||
|
# via pandas
|
||||||
|
python-multipart==0.0.9
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
pytorch-lightning==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
pytz==2024.1
|
||||||
|
# via dateutils
|
||||||
|
# via pandas
|
||||||
|
pyyaml==6.0.1
|
||||||
|
# via jsonargparse
|
||||||
|
# via lightning
|
||||||
|
# via omegaconf
|
||||||
|
# via pytorch-lightning
|
||||||
|
pyzmq==26.0.0
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
readchar==4.0.6
|
||||||
|
# via inquirer
|
||||||
|
redis==5.0.4
|
||||||
|
# via lightning
|
||||||
|
requests==2.31.0
|
||||||
|
# via docker
|
||||||
|
# via fsspec
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
# via pooch
|
||||||
|
rich==13.7.1
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
runs==1.2.2
|
||||||
|
# via editor
|
||||||
|
s3fs==2023.12.2
|
||||||
|
# via lightning
|
||||||
|
s3transfer==0.10.1
|
||||||
|
# via boto3
|
||||||
|
scikit-learn==1.3.2
|
||||||
|
# via batdetect2
|
||||||
|
# via librosa
|
||||||
|
scipy==1.10.1
|
||||||
|
# via batdetect2
|
||||||
|
# via librosa
|
||||||
|
# via scikit-learn
|
||||||
|
# via soundevent
|
||||||
|
setuptools==69.5.1
|
||||||
|
# via lightning-utilities
|
||||||
|
# via readchar
|
||||||
|
# via tensorboard
|
||||||
|
shapely==2.0.3
|
||||||
|
# via soundevent
|
||||||
|
six==1.16.0
|
||||||
|
# via asttokens
|
||||||
|
# via blessed
|
||||||
|
# via lightning-cloud
|
||||||
|
# via python-dateutil
|
||||||
|
# via tensorboard
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
soundevent==1.3.5
|
||||||
|
# via batdetect2
|
||||||
|
soundfile==0.12.1
|
||||||
|
# via librosa
|
||||||
|
# via soundevent
|
||||||
|
soupsieve==2.5
|
||||||
|
# via beautifulsoup4
|
||||||
|
soxr==0.3.7
|
||||||
|
# via librosa
|
||||||
|
stack-data==0.6.3
|
||||||
|
# via ipython
|
||||||
|
starlette==0.37.2
|
||||||
|
# via fastapi
|
||||||
|
# via lightning
|
||||||
|
sympy==1.12
|
||||||
|
# via torch
|
||||||
|
tensorboard==2.16.2
|
||||||
|
# via batdetect2
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
# via tensorboard
|
||||||
|
tensorboardx==2.6.2.2
|
||||||
|
# via lightning
|
||||||
|
threadpoolctl==3.4.0
|
||||||
|
# via scikit-learn
|
||||||
|
tomli==2.0.1
|
||||||
|
# via pytest
|
||||||
|
torch==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torchaudio
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
torchaudio==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
torchmetrics==1.3.2
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
torchvision==0.17.2
|
||||||
|
# via batdetect2
|
||||||
|
tornado==6.4
|
||||||
|
# via ipykernel
|
||||||
|
# via jupyter-client
|
||||||
|
tqdm==4.66.2
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
traitlets==5.14.2
|
||||||
|
# via comm
|
||||||
|
# via ipykernel
|
||||||
|
# via ipython
|
||||||
|
# via jupyter-client
|
||||||
|
# via jupyter-core
|
||||||
|
# via lightning
|
||||||
|
# via matplotlib-inline
|
||||||
|
triton==2.2.0
|
||||||
|
# via torch
|
||||||
|
types-python-dateutil==2.9.0.20240316
|
||||||
|
# via arrow
|
||||||
|
typeshed-client==2.5.1
|
||||||
|
# via jsonargparse
|
||||||
|
typing-extensions==4.11.0
|
||||||
|
# via aioitertools
|
||||||
|
# via anyio
|
||||||
|
# via fastapi
|
||||||
|
# via ipython
|
||||||
|
# via jsonargparse
|
||||||
|
# via librosa
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via lightning-utilities
|
||||||
|
# via pydantic
|
||||||
|
# via pydantic-core
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via starlette
|
||||||
|
# via torch
|
||||||
|
# via typeshed-client
|
||||||
|
# via uvicorn
|
||||||
|
tzdata==2024.1
|
||||||
|
# via pandas
|
||||||
|
urllib3==1.26.18
|
||||||
|
# via botocore
|
||||||
|
# via docker
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
# via requests
|
||||||
|
uvicorn==0.29.0
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
wcwidth==0.2.13
|
||||||
|
# via blessed
|
||||||
|
# via prompt-toolkit
|
||||||
|
websocket-client==1.7.0
|
||||||
|
# via docker
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
websockets==11.0.3
|
||||||
|
# via lightning
|
||||||
|
werkzeug==3.0.2
|
||||||
|
# via tensorboard
|
||||||
|
wrapt==1.16.0
|
||||||
|
# via aiobotocore
|
||||||
|
xarray==2023.1.0
|
||||||
|
# via cf-xarray
|
||||||
|
# via soundevent
|
||||||
|
xmod==1.8.1
|
||||||
|
# via editor
|
||||||
|
# via runs
|
||||||
|
yarl==1.9.4
|
||||||
|
# via aiohttp
|
||||||
|
zipp==3.18.1
|
||||||
|
# via importlib-metadata
|
||||||
|
# via importlib-resources
|
448
requirements.lock
Normal file
448
requirements.lock
Normal file
@ -0,0 +1,448 @@
|
|||||||
|
# generated by rye
|
||||||
|
# use `rye lock` or `rye sync` to update this lockfile
|
||||||
|
#
|
||||||
|
# last locked with the following flags:
|
||||||
|
# pre: false
|
||||||
|
# features: []
|
||||||
|
# all-features: false
|
||||||
|
# with-sources: false
|
||||||
|
|
||||||
|
-e file:.
|
||||||
|
absl-py==2.1.0
|
||||||
|
# via tensorboard
|
||||||
|
aiobotocore==2.12.3
|
||||||
|
# via s3fs
|
||||||
|
aiohttp==3.9.5
|
||||||
|
# via aiobotocore
|
||||||
|
# via fsspec
|
||||||
|
# via lightning
|
||||||
|
# via s3fs
|
||||||
|
aioitertools==0.11.0
|
||||||
|
# via aiobotocore
|
||||||
|
aiosignal==1.3.1
|
||||||
|
# via aiohttp
|
||||||
|
annotated-types==0.6.0
|
||||||
|
# via pydantic
|
||||||
|
antlr4-python3-runtime==4.9.3
|
||||||
|
# via hydra-core
|
||||||
|
# via omegaconf
|
||||||
|
anyio==4.3.0
|
||||||
|
# via starlette
|
||||||
|
arrow==1.3.0
|
||||||
|
# via lightning
|
||||||
|
async-timeout==4.0.3
|
||||||
|
# via aiohttp
|
||||||
|
# via redis
|
||||||
|
attrs==23.2.0
|
||||||
|
# via aiohttp
|
||||||
|
audioread==3.0.1
|
||||||
|
# via librosa
|
||||||
|
backoff==2.2.1
|
||||||
|
# via lightning
|
||||||
|
beautifulsoup4==4.12.3
|
||||||
|
# via lightning
|
||||||
|
bitsandbytes==0.41.0
|
||||||
|
# via lightning
|
||||||
|
blessed==1.20.0
|
||||||
|
# via inquirer
|
||||||
|
boto3==1.34.69
|
||||||
|
# via lightning-cloud
|
||||||
|
botocore==1.34.69
|
||||||
|
# via aiobotocore
|
||||||
|
# via boto3
|
||||||
|
# via s3transfer
|
||||||
|
certifi==2024.2.2
|
||||||
|
# via netcdf4
|
||||||
|
# via requests
|
||||||
|
cf-xarray==0.9.0
|
||||||
|
# via batdetect2
|
||||||
|
cffi==1.16.0
|
||||||
|
# via soundfile
|
||||||
|
cftime==1.6.3
|
||||||
|
# via netcdf4
|
||||||
|
charset-normalizer==3.3.2
|
||||||
|
# via requests
|
||||||
|
click==8.1.7
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
# via uvicorn
|
||||||
|
contourpy==1.1.1
|
||||||
|
# via matplotlib
|
||||||
|
croniter==1.4.1
|
||||||
|
# via lightning
|
||||||
|
cycler==0.12.1
|
||||||
|
# via matplotlib
|
||||||
|
cython==3.0.10
|
||||||
|
# via soundevent
|
||||||
|
dateutils==0.6.12
|
||||||
|
# via lightning
|
||||||
|
decorator==5.1.1
|
||||||
|
# via librosa
|
||||||
|
deepdiff==6.7.1
|
||||||
|
# via lightning
|
||||||
|
dnspython==2.6.1
|
||||||
|
# via email-validator
|
||||||
|
docker==6.1.3
|
||||||
|
# via lightning
|
||||||
|
docstring-parser==0.16
|
||||||
|
# via jsonargparse
|
||||||
|
editor==1.6.6
|
||||||
|
# via inquirer
|
||||||
|
email-validator==2.1.1
|
||||||
|
# via soundevent
|
||||||
|
exceptiongroup==1.2.1
|
||||||
|
# via anyio
|
||||||
|
fastapi==0.110.2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
filelock==3.13.4
|
||||||
|
# via torch
|
||||||
|
# via triton
|
||||||
|
fonttools==4.51.0
|
||||||
|
# via matplotlib
|
||||||
|
frozenlist==1.4.1
|
||||||
|
# via aiohttp
|
||||||
|
# via aiosignal
|
||||||
|
fsspec==2023.12.2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via s3fs
|
||||||
|
# via torch
|
||||||
|
grpcio==1.62.2
|
||||||
|
# via tensorboard
|
||||||
|
h11==0.14.0
|
||||||
|
# via uvicorn
|
||||||
|
hydra-core==1.3.2
|
||||||
|
# via lightning
|
||||||
|
idna==3.7
|
||||||
|
# via anyio
|
||||||
|
# via email-validator
|
||||||
|
# via requests
|
||||||
|
# via yarl
|
||||||
|
importlib-metadata==7.1.0
|
||||||
|
# via markdown
|
||||||
|
importlib-resources==6.4.0
|
||||||
|
# via matplotlib
|
||||||
|
# via typeshed-client
|
||||||
|
inquirer==3.2.4
|
||||||
|
# via lightning
|
||||||
|
jinja2==3.1.3
|
||||||
|
# via lightning
|
||||||
|
# via torch
|
||||||
|
jmespath==1.0.1
|
||||||
|
# via boto3
|
||||||
|
# via botocore
|
||||||
|
joblib==1.4.0
|
||||||
|
# via librosa
|
||||||
|
# via scikit-learn
|
||||||
|
jsonargparse==4.28.0
|
||||||
|
# via lightning
|
||||||
|
kiwisolver==1.4.5
|
||||||
|
# via matplotlib
|
||||||
|
lazy-loader==0.4
|
||||||
|
# via librosa
|
||||||
|
librosa==0.10.1
|
||||||
|
# via batdetect2
|
||||||
|
lightning==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
lightning-api-access==0.0.5
|
||||||
|
# via lightning
|
||||||
|
lightning-cloud==0.5.65
|
||||||
|
# via lightning
|
||||||
|
lightning-fabric==2.2.2
|
||||||
|
# via lightning
|
||||||
|
lightning-utilities==0.11.2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torchmetrics
|
||||||
|
llvmlite==0.41.1
|
||||||
|
# via numba
|
||||||
|
markdown==3.6
|
||||||
|
# via tensorboard
|
||||||
|
markdown-it-py==3.0.0
|
||||||
|
# via rich
|
||||||
|
markupsafe==2.1.5
|
||||||
|
# via jinja2
|
||||||
|
# via werkzeug
|
||||||
|
matplotlib==3.7.5
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via soundevent
|
||||||
|
mdurl==0.1.2
|
||||||
|
# via markdown-it-py
|
||||||
|
mpmath==1.3.0
|
||||||
|
# via sympy
|
||||||
|
msgpack==1.0.8
|
||||||
|
# via librosa
|
||||||
|
multidict==6.0.5
|
||||||
|
# via aiohttp
|
||||||
|
# via yarl
|
||||||
|
netcdf4==1.6.5
|
||||||
|
# via batdetect2
|
||||||
|
networkx==3.1
|
||||||
|
# via torch
|
||||||
|
numba==0.58.1
|
||||||
|
# via librosa
|
||||||
|
numpy==1.24.4
|
||||||
|
# via batdetect2
|
||||||
|
# via cftime
|
||||||
|
# via contourpy
|
||||||
|
# via librosa
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via matplotlib
|
||||||
|
# via netcdf4
|
||||||
|
# via numba
|
||||||
|
# via onnx
|
||||||
|
# via pandas
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via scikit-learn
|
||||||
|
# via scipy
|
||||||
|
# via shapely
|
||||||
|
# via soxr
|
||||||
|
# via tensorboard
|
||||||
|
# via tensorboardx
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
# via xarray
|
||||||
|
nvidia-cublas-cu12==12.1.3.1
|
||||||
|
# via nvidia-cudnn-cu12
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-cupti-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cuda-runtime-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
nvidia-cudnn-cu12==8.9.2.26
|
||||||
|
# via torch
|
||||||
|
nvidia-cufft-cu12==11.0.2.54
|
||||||
|
# via torch
|
||||||
|
nvidia-curand-cu12==10.3.2.106
|
||||||
|
# via torch
|
||||||
|
nvidia-cusolver-cu12==11.4.5.107
|
||||||
|
# via torch
|
||||||
|
nvidia-cusparse-cu12==12.1.0.106
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via torch
|
||||||
|
nvidia-nccl-cu12==2.19.3
|
||||||
|
# via torch
|
||||||
|
nvidia-nvjitlink-cu12==12.4.127
|
||||||
|
# via nvidia-cusolver-cu12
|
||||||
|
# via nvidia-cusparse-cu12
|
||||||
|
nvidia-nvtx-cu12==12.1.105
|
||||||
|
# via torch
|
||||||
|
omegaconf==2.3.0
|
||||||
|
# via hydra-core
|
||||||
|
# via lightning
|
||||||
|
onnx==1.16.0
|
||||||
|
# via batdetect2
|
||||||
|
ordered-set==4.1.0
|
||||||
|
# via deepdiff
|
||||||
|
packaging==24.0
|
||||||
|
# via docker
|
||||||
|
# via hydra-core
|
||||||
|
# via lazy-loader
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via lightning-utilities
|
||||||
|
# via matplotlib
|
||||||
|
# via pooch
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via tensorboardx
|
||||||
|
# via torchmetrics
|
||||||
|
# via xarray
|
||||||
|
pandas==2.0.3
|
||||||
|
# via batdetect2
|
||||||
|
# via xarray
|
||||||
|
pillow==10.3.0
|
||||||
|
# via matplotlib
|
||||||
|
# via torchvision
|
||||||
|
platformdirs==4.2.0
|
||||||
|
# via pooch
|
||||||
|
pooch==1.8.1
|
||||||
|
# via librosa
|
||||||
|
protobuf==5.26.1
|
||||||
|
# via onnx
|
||||||
|
# via tensorboard
|
||||||
|
# via tensorboardx
|
||||||
|
psutil==5.9.8
|
||||||
|
# via lightning
|
||||||
|
pycparser==2.22
|
||||||
|
# via cffi
|
||||||
|
pydantic==2.7.0
|
||||||
|
# via fastapi
|
||||||
|
# via lightning
|
||||||
|
# via soundevent
|
||||||
|
pydantic-core==2.18.1
|
||||||
|
# via pydantic
|
||||||
|
pygments==2.17.2
|
||||||
|
# via rich
|
||||||
|
pyjwt==2.8.0
|
||||||
|
# via lightning-cloud
|
||||||
|
pyparsing==3.1.2
|
||||||
|
# via matplotlib
|
||||||
|
python-dateutil==2.9.0.post0
|
||||||
|
# via arrow
|
||||||
|
# via botocore
|
||||||
|
# via croniter
|
||||||
|
# via dateutils
|
||||||
|
# via matplotlib
|
||||||
|
# via pandas
|
||||||
|
python-multipart==0.0.9
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
pytorch-lightning==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
pytz==2024.1
|
||||||
|
# via dateutils
|
||||||
|
# via pandas
|
||||||
|
pyyaml==6.0.1
|
||||||
|
# via jsonargparse
|
||||||
|
# via lightning
|
||||||
|
# via omegaconf
|
||||||
|
# via pytorch-lightning
|
||||||
|
readchar==4.0.6
|
||||||
|
# via inquirer
|
||||||
|
redis==5.0.4
|
||||||
|
# via lightning
|
||||||
|
requests==2.31.0
|
||||||
|
# via docker
|
||||||
|
# via fsspec
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
# via pooch
|
||||||
|
rich==13.7.1
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
runs==1.2.2
|
||||||
|
# via editor
|
||||||
|
s3fs==2023.12.2
|
||||||
|
# via lightning
|
||||||
|
s3transfer==0.10.1
|
||||||
|
# via boto3
|
||||||
|
scikit-learn==1.3.2
|
||||||
|
# via batdetect2
|
||||||
|
# via librosa
|
||||||
|
scipy==1.10.1
|
||||||
|
# via batdetect2
|
||||||
|
# via librosa
|
||||||
|
# via scikit-learn
|
||||||
|
# via soundevent
|
||||||
|
setuptools==69.5.1
|
||||||
|
# via lightning-utilities
|
||||||
|
# via readchar
|
||||||
|
# via tensorboard
|
||||||
|
shapely==2.0.3
|
||||||
|
# via soundevent
|
||||||
|
six==1.16.0
|
||||||
|
# via blessed
|
||||||
|
# via lightning-cloud
|
||||||
|
# via python-dateutil
|
||||||
|
# via tensorboard
|
||||||
|
sniffio==1.3.1
|
||||||
|
# via anyio
|
||||||
|
soundevent==1.3.5
|
||||||
|
# via batdetect2
|
||||||
|
soundfile==0.12.1
|
||||||
|
# via librosa
|
||||||
|
# via soundevent
|
||||||
|
soupsieve==2.5
|
||||||
|
# via beautifulsoup4
|
||||||
|
soxr==0.3.7
|
||||||
|
# via librosa
|
||||||
|
starlette==0.37.2
|
||||||
|
# via fastapi
|
||||||
|
# via lightning
|
||||||
|
sympy==1.12
|
||||||
|
# via torch
|
||||||
|
tensorboard==2.16.2
|
||||||
|
# via batdetect2
|
||||||
|
tensorboard-data-server==0.7.2
|
||||||
|
# via tensorboard
|
||||||
|
tensorboardx==2.6.2.2
|
||||||
|
# via lightning
|
||||||
|
threadpoolctl==3.4.0
|
||||||
|
# via scikit-learn
|
||||||
|
torch==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via torchaudio
|
||||||
|
# via torchmetrics
|
||||||
|
# via torchvision
|
||||||
|
torchaudio==2.2.2
|
||||||
|
# via batdetect2
|
||||||
|
torchmetrics==1.3.2
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
torchvision==0.17.2
|
||||||
|
# via batdetect2
|
||||||
|
tqdm==4.66.2
|
||||||
|
# via batdetect2
|
||||||
|
# via lightning
|
||||||
|
# via pytorch-lightning
|
||||||
|
traitlets==5.14.3
|
||||||
|
# via lightning
|
||||||
|
triton==2.2.0
|
||||||
|
# via torch
|
||||||
|
types-python-dateutil==2.9.0.20240316
|
||||||
|
# via arrow
|
||||||
|
typeshed-client==2.5.1
|
||||||
|
# via jsonargparse
|
||||||
|
typing-extensions==4.11.0
|
||||||
|
# via aioitertools
|
||||||
|
# via anyio
|
||||||
|
# via fastapi
|
||||||
|
# via jsonargparse
|
||||||
|
# via librosa
|
||||||
|
# via lightning
|
||||||
|
# via lightning-fabric
|
||||||
|
# via lightning-utilities
|
||||||
|
# via pydantic
|
||||||
|
# via pydantic-core
|
||||||
|
# via pytorch-lightning
|
||||||
|
# via starlette
|
||||||
|
# via torch
|
||||||
|
# via typeshed-client
|
||||||
|
# via uvicorn
|
||||||
|
tzdata==2024.1
|
||||||
|
# via pandas
|
||||||
|
urllib3==1.26.18
|
||||||
|
# via botocore
|
||||||
|
# via docker
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
# via requests
|
||||||
|
uvicorn==0.29.0
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
wcwidth==0.2.13
|
||||||
|
# via blessed
|
||||||
|
websocket-client==1.7.0
|
||||||
|
# via docker
|
||||||
|
# via lightning
|
||||||
|
# via lightning-cloud
|
||||||
|
websockets==11.0.3
|
||||||
|
# via lightning
|
||||||
|
werkzeug==3.0.2
|
||||||
|
# via tensorboard
|
||||||
|
wrapt==1.16.0
|
||||||
|
# via aiobotocore
|
||||||
|
xarray==2023.1.0
|
||||||
|
# via cf-xarray
|
||||||
|
# via soundevent
|
||||||
|
xmod==1.8.1
|
||||||
|
# via editor
|
||||||
|
# via runs
|
||||||
|
yarl==1.9.4
|
||||||
|
# via aiohttp
|
||||||
|
zipp==3.18.1
|
||||||
|
# via importlib-metadata
|
||||||
|
# via importlib-resources
|
@ -1,10 +0,0 @@
|
|||||||
librosa==0.9.2
|
|
||||||
matplotlib==3.6.2
|
|
||||||
numpy==1.23.4
|
|
||||||
pandas==1.5.2
|
|
||||||
scikit_learn==1.2.0
|
|
||||||
scipy==1.9.3
|
|
||||||
torch==1.13.0
|
|
||||||
torchaudio==0.13.0
|
|
||||||
torchvision==0.14.0
|
|
||||||
click
|
|
0
tests/test_data/__init__.py
Normal file
0
tests/test_data/__init__.py
Normal file
19
tests/test_data/test_batdetect.py
Normal file
19
tests/test_data/test_batdetect.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
"""Test suite for loading batdetect annotations."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.compat import load_annotation_project
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_example_annotation_project():
|
||||||
|
path = ROOT_DIR / "example_data" / "anns"
|
||||||
|
audio_dir = ROOT_DIR / "example_data" / "audio"
|
||||||
|
project = load_annotation_project(path, audio_dir=audio_dir)
|
||||||
|
assert isinstance(project, data.AnnotationProject)
|
||||||
|
assert project.name == str(path)
|
||||||
|
assert len(project.clip_annotations) == 3
|
||||||
|
assert len(project.tasks) == 3
|
@ -133,17 +133,16 @@ def test_compute_max_power_bb(max_power: int):
|
|||||||
audio = np.zeros((int(duration * samplerate),))
|
audio = np.zeros((int(duration * samplerate),))
|
||||||
|
|
||||||
# Add a signal during the time and frequency range of interest
|
# Add a signal during the time and frequency range of interest
|
||||||
audio[
|
audio[int(start_time * samplerate) : int(end_time * samplerate)] = (
|
||||||
int(start_time * samplerate) : int(end_time * samplerate)
|
0.5
|
||||||
] = 0.5 * librosa.tone(
|
* librosa.tone(
|
||||||
max_power, sr=samplerate, duration=end_time - start_time
|
max_power, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add a more powerful signal outside frequency range of interest
|
# Add a more powerful signal outside frequency range of interest
|
||||||
audio[
|
audio[int(start_time * samplerate) : int(end_time * samplerate)] += (
|
||||||
int(start_time * samplerate) : int(end_time * samplerate)
|
2 * librosa.tone(80_000, sr=samplerate, duration=end_time - start_time)
|
||||||
] += 2 * librosa.tone(
|
|
||||||
80_000, sr=samplerate, duration=end_time - start_time
|
|
||||||
)
|
)
|
||||||
|
|
||||||
params = api.get_config(
|
params = api.get_config(
|
||||||
@ -152,7 +151,7 @@ def test_compute_max_power_bb(max_power: int):
|
|||||||
target_samp_rate=samplerate,
|
target_samp_rate=samplerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = au.generate_spectrogram(
|
spec, _ = au.generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
samplerate,
|
samplerate,
|
||||||
params,
|
params,
|
||||||
@ -221,17 +220,17 @@ def test_compute_max_power():
|
|||||||
audio = np.zeros((int(duration * samplerate),))
|
audio = np.zeros((int(duration * samplerate),))
|
||||||
|
|
||||||
# Add a signal during the time and frequency range of interest
|
# Add a signal during the time and frequency range of interest
|
||||||
audio[
|
audio[int(start_time * samplerate) : int(end_time * samplerate)] = (
|
||||||
int(start_time * samplerate) : int(end_time * samplerate)
|
0.5
|
||||||
] = 0.5 * librosa.tone(
|
* librosa.tone(3_500, sr=samplerate, duration=end_time - start_time)
|
||||||
3_500, sr=samplerate, duration=end_time - start_time
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add a more powerful signal outside frequency range of interest
|
# Add a more powerful signal outside frequency range of interest
|
||||||
audio[
|
audio[int(start_time * samplerate) : int(end_time * samplerate)] += (
|
||||||
int(start_time * samplerate) : int(end_time * samplerate)
|
2
|
||||||
] += 2 * librosa.tone(
|
* librosa.tone(
|
||||||
max_power, sr=samplerate, duration=end_time - start_time
|
max_power, sr=samplerate, duration=end_time - start_time
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
params = api.get_config(
|
params = api.get_config(
|
||||||
@ -240,7 +239,7 @@ def test_compute_max_power():
|
|||||||
target_samp_rate=samplerate,
|
target_samp_rate=samplerate,
|
||||||
)
|
)
|
||||||
|
|
||||||
spec = au.generate_spectrogram(
|
spec, _ = au.generate_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
samplerate,
|
samplerate,
|
||||||
params,
|
params,
|
||||||
|
0
tests/test_migration/__init__.py
Normal file
0
tests/test_migration/__init__.py
Normal file
120
tests/test_migration/test_preprocessing.py
Normal file
120
tests/test_migration/test_preprocessing.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data import preprocessing
|
||||||
|
from batdetect2.utils import audio_utils
|
||||||
|
|
||||||
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
|
EXAMPLE_AUDIO = ROOT_DIR / "example_data" / "audio"
|
||||||
|
TEST_AUDIO = ROOT_DIR / "tests" / "data"
|
||||||
|
|
||||||
|
|
||||||
|
TEST_FILES = [
|
||||||
|
EXAMPLE_AUDIO / "20170701_213954-MYOMYS-LR_0_0.5.wav",
|
||||||
|
EXAMPLE_AUDIO / "20180530_213516-EPTSER-LR_0_0.5.wav",
|
||||||
|
EXAMPLE_AUDIO / "20180627_215323-RHIFER-LR_0_0.5.wav",
|
||||||
|
TEST_AUDIO / "20230322_172000_selec2.wav",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("audio_file", TEST_FILES)
|
||||||
|
@pytest.mark.parametrize("scale", [True, False])
|
||||||
|
def test_audio_loading_hasnt_changed(
|
||||||
|
audio_file,
|
||||||
|
scale,
|
||||||
|
):
|
||||||
|
time_expansion = 1
|
||||||
|
target_sampling_rate = 256_000
|
||||||
|
recording = data.Recording.from_file(
|
||||||
|
audio_file,
|
||||||
|
time_expansion=time_expansion,
|
||||||
|
)
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
_, audio_original = audio_utils.load_audio(
|
||||||
|
audio_file,
|
||||||
|
time_expansion,
|
||||||
|
target_samp_rate=target_sampling_rate,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
audio_new = preprocessing.load_clip_audio(
|
||||||
|
clip,
|
||||||
|
target_sampling_rate=target_sampling_rate,
|
||||||
|
scale=scale,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert audio_original.shape == audio_new.shape
|
||||||
|
assert audio_original.dtype == audio_new.dtype
|
||||||
|
assert np.isclose(audio_original, audio_new.data).all()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("audio_file", TEST_FILES)
|
||||||
|
@pytest.mark.parametrize("spec_scale", ["log", "pcen", "amplitude"])
|
||||||
|
@pytest.mark.parametrize("denoise_spec_avg", [True, False])
|
||||||
|
@pytest.mark.parametrize("max_scale_spec", [True, False])
|
||||||
|
@pytest.mark.parametrize("fft_win_length", [512 / 256_000, 1024 / 256_000])
|
||||||
|
def test_spectrogram_generation_hasnt_changed(
|
||||||
|
audio_file,
|
||||||
|
spec_scale,
|
||||||
|
denoise_spec_avg,
|
||||||
|
max_scale_spec,
|
||||||
|
fft_win_length,
|
||||||
|
):
|
||||||
|
time_expansion = 1
|
||||||
|
target_sampling_rate = 256_000
|
||||||
|
min_freq = 10_000
|
||||||
|
max_freq = 120_000
|
||||||
|
fft_overlap = 0.75
|
||||||
|
recording = data.Recording.from_file(
|
||||||
|
audio_file,
|
||||||
|
time_expansion=time_expansion,
|
||||||
|
)
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
audio = preprocessing.load_clip_audio(
|
||||||
|
clip,
|
||||||
|
target_sampling_rate=target_sampling_rate,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec_original, _ = audio_utils.generate_spectrogram(
|
||||||
|
audio.data,
|
||||||
|
sampling_rate=target_sampling_rate,
|
||||||
|
params=dict(
|
||||||
|
fft_win_length=fft_win_length,
|
||||||
|
fft_overlap=fft_overlap,
|
||||||
|
max_freq=max_freq,
|
||||||
|
min_freq=min_freq,
|
||||||
|
spec_scale=spec_scale,
|
||||||
|
denoise_spec_avg=denoise_spec_avg,
|
||||||
|
max_scale_spec=max_scale_spec,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
new_spec = preprocessing.compute_spectrogram(
|
||||||
|
audio,
|
||||||
|
fft_win_length=fft_win_length,
|
||||||
|
fft_overlap=fft_overlap,
|
||||||
|
max_freq=max_freq,
|
||||||
|
min_freq=min_freq,
|
||||||
|
spec_scale=spec_scale,
|
||||||
|
denoise_spec_avg=denoise_spec_avg,
|
||||||
|
max_scale_spec=max_scale_spec,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert spec_original.shape == new_spec.shape
|
||||||
|
assert spec_original.dtype == new_spec.dtype
|
||||||
|
|
||||||
|
# NOTE: The original spectrogram is flipped vertically
|
||||||
|
assert np.isclose(spec_original, np.flipud(new_spec.data)).all()
|
Loading…
Reference in New Issue
Block a user