mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Make sure labels are working in the notebook
This commit is contained in:
parent
6a9e33c729
commit
ee884da8b0
@ -195,18 +195,29 @@ def annotation_to_sound_event(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{sound_event.uuid}_annotation"),
|
||||||
sound_event=sound_event,
|
sound_event=sound_event,
|
||||||
tags=[
|
tags=[
|
||||||
data.Tag(key=label_key, value=annotation.label),
|
data.Tag(
|
||||||
data.Tag(key=event_key, value=annotation.event),
|
term=data.term_from_key(label_key),
|
||||||
data.Tag(key=individual_key, value=str(annotation.individual)),
|
value=annotation.label,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(event_key),
|
||||||
|
value=annotation.event,
|
||||||
|
),
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(individual_key),
|
||||||
|
value=str(annotation.individual),
|
||||||
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def file_annotation_to_clip(
|
def file_annotation_to_clip(
|
||||||
file_annotation: FileAnnotation,
|
file_annotation: FileAnnotation,
|
||||||
audio_dir: PathLike = Path.cwd(),
|
audio_dir: Optional[PathLike] = None,
|
||||||
) -> data.Clip:
|
) -> data.Clip:
|
||||||
"""Convert file annotation to recording."""
|
"""Convert file annotation to recording."""
|
||||||
|
audio_dir = audio_dir or Path.cwd()
|
||||||
|
|
||||||
full_path = Path(audio_dir) / file_annotation.id
|
full_path = Path(audio_dir) / file_annotation.id
|
||||||
|
|
||||||
if not full_path.exists():
|
if not full_path.exists():
|
||||||
@ -241,7 +252,11 @@ def file_annotation_to_clip_annotation(
|
|||||||
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
uuid=uuid.uuid5(NAMESPACE, f"{file_annotation.id}_clip_annotation"),
|
||||||
clip=clip,
|
clip=clip,
|
||||||
notes=notes,
|
notes=notes,
|
||||||
tags=[data.Tag(key=label_key, value=file_annotation.label)],
|
tags=[
|
||||||
|
data.Tag(
|
||||||
|
term=data.term_from_key(label_key), value=file_annotation.label
|
||||||
|
)
|
||||||
|
],
|
||||||
sound_events=[
|
sound_events=[
|
||||||
annotation_to_sound_event(
|
annotation_to_sound_event(
|
||||||
annotation,
|
annotation,
|
||||||
@ -286,9 +301,11 @@ def list_file_annotations(path: PathLike) -> List[Path]:
|
|||||||
def load_annotation_project(
|
def load_annotation_project(
|
||||||
path: PathLike,
|
path: PathLike,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
audio_dir: PathLike = Path.cwd(),
|
audio_dir: Optional[PathLike] = None,
|
||||||
) -> data.AnnotationProject:
|
) -> data.AnnotationProject:
|
||||||
"""Convert annotations to annotation project."""
|
"""Convert annotations to annotation project."""
|
||||||
|
audio_dir = audio_dir or Path.cwd()
|
||||||
|
|
||||||
paths = list_file_annotations(path)
|
paths = list_file_annotations(path)
|
||||||
|
|
||||||
if name is None:
|
if name is None:
|
||||||
|
@ -3,7 +3,7 @@ from typing import Tuple
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from scipy.ndimage import gaussian_filter
|
from scipy.ndimage import gaussian_filter
|
||||||
from soundevent import data, geometry, arrays
|
from soundevent import arrays, data, geometry
|
||||||
from soundevent.geometry.operations import Positions
|
from soundevent.geometry.operations import Positions
|
||||||
from soundevent.types import ClassMapper
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
@ -22,6 +22,8 @@ def generate_heatmaps(
|
|||||||
class_mapper: ClassMapper,
|
class_mapper: ClassMapper,
|
||||||
target_sigma: float = TARGET_SIGMA,
|
target_sigma: float = TARGET_SIGMA,
|
||||||
position: Positions = "bottom-left",
|
position: Positions = "bottom-left",
|
||||||
|
time_scale: float = 1.0,
|
||||||
|
frequency_scale: float = 1.0,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||||
shape = dict(zip(spec.dims, spec.shape))
|
shape = dict(zip(spec.dims, spec.shape))
|
||||||
@ -31,13 +33,6 @@ def generate_heatmaps(
|
|||||||
"Spectrogram must have time and frequency dimensions."
|
"Spectrogram must have time and frequency dimensions."
|
||||||
)
|
)
|
||||||
|
|
||||||
time_duration = arrays.get_dim_width(spec, dim="time")
|
|
||||||
freq_bandwidth = arrays.get_dim_width(spec, dim="frequency")
|
|
||||||
|
|
||||||
# Compute the size factors
|
|
||||||
time_scale = 1 / time_duration
|
|
||||||
frequency_scale = 1 / freq_bandwidth
|
|
||||||
|
|
||||||
# Initialize heatmaps
|
# Initialize heatmaps
|
||||||
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
detection_heatmap = xr.zeros_like(spec, dtype=dtype)
|
||||||
class_heatmap = xr.DataArray(
|
class_heatmap = xr.DataArray(
|
||||||
@ -92,7 +87,7 @@ def generate_heatmaps(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Get the class name of the sound event
|
# Get the class name of the sound event
|
||||||
class_name = class_mapper.transform(sound_event_annotation)
|
class_name = class_mapper.encode(sound_event_annotation)
|
||||||
|
|
||||||
if class_name is None:
|
if class_name is None:
|
||||||
# If the label is None skip the sound event
|
# If the label is None skip the sound event
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
"""Module containing functions for preprocessing audio clips."""
|
"""Module containing functions for preprocessing audio clips."""
|
||||||
|
|
||||||
from typing import Optional, Union
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Literal, Optional, Union
|
||||||
|
|
||||||
import librosa
|
import librosa
|
||||||
import librosa.core.spectrum
|
import librosa.core.spectrum
|
||||||
@ -10,7 +10,7 @@ import xarray as xr
|
|||||||
from numpy.typing import DTypeLike
|
from numpy.typing import DTypeLike
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from scipy.signal import resample_poly
|
from scipy.signal import resample_poly
|
||||||
from soundevent import audio, data, arrays
|
from soundevent import arrays, audio, data
|
||||||
from soundevent.arrays import operations as ops
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -34,32 +34,56 @@ DENOISE_SPEC_AVG = True
|
|||||||
MAX_SCALE_SPEC = False
|
MAX_SCALE_SPEC = False
|
||||||
|
|
||||||
|
|
||||||
|
class ResampleConfig(BaseModel):
|
||||||
|
samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
||||||
|
mode: str = "poly"
|
||||||
|
|
||||||
|
|
||||||
|
class AudioConfig(BaseModel):
|
||||||
|
resample: Optional[ResampleConfig] = Field(default_factory=ResampleConfig)
|
||||||
|
scale: bool = Field(default=SCALE_RAW_AUDIO)
|
||||||
|
center: bool = True
|
||||||
|
duration: Optional[float] = DEFAULT_DURATION
|
||||||
|
|
||||||
|
|
||||||
|
class FFTConfig(BaseModel):
|
||||||
|
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
||||||
|
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
||||||
|
window_fn: str = "hann"
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyConfig(BaseModel):
|
||||||
|
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
||||||
|
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
||||||
|
|
||||||
|
|
||||||
|
class PcenConfig(BaseModel):
|
||||||
|
time_constant: float = 0.4
|
||||||
|
hop_length: int = 512
|
||||||
|
gain: float = 0.98
|
||||||
|
bias: float = 2
|
||||||
|
power: float = 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class SpecSizeConfig(BaseModel):
|
||||||
|
height: int = SPEC_HEIGHT
|
||||||
|
time_period: float = SPEC_TIME_PERIOD
|
||||||
|
|
||||||
|
|
||||||
|
class SpectrogramConfig(BaseModel):
|
||||||
|
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||||
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
|
scale: Union[Literal["log"], None, PcenConfig] = "log"
|
||||||
|
denoise: bool = True
|
||||||
|
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||||
|
max_scale: bool = MAX_SCALE_SPEC
|
||||||
|
|
||||||
|
|
||||||
class PreprocessingConfig(BaseModel):
|
class PreprocessingConfig(BaseModel):
|
||||||
"""Configuration for preprocessing data."""
|
"""Configuration for preprocessing data."""
|
||||||
|
|
||||||
target_samplerate: int = Field(default=TARGET_SAMPLERATE_HZ, gt=0)
|
audio: AudioConfig = Field(default_factory=AudioConfig)
|
||||||
|
spectrogram: SpectrogramConfig = Field(default_factory=SpectrogramConfig)
|
||||||
scale_audio: bool = Field(default=SCALE_RAW_AUDIO)
|
|
||||||
|
|
||||||
fft_win_length: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
|
||||||
|
|
||||||
fft_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
|
||||||
|
|
||||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
|
||||||
|
|
||||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
|
||||||
|
|
||||||
spec_scale: str = Field(default=SPEC_SCALE)
|
|
||||||
|
|
||||||
denoise_spec_avg: bool = DENOISE_SPEC_AVG
|
|
||||||
|
|
||||||
max_scale_spec: bool = MAX_SCALE_SPEC
|
|
||||||
|
|
||||||
duration: Optional[float] = DEFAULT_DURATION
|
|
||||||
|
|
||||||
spec_height: int = SPEC_HEIGHT
|
|
||||||
|
|
||||||
spec_time_period: float = SPEC_TIME_PERIOD
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_file(
|
def from_file(
|
||||||
@ -104,7 +128,7 @@ class PreprocessingConfig(BaseModel):
|
|||||||
|
|
||||||
def preprocess_audio_clip(
|
def preprocess_audio_clip(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
config: PreprocessingConfig = PreprocessingConfig(),
|
config: Optional[PreprocessingConfig] = None,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
"""Preprocesses audio clip to generate spectrogram.
|
"""Preprocesses audio clip to generate spectrogram.
|
||||||
|
|
||||||
@ -121,81 +145,117 @@ def preprocess_audio_clip(
|
|||||||
Preprocessed spectrogram.
|
Preprocessed spectrogram.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
wav = load_clip_audio(
|
config = config or PreprocessingConfig()
|
||||||
clip,
|
wav = load_clip_audio(clip, config=config.audio)
|
||||||
target_sampling_rate=config.target_samplerate,
|
spec = compute_spectrogram(wav, config=config.spectrogram)
|
||||||
scale=config.scale_audio,
|
return spec
|
||||||
)
|
|
||||||
|
|
||||||
spec = compute_spectrogram(
|
|
||||||
wav,
|
|
||||||
fft_win_length=config.fft_win_length,
|
|
||||||
fft_overlap=config.fft_overlap,
|
|
||||||
max_freq=config.max_freq,
|
|
||||||
min_freq=config.min_freq,
|
|
||||||
spec_scale=config.spec_scale,
|
|
||||||
denoise_spec_avg=config.denoise_spec_avg,
|
|
||||||
max_scale_spec=config.max_scale_spec,
|
|
||||||
)
|
|
||||||
|
|
||||||
if config.duration is not None:
|
|
||||||
spec = adjust_spec_duration(clip, spec, config.duration)
|
|
||||||
|
|
||||||
duration = arrays.get_dim_width(spec, dim="time")
|
|
||||||
return ops.resize(
|
|
||||||
spec,
|
|
||||||
time=int(np.ceil(duration / config.spec_time_period)),
|
|
||||||
frequency=config.spec_height,
|
|
||||||
dtype=np.float32,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def adjust_spec_duration(
|
|
||||||
clip: data.Clip,
|
|
||||||
spec: xr.DataArray,
|
|
||||||
duration: float,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
current_duration = clip.end_time - clip.start_time
|
|
||||||
|
|
||||||
if current_duration == duration:
|
|
||||||
return spec
|
|
||||||
|
|
||||||
if current_duration > duration:
|
|
||||||
return arrays.crop_dim(
|
|
||||||
spec,
|
|
||||||
dim="time",
|
|
||||||
start=clip.start_time,
|
|
||||||
stop=clip.start_time + duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
return arrays.extend_dim(
|
|
||||||
spec,
|
|
||||||
dim="time",
|
|
||||||
start=clip.start_time,
|
|
||||||
stop=clip.start_time + duration,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_audio(
|
def load_clip_audio(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
target_sampling_rate: int = TARGET_SAMPLERATE_HZ,
|
config: Optional[AudioConfig] = None,
|
||||||
scale: bool = SCALE_RAW_AUDIO,
|
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
|
config = config or AudioConfig()
|
||||||
|
|
||||||
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
wav = audio.load_clip(clip).sel(channel=0).astype(dtype)
|
||||||
|
|
||||||
wav = resample_audio(wav, target_sampling_rate, dtype=dtype)
|
if config.duration is not None:
|
||||||
|
wav = adjust_audio_duration(wav, duration=config.duration)
|
||||||
|
|
||||||
if scale:
|
if config.resample:
|
||||||
|
wav = resample_audio(
|
||||||
|
wav,
|
||||||
|
samplerate=config.resample.samplerate,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
if config.center:
|
||||||
wav = ops.center(wav)
|
wav = ops.center(wav)
|
||||||
|
|
||||||
|
if config.scale:
|
||||||
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
wav = ops.scale(wav, 1 / (10e-6 + np.max(np.abs(wav))))
|
||||||
|
|
||||||
return wav.astype(dtype)
|
return wav.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_spectrogram(
|
||||||
|
wav: xr.DataArray,
|
||||||
|
config: Optional[SpectrogramConfig] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
|
spec = stft(
|
||||||
|
wav,
|
||||||
|
window_duration=config.fft.window_duration,
|
||||||
|
window_overlap=config.fft.window_overlap,
|
||||||
|
window_fn=config.fft.window_fn,
|
||||||
|
dtype=dtype,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = crop_spectrogram_frequencies(
|
||||||
|
spec,
|
||||||
|
min_freq=config.frequencies.min_freq,
|
||||||
|
max_freq=config.frequencies.max_freq,
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = scale_spectrogram(spec, scale=config.scale)
|
||||||
|
|
||||||
|
if config.denoise:
|
||||||
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
|
if config.resize:
|
||||||
|
spec = resize_spectrogram(spec, config=config.resize)
|
||||||
|
|
||||||
|
if config.max_scale:
|
||||||
|
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||||
|
|
||||||
|
return spec.astype(dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def crop_spectrogram_frequencies(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
min_freq: int = MIN_FREQ_HZ,
|
||||||
|
max_freq: int = MAX_FREQ_HZ,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
return arrays.crop_dim(
|
||||||
|
spec,
|
||||||
|
dim="frequency",
|
||||||
|
start=min_freq,
|
||||||
|
stop=max_freq,
|
||||||
|
).astype(spec.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_audio_duration(
|
||||||
|
wave: xr.DataArray,
|
||||||
|
duration: float,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
|
current_duration = end_time - start_time
|
||||||
|
|
||||||
|
if current_duration == duration:
|
||||||
|
return wave
|
||||||
|
|
||||||
|
if current_duration > duration:
|
||||||
|
return arrays.crop_dim(
|
||||||
|
wave,
|
||||||
|
dim="time",
|
||||||
|
start=start_time,
|
||||||
|
stop=start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
return arrays.extend_dim(
|
||||||
|
wave,
|
||||||
|
dim="time",
|
||||||
|
start=start_time,
|
||||||
|
stop=start_time + duration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resample_audio(
|
def resample_audio(
|
||||||
wav: xr.DataArray,
|
wav: xr.DataArray,
|
||||||
target_samplerate: int = TARGET_SAMPLERATE_HZ,
|
samplerate: int = TARGET_SAMPLERATE_HZ,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
if "time" not in wav.dims:
|
if "time" not in wav.dims:
|
||||||
@ -207,13 +267,13 @@ def resample_audio(
|
|||||||
step = arrays.get_dim_step(wav, dim="time")
|
step = arrays.get_dim_step(wav, dim="time")
|
||||||
original_samplerate = int(1 / step)
|
original_samplerate = int(1 / step)
|
||||||
|
|
||||||
if original_samplerate == target_samplerate:
|
if original_samplerate == samplerate:
|
||||||
return wav.astype(dtype)
|
return wav.astype(dtype)
|
||||||
|
|
||||||
gcd = np.gcd(original_samplerate, target_samplerate)
|
gcd = np.gcd(original_samplerate, samplerate)
|
||||||
resampled = resample_poly(
|
resampled = resample_poly(
|
||||||
wav.values,
|
wav.values,
|
||||||
target_samplerate // gcd,
|
samplerate // gcd,
|
||||||
original_samplerate // gcd,
|
original_samplerate // gcd,
|
||||||
axis=time_axis,
|
axis=time_axis,
|
||||||
)
|
)
|
||||||
@ -225,7 +285,6 @@ def resample_audio(
|
|||||||
endpoint=False,
|
endpoint=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=resampled.astype(dtype),
|
data=resampled.astype(dtype),
|
||||||
dims=wav.dims,
|
dims=wav.dims,
|
||||||
@ -233,70 +292,35 @@ def resample_audio(
|
|||||||
**wav.coords,
|
**wav.coords,
|
||||||
"time": arrays.create_time_dim_from_array(
|
"time": arrays.create_time_dim_from_array(
|
||||||
resampled_times,
|
resampled_times,
|
||||||
samplerate=target_samplerate,
|
samplerate=samplerate,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
attrs=wav.attrs,
|
attrs=wav.attrs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
def stft(
|
||||||
wav: xr.DataArray,
|
|
||||||
fft_win_length: float = FFT_WIN_LENGTH_S,
|
|
||||||
fft_overlap: float = FFT_OVERLAP,
|
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
|
||||||
spec_scale: str = SPEC_SCALE,
|
|
||||||
denoise_spec_avg: bool = True,
|
|
||||||
max_scale_spec: bool = False,
|
|
||||||
dtype: DTypeLike = np.float32,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
spec = gen_mag_spectrogram(
|
|
||||||
wav,
|
|
||||||
window_len=fft_win_length,
|
|
||||||
overlap_perc=fft_overlap,
|
|
||||||
dtype=dtype,
|
|
||||||
)
|
|
||||||
|
|
||||||
spec = arrays.crop_dim(
|
|
||||||
spec,
|
|
||||||
dim="frequency",
|
|
||||||
start=min_freq,
|
|
||||||
stop=max_freq,
|
|
||||||
).astype(dtype)
|
|
||||||
|
|
||||||
spec = scale_spectrogram(spec, scale=spec_scale)
|
|
||||||
|
|
||||||
if denoise_spec_avg:
|
|
||||||
spec = denoise_spectrogram(spec)
|
|
||||||
|
|
||||||
if max_scale_spec:
|
|
||||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
|
||||||
|
|
||||||
return spec.astype(dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def gen_mag_spectrogram(
|
|
||||||
wave: xr.DataArray,
|
wave: xr.DataArray,
|
||||||
window_len: float,
|
window_duration: float,
|
||||||
overlap_perc: float,
|
window_overlap: float,
|
||||||
|
window_fn: str = "hann",
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
start_time, end_time = arrays.get_dim_range(wave, dim="time")
|
||||||
step = arrays.get_dim_step(wave, dim="time")
|
step = arrays.get_dim_step(wave, dim="time")
|
||||||
sampling_rate = 1 / step
|
sampling_rate = 1 / step
|
||||||
|
|
||||||
hop_len = window_len * (1 - overlap_perc)
|
hop_len = window_duration * (1 - window_overlap)
|
||||||
nfft = int(window_len * sampling_rate)
|
nfft = int(window_duration * sampling_rate)
|
||||||
noverlap = int(overlap_perc * nfft)
|
noverlap = int(window_overlap * nfft)
|
||||||
|
|
||||||
# compute spec
|
|
||||||
spec, _ = librosa.core.spectrum._spectrogram(
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
y=wave.data,
|
y=wave.data.astype(dtype),
|
||||||
power=1,
|
power=1,
|
||||||
n_fft=nfft,
|
n_fft=nfft,
|
||||||
hop_length=nfft - noverlap,
|
hop_length=nfft - noverlap,
|
||||||
center=False,
|
center=False,
|
||||||
|
window=window_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
@ -316,7 +340,7 @@ def gen_mag_spectrogram(
|
|||||||
"time": arrays.create_time_dim_from_array(
|
"time": arrays.create_time_dim_from_array(
|
||||||
np.linspace(
|
np.linspace(
|
||||||
start_time,
|
start_time,
|
||||||
end_time - (window_len - hop_len),
|
end_time - (window_duration - hop_len),
|
||||||
spec.shape[1],
|
spec.shape[1],
|
||||||
endpoint=False,
|
endpoint=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
@ -333,9 +357,7 @@ def gen_mag_spectrogram(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def denoise_spectrogram(
|
def denoise_spectrogram(spec: xr.DataArray) -> xr.DataArray:
|
||||||
spec: xr.DataArray,
|
|
||||||
) -> xr.DataArray:
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=(spec - spec.mean("time")).clip(0),
|
data=(spec - spec.mean("time")).clip(0),
|
||||||
dims=spec.dims,
|
dims=spec.dims,
|
||||||
@ -346,35 +368,53 @@ def denoise_spectrogram(
|
|||||||
|
|
||||||
def scale_spectrogram(
|
def scale_spectrogram(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
scale: str = SPEC_SCALE,
|
scale: Union[Literal["log"], None, PcenConfig],
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
|
||||||
|
|
||||||
if scale == "pcen":
|
|
||||||
smoothing_constant = get_pcen_smoothing_constant(samplerate / 10)
|
|
||||||
return audio.pcen(
|
|
||||||
spec * (2**31),
|
|
||||||
smooth=smoothing_constant,
|
|
||||||
).astype(dtype)
|
|
||||||
|
|
||||||
if scale == "log":
|
if scale == "log":
|
||||||
return log_scale(spec, dtype=dtype)
|
return scale_log(spec, dtype=dtype)
|
||||||
|
|
||||||
|
if isinstance(scale, PcenConfig):
|
||||||
|
return scale_pcen(
|
||||||
|
spec,
|
||||||
|
time_constant=scale.time_constant,
|
||||||
|
hop_length=scale.hop_length,
|
||||||
|
gain=scale.gain,
|
||||||
|
power=scale.power,
|
||||||
|
bias=scale.bias,
|
||||||
|
)
|
||||||
|
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def log_scale(
|
def scale_pcen(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
time_constant: float = 0.4,
|
||||||
|
hop_length: int = 512,
|
||||||
|
gain: float = 0.98,
|
||||||
|
bias: float = 2,
|
||||||
|
power: float = 0.5,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
|
# NOTE: Not sure why the 10 is there
|
||||||
|
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||||
|
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
|
return audio.pcen(
|
||||||
|
spec * (2**31),
|
||||||
|
smooth=smoothing_constant,
|
||||||
|
gain=gain,
|
||||||
|
bias=bias,
|
||||||
|
power=power,
|
||||||
|
).astype(spec.dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def scale_log(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
dtype: DTypeLike = np.float32,
|
dtype: DTypeLike = np.float32,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
nfft = spec.attrs["nfft"]
|
nfft = spec.attrs["nfft"]
|
||||||
log_scaling = (
|
log_scaling = 2 / (samplerate * (np.abs(np.hanning(nfft)) ** 2).sum())
|
||||||
2.0
|
|
||||||
* (1.0 / samplerate)
|
|
||||||
* (1.0 / (np.abs(np.hanning(nfft)) ** 2).sum())
|
|
||||||
)
|
|
||||||
return xr.DataArray(
|
return xr.DataArray(
|
||||||
data=np.log1p(log_scaling * spec).astype(dtype),
|
data=np.log1p(log_scaling * spec).astype(dtype),
|
||||||
dims=spec.dims,
|
dims=spec.dims,
|
||||||
@ -383,10 +423,14 @@ def log_scale(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_pcen_smoothing_constant(
|
def resize_spectrogram(
|
||||||
sr: int,
|
spec: xr.DataArray,
|
||||||
time_constant: float = 0.4,
|
config: SpecSizeConfig,
|
||||||
hop_length: int = 512,
|
) -> xr.DataArray:
|
||||||
) -> float:
|
duration = arrays.get_dim_width(spec, dim="time")
|
||||||
t_frames = time_constant * sr / float(hop_length)
|
return ops.resize(
|
||||||
return (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
spec,
|
||||||
|
time=int(np.ceil(duration / config.time_period)),
|
||||||
|
frequency=config.height,
|
||||||
|
dtype=np.float32,
|
||||||
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
import pytorch_lightning as L
|
import pytorch_lightning as L
|
||||||
import torch
|
import torch
|
||||||
@ -6,11 +6,11 @@ import xarray as xr
|
|||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn, optim
|
from torch import nn, optim
|
||||||
|
|
||||||
from batdetect2.data.preprocessing import (
|
|
||||||
preprocess_audio_clip,
|
|
||||||
PreprocessingConfig,
|
|
||||||
)
|
|
||||||
from batdetect2.data.labels import ClassMapper
|
from batdetect2.data.labels import ClassMapper
|
||||||
|
from batdetect2.data.preprocessing import (
|
||||||
|
PreprocessingConfig,
|
||||||
|
preprocess_audio_clip,
|
||||||
|
)
|
||||||
from batdetect2.models.feature_extractors import Net2DFast
|
from batdetect2.models.feature_extractors import Net2DFast
|
||||||
from batdetect2.models.post_process import (
|
from batdetect2.models.post_process import (
|
||||||
PostprocessConfig,
|
PostprocessConfig,
|
||||||
@ -29,11 +29,14 @@ class DetectorModel(L.LightningModule):
|
|||||||
learning_rate: float = 1e-3,
|
learning_rate: float = 1e-3,
|
||||||
input_height: int = 128,
|
input_height: int = 128,
|
||||||
num_features: int = 32,
|
num_features: int = 32,
|
||||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
postprocessing_config: PostprocessConfig = PostprocessConfig(),
|
postprocessing_config: Optional[PostprocessConfig] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
preprocessing_config = preprocessing_config or PreprocessingConfig()
|
||||||
|
postprocessing_config = postprocessing_config or PostprocessConfig()
|
||||||
|
|
||||||
self.save_hyperparameters()
|
self.save_hyperparameters()
|
||||||
|
|
||||||
self.preprocessing_config = preprocessing_config
|
self.preprocessing_config = preprocessing_config
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
"""Module for postprocessing model outputs."""
|
"""Module for postprocessing model outputs."""
|
||||||
|
|
||||||
from typing import Callable, List, Tuple, Union
|
from typing import Callable, List, Tuple, Union
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@ -207,7 +207,7 @@ def compute_sound_events_from_outputs(
|
|||||||
),
|
),
|
||||||
features=[
|
features=[
|
||||||
data.Feature(
|
data.Feature(
|
||||||
name=f"batdetect2_{i}",
|
term=data.term_from_key(f"batdetect2_{i}"),
|
||||||
value=value.item(),
|
value=value.item(),
|
||||||
)
|
)
|
||||||
for i, value in enumerate(feature)
|
for i, value in enumerate(feature)
|
||||||
|
@ -3,18 +3,18 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from multiprocessing import Pool
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Sequence, Union
|
from typing import Callable, Optional, Sequence, Union
|
||||||
from tqdm.auto import tqdm
|
|
||||||
from multiprocessing import Pool
|
|
||||||
|
|
||||||
import xarray as xr
|
import xarray as xr
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
from batdetect2.data.labels import TARGET_SIGMA, ClassMapper, generate_heatmaps
|
||||||
from batdetect2.data.preprocessing import (
|
from batdetect2.data.preprocessing import (
|
||||||
preprocess_audio_clip,
|
|
||||||
PreprocessingConfig,
|
PreprocessingConfig,
|
||||||
|
preprocess_audio_clip,
|
||||||
)
|
)
|
||||||
|
|
||||||
PathLike = Union[Path, str, os.PathLike]
|
PathLike = Union[Path, str, os.PathLike]
|
||||||
@ -25,14 +25,15 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
clip_annotation: data.ClipAnnotation,
|
clip_annotation: data.ClipAnnotation,
|
||||||
class_mapper: ClassMapper,
|
class_mapper: ClassMapper,
|
||||||
preprocessing_config: PreprocessingConfig = PreprocessingConfig(),
|
preprocessing_config: Optional[PreprocessingConfig] = None,
|
||||||
target_sigma: float = TARGET_SIGMA,
|
target_sigma: float = TARGET_SIGMA,
|
||||||
) -> xr.Dataset:
|
) -> xr.Dataset:
|
||||||
"""Generate a training example."""
|
"""Generate a training example."""
|
||||||
|
preprocessing_config = preprocessing_config or PreprocessingConfig()
|
||||||
|
|
||||||
spectrogram = preprocess_audio_clip(
|
spectrogram = preprocess_audio_clip(
|
||||||
clip_annotation.clip,
|
clip_annotation.clip,
|
||||||
config=preprocessing_config,
|
config=preprocessing_config,
|
||||||
@ -83,14 +84,18 @@ def load_config(path: PathLike, **kwargs) -> PreprocessingConfig:
|
|||||||
path = Path(path)
|
path = Path(path)
|
||||||
|
|
||||||
if not path.is_file():
|
if not path.is_file():
|
||||||
warnings.warn(f"Config file not found: {path}. Using default config.")
|
warnings.warn(
|
||||||
|
f"Config file not found: {path}. Using default config.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
return PreprocessingConfig(**kwargs)
|
return PreprocessingConfig(**kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return PreprocessingConfig.model_validate_json(path.read_text())
|
return PreprocessingConfig.model_validate_json(path.read_text())
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Failed to load config file: {e}. Using default config."
|
f"Failed to load config file: {e}. Using default config.",
|
||||||
|
stacklevel=1,
|
||||||
)
|
)
|
||||||
return PreprocessingConfig(**kwargs)
|
return PreprocessingConfig(**kwargs)
|
||||||
|
|
||||||
|
@ -90,7 +90,7 @@ def generate_spectrogram(
|
|||||||
np.abs(
|
np.abs(
|
||||||
np.hanning(
|
np.hanning(
|
||||||
int(params["fft_win_length"] * sampling_rate)
|
int(params["fft_win_length"] * sampling_rate)
|
||||||
)
|
).astype(np.float32)
|
||||||
)
|
)
|
||||||
** 2
|
** 2
|
||||||
).sum()
|
).sum()
|
||||||
|
@ -409,7 +409,7 @@ def save_results_to_file(results, op_path: str) -> None:
|
|||||||
|
|
||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: float,
|
sampling_rate: int,
|
||||||
params: SpectrogramParameters,
|
params: SpectrogramParameters,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
) -> Tuple[float, torch.Tensor]:
|
) -> Tuple[float, torch.Tensor]:
|
||||||
@ -627,7 +627,7 @@ def process_spectrogram(
|
|||||||
|
|
||||||
def _process_audio_array(
|
def _process_audio_array(
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
sampling_rate: float,
|
sampling_rate: int,
|
||||||
model: DetectionModel,
|
model: DetectionModel,
|
||||||
config: ProcessingConfiguration,
|
config: ProcessingConfiguration,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
@ -17,7 +17,7 @@ dependencies = [
|
|||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.0.1",
|
"soundevent[audio,geometry,plot]>=2.2",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
|
@ -94,7 +94,7 @@ def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
|||||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
length = int(duration * samplerate)
|
length = int(duration * samplerate)
|
||||||
audio = np.random.rand(length)
|
audio = np.random.rand(length)
|
||||||
_, spectrogram, _ = detector_utils.compute_spectrogram(
|
_, spectrogram = detector_utils.compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
samplerate,
|
samplerate,
|
||||||
params,
|
params,
|
||||||
|
120
tests/test_data/test_labels.py
Normal file
120
tests/test_data/test_labels.py
Normal file
@ -0,0 +1,120 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import xarray as xr
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
|
from batdetect2.data.labels import generate_heatmaps
|
||||||
|
|
||||||
|
recording = data.Recording(
|
||||||
|
samplerate=256_000,
|
||||||
|
duration=1,
|
||||||
|
channels=1,
|
||||||
|
time_expansion=1,
|
||||||
|
hash="asdf98sdf",
|
||||||
|
path=Path("/path/to/audio.wav"),
|
||||||
|
)
|
||||||
|
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=1,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Mapper(ClassMapper):
|
||||||
|
class_labels = ["bat", "cat"]
|
||||||
|
|
||||||
|
def encode(self, sound_event_annotation: data.SoundEventAnnotation) -> str:
|
||||||
|
return "bat"
|
||||||
|
|
||||||
|
def decode(self, label: str) -> list:
|
||||||
|
return [data.Tag(term=data.term_from_key("species"), value="bat")]
|
||||||
|
|
||||||
|
|
||||||
|
def test_generated_heatmaps_have_correct_dimensions():
|
||||||
|
spec = xr.DataArray(
|
||||||
|
data=np.random.rand(100, 100),
|
||||||
|
dims=["time", "frequency"],
|
||||||
|
coords={
|
||||||
|
"time": np.linspace(0, 100, 100, endpoint=False),
|
||||||
|
"frequency": np.linspace(0, 100, 100, endpoint=False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_annotation = data.ClipAnnotation(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=[
|
||||||
|
data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[10, 10, 20, 20],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
class_mapper = Mapper()
|
||||||
|
|
||||||
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
|
clip_annotation,
|
||||||
|
spec,
|
||||||
|
class_mapper,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(detection_heatmap, xr.DataArray)
|
||||||
|
assert detection_heatmap.shape == (100, 100)
|
||||||
|
assert detection_heatmap.dims == ("time", "frequency")
|
||||||
|
|
||||||
|
assert isinstance(class_heatmap, xr.DataArray)
|
||||||
|
assert class_heatmap.shape == (2, 100, 100)
|
||||||
|
assert class_heatmap.dims == ("category", "time", "frequency")
|
||||||
|
assert class_heatmap.coords["category"].values.tolist() == ["bat", "cat"]
|
||||||
|
|
||||||
|
assert isinstance(size_heatmap, xr.DataArray)
|
||||||
|
assert size_heatmap.shape == (2, 100, 100)
|
||||||
|
assert size_heatmap.dims == ("dimension", "time", "frequency")
|
||||||
|
assert size_heatmap.coords["dimension"].values.tolist() == [
|
||||||
|
"width",
|
||||||
|
"height",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_generated_heatmap_are_non_zero_at_correct_positions():
|
||||||
|
spec = xr.DataArray(
|
||||||
|
data=np.random.rand(100, 100),
|
||||||
|
dims=["time", "frequency"],
|
||||||
|
coords={
|
||||||
|
"time": np.linspace(0, 100, 100, endpoint=False),
|
||||||
|
"frequency": np.linspace(0, 100, 100, endpoint=False),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_annotation = data.ClipAnnotation(
|
||||||
|
clip=clip,
|
||||||
|
sound_events=[
|
||||||
|
data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(
|
||||||
|
coordinates=[10, 10, 20, 20],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
class_mapper = Mapper()
|
||||||
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
|
clip_annotation,
|
||||||
|
spec,
|
||||||
|
class_mapper,
|
||||||
|
)
|
||||||
|
assert size_heatmap.sel(time=10, frequency=10, dimension="width") == 10
|
||||||
|
assert size_heatmap.sel(time=10, frequency=10, dimension="height") == 10
|
||||||
|
assert class_heatmap.sel(time=10, frequency=10, category="bat") == 1.0
|
||||||
|
assert class_heatmap.sel(time=10, frequency=10, category="cat") == 0.0
|
||||||
|
assert detection_heatmap.sel(time=10, frequency=10) == 1.0
|
@ -46,8 +46,14 @@ def test_audio_loading_hasnt_changed(
|
|||||||
)
|
)
|
||||||
audio_new = preprocessing.load_clip_audio(
|
audio_new = preprocessing.load_clip_audio(
|
||||||
clip,
|
clip,
|
||||||
target_sampling_rate=target_sampling_rate,
|
config=preprocessing.AudioConfig(
|
||||||
scale=scale,
|
resample=preprocessing.ResampleConfig(
|
||||||
|
samplerate=target_sampling_rate,
|
||||||
|
),
|
||||||
|
center=scale,
|
||||||
|
scale=scale,
|
||||||
|
duration=None,
|
||||||
|
),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -73,18 +79,46 @@ def test_spectrogram_generation_hasnt_changed(
|
|||||||
min_freq = 10_000
|
min_freq = 10_000
|
||||||
max_freq = 120_000
|
max_freq = 120_000
|
||||||
fft_overlap = 0.75
|
fft_overlap = 0.75
|
||||||
|
|
||||||
|
scale = None
|
||||||
|
if spec_scale == "log":
|
||||||
|
scale = "log"
|
||||||
|
elif spec_scale == "pcen":
|
||||||
|
scale = preprocessing.PcenConfig()
|
||||||
|
|
||||||
|
config = preprocessing.SpectrogramConfig(
|
||||||
|
fft=preprocessing.FFTConfig(
|
||||||
|
window_overlap=fft_overlap,
|
||||||
|
window_duration=fft_win_length,
|
||||||
|
),
|
||||||
|
frequencies=preprocessing.FrequencyConfig(
|
||||||
|
min_freq=min_freq,
|
||||||
|
max_freq=max_freq,
|
||||||
|
),
|
||||||
|
scale=scale,
|
||||||
|
denoise=denoise_spec_avg,
|
||||||
|
resize=None,
|
||||||
|
max_scale=max_scale_spec,
|
||||||
|
)
|
||||||
|
|
||||||
recording = data.Recording.from_file(
|
recording = data.Recording.from_file(
|
||||||
audio_file,
|
audio_file,
|
||||||
time_expansion=time_expansion,
|
time_expansion=time_expansion,
|
||||||
)
|
)
|
||||||
|
|
||||||
clip = data.Clip(
|
clip = data.Clip(
|
||||||
recording=recording,
|
recording=recording,
|
||||||
start_time=0,
|
start_time=0,
|
||||||
end_time=recording.duration,
|
end_time=recording.duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio = preprocessing.load_clip_audio(
|
audio = preprocessing.load_clip_audio(
|
||||||
clip,
|
clip,
|
||||||
target_sampling_rate=target_sampling_rate,
|
config=preprocessing.AudioConfig(
|
||||||
|
resample=preprocessing.ResampleConfig(
|
||||||
|
samplerate=target_sampling_rate,
|
||||||
|
)
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
spec_original, _ = audio_utils.generate_spectrogram(
|
spec_original, _ = audio_utils.generate_spectrogram(
|
||||||
@ -103,18 +137,19 @@ def test_spectrogram_generation_hasnt_changed(
|
|||||||
|
|
||||||
new_spec = preprocessing.compute_spectrogram(
|
new_spec = preprocessing.compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
fft_win_length=fft_win_length,
|
config=config,
|
||||||
fft_overlap=fft_overlap,
|
|
||||||
max_freq=max_freq,
|
|
||||||
min_freq=min_freq,
|
|
||||||
spec_scale=spec_scale,
|
|
||||||
denoise_spec_avg=denoise_spec_avg,
|
|
||||||
max_scale_spec=max_scale_spec,
|
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert spec_original.shape == new_spec.shape
|
assert spec_original.shape == new_spec.shape
|
||||||
assert spec_original.dtype == new_spec.dtype
|
assert spec_original.dtype == new_spec.dtype
|
||||||
|
|
||||||
|
# Check that the spectrogram content is the same within a tolerance of 1e-5
|
||||||
|
# for each element of the spectrogram at least 99.5% of the time.
|
||||||
|
# NOTE: The pcen function is not the same as the one in the original code
|
||||||
|
# thus the need for a tolerance, but the values are still very similar.
|
||||||
# NOTE: The original spectrogram is flipped vertically
|
# NOTE: The original spectrogram is flipped vertically
|
||||||
assert np.isclose(spec_original, np.flipud(new_spec.data)).all()
|
assert (
|
||||||
|
np.isclose(spec_original, np.flipud(new_spec.data), atol=1e-5).mean()
|
||||||
|
> 0.995
|
||||||
|
)
|
||||||
|
2
uv.lock
generated
2
uv.lock
generated
@ -236,7 +236,7 @@ requires-dist = [
|
|||||||
{ name = "pytorch-lightning", specifier = ">=2.2.2" },
|
{ name = "pytorch-lightning", specifier = ">=2.2.2" },
|
||||||
{ name = "scikit-learn", specifier = ">=1.2.2" },
|
{ name = "scikit-learn", specifier = ">=1.2.2" },
|
||||||
{ name = "scipy", specifier = ">=1.10.1" },
|
{ name = "scipy", specifier = ">=1.10.1" },
|
||||||
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.0.1" },
|
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.2" },
|
||||||
{ name = "tensorboard", specifier = ">=2.16.2" },
|
{ name = "tensorboard", specifier = ">=2.16.2" },
|
||||||
{ name = "torch", specifier = ">=1.13.1,<2.5.0" },
|
{ name = "torch", specifier = ">=1.13.1,<2.5.0" },
|
||||||
{ name = "torchaudio", specifier = ">=1.13.1,<2.5.0" },
|
{ name = "torchaudio", specifier = ">=1.13.1,<2.5.0" },
|
||||||
|
Loading…
Reference in New Issue
Block a user