mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
WIP updating training code
This commit is contained in:
parent
343bc5f87c
commit
c66d14b7c7
11
batdetect2/cli/__init__.py
Normal file
11
batdetect2/cli/__init__.py
Normal file
@ -0,0 +1,11 @@
|
||||
from batdetect2.cli.base import cli
|
||||
from batdetect2.cli.compat import detect
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
"detect",
|
||||
]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
27
batdetect2/cli/base.py
Normal file
27
batdetect2/cli/base.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
]
|
||||
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
@ -1,5 +1,4 @@
|
||||
"""BatDetect2 command line interface."""
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
@ -8,21 +7,7 @@ from batdetect2.detector.parameters import DEFAULT_MODEL_PATH
|
||||
from batdetect2.types import ProcessingConfiguration
|
||||
from batdetect2.utils.detector_utils import save_results_to_file
|
||||
|
||||
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
INFO_STR = """
|
||||
BatDetect2 - Detection and Classification
|
||||
Assumes audio files are mono, not stereo.
|
||||
Spaces in the input paths will throw an error. Wrap in quotes.
|
||||
Input files should be short in duration e.g. < 30 seconds.
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
def cli():
|
||||
"""BatDetect2 - Bat Call Detection and Classification."""
|
||||
click.echo(INFO_STR)
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
|
||||
@cli.command()
|
||||
@ -147,7 +132,3 @@ def print_config(config: ProcessingConfiguration):
|
||||
click.echo("\nProcessing Configuration:")
|
||||
click.echo(f"Time Expansion Factor: {config.get('time_expansion')}")
|
||||
click.echo(f"Detection Threshold: {config.get('detection_threshold')}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
0
batdetect2/data/__init__.py
Normal file
0
batdetect2/data/__init__.py
Normal file
304
batdetect2/data/augmentations.py
Normal file
304
batdetect2/data/augmentations.py
Normal file
@ -0,0 +1,304 @@
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
ClipAugmentation = Callable[[data.ClipAnnotation], data.ClipAnnotation]
|
||||
AudioAugmentation = Callable[
|
||||
[xr.DataArray, data.ClipAnnotation],
|
||||
Tuple[xr.DataArray, data.ClipAnnotation],
|
||||
]
|
||||
SpecAugmentation = Callable[
|
||||
[xr.DataArray, data.ClipAnnotation],
|
||||
Tuple[xr.DataArray, data.ClipAnnotation],
|
||||
]
|
||||
|
||||
ClipProvider = Callable[
|
||||
[data.ClipAnnotation], Tuple[xr.DataArray, data.ClipAnnotation]
|
||||
]
|
||||
"""A function that provides some clip and its annotation.
|
||||
|
||||
Usually this function loads a random clip from a dataset. Takes
|
||||
as input a clip annotation that can be used to filter the clips
|
||||
to load (in case you want to avoid loading the same clip multiple times).
|
||||
"""
|
||||
|
||||
|
||||
AUGMENTATION_PROBABILITY = 0.2
|
||||
MAX_DELAY = 0.005
|
||||
STRETCH_SQUEEZE_DELTA = 0.04
|
||||
MASK_MAX_TIME_PERC: float = 0.05
|
||||
MASK_MAX_FREQ_PERC: float = 0.10
|
||||
|
||||
|
||||
def maybe_apply(
|
||||
augmentation: Callable,
|
||||
prob: float = AUGMENTATION_PROBABILITY,
|
||||
) -> Callable:
|
||||
"""Apply an augmentation with a given probability."""
|
||||
|
||||
@wraps(augmentation)
|
||||
def _augmentation(x):
|
||||
if np.random.rand() > prob:
|
||||
return x
|
||||
return augmentation(x)
|
||||
|
||||
return _augmentation
|
||||
|
||||
|
||||
def select_random_subclip(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
duration: Optional[float] = None,
|
||||
proportion: float = 0.9,
|
||||
) -> data.ClipAnnotation:
|
||||
"""Select a random subclip from a clip."""
|
||||
clip = clip_annotation.clip
|
||||
|
||||
if duration is None:
|
||||
clip_duration = clip.end_time - clip.start_time
|
||||
duration = clip_duration * proportion
|
||||
|
||||
start_time = np.random.uniform(clip.start_time, clip.end_time - duration)
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
clip=clip.model_copy(
|
||||
update=dict(
|
||||
start_time=start_time,
|
||||
end_time=start_time + duration,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def combine_audio(
|
||||
audio1: xr.DataArray,
|
||||
audio2: xr.DataArray,
|
||||
alpha: Optional[float] = None,
|
||||
min_alpha: float = 0.3,
|
||||
max_alpha: float = 0.7,
|
||||
) -> xr.DataArray:
|
||||
"""Combine two audio clips."""
|
||||
|
||||
if alpha is None:
|
||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||
|
||||
return alpha * audio1 + (1 - alpha) * audio2.data
|
||||
|
||||
|
||||
def random_mix(
|
||||
audio: xr.DataArray,
|
||||
clip: data.ClipAnnotation,
|
||||
provider: Optional[ClipProvider] = None,
|
||||
alpha: Optional[float] = None,
|
||||
min_alpha: float = 0.3,
|
||||
max_alpha: float = 0.7,
|
||||
join_annotations: bool = True,
|
||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
"""Mix two audio clips."""
|
||||
if provider is None:
|
||||
raise ValueError("No audio provider given.")
|
||||
|
||||
try:
|
||||
other_audio, other_clip = provider(clip)
|
||||
except (StopIteration, ValueError):
|
||||
raise ValueError("No more audio sources available.")
|
||||
|
||||
new_audio = combine_audio(
|
||||
audio,
|
||||
other_audio,
|
||||
alpha=alpha,
|
||||
min_alpha=min_alpha,
|
||||
max_alpha=max_alpha,
|
||||
)
|
||||
|
||||
if join_annotations:
|
||||
clip = clip.model_copy(
|
||||
update=dict(
|
||||
sound_events=clip.sound_events + other_clip.sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
return new_audio, clip
|
||||
|
||||
|
||||
def add_echo(
|
||||
audio: xr.DataArray,
|
||||
clip: data.ClipAnnotation,
|
||||
delay: Optional[float] = None,
|
||||
alpha: Optional[float] = None,
|
||||
min_alpha: float = 0.0,
|
||||
max_alpha: float = 1.0,
|
||||
max_delay: float = MAX_DELAY,
|
||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
"""Add a delay to the audio."""
|
||||
if delay is None:
|
||||
delay = np.random.uniform(0, max_delay)
|
||||
|
||||
if alpha is None:
|
||||
alpha = np.random.uniform(min_alpha, max_alpha)
|
||||
|
||||
samplerate = audio.attrs["samplerate"]
|
||||
offset = int(delay * samplerate)
|
||||
|
||||
# NOTE: We use the copy method to avoid modifying the original audio
|
||||
# data.
|
||||
new_audio = audio.copy()
|
||||
new_audio[offset:] += alpha * audio.data[:-offset]
|
||||
return new_audio, clip
|
||||
|
||||
|
||||
def scale_volume(
|
||||
spec: xr.DataArray,
|
||||
clip: data.ClipAnnotation,
|
||||
factor: Optional[float] = None,
|
||||
max_scaling: float = 2,
|
||||
min_scaling: float = 0,
|
||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
"""Scale the volume of a spectrogram."""
|
||||
if factor is None:
|
||||
factor = np.random.uniform(min_scaling, max_scaling)
|
||||
|
||||
return spec * factor, clip
|
||||
|
||||
|
||||
def scale_sound_event_annotation(
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
time_factor: float = 1,
|
||||
frequency_factor: float = 1,
|
||||
) -> data.SoundEventAnnotation:
|
||||
sound_event = sound_event_annotation.sound_event
|
||||
geometry = sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return sound_event_annotation
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
new_geometry = data.BoundingBox(
|
||||
coordinates=[
|
||||
start_time * time_factor,
|
||||
low_freq * frequency_factor,
|
||||
end_time * time_factor,
|
||||
high_freq * frequency_factor,
|
||||
]
|
||||
)
|
||||
|
||||
return sound_event_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_event=sound_event.model_copy(
|
||||
update=dict(
|
||||
geometry=new_geometry,
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def warp_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
clip: data.ClipAnnotation,
|
||||
factor: Optional[float] = None,
|
||||
delta: float = STRETCH_SQUEEZE_DELTA,
|
||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
"""Warp a spectrogram."""
|
||||
if factor is None:
|
||||
factor = np.random.uniform(1 - delta, 1 + delta)
|
||||
|
||||
start_time = clip.clip.start_time
|
||||
end_time = clip.clip.end_time
|
||||
duration = end_time - start_time
|
||||
new_time = np.linspace(
|
||||
start_time,
|
||||
start_time + duration * factor,
|
||||
spec.time.size,
|
||||
)
|
||||
|
||||
scaled_spec = spec.interp(
|
||||
time=new_time,
|
||||
method="linear",
|
||||
kwargs={"fill_value": 0},
|
||||
)
|
||||
scaled_spec.coords["time"] = spec.time
|
||||
|
||||
scaled_clip = clip.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
scale_sound_event_annotation(
|
||||
sound_event_annotation,
|
||||
time_factor=1 / factor,
|
||||
)
|
||||
for sound_event_annotation in clip.sound_events
|
||||
]
|
||||
)
|
||||
)
|
||||
return scaled_spec, scaled_clip
|
||||
|
||||
|
||||
def mask_axis(
|
||||
array: xr.DataArray,
|
||||
axis: str,
|
||||
start: float,
|
||||
end: float,
|
||||
mask_value: float = 0,
|
||||
) -> xr.DataArray:
|
||||
if axis not in array.dims:
|
||||
raise ValueError(f"Axis {axis} not found in array")
|
||||
|
||||
coord = array[axis]
|
||||
return array.where((coord < start) | (coord > end), mask_value)
|
||||
|
||||
|
||||
def mask_time(
|
||||
spec: xr.DataArray,
|
||||
clip: data.ClipAnnotation,
|
||||
max_time_mask: float = MASK_MAX_TIME_PERC,
|
||||
max_num_masks: int = 3,
|
||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
"""Mask a random section of the time axis."""
|
||||
|
||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||
for _ in range(num_masks):
|
||||
mask_size = np.random.uniform(0, max_time_mask)
|
||||
start = np.random.uniform(0, spec.time[-1] - mask_size)
|
||||
end = start + mask_size
|
||||
spec = mask_axis(spec, "time", start, end)
|
||||
|
||||
return spec, clip
|
||||
|
||||
|
||||
def mask_frequency(
|
||||
spec: xr.DataArray,
|
||||
clip: data.ClipAnnotation,
|
||||
max_freq_mask: float = MASK_MAX_FREQ_PERC,
|
||||
max_num_masks: int = 3,
|
||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
"""Mask a random section of the frequency axis."""
|
||||
|
||||
num_masks = np.random.randint(1, max_num_masks + 1)
|
||||
for _ in range(num_masks):
|
||||
mask_size = np.random.uniform(0, max_freq_mask)
|
||||
start = np.random.uniform(0, spec.frequency[-1] - mask_size)
|
||||
end = start + mask_size
|
||||
spec = mask_axis(spec, "frequency", start, end)
|
||||
|
||||
return spec, clip
|
||||
|
||||
|
||||
CLIP_AUGMENTATIONS: List[ClipAugmentation] = [
|
||||
select_random_subclip,
|
||||
]
|
||||
|
||||
AUDIO_AUGMENTATIONS: List[AudioAugmentation] = [
|
||||
add_echo,
|
||||
random_mix,
|
||||
]
|
||||
|
||||
SPEC_AUGMENTATIONS: List[SpecAugmentation] = [
|
||||
scale_volume,
|
||||
warp_spectrogram,
|
||||
mask_time,
|
||||
mask_frequency,
|
||||
]
|
332
batdetect2/data/compat.py
Normal file
332
batdetect2/data/compat.py
Normal file
@ -0,0 +1,332 @@
|
||||
"""Compatibility functions between old and new data structures."""
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2 import types
|
||||
from batdetect2.data.labels import LabelFn
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
__all__ = [
|
||||
"convert_to_annotation_group",
|
||||
"load_annotation_project",
|
||||
]
|
||||
|
||||
SPECIES_TAG_KEY = "species"
|
||||
ECHOLOCATION_EVENT = "Echolocation"
|
||||
UNKNOWN_CLASS = "__UNKNOWN__"
|
||||
|
||||
NAMESPACE = uuid.UUID("97a9776b-c0fd-4c68-accb-0b0ecd719242")
|
||||
|
||||
|
||||
EventFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
|
||||
ClassFn = Callable[[data.Recording], int]
|
||||
|
||||
IndividualFn = Callable[[data.SoundEventAnnotation], int]
|
||||
|
||||
|
||||
def get_recording_class_name(recording: data.Recording) -> str:
|
||||
"""Get the class name for a recording."""
|
||||
tag = data.find_tag(recording.tags, SPECIES_TAG_KEY)
|
||||
if tag is None:
|
||||
return UNKNOWN_CLASS
|
||||
return tag.value
|
||||
|
||||
|
||||
def get_annotation_notes(annotation: data.ClipAnnotation) -> str:
|
||||
"""Get the notes for a ClipAnnotation."""
|
||||
all_notes = [
|
||||
*annotation.notes,
|
||||
*annotation.clip.recording.notes,
|
||||
]
|
||||
messages = [note.message for note in all_notes if note.message is not None]
|
||||
return "\n".join(messages)
|
||||
|
||||
|
||||
def convert_to_annotation_group(
|
||||
annotation: data.ClipAnnotation,
|
||||
label_fn: LabelFn = lambda _: None,
|
||||
event_fn: EventFn = lambda _: ECHOLOCATION_EVENT,
|
||||
class_fn: ClassFn = lambda _: 0,
|
||||
individual_fn: IndividualFn = lambda _: 0,
|
||||
) -> types.AudioLoaderAnnotationGroup:
|
||||
"""Convert a ClipAnnotation to an AudioLoaderAnnotationGroup."""
|
||||
recording = annotation.clip.recording
|
||||
|
||||
start_times = []
|
||||
end_times = []
|
||||
low_freqs = []
|
||||
high_freqs = []
|
||||
class_ids = []
|
||||
x_inds = []
|
||||
y_inds = []
|
||||
individual_ids = []
|
||||
annotations: List[types.Annotation] = []
|
||||
class_id_file = class_fn(recording)
|
||||
|
||||
for sound_event in annotation.sound_events:
|
||||
geometry = sound_event.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
continue
|
||||
|
||||
start_time, low_freq, end_time, high_freq = compute_bounds(geometry)
|
||||
class_id = label_fn(sound_event) or -1
|
||||
event = event_fn(sound_event)
|
||||
individual_id = individual_fn(sound_event) or -1
|
||||
|
||||
start_times.append(start_time)
|
||||
end_times.append(end_time)
|
||||
low_freqs.append(low_freq)
|
||||
high_freqs.append(high_freq)
|
||||
class_ids.append(class_id)
|
||||
individual_ids.append(individual_id)
|
||||
|
||||
# NOTE: This will be computed later so we just put a placeholder
|
||||
# here for now.
|
||||
x_inds.append(0)
|
||||
y_inds.append(0)
|
||||
|
||||
annotations.append(
|
||||
{
|
||||
"start_time": start_time,
|
||||
"end_time": end_time,
|
||||
"low_freq": low_freq,
|
||||
"high_freq": high_freq,
|
||||
"class_prob": 1.0,
|
||||
"det_prob": 1.0,
|
||||
"individual": "0",
|
||||
"event": event,
|
||||
"class_id": class_id, # type: ignore
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"id": str(recording.path),
|
||||
"duration": recording.duration,
|
||||
"issues": False,
|
||||
"file_path": str(recording.path),
|
||||
"time_exp": recording.time_expansion,
|
||||
"class_name": get_recording_class_name(recording),
|
||||
"notes": get_annotation_notes(annotation),
|
||||
"annotated": True,
|
||||
"start_times": np.array(start_times),
|
||||
"end_times": np.array(end_times),
|
||||
"low_freqs": np.array(low_freqs),
|
||||
"high_freqs": np.array(high_freqs),
|
||||
"class_ids": np.array(class_ids),
|
||||
"x_inds": np.array(x_inds),
|
||||
"y_inds": np.array(y_inds),
|
||||
"individual_ids": np.array(individual_ids),
|
||||
"annotation": annotations,
|
||||
"class_id_file": class_id_file,
|
||||
}
|
||||
|
||||
|
||||
class Annotation(BaseModel):
|
||||
"""Annotation class to hold batdetect annotations."""
|
||||
|
||||
label: str = Field(alias="class")
|
||||
event: str
|
||||
individual: int = 0
|
||||
|
||||
start_time: float
|
||||
end_time: float
|
||||
low_freq: float
|
||||
high_freq: float
|
||||
|
||||
|
||||
class FileAnnotation(BaseModel):
|
||||
"""FileAnnotation class to hold batdetect annotations for a file."""
|
||||
|
||||
id: str
|
||||
duration: float
|
||||
time_exp: float = 1
|
||||
|
||||
label: str = Field(alias="class_name")
|
||||
|
||||
annotation: List[Annotation]
|
||||
|
||||
annotated: bool = False
|
||||
issues: bool = False
|
||||
notes: str = ""
|
||||
|
||||
|
||||
def load_file_annotation(path: PathLike) -> FileAnnotation:
|
||||
"""Load annotation from batdetect format."""
|
||||
path = Path(path)
|
||||
return FileAnnotation.model_validate_json(path.read_text())
|
||||
|
||||
|
||||
def annotation_to_sound_event(
|
||||
annotation: Annotation,
|
||||
recording: data.Recording,
|
||||
label_key: str = "class",
|
||||
event_key: str = "event",
|
||||
individual_key: str = "individual",
|
||||
) -> data.SoundEventAnnotation:
|
||||
"""Convert annotation to sound event annotation."""
|
||||
sound_event = data.SoundEvent(
|
||||
uuid=uuid.uuid5(
|
||||
NAMESPACE,
|
||||
f"{recording.hash}_{annotation.start_time}_{annotation.end_time}",
|
||||
),
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
annotation.start_time,
|
||||
annotation.low_freq,
|
||||
annotation.end_time,
|
||||
annotation.high_freq,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
return data.SoundEventAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key=label_key, value=annotation.label),
|
||||
data.Tag(key=event_key, value=annotation.event),
|
||||
data.Tag(key=individual_key, value=str(annotation.individual)),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip(
|
||||
file_annotation: FileAnnotation,
|
||||
audio_dir: PathLike = Path.cwd(),
|
||||
) -> data.Clip:
|
||||
"""Convert file annotation to recording."""
|
||||
full_path = Path(audio_dir) / file_annotation.id
|
||||
|
||||
if not full_path.exists():
|
||||
raise FileNotFoundError(f"File {full_path} not found.")
|
||||
|
||||
recording = data.Recording.from_file(
|
||||
full_path,
|
||||
time_expansion=file_annotation.time_exp,
|
||||
)
|
||||
|
||||
return data.Clip(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip"),
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_clip_annotation(
|
||||
file_annotation: FileAnnotation,
|
||||
clip: data.Clip,
|
||||
label_key: str = "class",
|
||||
event_key: str = "event",
|
||||
individual_key: str = "individual",
|
||||
) -> data.ClipAnnotation:
|
||||
"""Convert file annotation to clip annotation."""
|
||||
notes = []
|
||||
if file_annotation.notes:
|
||||
notes.append(data.Note(message=file_annotation.notes))
|
||||
|
||||
return data.ClipAnnotation(
|
||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||
clip=clip,
|
||||
notes=notes,
|
||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
||||
sound_events=[
|
||||
annotation_to_sound_event(
|
||||
annotation,
|
||||
clip.recording,
|
||||
label_key=label_key,
|
||||
event_key=event_key,
|
||||
individual_key=individual_key,
|
||||
)
|
||||
for annotation in file_annotation.annotation
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def file_annotation_to_annotation_task(
|
||||
file_annotation: FileAnnotation,
|
||||
clip: data.Clip,
|
||||
) -> data.AnnotationTask:
|
||||
status_badges = []
|
||||
|
||||
if file_annotation.issues:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.rejected)
|
||||
)
|
||||
elif file_annotation.annotated:
|
||||
status_badges.append(
|
||||
data.StatusBadge(state=data.AnnotationState.completed)
|
||||
)
|
||||
|
||||
return data.AnnotationTask(
|
||||
uuid=uuid.uuid5(uuid.NAMESPACE_URL, f"{file_annotation.id}_task"),
|
||||
clip=clip,
|
||||
status_badges=status_badges,
|
||||
)
|
||||
|
||||
|
||||
def list_file_annotations(path: PathLike) -> List[Path]:
|
||||
"""List all annotations in a directory."""
|
||||
path = Path(path)
|
||||
return [file for file in path.glob("*.json")]
|
||||
|
||||
|
||||
def load_annotation_project(
|
||||
path: PathLike,
|
||||
name: Optional[str] = None,
|
||||
audio_dir: PathLike = Path.cwd(),
|
||||
) -> data.AnnotationProject:
|
||||
"""Convert annotations to annotation project."""
|
||||
paths = list_file_annotations(path)
|
||||
|
||||
if name is None:
|
||||
name = str(path)
|
||||
|
||||
annotations = []
|
||||
tasks = []
|
||||
|
||||
for p in paths:
|
||||
try:
|
||||
file_annotation = load_file_annotation(p)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
try:
|
||||
clip = file_annotation_to_clip(
|
||||
file_annotation,
|
||||
audio_dir=audio_dir,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
|
||||
annotations.append(
|
||||
file_annotation_to_clip_annotation(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
tasks.append(
|
||||
file_annotation_to_annotation_task(
|
||||
file_annotation,
|
||||
clip,
|
||||
)
|
||||
)
|
||||
|
||||
return data.AnnotationProject(
|
||||
name=name,
|
||||
clip_annotations=annotations,
|
||||
tasks=tasks,
|
||||
)
|
58
batdetect2/data/datasets.py
Normal file
58
batdetect2/data/datasets.py
Normal file
@ -0,0 +1,58 @@
|
||||
from typing import Callable, Generic, Iterable, List, TypeVar
|
||||
|
||||
from soundevent import data
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
__all__ = [
|
||||
"ClipAnnotationDataset",
|
||||
"ClipDataset",
|
||||
]
|
||||
|
||||
|
||||
E = TypeVar("E")
|
||||
|
||||
|
||||
class ClipAnnotationDataset(Dataset, Generic[E]):
|
||||
|
||||
clip_annotations: List[data.ClipAnnotation]
|
||||
|
||||
transform: Callable[[data.ClipAnnotation], E]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clip_annotations: Iterable[data.ClipAnnotation],
|
||||
transform: Callable[[data.ClipAnnotation], E],
|
||||
name: str = "ClipAnnotationDataset",
|
||||
):
|
||||
self.clip_annotations = list(clip_annotations)
|
||||
self.transform = transform
|
||||
self.name = name
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clip_annotations)
|
||||
|
||||
def __getitem__(self, idx: int) -> E:
|
||||
return self.transform(self.clip_annotations[idx])
|
||||
|
||||
|
||||
class ClipDataset(Dataset, Generic[E]):
|
||||
|
||||
clips: List[data.Clip]
|
||||
|
||||
transform: Callable[[data.Clip], E]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
clips: Iterable[data.Clip],
|
||||
transform: Callable[[data.Clip], E],
|
||||
name: str = "ClipDataset",
|
||||
):
|
||||
self.clips = list(clips)
|
||||
self.transform = transform
|
||||
self.name = name
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.clips)
|
||||
|
||||
def __getitem__(self, idx: int) -> E:
|
||||
return self.transform(self.clips[idx])
|
231
batdetect2/data/labels.py
Normal file
231
batdetect2/data/labels.py
Normal file
@ -0,0 +1,231 @@
|
||||
from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from scipy.ndimage import gaussian_filter
|
||||
from soundevent import data, geometry
|
||||
|
||||
__all__ = [
|
||||
"generate_heatmaps",
|
||||
]
|
||||
|
||||
PositionFn = Callable[[data.SoundEvent], Tuple[float, float]]
|
||||
"""Convert a sound event to a single position in time-frequency space."""
|
||||
|
||||
SizeFn = Callable[[data.SoundEvent, float, float], np.ndarray]
|
||||
"""Compute the size of a sound event in time-frequency space.
|
||||
|
||||
The time and frequency scales are provided as arguments to allow
|
||||
modifying the size of the sound event based on the spectrogram
|
||||
parameters.
|
||||
"""
|
||||
|
||||
LabelFn = Callable[[data.SoundEventAnnotation], Optional[str]]
|
||||
"""Convert a sound event annotation to a label.
|
||||
|
||||
When the label is None, this indicates that the sound event does not
|
||||
belong to any of the classes of interest.
|
||||
"""
|
||||
|
||||
TARGET_SIGMA = 3.0
|
||||
|
||||
|
||||
GENERIC_LABEL = "__UNKNOWN__"
|
||||
|
||||
|
||||
def get_lower_left_position(
|
||||
sound_event: data.SoundEvent,
|
||||
) -> Tuple[float, float]:
|
||||
if sound_event.geometry is None:
|
||||
raise ValueError("Sound event has no geometry.")
|
||||
|
||||
start_time, low_freq, _, _ = geometry.compute_bounds(sound_event.geometry)
|
||||
return start_time, low_freq
|
||||
|
||||
|
||||
def get_bbox_size(
|
||||
sound_event: data.SoundEvent,
|
||||
time_scale: float = 1.0,
|
||||
frequency_scale: float = 1.0,
|
||||
) -> np.ndarray:
|
||||
if sound_event.geometry is None:
|
||||
raise ValueError("Sound event has no geometry.")
|
||||
|
||||
start_time, low_freq, end_time, high_freq = geometry.compute_bounds(
|
||||
sound_event.geometry
|
||||
)
|
||||
|
||||
return np.array(
|
||||
[
|
||||
time_scale * (end_time - start_time),
|
||||
frequency_scale * (high_freq - low_freq),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
def _tag_key(tag: data.Tag) -> Tuple[str, str]:
|
||||
return (tag.key, tag.value)
|
||||
|
||||
|
||||
def set_value_at_position(
|
||||
array: xr.DataArray,
|
||||
value: Any,
|
||||
**query,
|
||||
) -> xr.DataArray:
|
||||
dims = {dim: n for n, dim in enumerate(array.dims)}
|
||||
indexer: List[Union[slice, int]] = [slice(None) for _ in range(array.ndim)]
|
||||
|
||||
for key, coord in query.items():
|
||||
if key not in dims:
|
||||
raise ValueError(f"Dimension {key} not found in array.")
|
||||
|
||||
coordinates = array.indexes[key]
|
||||
indexer[dims[key]] = coordinates.get_loc(coordinates.asof(coord))
|
||||
|
||||
if isinstance(value, (tuple, list)):
|
||||
value = np.array(value)
|
||||
|
||||
array.data[tuple(indexer)] = value
|
||||
return array
|
||||
|
||||
|
||||
def generate_heatmaps(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
spec: xr.DataArray,
|
||||
num_classes: int = 1,
|
||||
label_fn: LabelFn = lambda _: None,
|
||||
target_sigma: float = TARGET_SIGMA,
|
||||
size_fn: SizeFn = get_bbox_size,
|
||||
position_fn: PositionFn = get_lower_left_position,
|
||||
class_labels: Optional[List[str]] = None,
|
||||
dtype=np.float32,
|
||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||
if class_labels is None:
|
||||
class_labels = [str(i) for i in range(num_classes)]
|
||||
|
||||
if len(class_labels) != num_classes:
|
||||
raise ValueError(
|
||||
"Number of class labels must match the number of classes."
|
||||
)
|
||||
|
||||
shape = dict(zip(spec.dims, spec.shape))
|
||||
|
||||
if "time" not in shape or "frequency" not in shape:
|
||||
raise ValueError(
|
||||
"Spectrogram must have time and frequency dimensions."
|
||||
)
|
||||
|
||||
time_duration = spec.time.attrs["max"] - spec.time.attrs["min"]
|
||||
freq_bandwidth = spec.frequency.attrs["max"] - spec.frequency.attrs["min"]
|
||||
|
||||
# Compute the size factors
|
||||
time_scale = 1 / time_duration
|
||||
frequency_scale = 1 / freq_bandwidth
|
||||
|
||||
# Initialize heatmaps
|
||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||
class_heatmap = xr.DataArray(
|
||||
data=np.zeros((num_classes, *spec.shape), dtype=dtype),
|
||||
dims=["category", *spec.dims],
|
||||
coords={
|
||||
"category": class_labels,
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
size_heatmap = xr.DataArray(
|
||||
data=np.zeros((2, *spec.shape), dtype=dtype),
|
||||
dims=["dimension", *spec.dims],
|
||||
coords={
|
||||
"dimension": ["width", "height"],
|
||||
**spec.coords,
|
||||
},
|
||||
)
|
||||
|
||||
for sound_event_annotation in clip_annotation.sound_events:
|
||||
# Get the position of the sound event
|
||||
time, frequency = position_fn(sound_event_annotation.sound_event)
|
||||
|
||||
# Set 1.0 at the position of the sound event in the detection heatmap
|
||||
detection_heatmap = set_value_at_position(
|
||||
detection_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
|
||||
# Set the size of the sound event at the position in the size heatmap
|
||||
size = size_fn(
|
||||
sound_event_annotation.sound_event,
|
||||
time_scale,
|
||||
frequency_scale,
|
||||
|
||||
)
|
||||
size_heatmap = set_value_at_position(
|
||||
size_heatmap,
|
||||
size,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
)
|
||||
|
||||
# Get the label id for the sound event
|
||||
label = label_fn(sound_event_annotation)
|
||||
|
||||
if label is None or label not in class_labels:
|
||||
# If the label is None or not in the class labels, we skip the
|
||||
# sound event
|
||||
continue
|
||||
|
||||
# Set 1.0 at the position and category of the sound event in the class
|
||||
# heatmap
|
||||
class_heatmap = set_value_at_position(
|
||||
class_heatmap,
|
||||
1.0,
|
||||
time=time,
|
||||
frequency=frequency,
|
||||
category=label,
|
||||
)
|
||||
|
||||
# Apply gaussian filters
|
||||
detection_heatmap = xr.apply_ufunc(
|
||||
gaussian_filter,
|
||||
detection_heatmap,
|
||||
target_sigma,
|
||||
)
|
||||
|
||||
class_heatmap = class_heatmap.groupby("category").map(
|
||||
gaussian_filter, # type: ignore
|
||||
args=(target_sigma,),
|
||||
)
|
||||
|
||||
# Normalize heatmaps
|
||||
detection_heatmap = (
|
||||
detection_heatmap / detection_heatmap.max(dim=["time", "frequency"])
|
||||
).fillna(0.0)
|
||||
|
||||
class_heatmap = (
|
||||
class_heatmap / class_heatmap.max(dim=["time", "frequency"])
|
||||
).fillna(0.0)
|
||||
|
||||
return detection_heatmap, class_heatmap, size_heatmap
|
||||
|
||||
|
||||
class Labeler:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
"""Create a labeler from a list of tags.
|
||||
|
||||
Each tag is assigned a unique label. The labeler can then be used
|
||||
to convert sound event annotations to labels.
|
||||
"""
|
||||
self.tags = tags
|
||||
self._label_map = {_tag_key(tag): i for i, tag in enumerate(tags)}
|
||||
self._inverse_label_map = {v: k for k, v in self._label_map.items()}
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> Optional[int]:
|
||||
for tag in sound_event_annotation.tags:
|
||||
key = _tag_key(tag)
|
||||
if key in self._label_map:
|
||||
return self._label_map[key]
|
||||
|
||||
return None
|
586
batdetect2/data/preprocessing.py
Normal file
586
batdetect2/data/preprocessing.py
Normal file
@ -0,0 +1,586 @@
|
||||
"""Module containing functions for preprocessing audio clips."""
|
||||
|
||||
import random
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import xarray as xr
|
||||
from numpy.typing import DTypeLike
|
||||
from scipy.signal import resample_poly
|
||||
from soundevent import audio, data
|
||||
|
||||
TARGET_SAMPLERATE_HZ = 256000
|
||||
SCALE_RAW_AUDIO = False
|
||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
||||
FFT_OVERLAP = 0.75
|
||||
MAX_FREQ_HZ = 120000
|
||||
MIN_FREQ_HZ = 10000
|
||||
DEFAULT_DURATION = 1
|
||||
SPEC_HEIGHT = 128
|
||||
SPEC_WIDTH = 256
|
||||
SPEC_SCALE = "pcen"
|
||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
||||
DENOISE_SPEC_AVG = True
|
||||
MAX_SCALE_SPEC = False
|
||||
|
||||
|
||||
def preprocess_audio_clip(
|
||||
clip: data.Clip,
|
||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale_audio: bool = SCALE_RAW_AUDIO,
|
||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
||||
fft_overlap: float = FFT_OVERLAP,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
spec_scale: str = SPEC_SCALE,
|
||||
denoise_spec_avg: bool = True,
|
||||
max_scale_spec: bool = False,
|
||||
duration: Optional[float] = DEFAULT_DURATION,
|
||||
spec_height: int = SPEC_HEIGHT,
|
||||
spec_time_period: float = SPEC_TIME_PERIOD,
|
||||
) -> xr.DataArray:
|
||||
"""Preprocesses audio clip to generate spectrogram.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
clip
|
||||
The audio clip to preprocess.
|
||||
target_sampling_rate
|
||||
Target sampling rate for the audio. If the audio has a different
|
||||
sampling rate, it will be resampled to this rate.
|
||||
scale_audio
|
||||
Whether to scale the audio amplitudes to a range of [-1, 1].
|
||||
By default, the audio is not scaled.
|
||||
fft_win_length
|
||||
Length of the FFT window in seconds.
|
||||
fft_overlap
|
||||
Amount of overlap between FFT windows as a fraction of the window
|
||||
length.
|
||||
max_freq
|
||||
Maximum frequency for spectrogram. Anything above this frequency will
|
||||
be cropped.
|
||||
min_freq
|
||||
Minimum frequency for spectrogram. Anything below this frequency will
|
||||
be cropped.
|
||||
spec_scale
|
||||
Scaling method for the spectrogram. Can be "pcen", "log" or
|
||||
"amplitude".
|
||||
denoise_spec_avg
|
||||
Whether to denoise the spectrogram. Denoising is done by subtracting
|
||||
the average of the spectrogram from the spectrogram and clipping
|
||||
negative values to 0.
|
||||
max_scale_spec
|
||||
Whether to max scale the spectrogram. Max scaling is done by dividing
|
||||
the spectrogram by its maximum value thus scaling values to [0, 1].
|
||||
duration
|
||||
Duration of the spectrogram in seconds. If the clip duration is
|
||||
different from this value, the spectrogram will be cropped or extended
|
||||
to match this duration. If None, the spectrogram will have the same
|
||||
duration as the clip.
|
||||
spec_height
|
||||
Number of frequency bins for the spectrogram. This is the height of
|
||||
the final spectrogram.
|
||||
spec_time_period
|
||||
Time period for each spectrogram bin in seconds. The spectrogram array
|
||||
will be resized (using bilinear interpolation) to have this time
|
||||
period.
|
||||
|
||||
Returns
|
||||
-------
|
||||
xr.DataArray
|
||||
Preprocessed spectrogram.
|
||||
|
||||
"""
|
||||
wav = load_clip_audio(
|
||||
clip,
|
||||
target_sampling_rate=target_sampling_rate,
|
||||
scale=scale_audio,
|
||||
)
|
||||
|
||||
wav = wav.assign_attrs(
|
||||
recording_id=str(wav.attrs["recording_id"]),
|
||||
clip_id=str(wav.attrs["clip_id"]),
|
||||
path=str(wav.attrs["path"]),
|
||||
)
|
||||
|
||||
spec = compute_spectrogram(
|
||||
wav,
|
||||
fft_win_length=fft_win_length,
|
||||
fft_overlap=fft_overlap,
|
||||
max_freq=max_freq,
|
||||
min_freq=min_freq,
|
||||
spec_scale=spec_scale,
|
||||
denoise_spec_avg=denoise_spec_avg,
|
||||
max_scale_spec=max_scale_spec,
|
||||
)
|
||||
|
||||
if duration is not None:
|
||||
spec = adjust_spec_duration(clip, spec, duration)
|
||||
|
||||
duration = get_dim_width(spec, dim="time")
|
||||
return resize_spectrogram(
|
||||
spec,
|
||||
time_bins=int(np.ceil(duration / spec_time_period)),
|
||||
freq_bins=spec_height,
|
||||
)
|
||||
|
||||
|
||||
def adjust_spec_duration(
|
||||
clip: data.Clip,
|
||||
spec: xr.DataArray,
|
||||
duration: float,
|
||||
) -> xr.DataArray:
|
||||
current_duration = clip.end_time - clip.start_time
|
||||
|
||||
if current_duration == duration:
|
||||
return spec
|
||||
|
||||
if current_duration > duration:
|
||||
return crop_axis(
|
||||
spec,
|
||||
dim="time",
|
||||
start=clip.start_time,
|
||||
end=clip.start_time + duration,
|
||||
)
|
||||
|
||||
return extend_axis(
|
||||
spec,
|
||||
dim="time",
|
||||
start=clip.start_time,
|
||||
end=clip.start_time + duration,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_audio(
|
||||
clip: data.Clip,
|
||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale: bool = SCALE_RAW_AUDIO,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
wav = audio.load_clip(clip).sel(channel=0)
|
||||
|
||||
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
||||
|
||||
if scale:
|
||||
wav = scale_audio(wav)
|
||||
|
||||
wav.coords["time"] = wav.time.assign_attrs(
|
||||
unit="s",
|
||||
long_name="Seconds since start of recording",
|
||||
min=clip.start_time,
|
||||
max=clip.end_time,
|
||||
)
|
||||
|
||||
return wav
|
||||
|
||||
|
||||
def resample_audio(
|
||||
wav: xr.DataArray,
|
||||
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if "samplerate" not in wav.attrs:
|
||||
raise ValueError("Audio must have a 'samplerate' attribute")
|
||||
|
||||
if "time" not in wav.dims:
|
||||
raise ValueError("Audio must have a time dimension")
|
||||
|
||||
time_axis: int = wav.get_axis_num("time") # type: ignore
|
||||
original_samplerate = wav.attrs["samplerate"]
|
||||
|
||||
if original_samplerate == target_samplerate:
|
||||
return wav.astype(dtype)
|
||||
|
||||
gcd = np.gcd(original_samplerate, target_samplerate)
|
||||
resampled = resample_poly(
|
||||
wav.values,
|
||||
target_samplerate // gcd,
|
||||
original_samplerate // gcd,
|
||||
axis=time_axis,
|
||||
)
|
||||
|
||||
resampled_times = np.linspace(
|
||||
wav.time[0],
|
||||
wav.time[-1],
|
||||
len(resampled),
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
return xr.DataArray(
|
||||
data=resampled.astype(dtype),
|
||||
dims=wav.dims,
|
||||
coords={
|
||||
**wav.coords,
|
||||
"time": resampled_times,
|
||||
},
|
||||
attrs={
|
||||
**wav.attrs,
|
||||
"samplerate": target_samplerate,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def scale_audio(
|
||||
audio: xr.DataArray,
|
||||
eps: float = 10e-6,
|
||||
) -> xr.DataArray:
|
||||
audio = audio - audio.mean()
|
||||
return audio / np.add(np.abs(audio).max(), eps, dtype=audio.dtype)
|
||||
|
||||
|
||||
def compute_spectrogram(
|
||||
wav: xr.DataArray,
|
||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
||||
fft_overlap: float = FFT_OVERLAP,
|
||||
max_freq: int = MAX_FREQ_HZ,
|
||||
min_freq: int = MIN_FREQ_HZ,
|
||||
spec_scale: str = SPEC_SCALE,
|
||||
denoise_spec_avg: bool = True,
|
||||
max_scale_spec: bool = False,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
spec = gen_mag_spectrogram(
|
||||
wav,
|
||||
window_len=fft_win_length,
|
||||
overlap_perc=fft_overlap,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
spec = crop_axis(
|
||||
spec,
|
||||
dim="frequency",
|
||||
start=min_freq,
|
||||
end=max_freq,
|
||||
)
|
||||
|
||||
spec = scale_spectrogram(spec, scale=spec_scale)
|
||||
|
||||
if denoise_spec_avg:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if max_scale_spec:
|
||||
spec = max_scale_spectrogram(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def crop_axis(
|
||||
arr: xr.DataArray,
|
||||
dim: str,
|
||||
start: float,
|
||||
end: float,
|
||||
right_closed: bool = False,
|
||||
left_closed: bool = True,
|
||||
eps: float = 10e-6,
|
||||
) -> xr.DataArray:
|
||||
coord = arr.coords[dim]
|
||||
|
||||
if not all(attr in coord.attrs for attr in ["min", "max"]):
|
||||
raise ValueError(
|
||||
f"Coordinate '{dim}' must have 'min' and 'max' attributes"
|
||||
)
|
||||
|
||||
current_min = coord.attrs["min"]
|
||||
current_max = coord.attrs["max"]
|
||||
|
||||
if start < current_min or end > current_max:
|
||||
raise ValueError(
|
||||
f"Cannot select axis '{dim}' from {start} to {end}. "
|
||||
f"Axis range is {current_min} to {current_max}"
|
||||
)
|
||||
|
||||
slice_end = end
|
||||
if not right_closed:
|
||||
slice_end = end - eps
|
||||
|
||||
slice_start = start
|
||||
if not left_closed:
|
||||
slice_start = start + eps
|
||||
|
||||
arr = arr.sel({dim: slice(slice_start, slice_end)})
|
||||
|
||||
arr.coords[dim].attrs.update(
|
||||
min=start,
|
||||
max=end,
|
||||
)
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
def extend_axis(
|
||||
arr: xr.DataArray,
|
||||
dim: str,
|
||||
start: float,
|
||||
end: float,
|
||||
fill_value: float = 0,
|
||||
) -> xr.DataArray:
|
||||
coord = arr.coords[dim]
|
||||
|
||||
if not all(attr in coord.attrs for attr in ["min", "max", "period"]):
|
||||
raise ValueError(
|
||||
f"Coordinate '{dim}' must have 'min', 'max' and 'period' attributes"
|
||||
" to extend axis"
|
||||
)
|
||||
|
||||
current_min = coord.attrs["min"]
|
||||
current_max = coord.attrs["max"]
|
||||
period = coord.attrs["period"]
|
||||
|
||||
coords = coord.data
|
||||
|
||||
if start < current_min:
|
||||
new_coords = np.arange(
|
||||
current_min,
|
||||
start,
|
||||
-period,
|
||||
dtype=coord.dtype,
|
||||
)[1:][::-1]
|
||||
coords = np.concatenate([new_coords, coords])
|
||||
|
||||
if end > current_max:
|
||||
new_coords = np.arange(
|
||||
current_max,
|
||||
end,
|
||||
period,
|
||||
dtype=coord.dtype,
|
||||
)[1:]
|
||||
coords = np.concatenate([coords, new_coords])
|
||||
|
||||
arr = arr.reindex(
|
||||
{dim: coords},
|
||||
fill_value=fill_value, # type: ignore
|
||||
)
|
||||
|
||||
arr.coords[dim].attrs.update(
|
||||
min=start,
|
||||
max=end,
|
||||
)
|
||||
|
||||
return arr
|
||||
|
||||
|
||||
def gen_mag_spectrogram(
|
||||
audio: xr.DataArray,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
sampling_rate = audio.attrs["samplerate"]
|
||||
hop_len = window_len * (1 - overlap_perc)
|
||||
nfft = int(window_len * sampling_rate)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
start_time = audio.time.attrs["min"]
|
||||
end_time = audio.time.attrs["max"]
|
||||
|
||||
# compute spec
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=audio.data,
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
center=False,
|
||||
)
|
||||
|
||||
spec = xr.DataArray(
|
||||
data=spec.astype(dtype),
|
||||
dims=["frequency", "time"],
|
||||
coords={
|
||||
"frequency": np.linspace(
|
||||
0,
|
||||
sampling_rate / 2,
|
||||
spec.shape[0],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
"time": np.linspace(
|
||||
start_time,
|
||||
end_time - (window_len - hop_len),
|
||||
spec.shape[1],
|
||||
endpoint=False,
|
||||
dtype=dtype,
|
||||
),
|
||||
},
|
||||
attrs={
|
||||
**audio.attrs,
|
||||
"nfft": nfft,
|
||||
"noverlap": noverlap,
|
||||
},
|
||||
)
|
||||
|
||||
# Add metadata to coordinates
|
||||
spec.coords["time"].attrs.update(
|
||||
unit="s",
|
||||
long_name="Time",
|
||||
min=start_time,
|
||||
max=end_time - (window_len - hop_len),
|
||||
period=(nfft - noverlap) / sampling_rate,
|
||||
)
|
||||
spec.coords["frequency"].attrs.update(
|
||||
unit="Hz",
|
||||
long_name="Frequency",
|
||||
period=(sampling_rate / nfft),
|
||||
min=0,
|
||||
max=sampling_rate / 2,
|
||||
)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def denoise_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=(spec - spec.mean("time")).clip(0),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs={
|
||||
**spec.attrs,
|
||||
"denoised": 1,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
scale: str = SPEC_SCALE,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
if scale == "pcen":
|
||||
return pcen(spec, dtype=dtype)
|
||||
|
||||
if scale == "log":
|
||||
return log_scale(spec, dtype=dtype)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def log_scale(
|
||||
spec: xr.DataArray,
|
||||
dtype: DTypeLike = np.float32,
|
||||
) -> xr.DataArray:
|
||||
nfft = spec.attrs["nfft"]
|
||||
sampling_rate = spec.attrs["samplerate"]
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||
)
|
||||
return xr.DataArray(
|
||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs={
|
||||
**spec.attrs,
|
||||
"scale": "log",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def pcen(spec: xr.DataArray, dtype: DTypeLike = np.float32) -> xr.DataArray:
|
||||
sampling_rate = spec.attrs["samplerate"]
|
||||
data = librosa.pcen(
|
||||
spec.data * (2**31),
|
||||
sr=sampling_rate / 10,
|
||||
)
|
||||
return xr.DataArray(
|
||||
data=data.astype(dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs={
|
||||
**spec.attrs,
|
||||
"scale": "pcen",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def max_scale_spectrogram(spec: xr.DataArray, eps=10e-6) -> xr.DataArray:
|
||||
return xr.DataArray(
|
||||
data=spec / np.add(spec.max(), eps, dtype=spec.dtype),
|
||||
dims=spec.dims,
|
||||
coords=spec.coords,
|
||||
attrs={
|
||||
**spec.attrs,
|
||||
"max_scaled": 1,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def resize_spectrogram(
|
||||
spec: xr.DataArray,
|
||||
time_bins: int,
|
||||
freq_bins: int,
|
||||
) -> xr.DataArray:
|
||||
new_times = np.linspace(
|
||||
spec.time[0],
|
||||
spec.time[-1],
|
||||
time_bins,
|
||||
dtype=spec.time.dtype,
|
||||
endpoint=True,
|
||||
)
|
||||
new_frequencies = np.linspace(
|
||||
spec.frequency[0],
|
||||
spec.frequency[-1],
|
||||
freq_bins,
|
||||
dtype=spec.frequency.dtype,
|
||||
endpoint=True,
|
||||
)
|
||||
|
||||
return spec.interp(
|
||||
coords=dict(
|
||||
time=new_times,
|
||||
frequency=new_frequencies,
|
||||
),
|
||||
method="linear",
|
||||
)
|
||||
|
||||
|
||||
def get_dim_width(arr: xr.DataArray, dim: str) -> float:
|
||||
coord = arr.coords[dim]
|
||||
attrs = coord.attrs
|
||||
if "min" in attrs and "max" in attrs:
|
||||
return attrs["max"] - attrs["min"]
|
||||
|
||||
coord_min = coord.min()
|
||||
coord_max = coord.max()
|
||||
return float(coord_max - coord_min)
|
||||
|
||||
|
||||
class RandomClipProvider:
|
||||
def __init__(
|
||||
self,
|
||||
clip_annotations: List[data.ClipAnnotation],
|
||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
||||
scale_audio: bool = SCALE_RAW_AUDIO,
|
||||
):
|
||||
self.target_sampling_rate = target_sampling_rate
|
||||
self.scale_audio = scale_audio
|
||||
self.clip_annotations = clip_annotations
|
||||
|
||||
def get_next_clip(self, clip: data.ClipAnnotation) -> data.ClipAnnotation:
|
||||
tries = 0
|
||||
while True:
|
||||
random_clip = random.choice(self.clip_annotations)
|
||||
|
||||
if random_clip.clip != clip.clip:
|
||||
return random_clip
|
||||
|
||||
tries += 1
|
||||
if tries > 4:
|
||||
raise ValueError("Could not find a different clip")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
clip: data.ClipAnnotation,
|
||||
) -> Tuple[xr.DataArray, data.ClipAnnotation]:
|
||||
random_clip = self.get_next_clip(clip)
|
||||
|
||||
wav = load_clip_audio(
|
||||
random_clip.clip,
|
||||
target_sampling_rate=self.target_sampling_rate,
|
||||
scale=self.scale_audio,
|
||||
)
|
||||
|
||||
return wav, random_clip
|
@ -53,7 +53,13 @@ class SelfAttention(nn.Module):
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
def __init__(
|
||||
self, in_chn, out_chn, ip_height, k_size=3, pad_size=1, stride=1
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
ip_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
):
|
||||
super(ConvBlockDownCoordF, self).__init__()
|
||||
self.coords = nn.Parameter(
|
||||
|
@ -1,5 +1,4 @@
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
@ -207,7 +206,7 @@ class Net2DFastNoAttn(nn.Module):
|
||||
num_filts // 4, 2, kernel_size=1, padding=0
|
||||
)
|
||||
self.conv_classes_op = nn.Conv2d(
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0
|
||||
num_filts // 4, self.num_classes + 1, kernel_size=1, padding=0,
|
||||
)
|
||||
|
||||
if self.emb_dim > 0:
|
||||
|
@ -28,6 +28,7 @@ MAX_SCALE_SPEC = False
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"checkpoints",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
@ -68,6 +68,7 @@ def run_nms(
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
print("duration", duration)
|
||||
top_k = int(duration * params["nms_top_k_per_sec"])
|
||||
scores, y_pos, x_pos = get_topk_scores(pred_det_nms, top_k)
|
||||
|
||||
|
@ -7,16 +7,16 @@ import copy
|
||||
import json
|
||||
import os
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
|
||||
from batdetect2.detector import parameters
|
||||
import batdetect2.train.evaluate as evl
|
||||
import batdetect2.train.train_utils as tu
|
||||
import batdetect2.utils.detector_utils as du
|
||||
import batdetect2.utils.plot_utils as pu
|
||||
from batdetect2.detector import parameters
|
||||
|
||||
|
||||
def get_blank_annotation(ip_str):
|
||||
@ -330,7 +330,8 @@ def load_gt_data(datasets, events_of_interest, class_names, classes_to_ignore):
|
||||
for dd in datasets:
|
||||
print("\n" + dd["dataset_name"])
|
||||
gt_dataset = tu.load_set_of_anns(
|
||||
[dd], events_of_interest=events_of_interest, verbose=True
|
||||
[dd],
|
||||
events_of_interest=events_of_interest,
|
||||
)
|
||||
gt_dataset = [
|
||||
parse_data(gg, class_names, classes_to_ignore, False)
|
||||
@ -553,7 +554,9 @@ if __name__ == "__main__":
|
||||
test_dict["dataset_name"] = args["test_file"].replace(".json", "")
|
||||
test_dict["is_test"] = True
|
||||
test_dict["is_binary"] = True
|
||||
test_dict["ann_path"] = os.path.join(args["ann_dir"], args["test_file"])
|
||||
test_dict["ann_path"] = os.path.join(
|
||||
args["ann_dir"], args["test_file"]
|
||||
)
|
||||
test_dict["wav_path"] = args["data_dir"]
|
||||
test_sets = [test_dict]
|
||||
|
||||
|
Binary file not shown.
91
batdetect2/models/__init__.py
Normal file
91
batdetect2/models/__init__.py
Normal file
@ -0,0 +1,91 @@
|
||||
import os
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.models.encoders import (
|
||||
Net2DFast,
|
||||
Net2DFastNoAttn,
|
||||
Net2DFastNoCoordConv,
|
||||
)
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
|
||||
__all__ = [
|
||||
"load_model",
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
]
|
||||
|
||||
DEFAULT_MODEL_PATH = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"models",
|
||||
"checkpoints",
|
||||
"Net2DFast_UK_same.pth.tar",
|
||||
)
|
||||
|
||||
|
||||
def load_model(
|
||||
model_path: str = DEFAULT_MODEL_PATH,
|
||||
load_weights: bool = True,
|
||||
device: Union[torch.device, str, None] = None,
|
||||
) -> Tuple[DetectionModel, dict]:
|
||||
"""Load model from file.
|
||||
|
||||
Args:
|
||||
model_path (str): Path to model file. Defaults to DEFAULT_MODEL_PATH.
|
||||
load_weights (bool, optional): Load weights. Defaults to True.
|
||||
|
||||
Returns:
|
||||
model, params: Model and parameters.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: Model file not found.
|
||||
ValueError: Unknown model name.
|
||||
"""
|
||||
if device is None:
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if not os.path.isfile(model_path):
|
||||
raise FileNotFoundError("Model file not found.")
|
||||
|
||||
net_params = torch.load(model_path, map_location=device)
|
||||
|
||||
params = net_params["params"]
|
||||
|
||||
model: DetectionModel
|
||||
|
||||
if params["model_name"] == "Net2DFast":
|
||||
model = Net2DFast(
|
||||
params["num_filters"],
|
||||
num_classes=len(params["class_names"]),
|
||||
emb_dim=params["emb_dim"],
|
||||
ip_height=params["ip_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
elif params["model_name"] == "Net2DFastNoAttn":
|
||||
model = Net2DFastNoAttn(
|
||||
params["num_filters"],
|
||||
num_classes=len(params["class_names"]),
|
||||
emb_dim=params["emb_dim"],
|
||||
ip_height=params["ip_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
elif params["model_name"] == "Net2DFastNoCoordConv":
|
||||
model = Net2DFastNoCoordConv(
|
||||
params["num_filters"],
|
||||
num_classes=len(params["class_names"]),
|
||||
emb_dim=params["emb_dim"],
|
||||
ip_height=params["ip_height"],
|
||||
resize_factor=params["resize_factor"],
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown model.")
|
||||
|
||||
if load_weights:
|
||||
model.load_state_dict(net_params["state_dict"])
|
||||
|
||||
model = model.to(device)
|
||||
model.eval()
|
||||
|
||||
return model, params
|
219
batdetect2/models/blocks.py
Normal file
219
batdetect2/models/blocks.py
Normal file
@ -0,0 +1,219 @@
|
||||
"""Module containing custom NN blocks.
|
||||
|
||||
All these classes are subclasses of `torch.nn.Module` and can be used to build
|
||||
complex neural network architectures.
|
||||
"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
"SelfAttention",
|
||||
"ConvBlockDownCoordF",
|
||||
"ConvBlockDownStandard",
|
||||
"ConvBlockUpF",
|
||||
"ConvBlockUpStandard",
|
||||
]
|
||||
|
||||
|
||||
class SelfAttention(nn.Module):
|
||||
"""Self-Attention module.
|
||||
|
||||
This module implements self-attention mechanism.
|
||||
"""
|
||||
|
||||
def __init__(self, ip_dim: int, att_dim: int):
|
||||
super().__init__()
|
||||
|
||||
# Note, does not encode position information (absolute or realtive)
|
||||
self.temperature = 1.0
|
||||
self.att_dim = att_dim
|
||||
self.key_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.val_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.que_fun = nn.Linear(ip_dim, att_dim)
|
||||
self.pro_fun = nn.Linear(att_dim, ip_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.squeeze(2).permute(0, 2, 1)
|
||||
|
||||
key = torch.matmul(
|
||||
x, self.key_fun.weight.T
|
||||
) + self.key_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
query = torch.matmul(
|
||||
x, self.que_fun.weight.T
|
||||
) + self.que_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
value = torch.matmul(
|
||||
x, self.val_fun.weight.T
|
||||
) + self.val_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
|
||||
kk_qq = torch.bmm(key, query.permute(0, 2, 1)) / (
|
||||
self.temperature * self.att_dim
|
||||
)
|
||||
att_weights = F.softmax(kk_qq, 1)
|
||||
att = torch.bmm(value.permute(0, 2, 1), att_weights)
|
||||
|
||||
op = torch.matmul(
|
||||
att.permute(0, 2, 1), self.pro_fun.weight.T
|
||||
) + self.pro_fun.bias.unsqueeze(0).unsqueeze(0)
|
||||
op = op.permute(0, 2, 1).unsqueeze(2)
|
||||
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlockDownCoordF(nn.Module):
|
||||
"""Convolutional Block with Downsampling and Coord Feature.
|
||||
|
||||
This block performs convolution followed by downsampling
|
||||
and concatenates with coordinate information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
ip_height: int,
|
||||
k_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
stride: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, ip_height)[None, None, ..., None],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn + 1,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
freq_info = self.coords.repeat(x.shape[0], 1, 1, x.shape[3])
|
||||
x = torch.cat((x, freq_info), 1)
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlockDownStandard(nn.Module):
|
||||
"""Convolutional Block with Downsampling.
|
||||
|
||||
This block performs convolution followed by downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn,
|
||||
out_chn,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
):
|
||||
super(ConvBlockDownStandard, self).__init__()
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn,
|
||||
out_chn,
|
||||
kernel_size=k_size,
|
||||
padding=pad_size,
|
||||
stride=stride,
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.max_pool2d(self.conv(x), 2, 2)
|
||||
x = F.relu(self.conv_bn(x), inplace=True)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBlockUpF(nn.Module):
|
||||
"""Convolutional Block with Upsampling and Coord Feature.
|
||||
|
||||
This block performs convolution followed by upsampling
|
||||
and concatenates with coordinate information.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
ip_height: int,
|
||||
k_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.coords = nn.Parameter(
|
||||
torch.linspace(-1, 1, ip_height * up_scale[0])[
|
||||
None, None, ..., None
|
||||
],
|
||||
requires_grad=False,
|
||||
)
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn + 1, out_chn, kernel_size=k_size, padding=pad_size
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
x.shape[-2] * self.up_scale[0],
|
||||
x.shape[-1] * self.up_scale[1],
|
||||
),
|
||||
mode=self.up_mode,
|
||||
align_corners=False,
|
||||
)
|
||||
freq_info = self.coords.repeat(op.shape[0], 1, 1, op.shape[3])
|
||||
op = torch.cat((op, freq_info), 1)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
return op
|
||||
|
||||
|
||||
class ConvBlockUpStandard(nn.Module):
|
||||
"""Convolutional Block with Upsampling.
|
||||
|
||||
This block performs convolution followed by upsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_chn: int,
|
||||
out_chn: int,
|
||||
k_size: int = 3,
|
||||
pad_size: int = 1,
|
||||
up_mode: str = "bilinear",
|
||||
up_scale: Tuple[int, int] = (2, 2),
|
||||
):
|
||||
super(ConvBlockUpStandard, self).__init__()
|
||||
self.up_scale = up_scale
|
||||
self.up_mode = up_mode
|
||||
self.conv = nn.Conv2d(
|
||||
in_chn, out_chn, kernel_size=k_size, padding=pad_size
|
||||
)
|
||||
self.conv_bn = nn.BatchNorm2d(out_chn)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
op = F.interpolate(
|
||||
x,
|
||||
size=(
|
||||
x.shape[-2] * self.up_scale[0],
|
||||
x.shape[-1] * self.up_scale[1],
|
||||
),
|
||||
mode=self.up_mode,
|
||||
align_corners=False,
|
||||
)
|
||||
op = self.conv(op)
|
||||
op = F.relu(self.conv_bn(op), inplace=True)
|
||||
return op
|
148
batdetect2/models/detectors.py
Normal file
148
batdetect2/models/detectors.py
Normal file
@ -0,0 +1,148 @@
|
||||
import pytorch_lightning as L
|
||||
import torch
|
||||
import xarray as xr
|
||||
from soundevent import data
|
||||
from torch import nn, optim
|
||||
|
||||
from batdetect2.data.preprocessing import preprocess_audio_clip
|
||||
from batdetect2.models.typing import EncoderModel, ModelOutput
|
||||
from batdetect2.train import losses
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
from batdetect2.models.post_process import (
|
||||
PostprocessConfig,
|
||||
postprocess_model_outputs,
|
||||
)
|
||||
from batdetect2.train.preprocess import PreprocessingConfig
|
||||
|
||||
|
||||
class DetectorModel(L.LightningModule):
|
||||
def __init__(
|
||||
self,
|
||||
encoder: EncoderModel,
|
||||
num_classes: int,
|
||||
learning_rate: float = 1e-3,
|
||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
||||
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.preprocessing_config = preprocessing_config
|
||||
self.postprocessing_config = postprocessing_config
|
||||
self.num_classes = num_classes
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
self.encoder = encoder
|
||||
|
||||
self.classifier = nn.Conv2d(
|
||||
self.encoder.num_filts // 4,
|
||||
self.num_classes + 1,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
self.bbox = nn.Conv2d(
|
||||
self.encoder.num_filts // 4,
|
||||
2,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> ModelOutput: # type: ignore
|
||||
features = self.encoder(spec)
|
||||
|
||||
classification_logits = self.classifier(features)
|
||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
||||
|
||||
return ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=self.bbox(features),
|
||||
class_probs=classification_probs,
|
||||
features=features,
|
||||
)
|
||||
|
||||
def compute_spectrogram(self, clip: data.Clip) -> xr.DataArray:
|
||||
config = self.preprocessing_config
|
||||
|
||||
return preprocess_audio_clip(
|
||||
clip,
|
||||
target_sampling_rate=config.target_samplerate,
|
||||
scale_audio=config.scale_audio,
|
||||
fft_win_length=config.fft_win_length,
|
||||
fft_overlap=config.fft_overlap,
|
||||
max_freq=config.max_freq,
|
||||
min_freq=config.min_freq,
|
||||
spec_scale=config.spec_scale,
|
||||
denoise_spec_avg=config.denoise_spec_avg,
|
||||
max_scale_spec=config.max_scale_spec,
|
||||
)
|
||||
|
||||
def process_clip(self, clip: data.Clip):
|
||||
spectrogram = self.compute_spectrogram(clip)
|
||||
spec_tensor = (
|
||||
torch.tensor(spectrogram.values).unsqueeze(0).unsqueeze(0)
|
||||
)
|
||||
|
||||
outputs = self(spec_tensor)
|
||||
|
||||
config = self.postprocessing_config
|
||||
return postprocess_model_outputs(
|
||||
outputs,
|
||||
[clip],
|
||||
nms_kernel_size=config.nms_kernel_size,
|
||||
detection_threshold=config.detection_threshold,
|
||||
min_freq=config.min_freq,
|
||||
max_freq=config.max_freq,
|
||||
top_k_per_sec=config.top_k_per_sec,
|
||||
)
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
) -> torch.Tensor:
|
||||
detection_loss = losses.focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
)
|
||||
|
||||
size_loss = losses.bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = losses.focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
valid_mask=valid_mask,
|
||||
)
|
||||
|
||||
return detection_loss + size_loss + classification_loss
|
||||
|
||||
def training_step( # type: ignore
|
||||
self,
|
||||
batch: TrainExample,
|
||||
):
|
||||
features = self.encoder(batch.spec)
|
||||
|
||||
classification_logits = self.classifier(features)
|
||||
classification_probs = torch.softmax(classification_logits, dim=1)
|
||||
detection_probs = classification_probs[:, :-1].sum(dim=1, keepdim=True)
|
||||
|
||||
loss = self.compute_loss(
|
||||
ModelOutput(
|
||||
detection_probs=detection_probs,
|
||||
size_preds=self.bbox(features),
|
||||
class_probs=classification_probs,
|
||||
features=features,
|
||||
),
|
||||
batch,
|
||||
)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||
return [optimizer], [scheduler]
|
319
batdetect2/models/encoders.py
Normal file
319
batdetect2/models/encoders.py
Normal file
@ -0,0 +1,319 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.typing import EncoderModel
|
||||
from batdetect2.models.blocks import (
|
||||
ConvBlockDownCoordF,
|
||||
ConvBlockDownStandard,
|
||||
ConvBlockUpF,
|
||||
ConvBlockUpStandard,
|
||||
SelfAttention,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Net2DFast",
|
||||
"Net2DFastNoAttn",
|
||||
"Net2DFastNoCoordConv",
|
||||
]
|
||||
|
||||
|
||||
class Net2DFast(EncoderModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_filts: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_filts = num_filts
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
# encoder
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_filts // 4,
|
||||
self.input_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 2,
|
||||
self.input_height // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_filts // 2,
|
||||
self.num_filts,
|
||||
self.input_height // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_filts,
|
||||
self.num_filts * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||
|
||||
# bottleneck
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_filts * 2,
|
||||
self.num_filts * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
|
||||
|
||||
# decoder
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_filts * 2,
|
||||
self.num_filts // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_filts // 2,
|
||||
self.num_filts // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
||||
|
||||
def pad_adjust(self, spec: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
|
||||
h, w = spec.shape[2:]
|
||||
h_pad = (32 - h % 32) % 32
|
||||
w_pad = (32 - w % 32) % 32
|
||||
return F.pad(spec, (0, w_pad, 0, h_pad)), h_pad, w_pad
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
# encoder
|
||||
spec, h_pad, w_pad = self.pad_adjust(spec)
|
||||
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
# bottleneck
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
# decoder
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
# Restore original size
|
||||
if h_pad > 0:
|
||||
x = x[:, :, :-h_pad, :]
|
||||
|
||||
if w_pad > 0:
|
||||
x = x[:, :, :, :-w_pad]
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
|
||||
|
||||
class Net2DFastNoAttn(EncoderModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_filts: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_filts = num_filts
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownCoordF(
|
||||
1,
|
||||
self.num_filts // 4,
|
||||
self.input_height,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownCoordF(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 2,
|
||||
self.input_height // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownCoordF(
|
||||
self.num_filts // 2,
|
||||
self.num_filts,
|
||||
self.input_height // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_filts,
|
||||
self.num_filts * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_filts * 2,
|
||||
self.num_filts * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpF(
|
||||
self.num_filts * 2,
|
||||
self.num_filts // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpF(
|
||||
self.num_filts // 2,
|
||||
self.num_filts // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpF(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
||||
|
||||
|
||||
class Net2DFastNoCoordConv(EncoderModel):
|
||||
def __init__(
|
||||
self,
|
||||
num_filts: int,
|
||||
input_height: int = 128,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_filts = num_filts
|
||||
self.input_height = input_height
|
||||
self.bottleneck_height = self.input_height // 32
|
||||
|
||||
self.conv_dn_0 = ConvBlockDownStandard(
|
||||
1,
|
||||
self.num_filts // 4,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_1 = ConvBlockDownStandard(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 2,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_2 = ConvBlockDownStandard(
|
||||
self.num_filts // 2,
|
||||
self.num_filts,
|
||||
k_size=3,
|
||||
pad_size=1,
|
||||
stride=1,
|
||||
)
|
||||
self.conv_dn_3 = nn.Conv2d(
|
||||
self.num_filts,
|
||||
self.num_filts * 2,
|
||||
3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_dn_3_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||
|
||||
self.conv_1d = nn.Conv2d(
|
||||
self.num_filts * 2,
|
||||
self.num_filts * 2,
|
||||
(self.input_height // 8, 1),
|
||||
padding=0,
|
||||
)
|
||||
self.conv_1d_bn = nn.BatchNorm2d(self.num_filts * 2)
|
||||
|
||||
self.att = SelfAttention(self.num_filts * 2, self.num_filts * 2)
|
||||
|
||||
self.conv_up_2 = ConvBlockUpStandard(
|
||||
self.num_filts * 2,
|
||||
self.num_filts // 2,
|
||||
self.input_height // 8,
|
||||
)
|
||||
self.conv_up_3 = ConvBlockUpStandard(
|
||||
self.num_filts // 2,
|
||||
self.num_filts // 4,
|
||||
self.input_height // 4,
|
||||
)
|
||||
self.conv_up_4 = ConvBlockUpStandard(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 4,
|
||||
self.input_height // 2,
|
||||
)
|
||||
|
||||
self.conv_op = nn.Conv2d(
|
||||
self.num_filts // 4,
|
||||
self.num_filts // 4,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
)
|
||||
self.conv_op_bn = nn.BatchNorm2d(self.num_filts // 4)
|
||||
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
x1 = self.conv_dn_0(spec)
|
||||
x2 = self.conv_dn_1(x1)
|
||||
x3 = self.conv_dn_2(x2)
|
||||
x3 = F.relu_(self.conv_dn_3_bn(self.conv_dn_3(x3)))
|
||||
|
||||
x = F.relu_(self.conv_1d_bn(self.conv_1d(x3)))
|
||||
x = self.att(x)
|
||||
x = x.repeat([1, 1, self.bottleneck_height * 4, 1])
|
||||
|
||||
x = self.conv_up_2(x + x3)
|
||||
x = self.conv_up_3(x + x2)
|
||||
x = self.conv_up_4(x + x1)
|
||||
|
||||
return F.relu_(self.conv_op_bn(self.conv_op(x)))
|
310
batdetect2/models/post_process.py
Normal file
310
batdetect2/models/post_process.py
Normal file
@ -0,0 +1,310 @@
|
||||
"""Module for postprocessing model outputs."""
|
||||
|
||||
from typing import Callable, List, Tuple, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch import nn
|
||||
|
||||
from batdetect2.models.typing import ModelOutput
|
||||
|
||||
__all__ = [
|
||||
"postprocess_model_outputs",
|
||||
"PostprocessConfig",
|
||||
]
|
||||
|
||||
NMS_KERNEL_SIZE = 9
|
||||
DETECTION_THRESHOLD = 0.01
|
||||
TOP_K_PER_SEC = 200
|
||||
|
||||
|
||||
class PostprocessConfig(BaseModel):
|
||||
"""Configuration for postprocessing model outputs."""
|
||||
|
||||
nms_kernel_size: int = Field(default=NMS_KERNEL_SIZE, gt=0)
|
||||
detection_threshold: float = Field(default=DETECTION_THRESHOLD, ge=0)
|
||||
min_freq: int = Field(default=10000, gt=0)
|
||||
max_freq: int = Field(default=120000, gt=0)
|
||||
top_k_per_sec: int = Field(default=TOP_K_PER_SEC, gt=0)
|
||||
|
||||
|
||||
TagFunction = Callable[[int], List[data.Tag]]
|
||||
|
||||
|
||||
def postprocess_model_outputs(
|
||||
outputs: ModelOutput,
|
||||
clips: List[data.Clip],
|
||||
nms_kernel_size: int = NMS_KERNEL_SIZE,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
top_k_per_sec: int = TOP_K_PER_SEC,
|
||||
) -> List[data.ClipPrediction]:
|
||||
"""Postprocesses model outputs to generate clip predictions.
|
||||
|
||||
This function takes the output from the model, applies non-maximum suppression,
|
||||
selects the top-k scores, computes sound events from the outputs, and returns
|
||||
clip predictions based on these processed outputs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
outputs
|
||||
Output from the model containing detection probabilities, size
|
||||
predictions, class logits, and features. All tensors are expected
|
||||
to have a batch dimension.
|
||||
clips
|
||||
List of clips for which predictions are made. The number of clips
|
||||
must match the batch dimension of the model outputs.
|
||||
nms_kernel_size
|
||||
Size of the non-maximum suppression kernel. Default is 9.
|
||||
detection_threshold
|
||||
Detection threshold. Default is 0.01.
|
||||
min_freq
|
||||
Minimum frequency. Default is 10000.
|
||||
max_freq
|
||||
Maximum frequency. Default is 120000.
|
||||
top_k_per_sec
|
||||
Top k per second. Default is 200.
|
||||
|
||||
Returns
|
||||
-------
|
||||
predictions: List[data.ClipPrediction]
|
||||
List of clip predictions containing predicted sound events.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If the number of predictions does not match the number of clips.
|
||||
"""
|
||||
num_predictions = len(outputs.detection_probs)
|
||||
|
||||
if num_predictions == 0:
|
||||
return []
|
||||
|
||||
if num_predictions != len(clips):
|
||||
raise ValueError(
|
||||
"Number of predictions must match the number of clips."
|
||||
)
|
||||
|
||||
detection_probs = non_max_suppression(
|
||||
outputs.detection_probs,
|
||||
kernel_size=nms_kernel_size,
|
||||
)
|
||||
|
||||
duration = clips[0].end_time - clips[0].start_time
|
||||
|
||||
scores_batch, y_pos_batch, x_pos_batch = get_topk_scores(
|
||||
detection_probs,
|
||||
int(top_k_per_sec * duration / 2),
|
||||
)
|
||||
|
||||
predictions: List[data.ClipPrediction] = []
|
||||
for scores, y_pos, x_pos, size_preds, class_probs, features, clip in zip(
|
||||
scores_batch,
|
||||
y_pos_batch,
|
||||
x_pos_batch,
|
||||
outputs.size_preds,
|
||||
outputs.class_probs,
|
||||
outputs.features,
|
||||
clips,
|
||||
):
|
||||
sound_events = compute_sound_events_from_outputs(
|
||||
clip,
|
||||
scores,
|
||||
y_pos,
|
||||
x_pos,
|
||||
size_preds,
|
||||
class_probs,
|
||||
features,
|
||||
min_freq=min_freq,
|
||||
max_freq=max_freq,
|
||||
detection_threshold=detection_threshold,
|
||||
)
|
||||
|
||||
predictions.append(
|
||||
data.ClipPrediction(
|
||||
clip=clip,
|
||||
sound_events=sound_events,
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def compute_sound_events_from_outputs(
|
||||
clip: data.Clip,
|
||||
scores: torch.Tensor,
|
||||
y_pos: torch.Tensor,
|
||||
x_pos: torch.Tensor,
|
||||
size_preds: torch.Tensor,
|
||||
class_probs: torch.Tensor,
|
||||
features: torch.Tensor,
|
||||
tag_fn: TagFunction = lambda _: [],
|
||||
min_freq: int = 10000,
|
||||
max_freq: int = 120000,
|
||||
detection_threshold: float = DETECTION_THRESHOLD,
|
||||
) -> List[data.SoundEventPrediction]:
|
||||
_, freq_bins, time_bins = size_preds.shape
|
||||
|
||||
sorted_indices = torch.argsort(x_pos)
|
||||
valid_indices = sorted_indices[
|
||||
scores[sorted_indices] > detection_threshold
|
||||
]
|
||||
|
||||
scores = scores[valid_indices]
|
||||
x_pos = x_pos[valid_indices]
|
||||
y_pos = y_pos[valid_indices]
|
||||
|
||||
predictions: List[data.SoundEventPrediction] = []
|
||||
for score, x, y in zip(scores, x_pos, y_pos):
|
||||
width, height = size_preds[:, y, x]
|
||||
print(width, height)
|
||||
class_prob = class_probs[:, y, x]
|
||||
feature = features[:, y, x]
|
||||
|
||||
start_time = np.interp(
|
||||
x.item(),
|
||||
[0, time_bins],
|
||||
[clip.start_time, clip.end_time],
|
||||
)
|
||||
|
||||
end_time = np.interp(
|
||||
x.item() + width.item(),
|
||||
[0, time_bins],
|
||||
[clip.start_time, clip.end_time],
|
||||
)
|
||||
|
||||
low_freq = np.interp(
|
||||
y.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
high_freq = np.interp(
|
||||
y.item() - height.item(),
|
||||
[0, freq_bins],
|
||||
[max_freq, min_freq],
|
||||
)
|
||||
|
||||
predicted_tags: List[data.PredictedTag] = []
|
||||
|
||||
for label_id, class_score in enumerate(class_prob):
|
||||
corresponding_tags = tag_fn(label_id)
|
||||
predicted_tags.extend(
|
||||
[
|
||||
data.PredictedTag(
|
||||
tag=tag,
|
||||
score=class_score.item(),
|
||||
)
|
||||
for tag in corresponding_tags
|
||||
]
|
||||
)
|
||||
|
||||
start_time, end_time = sorted([float(start_time), float(end_time)])
|
||||
low_freq, high_freq = sorted([float(low_freq), float(high_freq)])
|
||||
|
||||
sound_event = data.SoundEvent(
|
||||
recording=clip.recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[
|
||||
start_time,
|
||||
low_freq,
|
||||
end_time,
|
||||
high_freq,
|
||||
]
|
||||
),
|
||||
features=[
|
||||
data.Feature(
|
||||
name=f"batdetect2_{i}",
|
||||
value=value.item(),
|
||||
)
|
||||
for i, value in enumerate(feature)
|
||||
],
|
||||
)
|
||||
|
||||
predictions.append(
|
||||
data.SoundEventPrediction(
|
||||
sound_event=sound_event,
|
||||
score=score.item(),
|
||||
tags=predicted_tags,
|
||||
)
|
||||
)
|
||||
|
||||
return predictions
|
||||
|
||||
|
||||
def non_max_suppression(
|
||||
tensor: torch.Tensor,
|
||||
kernel_size: Union[int, Tuple[int, int]] = NMS_KERNEL_SIZE,
|
||||
) -> torch.Tensor:
|
||||
"""Run non-maximum suppression on a tensor.
|
||||
|
||||
This function removes values from the input tensor that are not local
|
||||
maxima in the neighborhood of the given kernel size.
|
||||
|
||||
All non-maximum values are set to zero.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
tensor : torch.Tensor
|
||||
Input tensor.
|
||||
kernel_size : Union[int, Tuple[int, int]], optional
|
||||
Size of the neighborhood to consider for non-maximum suppression.
|
||||
If an integer is given, the neighborhood will be a square of the
|
||||
given size. If a tuple is given, the neighborhood will be a
|
||||
rectangle with the given height and width.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Tensor with non-maximum suppressed values.
|
||||
"""
|
||||
if isinstance(kernel_size, int):
|
||||
kernel_size_h = kernel_size
|
||||
kernel_size_w = kernel_size
|
||||
else:
|
||||
kernel_size_h, kernel_size_w = kernel_size
|
||||
|
||||
pad_h = (kernel_size_h - 1) // 2
|
||||
pad_w = (kernel_size_w - 1) // 2
|
||||
|
||||
hmax = nn.functional.max_pool2d(
|
||||
tensor,
|
||||
(kernel_size_h, kernel_size_w),
|
||||
stride=1,
|
||||
padding=(pad_h, pad_w),
|
||||
)
|
||||
keep = (hmax == tensor).float()
|
||||
return tensor * keep
|
||||
|
||||
|
||||
def get_topk_scores(
|
||||
scores: torch.Tensor,
|
||||
K: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Get the top-k scores and their indices.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
scores : torch.Tensor
|
||||
Tensor with scores. Expects input of size: `batch x 1 x height x width`.
|
||||
K : int
|
||||
Number of top scores to return.
|
||||
|
||||
Returns
|
||||
-------
|
||||
scores : torch.Tensor
|
||||
Top-k scores.
|
||||
ys : torch.Tensor
|
||||
Y coordinates of the top-k scores.
|
||||
xs : torch.Tensor
|
||||
X coordinates of the top-k scores.
|
||||
"""
|
||||
batch, _, height, width = scores.size()
|
||||
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
||||
topk_inds = topk_inds % (height * width)
|
||||
topk_ys = torch.div(topk_inds, width, rounding_mode="floor").long()
|
||||
topk_xs = (topk_inds % width).long()
|
||||
return topk_scores, topk_ys, topk_xs
|
55
batdetect2/models/typing.py
Normal file
55
batdetect2/models/typing.py
Normal file
@ -0,0 +1,55 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import NamedTuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class ModelOutput(NamedTuple):
|
||||
"""Output of the detection model.
|
||||
|
||||
Each of the tensors has a shape of
|
||||
|
||||
`(batch_size, num_channels, spec_height, spec_width)`.
|
||||
|
||||
Where `spec_height` and `spec_width` are the height and width of the
|
||||
input spectrograms.
|
||||
|
||||
They contain localised information of:
|
||||
|
||||
1. The probability of a bounding box detection at the given location.
|
||||
2. The predicted size of the bounding box at the given location.
|
||||
3. The probabilities of each class at the given location before softmax.
|
||||
4. Features used to make the predictions at the given location.
|
||||
"""
|
||||
|
||||
detection_probs: torch.Tensor
|
||||
"""Tensor with predict detection probabilities."""
|
||||
|
||||
size_preds: torch.Tensor
|
||||
"""Tensor with predicted bounding box sizes."""
|
||||
|
||||
class_probs: torch.Tensor
|
||||
"""Tensor with predicted class probabilities."""
|
||||
|
||||
features: torch.Tensor
|
||||
"""Tensor with intermediate features."""
|
||||
|
||||
|
||||
class EncoderModel(ABC, nn.Module):
|
||||
|
||||
input_height: int
|
||||
"""Height of the input spectrogram."""
|
||||
|
||||
num_filts: int
|
||||
"""Dimension of the feature tensor."""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the encoder model."""
|
||||
|
||||
|
||||
class DetectionModel(ABC, nn.Module):
|
||||
@abstractmethod
|
||||
def forward(self, spec: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward pass of the detection model."""
|
@ -102,6 +102,7 @@ def spectrogram(
|
||||
return ax
|
||||
|
||||
|
||||
|
||||
def spectrogram_with_detections(
|
||||
spec: Union[torch.Tensor, np.ndarray],
|
||||
dets: List[Annotation],
|
||||
|
0
batdetect2/plotting/__init__.py
Normal file
0
batdetect2/plotting/__init__.py
Normal file
22
batdetect2/plotting/common.py
Normal file
22
batdetect2/plotting/common.py
Normal file
@ -0,0 +1,22 @@
|
||||
"""General plotting utilities."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib import axes
|
||||
|
||||
__all__ = [
|
||||
"create_ax",
|
||||
]
|
||||
|
||||
|
||||
def create_ax(
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
**kwargs,
|
||||
) -> axes.Axes:
|
||||
"""Create a new axis if none is provided"""
|
||||
if ax is None:
|
||||
_, ax = plt.subplots(figsize=figsize, **kwargs) # type: ignore
|
||||
|
||||
return ax # type: ignore
|
27
batdetect2/plotting/heatmaps.py
Normal file
27
batdetect2/plotting/heatmaps.py
Normal file
@ -0,0 +1,27 @@
|
||||
"""Plot heatmaps"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import xarray as xr
|
||||
from matplotlib import axes
|
||||
|
||||
from batdetect2.plotting.common import create_ax
|
||||
|
||||
|
||||
def plot_heatmap(
|
||||
heatmap: xr.DataArray,
|
||||
ax: Optional[axes.Axes] = None,
|
||||
figsize: Tuple[int, int] = (10, 10),
|
||||
) -> axes.Axes:
|
||||
ax = create_ax(ax, figsize=figsize)
|
||||
|
||||
ax.pcolormesh(
|
||||
heatmap.time,
|
||||
heatmap.frequency,
|
||||
heatmap,
|
||||
vmax=1,
|
||||
vmin=0,
|
||||
)
|
||||
|
||||
return ax
|
@ -1,4 +1,5 @@
|
||||
"""Functions and dataloaders for training and testing the model."""
|
||||
|
||||
import copy
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
@ -199,8 +200,7 @@ def draw_gaussian(
|
||||
x0 = y0 = size // 2
|
||||
# g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
||||
g = np.exp(
|
||||
-((x - x0) ** 2) / (2 * sigmax**2)
|
||||
- ((y - y0) ** 2) / (2 * sigmay**2)
|
||||
-((x - x0) ** 2) / (2 * sigmax**2) - ((y - y0) ** 2) / (2 * sigmay**2)
|
||||
)
|
||||
g_x = max(0, -ul[0]), min(br[0], h) - ul[0]
|
||||
g_y = max(0, -ul[1]), min(br[1], w) - ul[1]
|
||||
@ -399,6 +399,8 @@ def echo_aug(
|
||||
sample_offset = (
|
||||
int(echo_max_delay * np.random.random() * sampling_rate) + 1
|
||||
)
|
||||
# NOTE: This seems to be wrong, as the echo should be added to the
|
||||
# end of the audio, not the beginning.
|
||||
audio[:-sample_offset] += np.random.random() * audio[sample_offset:]
|
||||
return audio
|
||||
|
||||
@ -820,9 +822,10 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
# )
|
||||
|
||||
# create spectrogram
|
||||
spec = au.generate_spectrogram(
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
params=dict(
|
||||
fft_win_length=self.params["fft_win_length"],
|
||||
fft_overlap=self.params["fft_overlap"],
|
||||
max_freq=self.params["max_freq"],
|
||||
@ -830,6 +833,7 @@ class AudioLoader(torch.utils.data.Dataset):
|
||||
spec_scale=self.params["spec_scale"],
|
||||
denoise_spec_avg=self.params["denoise_spec_avg"],
|
||||
max_scale_spec=self.params["max_scale_spec"],
|
||||
),
|
||||
)
|
||||
rsf = self.params["resize_factor"]
|
||||
spec_op_shape = (
|
||||
|
86
batdetect2/train/dataset.py
Normal file
86
batdetect2/train/dataset.py
Normal file
@ -0,0 +1,86 @@
|
||||
import os
|
||||
from typing import NamedTuple
|
||||
from pathlib import Path
|
||||
from typing import Sequence, Union, Dict
|
||||
from soundevent import data
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import torch
|
||||
import xarray as xr
|
||||
|
||||
from batdetect2.train.preprocess import PreprocessingConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TrainExample",
|
||||
"LabeledDataset",
|
||||
]
|
||||
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
|
||||
|
||||
class TrainExample(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
idx: torch.Tensor
|
||||
|
||||
|
||||
def get_files(directory: PathLike, extension: str = ".nc") -> Sequence[Path]:
|
||||
return list(Path(directory).glob(f"*{extension}"))
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(self, filenames: Sequence[PathLike]):
|
||||
self.filenames = filenames
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
def __getitem__(self, idx) -> TrainExample:
|
||||
data = self.load(self.filenames[idx])
|
||||
return TrainExample(
|
||||
spec=data["spectrogram"],
|
||||
detection_heatmap=data["detection"],
|
||||
class_heatmap=data["class"],
|
||||
size_heatmap=data["size"],
|
||||
idx=torch.tensor(idx),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_directory(cls, directory: PathLike, extension: str = ".nc"):
|
||||
return cls(get_files(directory, extension))
|
||||
|
||||
def load(self, filename: PathLike) -> Dict[str, torch.Tensor]:
|
||||
dataset = xr.open_dataset(filename)
|
||||
spectrogram = torch.tensor(dataset["spectrogram"].values).unsqueeze(0)
|
||||
return {
|
||||
"spectrogram": spectrogram,
|
||||
"detection": torch.tensor(dataset["detection"].values),
|
||||
"class": torch.tensor(dataset["class"].values),
|
||||
"size": torch.tensor(dataset["size"].values),
|
||||
}
|
||||
|
||||
def get_spectrogram(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["spectrogram"]
|
||||
|
||||
def get_detection_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["detection"]
|
||||
|
||||
def get_class_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["class"]
|
||||
|
||||
def get_size_mask(self, idx):
|
||||
return xr.open_dataset(self.filenames[idx])["size"]
|
||||
|
||||
def get_clip_annotation(self, idx):
|
||||
filename = self.filenames[idx]
|
||||
dataset = xr.open_dataset(filename)
|
||||
clip_annotation = dataset.attrs["clip_annotation"]
|
||||
return data.ClipAnnotation.model_validate_json(clip_annotation)
|
||||
|
||||
def get_preprocessing_configuration(self, idx):
|
||||
config = xr.open_dataset(self.filenames[idx]).attrs["configuration"]
|
||||
return PreprocessingConfig.model_validate_json(config)
|
56
batdetect2/train/light.py
Normal file
56
batdetect2/train/light.py
Normal file
@ -0,0 +1,56 @@
|
||||
import pytorch_lightning as L
|
||||
from torch import Tensor, optim
|
||||
|
||||
from batdetect2.models.typing import DetectionModel, ModelOutput
|
||||
from batdetect2.train import losses
|
||||
|
||||
from batdetect2.train.dataset import TrainExample
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LitDetectorModel",
|
||||
]
|
||||
|
||||
|
||||
class LitDetectorModel(L.LightningModule):
|
||||
model: DetectionModel
|
||||
|
||||
def __init__(self, model: DetectionModel, learning_rate: float = 1e-3):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.learning_rate = learning_rate
|
||||
|
||||
def compute_loss(
|
||||
self,
|
||||
outputs: ModelOutput,
|
||||
batch: TrainExample,
|
||||
) -> Tensor:
|
||||
detection_loss = losses.focal_loss(
|
||||
outputs.detection_probs,
|
||||
batch.detection_heatmap,
|
||||
)
|
||||
|
||||
size_loss = losses.bbox_size_loss(
|
||||
outputs.size_preds,
|
||||
batch.size_heatmap,
|
||||
)
|
||||
|
||||
valid_mask = batch.class_heatmap.any(dim=1, keepdim=True).float()
|
||||
classification_loss = losses.focal_loss(
|
||||
outputs.class_probs,
|
||||
batch.class_heatmap,
|
||||
valid_mask=valid_mask,
|
||||
)
|
||||
|
||||
return detection_loss + size_loss + classification_loss
|
||||
|
||||
def training_step(self, batch: TrainExample, batch_idx: int): # type: ignore
|
||||
outputs: ModelOutput = self.model(batch.spec)
|
||||
loss = self.compute_loss(outputs, batch)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.parameters(), lr=self.learning_rate)
|
||||
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 100)
|
||||
return [optimizer], [scheduler]
|
@ -22,15 +22,15 @@ def focal_loss(
|
||||
gt: torch.Tensor,
|
||||
weights: Optional[torch.Tensor] = None,
|
||||
valid_mask: Optional[torch.Tensor] = None,
|
||||
eps: float = 1e-5,
|
||||
beta: float = 4,
|
||||
alpha: float = 2,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Focal loss adapted from CornerNet: Detecting Objects as Paired Keypoints
|
||||
pred (batch x c x h x w)
|
||||
gt (batch x c x h x w)
|
||||
"""
|
||||
eps = 1e-5
|
||||
beta = 4
|
||||
alpha = 2
|
||||
|
||||
pos_inds = gt.eq(1).float()
|
||||
neg_inds = gt.lt(1).float()
|
||||
|
209
batdetect2/train/preprocess.py
Normal file
209
batdetect2/train/preprocess.py
Normal file
@ -0,0 +1,209 @@
|
||||
"""Module for preprocessing data for training."""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Sequence, Union
|
||||
from tqdm.auto import tqdm
|
||||
from multiprocessing import Pool
|
||||
|
||||
import xarray as xr
|
||||
from pydantic import BaseModel, Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.labels import TARGET_SIGMA, LabelFn, generate_heatmaps
|
||||
from batdetect2.data.preprocessing import (
|
||||
DENOISE_SPEC_AVG,
|
||||
FFT_OVERLAP,
|
||||
FFT_WIN_LENGTH_S,
|
||||
MAX_FREQ_HZ,
|
||||
MAX_SCALE_SPEC,
|
||||
MIN_FREQ_HZ,
|
||||
SCALE_RAW_AUDIO,
|
||||
SPEC_SCALE,
|
||||
TARGET_SAMPLERATE_HZ,
|
||||
preprocess_audio_clip,
|
||||
)
|
||||
|
||||
PathLike = Union[Path, str, os.PathLike]
|
||||
FilenameFn = Callable[[data.ClipAnnotation], str]
|
||||
|
||||
__all__ = [
|
||||
"preprocess_annotations",
|
||||
]
|
||||
|
||||
|
||||
class PreprocessingConfig(BaseModel):
|
||||
"""Configuration for preprocessing data."""
|
||||
|
||||
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||
|
||||
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
||||
|
||||
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||
|
||||
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||
|
||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||
|
||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||
|
||||
spec_scale: str = Field(default=SPEC_SCALE)
|
||||
|
||||
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
||||
|
||||
max_scale_spec: bool = MAX_SCALE_SPEC
|
||||
|
||||
target_sigma: float = Field(default=TARGET_SIGMA, gt=0)
|
||||
|
||||
class_labels: Sequence[str] = ["bat"]
|
||||
|
||||
|
||||
def generate_train_example(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
label_fn: LabelFn = lambda _: None,
|
||||
config: Optional[PreprocessingConfig] = None,
|
||||
) -> xr.Dataset:
|
||||
"""Generate a training example."""
|
||||
if config is None:
|
||||
config = PreprocessingConfig()
|
||||
|
||||
spectrogram = preprocess_audio_clip(
|
||||
clip_annotation.clip,
|
||||
target_sampling_rate=config.target_samplerate,
|
||||
scale_audio=config.scale_audio,
|
||||
fft_win_length=config.fft_win_length,
|
||||
fft_overlap=config.fft_overlap,
|
||||
max_freq=config.max_freq,
|
||||
min_freq=config.min_freq,
|
||||
spec_scale=config.spec_scale,
|
||||
denoise_spec_avg=config.denoise_spec_avg,
|
||||
max_scale_spec=config.max_scale_spec,
|
||||
)
|
||||
|
||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||
clip_annotation,
|
||||
spectrogram,
|
||||
target_sigma=config.target_sigma,
|
||||
num_classes=len(config.class_labels),
|
||||
class_labels=list(config.class_labels),
|
||||
label_fn=label_fn,
|
||||
)
|
||||
|
||||
dataset = xr.Dataset(
|
||||
{
|
||||
"spectrogram": spectrogram,
|
||||
"detection": detection_heatmap,
|
||||
"class": class_heatmap,
|
||||
"size": size_heatmap,
|
||||
}
|
||||
)
|
||||
|
||||
return dataset.assign_attrs(
|
||||
title=f"Training example for {clip_annotation.uuid}",
|
||||
configuration=config.model_dump_json(),
|
||||
clip_annotation=clip_annotation.model_dump_json(),
|
||||
)
|
||||
|
||||
|
||||
def save_to_file(
|
||||
dataset: xr.Dataset,
|
||||
path: PathLike,
|
||||
) -> None:
|
||||
dataset.to_netcdf(
|
||||
path,
|
||||
encoding={
|
||||
"spectrogram": {"zlib": True},
|
||||
"size": {"zlib": True},
|
||||
"class": {"zlib": True},
|
||||
"detection": {"zlib": True},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
||||
"""Load configuration from file."""
|
||||
|
||||
path = Path(path)
|
||||
|
||||
if not path.is_file():
|
||||
warnings.warn(f"Config file not found: {path}. Using default config.")
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
try:
|
||||
return PreprocessingConfig.model_validate_json(path.read_text())
|
||||
except ValueError as e:
|
||||
warnings.warn(
|
||||
f"Failed to load config file: {e}. Using default config."
|
||||
)
|
||||
return PreprocessingConfig(**kwargs)
|
||||
|
||||
|
||||
def _get_filename(clip_annotation: data.ClipAnnotation) -> str:
|
||||
return f"{clip_annotation.uuid}.nc"
|
||||
|
||||
|
||||
def preprocess_single_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
output_dir: PathLike,
|
||||
config: PreprocessingConfig,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
label_fn: LabelFn = lambda _: None,
|
||||
) -> None:
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
filename = filename_fn(clip_annotation)
|
||||
path = output_dir / filename
|
||||
|
||||
if path.is_file() and not replace:
|
||||
return
|
||||
|
||||
sample = generate_train_example(
|
||||
clip_annotation,
|
||||
label_fn=label_fn,
|
||||
config=config,
|
||||
)
|
||||
|
||||
save_to_file(sample, path)
|
||||
|
||||
|
||||
def preprocess_annotations(
|
||||
clip_annotations: Sequence[data.ClipAnnotation],
|
||||
output_dir: PathLike,
|
||||
filename_fn: FilenameFn = _get_filename,
|
||||
replace: bool = False,
|
||||
config_file: Optional[PathLike] = None,
|
||||
label_fn: LabelFn = lambda _: None,
|
||||
max_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Preprocess annotations and save to disk."""
|
||||
output_dir = Path(output_dir)
|
||||
|
||||
if not output_dir.is_dir():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
if config_file is not None:
|
||||
config = load_config(config_file, **kwargs)
|
||||
else:
|
||||
config = PreprocessingConfig(**kwargs)
|
||||
|
||||
with Pool(max_workers) as pool:
|
||||
list(
|
||||
tqdm(
|
||||
pool.imap_unordered(
|
||||
partial(
|
||||
preprocess_single_annotation,
|
||||
output_dir=output_dir,
|
||||
config=config,
|
||||
filename_fn=filename_fn,
|
||||
replace=replace,
|
||||
label_fn=label_fn,
|
||||
),
|
||||
clip_annotations,
|
||||
),
|
||||
total=len(clip_annotations),
|
||||
)
|
||||
)
|
82
batdetect2/train/train.py
Normal file
82
batdetect2/train/train.py
Normal file
@ -0,0 +1,82 @@
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
from soundevent import data
|
||||
from torch.optim import Adam
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from batdetect2.data.datasets import ClipAnnotationDataset
|
||||
from batdetect2.models.typing import DetectionModel
|
||||
|
||||
|
||||
class TrainInputs(NamedTuple):
|
||||
spec: torch.Tensor
|
||||
detection_heatmap: torch.Tensor
|
||||
class_heatmap: torch.Tensor
|
||||
size_heatmap: torch.Tensor
|
||||
|
||||
|
||||
def train_loop(
|
||||
model: DetectionModel,
|
||||
train_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
validation_dataset: ClipAnnotationDataset[TrainInputs],
|
||||
device: Optional[torch.device] = None,
|
||||
num_epochs: int = 100,
|
||||
learning_rate: float = 1e-4,
|
||||
):
|
||||
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
||||
validation_loader = DataLoader(validation_dataset, batch_size=32)
|
||||
|
||||
model.to(device)
|
||||
|
||||
optimizer = Adam(model.parameters(), lr=learning_rate)
|
||||
scheduler = CosineAnnealingLR(
|
||||
optimizer,
|
||||
num_epochs * len(train_loader),
|
||||
)
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
train_loss = train_single_epoch(
|
||||
model,
|
||||
train_loader,
|
||||
optimizer,
|
||||
device,
|
||||
scheduler,
|
||||
)
|
||||
|
||||
|
||||
def train_single_epoch(
|
||||
model: DetectionModel,
|
||||
train_loader: DataLoader,
|
||||
optimizer: Adam,
|
||||
device: torch.device,
|
||||
scheduler: CosineAnnealingLR,
|
||||
):
|
||||
model.train()
|
||||
train_loss = tu.AverageMeter()
|
||||
|
||||
for batch in train_loader:
|
||||
optimizer.zero_grad()
|
||||
|
||||
spec = batch.spec.to(device)
|
||||
detection_heatmap = batch.detection_heatmap.to(device)
|
||||
class_heatmap = batch.class_heatmap.to(device)
|
||||
size_heatmap = batch.size_heatmap.to(device)
|
||||
|
||||
outputs = model(spec)
|
||||
|
||||
loss = loss_fun(
|
||||
outputs,
|
||||
gt_det,
|
||||
gt_size,
|
||||
gt_class,
|
||||
det_criterion,
|
||||
params,
|
||||
class_inv_freq,
|
||||
)
|
||||
|
||||
train_loss.update(loss.item(), data.shape[0])
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
@ -1,3 +1,4 @@
|
||||
import sys
|
||||
import json
|
||||
from collections import Counter
|
||||
from pathlib import Path
|
||||
@ -7,6 +8,11 @@ import numpy as np
|
||||
|
||||
from batdetect2 import types
|
||||
|
||||
if sys.version_info >= (3, 9):
|
||||
StringCounter = Counter[str]
|
||||
else:
|
||||
from typing import Counter as StringCounter
|
||||
|
||||
|
||||
def write_notes_file(file_name: str, text: str):
|
||||
with open(file_name, "a") as da:
|
||||
@ -148,7 +154,7 @@ def format_annotation(
|
||||
def get_class_names(
|
||||
data: List[types.FileAnnotation],
|
||||
classes_to_ignore: Optional[List[str]] = None,
|
||||
) -> Tuple[Counter[str], List[float]]:
|
||||
) -> Tuple[StringCounter, List[float]]:
|
||||
"""Extracts class names and their inverse frequencies.
|
||||
|
||||
Parameters
|
||||
@ -182,7 +188,7 @@ def get_class_names(
|
||||
return counts, [mean_counts / counts[cc] for cc in class_names_list]
|
||||
|
||||
|
||||
def report_class_counts(class_names: Counter[str]):
|
||||
def report_class_counts(class_names: StringCounter):
|
||||
print("Class count:")
|
||||
str_len = np.max([len(cc) for cc in class_names]) + 5
|
||||
for index, (class_name, count) in enumerate(class_names.most_common()):
|
||||
|
@ -1,11 +1,13 @@
|
||||
import warnings
|
||||
from typing import Optional, Tuple, Union, overload
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import librosa
|
||||
import librosa.core.spectrum
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from . import wavfile
|
||||
|
||||
__all__ = [
|
||||
"load_audio",
|
||||
"generate_spectrogram",
|
||||
@ -13,171 +15,113 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
def time_to_x_coords(
|
||||
time_in_file: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> np.ndarray:
|
||||
...
|
||||
|
||||
|
||||
@overload
|
||||
def time_to_x_coords(
|
||||
time_in_file: float,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> float:
|
||||
...
|
||||
|
||||
|
||||
def time_to_x_coords(
|
||||
time_in_file: Union[float, np.ndarray],
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> Union[float, np.ndarray]:
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
def time_to_x_coords(time_in_file, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate) # int() uses floor
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return (time_in_file * sampling_rate - noverlap) / (nfft - noverlap)
|
||||
|
||||
|
||||
# NOTE this is also defined in post_process
|
||||
def x_coords_to_time(
|
||||
x_pos: float,
|
||||
sampling_rate: int,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
) -> float:
|
||||
def x_coords_to_time(x_pos, sampling_rate, fft_win_length, fft_overlap):
|
||||
nfft = np.floor(fft_win_length * sampling_rate)
|
||||
noverlap = np.floor(fft_overlap * nfft)
|
||||
return ((x_pos * (nfft - noverlap)) + noverlap) / sampling_rate
|
||||
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for
|
||||
# center of temporal window
|
||||
# return (1.0 - fft_overlap) * fft_win_length * (x_pos + 0.5) # 0.5 is for center of temporal window
|
||||
|
||||
|
||||
def generate_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
fft_overlap: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
spec_scale: str,
|
||||
denoise_spec_avg: bool = False,
|
||||
max_scale_spec: bool = False,
|
||||
) -> np.ndarray:
|
||||
audio,
|
||||
sampling_rate,
|
||||
params,
|
||||
return_spec_for_viz=False,
|
||||
check_spec_size=True,
|
||||
):
|
||||
# generate spectrogram
|
||||
spec = gen_mag_spectrogram(
|
||||
audio,
|
||||
sampling_rate,
|
||||
window_len=fft_win_length,
|
||||
overlap_perc=fft_overlap,
|
||||
)
|
||||
spec = crop_spectrogram(
|
||||
spec,
|
||||
fft_win_length=fft_win_length,
|
||||
max_freq=max_freq,
|
||||
min_freq=min_freq,
|
||||
)
|
||||
spec = scale_spectrogram(
|
||||
spec,
|
||||
sampling_rate,
|
||||
spec_scale=spec_scale,
|
||||
fft_win_length=fft_win_length,
|
||||
params["fft_win_length"],
|
||||
params["fft_overlap"],
|
||||
)
|
||||
|
||||
if denoise_spec_avg:
|
||||
spec = denoise_spectrogram(spec)
|
||||
|
||||
if max_scale_spec:
|
||||
spec = max_scale_spectrogram(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def crop_spectrogram(
|
||||
spec: np.ndarray,
|
||||
fft_win_length: float,
|
||||
max_freq: float,
|
||||
min_freq: float,
|
||||
) -> np.ndarray:
|
||||
# crop to min/max freq
|
||||
max_freq = round(max_freq * fft_win_length)
|
||||
min_freq = round(min_freq * fft_win_length)
|
||||
max_freq = round(params["max_freq"] * params["fft_win_length"])
|
||||
min_freq = round(params["min_freq"] * params["fft_win_length"])
|
||||
if spec.shape[0] < max_freq:
|
||||
freq_pad = max_freq - spec.shape[0]
|
||||
spec = np.vstack(
|
||||
(np.zeros((freq_pad, spec.shape[1]), dtype=spec.dtype), spec)
|
||||
)
|
||||
return spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
spec = spec[-max_freq : spec.shape[0] - min_freq, :]
|
||||
|
||||
|
||||
def denoise_spectrogram(spec: np.ndarray) -> np.ndarray:
|
||||
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
||||
return spec.clip(min=0)
|
||||
|
||||
|
||||
def max_scale_spectrogram(spec: np.ndarray) -> np.ndarray:
|
||||
return spec / (spec.max() + 10e-6)
|
||||
|
||||
|
||||
def log_scale(
|
||||
spec: np.ndarray,
|
||||
sampling_rate: float,
|
||||
fft_win_length: float,
|
||||
) -> np.ndarray:
|
||||
if params["spec_scale"] == "log":
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (
|
||||
1.0
|
||||
/ (
|
||||
np.abs(np.hanning(int(fft_win_length * sampling_rate))) ** 2
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
return np.log1p(log_scaling * spec)
|
||||
# log_scaling = (1.0 / sampling_rate)*0.1
|
||||
# log_scaling = (1.0 / sampling_rate)*10e4
|
||||
spec = np.log1p(log_scaling * spec)
|
||||
elif params["spec_scale"] == "pcen":
|
||||
spec = pcen(spec , sampling_rate)
|
||||
|
||||
elif params["spec_scale"] == "none":
|
||||
pass
|
||||
|
||||
def scale_spectrogram(
|
||||
spec: np.ndarray,
|
||||
sampling_rate: float,
|
||||
spec_scale: str,
|
||||
fft_win_length: float,
|
||||
) -> np.ndarray:
|
||||
if spec_scale == "log":
|
||||
return log_scale(spec, sampling_rate, fft_win_length)
|
||||
if params["denoise_spec_avg"]:
|
||||
spec = spec - np.mean(spec, 1)[:, np.newaxis]
|
||||
spec.clip(min=0, out=spec)
|
||||
|
||||
if spec_scale == "pcen":
|
||||
return pcen(spec, sampling_rate)
|
||||
if params["max_scale_spec"]:
|
||||
spec = spec / (spec.max() + 10e-6)
|
||||
|
||||
return spec
|
||||
# needs to be divisible by specific factor - if not it should have been padded
|
||||
# if check_spec_size:
|
||||
# assert((int(spec.shape[0]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
# assert((int(spec.shape[1]*params['resize_factor']) % params['spec_divide_factor']) == 0)
|
||||
|
||||
|
||||
def prepare_spec_for_viz(
|
||||
spec: np.ndarray,
|
||||
sampling_rate: int,
|
||||
fft_win_length: float,
|
||||
) -> np.ndarray:
|
||||
# for visualization purposes - use log scaled spectrogram
|
||||
return log_scale(
|
||||
spec,
|
||||
sampling_rate,
|
||||
fft_win_length=fft_win_length,
|
||||
).astype(np.float32)
|
||||
if return_spec_for_viz:
|
||||
log_scaling = (
|
||||
2.0
|
||||
* (1.0 / sampling_rate)
|
||||
* (
|
||||
1.0
|
||||
/ (
|
||||
np.abs(
|
||||
np.hanning(
|
||||
int(params["fft_win_length"] * sampling_rate)
|
||||
)
|
||||
)
|
||||
** 2
|
||||
).sum()
|
||||
)
|
||||
)
|
||||
spec_for_viz = np.log1p(log_scaling * spec).astype(np.float32)
|
||||
else:
|
||||
spec_for_viz = None
|
||||
|
||||
return spec, spec_for_viz
|
||||
|
||||
|
||||
def load_audio(
|
||||
audio_file: str,
|
||||
time_exp_fact: float,
|
||||
target_sampling_rate: int,
|
||||
target_samp_rate: int,
|
||||
scale: bool = False,
|
||||
max_duration: Optional[float] = None,
|
||||
) -> Tuple[float, np.ndarray]:
|
||||
) -> Tuple[int, np.ndarray]:
|
||||
"""Load an audio file and resample it to the target sampling rate.
|
||||
|
||||
The audio is also scaled to [-1, 1] and clipped to the maximum duration.
|
||||
@ -208,82 +152,63 @@ def load_audio(
|
||||
|
||||
"""
|
||||
with warnings.catch_warnings():
|
||||
audio, sampling_rate = librosa.load(
|
||||
warnings.filterwarnings("ignore", category=wavfile.WavFileWarning)
|
||||
# sampling_rate, audio_raw = wavfile.read(audio_file)
|
||||
audio_raw, sampling_rate = librosa.load(
|
||||
audio_file,
|
||||
sr=None,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
if len(audio.shape) > 1:
|
||||
if len(audio_raw.shape) > 1:
|
||||
raise ValueError("Currently does not handle stereo files")
|
||||
|
||||
sampling_rate = sampling_rate * time_exp_fact
|
||||
|
||||
# resample - need to do this after correcting for time expansion
|
||||
audio = resample_audio(audio, sampling_rate, target_sampling_rate)
|
||||
|
||||
if max_duration is not None:
|
||||
audio = clip_audio(audio, target_sampling_rate, max_duration)
|
||||
|
||||
# scale to [-1, 1]
|
||||
if scale:
|
||||
audio = scale_audio(audio)
|
||||
|
||||
return target_sampling_rate, audio
|
||||
|
||||
|
||||
def resample_audio(
|
||||
audio: np.ndarray,
|
||||
sr_orig: float,
|
||||
sr_target: float,
|
||||
) -> np.ndarray:
|
||||
if sr_orig != sr_target:
|
||||
return librosa.resample(
|
||||
audio,
|
||||
orig_sr=sr_orig,
|
||||
target_sr=sr_target,
|
||||
sampling_rate_old = sampling_rate
|
||||
sampling_rate = target_samp_rate
|
||||
if sampling_rate_old != sampling_rate:
|
||||
audio_raw = librosa.resample(
|
||||
audio_raw,
|
||||
orig_sr=sampling_rate_old,
|
||||
target_sr=sampling_rate,
|
||||
res_type="polyphase",
|
||||
)
|
||||
|
||||
return audio
|
||||
|
||||
|
||||
def clip_audio(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
max_duration: float,
|
||||
) -> np.ndarray:
|
||||
# clipping maximum duration
|
||||
if max_duration is not None:
|
||||
max_duration = int(
|
||||
np.minimum(
|
||||
int(sampling_rate * max_duration),
|
||||
audio.shape[0],
|
||||
audio_raw.shape[0],
|
||||
)
|
||||
)
|
||||
return audio[:max_duration]
|
||||
audio_raw = audio_raw[:max_duration]
|
||||
|
||||
# scale to [-1, 1]
|
||||
if scale:
|
||||
audio_raw = audio_raw - audio_raw.mean()
|
||||
audio_raw = audio_raw / (np.abs(audio_raw).max() + 10e-6)
|
||||
|
||||
def scale_audio(
|
||||
audio: np.ndarray,
|
||||
eps: float = 10e-6,
|
||||
) -> np.ndarray:
|
||||
return (audio - audio.mean()) / (np.abs(audio).max() + eps)
|
||||
return sampling_rate, audio_raw
|
||||
|
||||
|
||||
def pad_audio(
|
||||
audio_raw: np.ndarray,
|
||||
sampling_rate: float,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
resize_factor: float,
|
||||
divide_factor: float,
|
||||
fixed_width: Optional[int] = None,
|
||||
) -> np.ndarray:
|
||||
audio_raw,
|
||||
fs,
|
||||
ms,
|
||||
overlap_perc,
|
||||
resize_factor,
|
||||
divide_factor,
|
||||
fixed_width=None,
|
||||
):
|
||||
# Adds zeros to the end of the raw data so that the generated sepctrogram
|
||||
# will be evenly divisible by `divide_factor`
|
||||
# Also deals with very short audio clips and fixed_width during training
|
||||
|
||||
# This code could be clearer, clean up
|
||||
nfft = int(window_len * sampling_rate)
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
step = nfft - noverlap
|
||||
min_size = int(divide_factor * (1.0 / resize_factor))
|
||||
@ -320,23 +245,22 @@ def pad_audio(
|
||||
return audio_raw
|
||||
|
||||
|
||||
def gen_mag_spectrogram(
|
||||
audio: np.ndarray,
|
||||
sampling_rate: float,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
) -> np.ndarray:
|
||||
def gen_mag_spectrogram(x, fs, ms, overlap_perc):
|
||||
# Computes magnitude spectrogram by specifying time.
|
||||
audio = audio.astype(np.float32)
|
||||
nfft = int(window_len * sampling_rate)
|
||||
|
||||
x = x.astype(np.float32)
|
||||
nfft = int(ms * fs)
|
||||
noverlap = int(overlap_perc * nfft)
|
||||
|
||||
# window data
|
||||
step = nfft - noverlap
|
||||
|
||||
# compute spec
|
||||
spec, _ = librosa.core.spectrum._spectrogram(
|
||||
y=audio,
|
||||
y=x,
|
||||
power=1,
|
||||
n_fft=nfft,
|
||||
hop_length=nfft - noverlap,
|
||||
hop_length=step,
|
||||
center=False,
|
||||
)
|
||||
|
||||
@ -346,25 +270,24 @@ def gen_mag_spectrogram(
|
||||
return spec.astype(np.float32)
|
||||
|
||||
|
||||
def gen_mag_spectrogram_pt(
|
||||
audio: torch.Tensor,
|
||||
sampling_rate: float,
|
||||
window_len: float,
|
||||
overlap_perc: float,
|
||||
) -> torch.Tensor:
|
||||
nfft = int(window_len * sampling_rate)
|
||||
def gen_mag_spectrogram_pt(x, fs, ms, overlap_perc):
|
||||
nfft = int(ms * fs)
|
||||
nstep = round((1.0 - overlap_perc) * nfft)
|
||||
han_win = torch.hann_window(nfft, periodic=False).to(audio.device)
|
||||
|
||||
complex_spec = torch.stft(audio, nfft, nstep, window=han_win, center=False)
|
||||
han_win = torch.hann_window(nfft, periodic=False).to(x.device)
|
||||
|
||||
complex_spec = torch.stft(x, nfft, nstep, window=han_win, center=False)
|
||||
spec = complex_spec.pow(2.0).sum(-1)
|
||||
|
||||
# remove DC component and flip vertically
|
||||
return torch.flipud(spec[0, 1:, :])
|
||||
spec = torch.flipud(spec[0, 1:, :])
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
def pcen(spec: np.ndarray, sampling_rate: float) -> np.ndarray:
|
||||
def pcen(spec_cropped, sampling_rate):
|
||||
# TODO should be passing hop_length too i.e. step
|
||||
return librosa.pcen(spec * (2**31), sr=sampling_rate / 10).astype(
|
||||
spec = librosa.pcen(spec_cropped * (2**31), sr=sampling_rate / 10).astype(
|
||||
np.float32
|
||||
)
|
||||
return spec
|
||||
|
@ -437,7 +437,7 @@ def compute_spectrogram(
|
||||
)
|
||||
|
||||
# generate spectrogram
|
||||
spec = au.generate_spectrogram(audio, sampling_rate, params)
|
||||
spec, _ = au.generate_spectrogram(audio, sampling_rate, params)
|
||||
|
||||
# convert to pytorch
|
||||
spec = torch.from_numpy(spec).to(device)
|
||||
@ -746,7 +746,7 @@ def process_file(
|
||||
sampling_rate, audio_full = au.load_audio(
|
||||
audio_file,
|
||||
time_exp_fact=config.get("time_expansion", 1) or 1,
|
||||
target_sampling_rate=config["target_samp_rate"],
|
||||
target_samp_rate=config["target_samp_rate"],
|
||||
scale=config["scale_raw_audio"],
|
||||
max_duration=config.get("max_duration"),
|
||||
)
|
||||
|
@ -1,17 +0,0 @@
|
||||
name: batdetect2
|
||||
channels:
|
||||
- defaults
|
||||
- conda-forge
|
||||
- pytorch
|
||||
- nvidia
|
||||
dependencies:
|
||||
- python==3.10
|
||||
- matplotlib
|
||||
- pandas
|
||||
- scikit-learn
|
||||
- numpy
|
||||
- pytorch
|
||||
- scipy
|
||||
- torchvision
|
||||
- librosa
|
||||
- torchaudio
|
@ -1,3 +1,9 @@
|
||||
[tool]
|
||||
rye = { dev-dependencies = [
|
||||
"ipykernel>=6.29.4",
|
||||
"setuptools>=69.5.1",
|
||||
"pytest>=8.1.1",
|
||||
] }
|
||||
[tool.pdm]
|
||||
[tool.pdm.dev-dependencies]
|
||||
dev = [
|
||||
@ -22,12 +28,17 @@ dependencies = [
|
||||
"torch>=1.13.1",
|
||||
"torchaudio",
|
||||
"torchvision",
|
||||
"click",
|
||||
"soundevent>=1.3.5",
|
||||
"click",
|
||||
"soundevent[audio,geometry,plot]>=1.3.5",
|
||||
"click>=8.1.7",
|
||||
"netcdf4>=1.6.5",
|
||||
"tqdm>=4.66.2",
|
||||
"pytorch-lightning>=2.2.2",
|
||||
"cf-xarray>=0.9.0",
|
||||
"onnx>=1.16.0",
|
||||
"lightning[extra]>=2.2.2",
|
||||
"tensorboard>=2.16.2",
|
||||
]
|
||||
requires-python = ">=3.8,<3.12"
|
||||
requires-python = ">=3.9"
|
||||
readme = "README.md"
|
||||
license = { text = "CC-by-nc-4" }
|
||||
classifiers = [
|
||||
@ -65,6 +76,9 @@ line-length = 79
|
||||
profile = "black"
|
||||
line_length = 79
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"librosa",
|
||||
|
518
requirements-dev.lock
Normal file
518
requirements-dev.lock
Normal file
@ -0,0 +1,518 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
absl-py==2.1.0
|
||||
# via tensorboard
|
||||
aiobotocore==2.12.3
|
||||
# via s3fs
|
||||
aiohttp==3.9.5
|
||||
# via aiobotocore
|
||||
# via fsspec
|
||||
# via lightning
|
||||
# via s3fs
|
||||
aioitertools==0.11.0
|
||||
# via aiobotocore
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via hydra-core
|
||||
# via omegaconf
|
||||
anyio==4.3.0
|
||||
# via starlette
|
||||
arrow==1.3.0
|
||||
# via lightning
|
||||
asttokens==2.4.1
|
||||
# via stack-data
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
# via redis
|
||||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
audioread==3.0.1
|
||||
# via librosa
|
||||
backcall==0.2.0
|
||||
# via ipython
|
||||
backoff==2.2.1
|
||||
# via lightning
|
||||
beautifulsoup4==4.12.3
|
||||
# via lightning
|
||||
bitsandbytes==0.41.0
|
||||
# via lightning
|
||||
blessed==1.20.0
|
||||
# via inquirer
|
||||
boto3==1.34.69
|
||||
# via lightning-cloud
|
||||
botocore==1.34.69
|
||||
# via aiobotocore
|
||||
# via boto3
|
||||
# via s3transfer
|
||||
certifi==2024.2.2
|
||||
# via netcdf4
|
||||
# via requests
|
||||
cf-xarray==0.9.0
|
||||
# via batdetect2
|
||||
cffi==1.16.0
|
||||
# via soundfile
|
||||
cftime==1.6.3
|
||||
# via netcdf4
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
# via uvicorn
|
||||
comm==0.2.2
|
||||
# via ipykernel
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
croniter==1.4.1
|
||||
# via lightning
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
cython==3.0.10
|
||||
# via soundevent
|
||||
dateutils==0.6.12
|
||||
# via lightning
|
||||
debugpy==1.8.1
|
||||
# via ipykernel
|
||||
decorator==5.1.1
|
||||
# via ipython
|
||||
# via librosa
|
||||
deepdiff==6.7.1
|
||||
# via lightning
|
||||
dnspython==2.6.1
|
||||
# via email-validator
|
||||
docker==6.1.3
|
||||
# via lightning
|
||||
docstring-parser==0.16
|
||||
# via jsonargparse
|
||||
editor==1.6.6
|
||||
# via inquirer
|
||||
email-validator==2.1.1
|
||||
# via soundevent
|
||||
exceptiongroup==1.2.1
|
||||
# via anyio
|
||||
# via pytest
|
||||
executing==2.0.1
|
||||
# via stack-data
|
||||
fastapi==0.110.2
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
filelock==3.13.4
|
||||
# via torch
|
||||
# via triton
|
||||
fonttools==4.51.0
|
||||
# via matplotlib
|
||||
frozenlist==1.4.1
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
fsspec==2023.12.2
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via pytorch-lightning
|
||||
# via s3fs
|
||||
# via torch
|
||||
grpcio==1.62.2
|
||||
# via tensorboard
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
hydra-core==1.3.2
|
||||
# via lightning
|
||||
idna==3.7
|
||||
# via anyio
|
||||
# via email-validator
|
||||
# via requests
|
||||
# via yarl
|
||||
importlib-metadata==7.1.0
|
||||
# via jupyter-client
|
||||
# via markdown
|
||||
importlib-resources==6.4.0
|
||||
# via matplotlib
|
||||
# via typeshed-client
|
||||
iniconfig==2.0.0
|
||||
# via pytest
|
||||
inquirer==3.2.4
|
||||
# via lightning
|
||||
ipykernel==6.29.4
|
||||
ipython==8.12.3
|
||||
# via ipykernel
|
||||
jedi==0.19.1
|
||||
# via ipython
|
||||
jinja2==3.1.3
|
||||
# via lightning
|
||||
# via torch
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
joblib==1.4.0
|
||||
# via librosa
|
||||
# via scikit-learn
|
||||
jsonargparse==4.28.0
|
||||
# via lightning
|
||||
jupyter-client==8.6.1
|
||||
# via ipykernel
|
||||
jupyter-core==5.7.2
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
lazy-loader==0.4
|
||||
# via librosa
|
||||
librosa==0.10.1
|
||||
# via batdetect2
|
||||
lightning==2.2.2
|
||||
# via batdetect2
|
||||
lightning-api-access==0.0.5
|
||||
# via lightning
|
||||
lightning-cloud==0.5.65
|
||||
# via lightning
|
||||
lightning-fabric==2.2.2
|
||||
# via lightning
|
||||
lightning-utilities==0.11.2
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via pytorch-lightning
|
||||
# via torchmetrics
|
||||
llvmlite==0.41.1
|
||||
# via numba
|
||||
markdown==3.6
|
||||
# via tensorboard
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
# via werkzeug
|
||||
matplotlib==3.7.5
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via soundevent
|
||||
matplotlib-inline==0.1.7
|
||||
# via ipykernel
|
||||
# via ipython
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
msgpack==1.0.8
|
||||
# via librosa
|
||||
multidict==6.0.5
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
nest-asyncio==1.6.0
|
||||
# via ipykernel
|
||||
netcdf4==1.6.5
|
||||
# via batdetect2
|
||||
networkx==3.1
|
||||
# via torch
|
||||
numba==0.58.1
|
||||
# via librosa
|
||||
numpy==1.24.4
|
||||
# via batdetect2
|
||||
# via cftime
|
||||
# via contourpy
|
||||
# via librosa
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via matplotlib
|
||||
# via netcdf4
|
||||
# via numba
|
||||
# via onnx
|
||||
# via pandas
|
||||
# via pytorch-lightning
|
||||
# via scikit-learn
|
||||
# via scipy
|
||||
# via shapely
|
||||
# via soxr
|
||||
# via tensorboard
|
||||
# via tensorboardx
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
# via xarray
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.19.3
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
# via lightning
|
||||
onnx==1.16.0
|
||||
# via batdetect2
|
||||
ordered-set==4.1.0
|
||||
# via deepdiff
|
||||
packaging==24.0
|
||||
# via docker
|
||||
# via hydra-core
|
||||
# via ipykernel
|
||||
# via lazy-loader
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via lightning-utilities
|
||||
# via matplotlib
|
||||
# via pooch
|
||||
# via pytest
|
||||
# via pytorch-lightning
|
||||
# via tensorboardx
|
||||
# via torchmetrics
|
||||
# via xarray
|
||||
pandas==2.0.3
|
||||
# via batdetect2
|
||||
# via xarray
|
||||
parso==0.8.4
|
||||
# via jedi
|
||||
pexpect==4.9.0
|
||||
# via ipython
|
||||
pickleshare==0.7.5
|
||||
# via ipython
|
||||
pillow==10.3.0
|
||||
# via matplotlib
|
||||
# via torchvision
|
||||
platformdirs==4.2.0
|
||||
# via jupyter-core
|
||||
# via pooch
|
||||
pluggy==1.4.0
|
||||
# via pytest
|
||||
pooch==1.8.1
|
||||
# via librosa
|
||||
prompt-toolkit==3.0.43
|
||||
# via ipython
|
||||
protobuf==5.26.1
|
||||
# via onnx
|
||||
# via tensorboard
|
||||
# via tensorboardx
|
||||
psutil==5.9.8
|
||||
# via ipykernel
|
||||
# via lightning
|
||||
ptyprocess==0.7.0
|
||||
# via pexpect
|
||||
pure-eval==0.2.2
|
||||
# via stack-data
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.7.0
|
||||
# via fastapi
|
||||
# via lightning
|
||||
# via soundevent
|
||||
pydantic-core==2.18.1
|
||||
# via pydantic
|
||||
pygments==2.17.2
|
||||
# via ipython
|
||||
# via rich
|
||||
pyjwt==2.8.0
|
||||
# via lightning-cloud
|
||||
pyparsing==3.1.2
|
||||
# via matplotlib
|
||||
pytest==8.1.1
|
||||
python-dateutil==2.9.0.post0
|
||||
# via arrow
|
||||
# via botocore
|
||||
# via croniter
|
||||
# via dateutils
|
||||
# via jupyter-client
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
python-multipart==0.0.9
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
pytorch-lightning==2.2.2
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
pytz==2024.1
|
||||
# via dateutils
|
||||
# via pandas
|
||||
pyyaml==6.0.1
|
||||
# via jsonargparse
|
||||
# via lightning
|
||||
# via omegaconf
|
||||
# via pytorch-lightning
|
||||
pyzmq==26.0.0
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
readchar==4.0.6
|
||||
# via inquirer
|
||||
redis==5.0.4
|
||||
# via lightning
|
||||
requests==2.31.0
|
||||
# via docker
|
||||
# via fsspec
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
# via pooch
|
||||
rich==13.7.1
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
runs==1.2.2
|
||||
# via editor
|
||||
s3fs==2023.12.2
|
||||
# via lightning
|
||||
s3transfer==0.10.1
|
||||
# via boto3
|
||||
scikit-learn==1.3.2
|
||||
# via batdetect2
|
||||
# via librosa
|
||||
scipy==1.10.1
|
||||
# via batdetect2
|
||||
# via librosa
|
||||
# via scikit-learn
|
||||
# via soundevent
|
||||
setuptools==69.5.1
|
||||
# via lightning-utilities
|
||||
# via readchar
|
||||
# via tensorboard
|
||||
shapely==2.0.3
|
||||
# via soundevent
|
||||
six==1.16.0
|
||||
# via asttokens
|
||||
# via blessed
|
||||
# via lightning-cloud
|
||||
# via python-dateutil
|
||||
# via tensorboard
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundevent==1.3.5
|
||||
# via batdetect2
|
||||
soundfile==0.12.1
|
||||
# via librosa
|
||||
# via soundevent
|
||||
soupsieve==2.5
|
||||
# via beautifulsoup4
|
||||
soxr==0.3.7
|
||||
# via librosa
|
||||
stack-data==0.6.3
|
||||
# via ipython
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
# via lightning
|
||||
sympy==1.12
|
||||
# via torch
|
||||
tensorboard==2.16.2
|
||||
# via batdetect2
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.2.2
|
||||
# via lightning
|
||||
threadpoolctl==3.4.0
|
||||
# via scikit-learn
|
||||
tomli==2.0.1
|
||||
# via pytest
|
||||
torch==2.2.2
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via pytorch-lightning
|
||||
# via torchaudio
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
torchaudio==2.2.2
|
||||
# via batdetect2
|
||||
torchmetrics==1.3.2
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
torchvision==0.17.2
|
||||
# via batdetect2
|
||||
tornado==6.4
|
||||
# via ipykernel
|
||||
# via jupyter-client
|
||||
tqdm==4.66.2
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
traitlets==5.14.2
|
||||
# via comm
|
||||
# via ipykernel
|
||||
# via ipython
|
||||
# via jupyter-client
|
||||
# via jupyter-core
|
||||
# via lightning
|
||||
# via matplotlib-inline
|
||||
triton==2.2.0
|
||||
# via torch
|
||||
types-python-dateutil==2.9.0.20240316
|
||||
# via arrow
|
||||
typeshed-client==2.5.1
|
||||
# via jsonargparse
|
||||
typing-extensions==4.11.0
|
||||
# via aioitertools
|
||||
# via anyio
|
||||
# via fastapi
|
||||
# via ipython
|
||||
# via jsonargparse
|
||||
# via librosa
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via lightning-utilities
|
||||
# via pydantic
|
||||
# via pydantic-core
|
||||
# via pytorch-lightning
|
||||
# via starlette
|
||||
# via torch
|
||||
# via typeshed-client
|
||||
# via uvicorn
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
urllib3==1.26.18
|
||||
# via botocore
|
||||
# via docker
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
# via requests
|
||||
uvicorn==0.29.0
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
wcwidth==0.2.13
|
||||
# via blessed
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.7.0
|
||||
# via docker
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
websockets==11.0.3
|
||||
# via lightning
|
||||
werkzeug==3.0.2
|
||||
# via tensorboard
|
||||
wrapt==1.16.0
|
||||
# via aiobotocore
|
||||
xarray==2023.1.0
|
||||
# via cf-xarray
|
||||
# via soundevent
|
||||
xmod==1.8.1
|
||||
# via editor
|
||||
# via runs
|
||||
yarl==1.9.4
|
||||
# via aiohttp
|
||||
zipp==3.18.1
|
||||
# via importlib-metadata
|
||||
# via importlib-resources
|
448
requirements.lock
Normal file
448
requirements.lock
Normal file
@ -0,0 +1,448 @@
|
||||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
absl-py==2.1.0
|
||||
# via tensorboard
|
||||
aiobotocore==2.12.3
|
||||
# via s3fs
|
||||
aiohttp==3.9.5
|
||||
# via aiobotocore
|
||||
# via fsspec
|
||||
# via lightning
|
||||
# via s3fs
|
||||
aioitertools==0.11.0
|
||||
# via aiobotocore
|
||||
aiosignal==1.3.1
|
||||
# via aiohttp
|
||||
annotated-types==0.6.0
|
||||
# via pydantic
|
||||
antlr4-python3-runtime==4.9.3
|
||||
# via hydra-core
|
||||
# via omegaconf
|
||||
anyio==4.3.0
|
||||
# via starlette
|
||||
arrow==1.3.0
|
||||
# via lightning
|
||||
async-timeout==4.0.3
|
||||
# via aiohttp
|
||||
# via redis
|
||||
attrs==23.2.0
|
||||
# via aiohttp
|
||||
audioread==3.0.1
|
||||
# via librosa
|
||||
backoff==2.2.1
|
||||
# via lightning
|
||||
beautifulsoup4==4.12.3
|
||||
# via lightning
|
||||
bitsandbytes==0.41.0
|
||||
# via lightning
|
||||
blessed==1.20.0
|
||||
# via inquirer
|
||||
boto3==1.34.69
|
||||
# via lightning-cloud
|
||||
botocore==1.34.69
|
||||
# via aiobotocore
|
||||
# via boto3
|
||||
# via s3transfer
|
||||
certifi==2024.2.2
|
||||
# via netcdf4
|
||||
# via requests
|
||||
cf-xarray==0.9.0
|
||||
# via batdetect2
|
||||
cffi==1.16.0
|
||||
# via soundfile
|
||||
cftime==1.6.3
|
||||
# via netcdf4
|
||||
charset-normalizer==3.3.2
|
||||
# via requests
|
||||
click==8.1.7
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
# via uvicorn
|
||||
contourpy==1.1.1
|
||||
# via matplotlib
|
||||
croniter==1.4.1
|
||||
# via lightning
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
cython==3.0.10
|
||||
# via soundevent
|
||||
dateutils==0.6.12
|
||||
# via lightning
|
||||
decorator==5.1.1
|
||||
# via librosa
|
||||
deepdiff==6.7.1
|
||||
# via lightning
|
||||
dnspython==2.6.1
|
||||
# via email-validator
|
||||
docker==6.1.3
|
||||
# via lightning
|
||||
docstring-parser==0.16
|
||||
# via jsonargparse
|
||||
editor==1.6.6
|
||||
# via inquirer
|
||||
email-validator==2.1.1
|
||||
# via soundevent
|
||||
exceptiongroup==1.2.1
|
||||
# via anyio
|
||||
fastapi==0.110.2
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
filelock==3.13.4
|
||||
# via torch
|
||||
# via triton
|
||||
fonttools==4.51.0
|
||||
# via matplotlib
|
||||
frozenlist==1.4.1
|
||||
# via aiohttp
|
||||
# via aiosignal
|
||||
fsspec==2023.12.2
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via pytorch-lightning
|
||||
# via s3fs
|
||||
# via torch
|
||||
grpcio==1.62.2
|
||||
# via tensorboard
|
||||
h11==0.14.0
|
||||
# via uvicorn
|
||||
hydra-core==1.3.2
|
||||
# via lightning
|
||||
idna==3.7
|
||||
# via anyio
|
||||
# via email-validator
|
||||
# via requests
|
||||
# via yarl
|
||||
importlib-metadata==7.1.0
|
||||
# via markdown
|
||||
importlib-resources==6.4.0
|
||||
# via matplotlib
|
||||
# via typeshed-client
|
||||
inquirer==3.2.4
|
||||
# via lightning
|
||||
jinja2==3.1.3
|
||||
# via lightning
|
||||
# via torch
|
||||
jmespath==1.0.1
|
||||
# via boto3
|
||||
# via botocore
|
||||
joblib==1.4.0
|
||||
# via librosa
|
||||
# via scikit-learn
|
||||
jsonargparse==4.28.0
|
||||
# via lightning
|
||||
kiwisolver==1.4.5
|
||||
# via matplotlib
|
||||
lazy-loader==0.4
|
||||
# via librosa
|
||||
librosa==0.10.1
|
||||
# via batdetect2
|
||||
lightning==2.2.2
|
||||
# via batdetect2
|
||||
lightning-api-access==0.0.5
|
||||
# via lightning
|
||||
lightning-cloud==0.5.65
|
||||
# via lightning
|
||||
lightning-fabric==2.2.2
|
||||
# via lightning
|
||||
lightning-utilities==0.11.2
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via pytorch-lightning
|
||||
# via torchmetrics
|
||||
llvmlite==0.41.1
|
||||
# via numba
|
||||
markdown==3.6
|
||||
# via tensorboard
|
||||
markdown-it-py==3.0.0
|
||||
# via rich
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
# via werkzeug
|
||||
matplotlib==3.7.5
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via soundevent
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
msgpack==1.0.8
|
||||
# via librosa
|
||||
multidict==6.0.5
|
||||
# via aiohttp
|
||||
# via yarl
|
||||
netcdf4==1.6.5
|
||||
# via batdetect2
|
||||
networkx==3.1
|
||||
# via torch
|
||||
numba==0.58.1
|
||||
# via librosa
|
||||
numpy==1.24.4
|
||||
# via batdetect2
|
||||
# via cftime
|
||||
# via contourpy
|
||||
# via librosa
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via matplotlib
|
||||
# via netcdf4
|
||||
# via numba
|
||||
# via onnx
|
||||
# via pandas
|
||||
# via pytorch-lightning
|
||||
# via scikit-learn
|
||||
# via scipy
|
||||
# via shapely
|
||||
# via soxr
|
||||
# via tensorboard
|
||||
# via tensorboardx
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
# via xarray
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.19.3
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.4.127
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
omegaconf==2.3.0
|
||||
# via hydra-core
|
||||
# via lightning
|
||||
onnx==1.16.0
|
||||
# via batdetect2
|
||||
ordered-set==4.1.0
|
||||
# via deepdiff
|
||||
packaging==24.0
|
||||
# via docker
|
||||
# via hydra-core
|
||||
# via lazy-loader
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via lightning-utilities
|
||||
# via matplotlib
|
||||
# via pooch
|
||||
# via pytorch-lightning
|
||||
# via tensorboardx
|
||||
# via torchmetrics
|
||||
# via xarray
|
||||
pandas==2.0.3
|
||||
# via batdetect2
|
||||
# via xarray
|
||||
pillow==10.3.0
|
||||
# via matplotlib
|
||||
# via torchvision
|
||||
platformdirs==4.2.0
|
||||
# via pooch
|
||||
pooch==1.8.1
|
||||
# via librosa
|
||||
protobuf==5.26.1
|
||||
# via onnx
|
||||
# via tensorboard
|
||||
# via tensorboardx
|
||||
psutil==5.9.8
|
||||
# via lightning
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pydantic==2.7.0
|
||||
# via fastapi
|
||||
# via lightning
|
||||
# via soundevent
|
||||
pydantic-core==2.18.1
|
||||
# via pydantic
|
||||
pygments==2.17.2
|
||||
# via rich
|
||||
pyjwt==2.8.0
|
||||
# via lightning-cloud
|
||||
pyparsing==3.1.2
|
||||
# via matplotlib
|
||||
python-dateutil==2.9.0.post0
|
||||
# via arrow
|
||||
# via botocore
|
||||
# via croniter
|
||||
# via dateutils
|
||||
# via matplotlib
|
||||
# via pandas
|
||||
python-multipart==0.0.9
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
pytorch-lightning==2.2.2
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
pytz==2024.1
|
||||
# via dateutils
|
||||
# via pandas
|
||||
pyyaml==6.0.1
|
||||
# via jsonargparse
|
||||
# via lightning
|
||||
# via omegaconf
|
||||
# via pytorch-lightning
|
||||
readchar==4.0.6
|
||||
# via inquirer
|
||||
redis==5.0.4
|
||||
# via lightning
|
||||
requests==2.31.0
|
||||
# via docker
|
||||
# via fsspec
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
# via pooch
|
||||
rich==13.7.1
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
runs==1.2.2
|
||||
# via editor
|
||||
s3fs==2023.12.2
|
||||
# via lightning
|
||||
s3transfer==0.10.1
|
||||
# via boto3
|
||||
scikit-learn==1.3.2
|
||||
# via batdetect2
|
||||
# via librosa
|
||||
scipy==1.10.1
|
||||
# via batdetect2
|
||||
# via librosa
|
||||
# via scikit-learn
|
||||
# via soundevent
|
||||
setuptools==69.5.1
|
||||
# via lightning-utilities
|
||||
# via readchar
|
||||
# via tensorboard
|
||||
shapely==2.0.3
|
||||
# via soundevent
|
||||
six==1.16.0
|
||||
# via blessed
|
||||
# via lightning-cloud
|
||||
# via python-dateutil
|
||||
# via tensorboard
|
||||
sniffio==1.3.1
|
||||
# via anyio
|
||||
soundevent==1.3.5
|
||||
# via batdetect2
|
||||
soundfile==0.12.1
|
||||
# via librosa
|
||||
# via soundevent
|
||||
soupsieve==2.5
|
||||
# via beautifulsoup4
|
||||
soxr==0.3.7
|
||||
# via librosa
|
||||
starlette==0.37.2
|
||||
# via fastapi
|
||||
# via lightning
|
||||
sympy==1.12
|
||||
# via torch
|
||||
tensorboard==2.16.2
|
||||
# via batdetect2
|
||||
tensorboard-data-server==0.7.2
|
||||
# via tensorboard
|
||||
tensorboardx==2.6.2.2
|
||||
# via lightning
|
||||
threadpoolctl==3.4.0
|
||||
# via scikit-learn
|
||||
torch==2.2.2
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via pytorch-lightning
|
||||
# via torchaudio
|
||||
# via torchmetrics
|
||||
# via torchvision
|
||||
torchaudio==2.2.2
|
||||
# via batdetect2
|
||||
torchmetrics==1.3.2
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
torchvision==0.17.2
|
||||
# via batdetect2
|
||||
tqdm==4.66.2
|
||||
# via batdetect2
|
||||
# via lightning
|
||||
# via pytorch-lightning
|
||||
traitlets==5.14.3
|
||||
# via lightning
|
||||
triton==2.2.0
|
||||
# via torch
|
||||
types-python-dateutil==2.9.0.20240316
|
||||
# via arrow
|
||||
typeshed-client==2.5.1
|
||||
# via jsonargparse
|
||||
typing-extensions==4.11.0
|
||||
# via aioitertools
|
||||
# via anyio
|
||||
# via fastapi
|
||||
# via jsonargparse
|
||||
# via librosa
|
||||
# via lightning
|
||||
# via lightning-fabric
|
||||
# via lightning-utilities
|
||||
# via pydantic
|
||||
# via pydantic-core
|
||||
# via pytorch-lightning
|
||||
# via starlette
|
||||
# via torch
|
||||
# via typeshed-client
|
||||
# via uvicorn
|
||||
tzdata==2024.1
|
||||
# via pandas
|
||||
urllib3==1.26.18
|
||||
# via botocore
|
||||
# via docker
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
# via requests
|
||||
uvicorn==0.29.0
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
wcwidth==0.2.13
|
||||
# via blessed
|
||||
websocket-client==1.7.0
|
||||
# via docker
|
||||
# via lightning
|
||||
# via lightning-cloud
|
||||
websockets==11.0.3
|
||||
# via lightning
|
||||
werkzeug==3.0.2
|
||||
# via tensorboard
|
||||
wrapt==1.16.0
|
||||
# via aiobotocore
|
||||
xarray==2023.1.0
|
||||
# via cf-xarray
|
||||
# via soundevent
|
||||
xmod==1.8.1
|
||||
# via editor
|
||||
# via runs
|
||||
yarl==1.9.4
|
||||
# via aiohttp
|
||||
zipp==3.18.1
|
||||
# via importlib-metadata
|
||||
# via importlib-resources
|
@ -1,10 +0,0 @@
|
||||
librosa==0.9.2
|
||||
matplotlib==3.6.2
|
||||
numpy==1.23.4
|
||||
pandas==1.5.2
|
||||
scikit_learn==1.2.0
|
||||
scipy==1.9.3
|
||||
torch==1.13.0
|
||||
torchaudio==0.13.0
|
||||
torchvision==0.14.0
|
||||
click
|
0
tests/test_data/__init__.py
Normal file
0
tests/test_data/__init__.py
Normal file
19
tests/test_data/test_batdetect.py
Normal file
19
tests/test_data/test_batdetect.py
Normal file
@ -0,0 +1,19 @@
|
||||
"""Test suite for loading batdetect annotations."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.compat import load_annotation_project
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
||||
def test_load_example_annotation_project():
|
||||
path = ROOT_DIR / "example_data" / "anns"
|
||||
audio_dir = ROOT_DIR / "example_data" / "audio"
|
||||
project = load_annotation_project(path, audio_dir=audio_dir)
|
||||
assert isinstance(project, data.AnnotationProject)
|
||||
assert project.name == str(path)
|
||||
assert len(project.clip_annotations) == 3
|
||||
assert len(project.tasks) == 3
|
@ -133,17 +133,16 @@ def test_compute_max_power_bb(max_power: int):
|
||||
audio = np.zeros((int(duration * samplerate),))
|
||||
|
||||
# Add a signal during the time and frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] = 0.5 * librosa.tone(
|
||||
audio[int(start_time * samplerate) : int(end_time * samplerate)] = (
|
||||
0.5
|
||||
* librosa.tone(
|
||||
max_power, sr=samplerate, duration=end_time - start_time
|
||||
)
|
||||
)
|
||||
|
||||
# Add a more powerful signal outside frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] += 2 * librosa.tone(
|
||||
80_000, sr=samplerate, duration=end_time - start_time
|
||||
audio[int(start_time * samplerate) : int(end_time * samplerate)] += (
|
||||
2 * librosa.tone(80_000, sr=samplerate, duration=end_time - start_time)
|
||||
)
|
||||
|
||||
params = api.get_config(
|
||||
@ -152,7 +151,7 @@ def test_compute_max_power_bb(max_power: int):
|
||||
target_samp_rate=samplerate,
|
||||
)
|
||||
|
||||
spec = au.generate_spectrogram(
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
@ -221,18 +220,18 @@ def test_compute_max_power():
|
||||
audio = np.zeros((int(duration * samplerate),))
|
||||
|
||||
# Add a signal during the time and frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] = 0.5 * librosa.tone(
|
||||
3_500, sr=samplerate, duration=end_time - start_time
|
||||
audio[int(start_time * samplerate) : int(end_time * samplerate)] = (
|
||||
0.5
|
||||
* librosa.tone(3_500, sr=samplerate, duration=end_time - start_time)
|
||||
)
|
||||
|
||||
# Add a more powerful signal outside frequency range of interest
|
||||
audio[
|
||||
int(start_time * samplerate) : int(end_time * samplerate)
|
||||
] += 2 * librosa.tone(
|
||||
audio[int(start_time * samplerate) : int(end_time * samplerate)] += (
|
||||
2
|
||||
* librosa.tone(
|
||||
max_power, sr=samplerate, duration=end_time - start_time
|
||||
)
|
||||
)
|
||||
|
||||
params = api.get_config(
|
||||
min_freq=min_freq,
|
||||
@ -240,7 +239,7 @@ def test_compute_max_power():
|
||||
target_samp_rate=samplerate,
|
||||
)
|
||||
|
||||
spec = au.generate_spectrogram(
|
||||
spec, _ = au.generate_spectrogram(
|
||||
audio,
|
||||
samplerate,
|
||||
params,
|
||||
|
0
tests/test_migration/__init__.py
Normal file
0
tests/test_migration/__init__.py
Normal file
120
tests/test_migration/test_preprocessing.py
Normal file
120
tests/test_migration/test_preprocessing.py
Normal file
@ -0,0 +1,120 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data import preprocessing
|
||||
from batdetect2.utils import audio_utils
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
EXAMPLE_AUDIO = ROOT_DIR / "example_data" / "audio"
|
||||
TEST_AUDIO = ROOT_DIR / "tests" / "data"
|
||||
|
||||
|
||||
TEST_FILES = [
|
||||
EXAMPLE_AUDIO / "20170701_213954-MYOMYS-LR_0_0.5.wav",
|
||||
EXAMPLE_AUDIO / "20180530_213516-EPTSER-LR_0_0.5.wav",
|
||||
EXAMPLE_AUDIO / "20180627_215323-RHIFER-LR_0_0.5.wav",
|
||||
TEST_AUDIO / "20230322_172000_selec2.wav",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("audio_file", TEST_FILES)
|
||||
@pytest.mark.parametrize("scale", [True, False])
|
||||
def test_audio_loading_hasnt_changed(
|
||||
audio_file,
|
||||
scale,
|
||||
):
|
||||
time_expansion = 1
|
||||
target_sampling_rate = 256_000
|
||||
recording = data.Recording.from_file(
|
||||
audio_file,
|
||||
time_expansion=time_expansion,
|
||||
)
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
|
||||
_, audio_original = audio_utils.load_audio(
|
||||
audio_file,
|
||||
time_expansion,
|
||||
target_samp_rate=target_sampling_rate,
|
||||
scale=scale,
|
||||
)
|
||||
audio_new = preprocessing.load_clip_audio(
|
||||
clip,
|
||||
target_sampling_rate=target_sampling_rate,
|
||||
scale=scale,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
assert audio_original.shape == audio_new.shape
|
||||
assert audio_original.dtype == audio_new.dtype
|
||||
assert np.isclose(audio_original, audio_new.data).all()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("audio_file", TEST_FILES)
|
||||
@pytest.mark.parametrize("spec_scale", ["log", "pcen", "amplitude"])
|
||||
@pytest.mark.parametrize("denoise_spec_avg", [True, False])
|
||||
@pytest.mark.parametrize("max_scale_spec", [True, False])
|
||||
@pytest.mark.parametrize("fft_win_length", [512 / 256_000, 1024 / 256_000])
|
||||
def test_spectrogram_generation_hasnt_changed(
|
||||
audio_file,
|
||||
spec_scale,
|
||||
denoise_spec_avg,
|
||||
max_scale_spec,
|
||||
fft_win_length,
|
||||
):
|
||||
time_expansion = 1
|
||||
target_sampling_rate = 256_000
|
||||
min_freq = 10_000
|
||||
max_freq = 120_000
|
||||
fft_overlap = 0.75
|
||||
recording = data.Recording.from_file(
|
||||
audio_file,
|
||||
time_expansion=time_expansion,
|
||||
)
|
||||
clip = data.Clip(
|
||||
recording=recording,
|
||||
start_time=0,
|
||||
end_time=recording.duration,
|
||||
)
|
||||
audio = preprocessing.load_clip_audio(
|
||||
clip,
|
||||
target_sampling_rate=target_sampling_rate,
|
||||
)
|
||||
|
||||
spec_original, _ = audio_utils.generate_spectrogram(
|
||||
audio.data,
|
||||
sampling_rate=target_sampling_rate,
|
||||
params=dict(
|
||||
fft_win_length=fft_win_length,
|
||||
fft_overlap=fft_overlap,
|
||||
max_freq=max_freq,
|
||||
min_freq=min_freq,
|
||||
spec_scale=spec_scale,
|
||||
denoise_spec_avg=denoise_spec_avg,
|
||||
max_scale_spec=max_scale_spec,
|
||||
),
|
||||
)
|
||||
|
||||
new_spec = preprocessing.compute_spectrogram(
|
||||
audio,
|
||||
fft_win_length=fft_win_length,
|
||||
fft_overlap=fft_overlap,
|
||||
max_freq=max_freq,
|
||||
min_freq=min_freq,
|
||||
spec_scale=spec_scale,
|
||||
denoise_spec_avg=denoise_spec_avg,
|
||||
max_scale_spec=max_scale_spec,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
assert spec_original.shape == new_spec.shape
|
||||
assert spec_original.dtype == new_spec.dtype
|
||||
|
||||
# NOTE: The original spectrogram is flipped vertically
|
||||
assert np.isclose(spec_original, np.flipud(new_spec.data)).all()
|
Loading…
Reference in New Issue
Block a user