mirror of
https://github.com/macaodha/batdetect2.git
synced 2025-06-29 14:41:58 +02:00
Ensure train inputs are almost equal
This commit is contained in:
parent
1f0fb14d89
commit
36c90a600f
148
batdetect2/compat/params.py
Normal file
148
batdetect2/compat/params.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
from batdetect2.preprocess import (
|
||||||
|
AudioConfig,
|
||||||
|
FFTConfig,
|
||||||
|
FrequencyConfig,
|
||||||
|
PcenConfig,
|
||||||
|
PreprocessingConfig,
|
||||||
|
ResampleConfig,
|
||||||
|
SpecSizeConfig,
|
||||||
|
SpectrogramConfig,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess.spectrogram import get_spectrogram_resolution
|
||||||
|
from batdetect2.terms import TagInfo
|
||||||
|
from batdetect2.train.preprocess import (
|
||||||
|
HeatmapsConfig,
|
||||||
|
TargetConfig,
|
||||||
|
TrainPreprocessingConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_spectrogram_scale(scale: str):
|
||||||
|
if scale == "pcen":
|
||||||
|
return PcenConfig()
|
||||||
|
if scale == "log":
|
||||||
|
return "log"
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def get_preprocessing_config(params: dict) -> PreprocessingConfig:
|
||||||
|
return PreprocessingConfig(
|
||||||
|
audio=AudioConfig(
|
||||||
|
resample=ResampleConfig(
|
||||||
|
samplerate=params["target_samp_rate"],
|
||||||
|
mode="poly",
|
||||||
|
),
|
||||||
|
scale=params["scale_raw_audio"],
|
||||||
|
center=params["scale_raw_audio"],
|
||||||
|
duration=None,
|
||||||
|
),
|
||||||
|
spectrogram=SpectrogramConfig(
|
||||||
|
fft=FFTConfig(
|
||||||
|
window_duration=params["fft_win_length"],
|
||||||
|
window_overlap=params["fft_overlap"],
|
||||||
|
window_fn="hann",
|
||||||
|
),
|
||||||
|
frequencies=FrequencyConfig(
|
||||||
|
min_freq=params["min_freq"],
|
||||||
|
max_freq=params["max_freq"],
|
||||||
|
),
|
||||||
|
scale=get_spectrogram_scale(params["spec_scale"]),
|
||||||
|
denoise=params["denoise_spec_avg"],
|
||||||
|
size=SpecSizeConfig(
|
||||||
|
height=params["spec_height"],
|
||||||
|
resize_factor=params["resize_factor"],
|
||||||
|
),
|
||||||
|
max_scale=params["max_scale_spec"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_preprocessing_config(
|
||||||
|
params: dict,
|
||||||
|
) -> TrainPreprocessingConfig:
|
||||||
|
generic = params["generic_class"][0]
|
||||||
|
preprocessing = get_preprocessing_config(params)
|
||||||
|
|
||||||
|
freq_bin_width, time_bin_width = get_spectrogram_resolution(
|
||||||
|
preprocessing.spectrogram
|
||||||
|
)
|
||||||
|
|
||||||
|
return TrainPreprocessingConfig(
|
||||||
|
preprocessing=preprocessing,
|
||||||
|
target=TargetConfig(
|
||||||
|
classes=[
|
||||||
|
TagInfo(key="class", value=class_name, label=class_name)
|
||||||
|
for class_name in params["class_names"]
|
||||||
|
],
|
||||||
|
generic_class=TagInfo(
|
||||||
|
key="class",
|
||||||
|
value=generic,
|
||||||
|
label=generic,
|
||||||
|
),
|
||||||
|
include=[
|
||||||
|
TagInfo(key="event", value=event)
|
||||||
|
for event in params["events_of_interest"]
|
||||||
|
],
|
||||||
|
exclude=[
|
||||||
|
TagInfo(key="class", value=value)
|
||||||
|
for value in params["classes_to_ignore"]
|
||||||
|
],
|
||||||
|
),
|
||||||
|
heatmaps=HeatmapsConfig(
|
||||||
|
position="bottom-left",
|
||||||
|
time_scale=1 / time_bin_width,
|
||||||
|
frequency_scale=1 / freq_bin_width,
|
||||||
|
sigma=params["target_sigma"],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 'standardize_classs_names_ip',
|
||||||
|
# 'convert_to_genus',
|
||||||
|
# 'genus_mapping',
|
||||||
|
# 'standardize_classs_names',
|
||||||
|
# 'genus_names',
|
||||||
|
|
||||||
|
# ['data_dir',
|
||||||
|
# 'ann_dir',
|
||||||
|
# 'train_split',
|
||||||
|
# 'model_name',
|
||||||
|
# 'num_filters',
|
||||||
|
# 'experiment',
|
||||||
|
# 'model_file_name',
|
||||||
|
# 'op_im_dir',
|
||||||
|
# 'op_im_dir_test',
|
||||||
|
# 'notes',
|
||||||
|
# 'spec_divide_factor',
|
||||||
|
# 'detection_overlap',
|
||||||
|
# 'ignore_start_end',
|
||||||
|
# 'detection_threshold',
|
||||||
|
# 'nms_kernel_size',
|
||||||
|
# 'nms_top_k_per_sec',
|
||||||
|
# 'aug_prob',
|
||||||
|
# 'augment_at_train',
|
||||||
|
# 'augment_at_train_combine',
|
||||||
|
# 'echo_max_delay',
|
||||||
|
# 'stretch_squeeze_delta',
|
||||||
|
# 'mask_max_time_perc',
|
||||||
|
# 'mask_max_freq_perc',
|
||||||
|
# 'spec_amp_scaling',
|
||||||
|
# 'aug_sampling_rates',
|
||||||
|
# 'train_loss',
|
||||||
|
# 'det_loss_weight',
|
||||||
|
# 'size_loss_weight',
|
||||||
|
# 'class_loss_weight',
|
||||||
|
# 'individual_loss_weight',
|
||||||
|
# 'emb_dim',
|
||||||
|
# 'lr',
|
||||||
|
# 'batch_size',
|
||||||
|
# 'num_workers',
|
||||||
|
# 'num_epochs',
|
||||||
|
# 'num_eval_epochs',
|
||||||
|
# 'device',
|
||||||
|
# 'save_test_image_during_train',
|
||||||
|
# 'save_test_image_after_train',
|
||||||
|
# 'train_sets',
|
||||||
|
# 'test_sets',
|
||||||
|
# 'class_inv_freq',
|
||||||
|
# 'ip_height']
|
@ -27,6 +27,28 @@ class AudioConfig(BaseConfig):
|
|||||||
duration: Optional[float] = DEFAULT_DURATION
|
duration: Optional[float] = DEFAULT_DURATION
|
||||||
|
|
||||||
|
|
||||||
|
def load_file_audio(
|
||||||
|
path: data.PathLike,
|
||||||
|
config: Optional[AudioConfig] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
recording = data.Recording.from_file(path)
|
||||||
|
return load_recording_audio(recording, config=config, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
def load_recording_audio(
|
||||||
|
recording: data.Recording,
|
||||||
|
config: Optional[AudioConfig] = None,
|
||||||
|
dtype: DTypeLike = np.float32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
clip = data.Clip(
|
||||||
|
recording=recording,
|
||||||
|
start_time=0,
|
||||||
|
end_time=recording.duration,
|
||||||
|
)
|
||||||
|
return load_clip_audio(clip, config=config, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
def load_clip_audio(
|
def load_clip_audio(
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
config: Optional[AudioConfig] = None,
|
config: Optional[AudioConfig] = None,
|
||||||
|
@ -10,29 +10,17 @@ from soundevent import arrays, audio
|
|||||||
from soundevent.arrays import operations as ops
|
from soundevent.arrays import operations as ops
|
||||||
|
|
||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess.audio import DEFAULT_DURATION
|
|
||||||
|
|
||||||
FFT_WIN_LENGTH_S = 512 / 256000.0
|
|
||||||
FFT_OVERLAP = 0.75
|
|
||||||
MAX_FREQ_HZ = 120000
|
|
||||||
MIN_FREQ_HZ = 10000
|
|
||||||
SPEC_HEIGHT = 128
|
|
||||||
SPEC_WIDTH = 256
|
|
||||||
SPEC_SCALE = "pcen"
|
|
||||||
SPEC_TIME_PERIOD = DEFAULT_DURATION / SPEC_WIDTH
|
|
||||||
DENOISE_SPEC_AVG = True
|
|
||||||
MAX_SCALE_SPEC = False
|
|
||||||
|
|
||||||
|
|
||||||
class FFTConfig(BaseConfig):
|
class FFTConfig(BaseConfig):
|
||||||
window_duration: float = Field(default=FFT_WIN_LENGTH_S, gt=0)
|
window_duration: float = Field(default=0.002, gt=0)
|
||||||
window_overlap: float = Field(default=FFT_OVERLAP, ge=0, lt=1)
|
window_overlap: float = Field(default=0.75, ge=0, lt=1)
|
||||||
window_fn: str = "hann"
|
window_fn: str = "hann"
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
class FrequencyConfig(BaseConfig):
|
||||||
max_freq: int = Field(default=MAX_FREQ_HZ, gt=0)
|
max_freq: int = Field(default=120_000, gt=0)
|
||||||
min_freq: int = Field(default=MIN_FREQ_HZ, gt=0)
|
min_freq: int = Field(default=10_000, gt=0)
|
||||||
|
|
||||||
|
|
||||||
class PcenConfig(BaseConfig):
|
class PcenConfig(BaseConfig):
|
||||||
@ -44,17 +32,20 @@ class PcenConfig(BaseConfig):
|
|||||||
|
|
||||||
|
|
||||||
class SpecSizeConfig(BaseConfig):
|
class SpecSizeConfig(BaseConfig):
|
||||||
height: int = SPEC_HEIGHT
|
height: int = 256
|
||||||
time_period: float = SPEC_TIME_PERIOD
|
resize_factor: Optional[float] = 0.5
|
||||||
|
divide_factor: Optional[int] = 32
|
||||||
|
|
||||||
|
|
||||||
class SpectrogramConfig(BaseConfig):
|
class SpectrogramConfig(BaseConfig):
|
||||||
fft: FFTConfig = Field(default_factory=FFTConfig)
|
fft: FFTConfig = Field(default_factory=FFTConfig)
|
||||||
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
frequencies: FrequencyConfig = Field(default_factory=FrequencyConfig)
|
||||||
scale: Union[Literal["log"], None, PcenConfig] = "log"
|
scale: Union[Literal["log"], None, PcenConfig] = Field(
|
||||||
|
default_factory=PcenConfig
|
||||||
|
)
|
||||||
|
size: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
||||||
denoise: bool = True
|
denoise: bool = True
|
||||||
resize: Optional[SpecSizeConfig] = Field(default_factory=SpecSizeConfig)
|
max_scale: bool = False
|
||||||
max_scale: bool = MAX_SCALE_SPEC
|
|
||||||
|
|
||||||
|
|
||||||
def compute_spectrogram(
|
def compute_spectrogram(
|
||||||
@ -64,6 +55,16 @@ def compute_spectrogram(
|
|||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
config = config or SpectrogramConfig()
|
config = config or SpectrogramConfig()
|
||||||
|
|
||||||
|
if config.size and config.size.divide_factor:
|
||||||
|
# Need to pad the audio to make sure the spectrogram has a
|
||||||
|
# width compatible with the divide factor
|
||||||
|
wav = pad_audio(
|
||||||
|
wav,
|
||||||
|
window_duration=config.fft.window_duration,
|
||||||
|
window_overlap=config.fft.window_overlap,
|
||||||
|
divide_factor=config.size.divide_factor,
|
||||||
|
)
|
||||||
|
|
||||||
spec = stft(
|
spec = stft(
|
||||||
wav,
|
wav,
|
||||||
window_duration=config.fft.window_duration,
|
window_duration=config.fft.window_duration,
|
||||||
@ -83,8 +84,12 @@ def compute_spectrogram(
|
|||||||
if config.denoise:
|
if config.denoise:
|
||||||
spec = denoise_spectrogram(spec)
|
spec = denoise_spectrogram(spec)
|
||||||
|
|
||||||
if config.resize:
|
if config.size:
|
||||||
spec = resize_spectrogram(spec, config=config.resize)
|
spec = resize_spectrogram(
|
||||||
|
spec,
|
||||||
|
height=config.size.height,
|
||||||
|
resize_factor=config.size.resize_factor,
|
||||||
|
)
|
||||||
|
|
||||||
if config.max_scale:
|
if config.max_scale:
|
||||||
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
|
||||||
@ -94,8 +99,8 @@ def compute_spectrogram(
|
|||||||
|
|
||||||
def crop_spectrogram_frequencies(
|
def crop_spectrogram_frequencies(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
min_freq: int = MIN_FREQ_HZ,
|
min_freq: int = 10_000,
|
||||||
max_freq: int = MAX_FREQ_HZ,
|
max_freq: int = 120_000,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
return arrays.crop_dim(
|
return arrays.crop_dim(
|
||||||
spec,
|
spec,
|
||||||
@ -116,9 +121,10 @@ def stft(
|
|||||||
step = arrays.get_dim_step(wave, dim="time")
|
step = arrays.get_dim_step(wave, dim="time")
|
||||||
sampling_rate = 1 / step
|
sampling_rate = 1 / step
|
||||||
|
|
||||||
hop_len = window_duration * (1 - window_overlap)
|
|
||||||
nfft = int(window_duration * sampling_rate)
|
nfft = int(window_duration * sampling_rate)
|
||||||
noverlap = int(window_overlap * nfft)
|
noverlap = int(window_overlap * nfft)
|
||||||
|
hop_len = nfft - noverlap
|
||||||
|
hop_duration = hop_len / sampling_rate
|
||||||
|
|
||||||
spec, _ = librosa.core.spectrum._spectrogram(
|
spec, _ = librosa.core.spectrum._spectrogram(
|
||||||
y=wave.data.astype(dtype),
|
y=wave.data.astype(dtype),
|
||||||
@ -146,12 +152,12 @@ def stft(
|
|||||||
"time": arrays.create_time_dim_from_array(
|
"time": arrays.create_time_dim_from_array(
|
||||||
np.linspace(
|
np.linspace(
|
||||||
start_time,
|
start_time,
|
||||||
end_time - (window_duration - hop_len),
|
end_time - (window_duration - hop_duration),
|
||||||
spec.shape[1],
|
spec.shape[1],
|
||||||
endpoint=False,
|
endpoint=False,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
),
|
),
|
||||||
step=hop_len,
|
step=hop_duration,
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
attrs={
|
attrs={
|
||||||
@ -202,7 +208,6 @@ def scale_pcen(
|
|||||||
power: float = 0.5,
|
power: float = 0.5,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
samplerate = spec.attrs["original_samplerate"]
|
samplerate = spec.attrs["original_samplerate"]
|
||||||
# NOTE: Not sure why the 10 is there
|
|
||||||
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
t_frames = time_constant * samplerate / (float(hop_length) * 10)
|
||||||
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
smoothing_constant = (np.sqrt(1 + 4 * t_frames**2) - 1) / (2 * t_frames**2)
|
||||||
return audio.pcen(
|
return audio.pcen(
|
||||||
@ -231,12 +236,114 @@ def scale_log(
|
|||||||
|
|
||||||
def resize_spectrogram(
|
def resize_spectrogram(
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
config: SpecSizeConfig,
|
height: int = 128,
|
||||||
|
resize_factor: Optional[float] = 0.5,
|
||||||
) -> xr.DataArray:
|
) -> xr.DataArray:
|
||||||
duration = arrays.get_dim_width(spec, dim="time")
|
resize_factor = resize_factor or 1
|
||||||
|
current_width = spec.sizes["time"]
|
||||||
return ops.resize(
|
return ops.resize(
|
||||||
spec,
|
spec,
|
||||||
time=int(np.ceil(duration / config.time_period)),
|
time=int(resize_factor * current_width),
|
||||||
frequency=config.height,
|
frequency=int(resize_factor * height),
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def adjust_spectrogram_width(
|
||||||
|
spec: xr.DataArray,
|
||||||
|
divide_factor: int = 32,
|
||||||
|
time_period: float = 0.001,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
time_width = spec.sizes["time"]
|
||||||
|
|
||||||
|
if time_width % divide_factor == 0:
|
||||||
|
return spec
|
||||||
|
|
||||||
|
target_size = int(
|
||||||
|
np.ceil(spec.sizes["time"] / divide_factor) * divide_factor
|
||||||
|
)
|
||||||
|
extra_duration = (target_size - time_width) * time_period
|
||||||
|
_, stop = arrays.get_dim_range(spec, dim="time")
|
||||||
|
resized = ops.extend_dim(
|
||||||
|
spec,
|
||||||
|
dim="time",
|
||||||
|
stop=stop + extra_duration,
|
||||||
|
)
|
||||||
|
return resized
|
||||||
|
|
||||||
|
|
||||||
|
def pad_audio(
|
||||||
|
wave: xr.DataArray,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
divide_factor: int = 32,
|
||||||
|
) -> xr.DataArray:
|
||||||
|
current_duration = arrays.get_dim_width(wave, dim="time")
|
||||||
|
step = arrays.get_dim_step(wave, dim="time")
|
||||||
|
samplerate = int(1 / step)
|
||||||
|
|
||||||
|
estimated_spec_width = duration_to_spec_width(
|
||||||
|
current_duration,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
)
|
||||||
|
|
||||||
|
if estimated_spec_width % divide_factor == 0:
|
||||||
|
return wave
|
||||||
|
|
||||||
|
target_spec_width = int(
|
||||||
|
np.ceil(estimated_spec_width / divide_factor) * divide_factor
|
||||||
|
)
|
||||||
|
target_samples = spec_width_to_samples(
|
||||||
|
target_spec_width,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
)
|
||||||
|
return ops.adjust_dim_width(
|
||||||
|
wave,
|
||||||
|
dim="time",
|
||||||
|
width=target_samples,
|
||||||
|
position="start",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def duration_to_spec_width(
|
||||||
|
duration: float,
|
||||||
|
samplerate: int,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
) -> int:
|
||||||
|
samples = int(duration * samplerate)
|
||||||
|
fft_len = int(window_duration * samplerate)
|
||||||
|
fft_overlap = int(window_overlap * fft_len)
|
||||||
|
hop_len = fft_len - fft_overlap
|
||||||
|
width = (samples - fft_len + hop_len) / hop_len
|
||||||
|
return int(np.floor(width))
|
||||||
|
|
||||||
|
|
||||||
|
def spec_width_to_samples(
|
||||||
|
width: int,
|
||||||
|
samplerate: int,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
) -> int:
|
||||||
|
fft_len = int(window_duration * samplerate)
|
||||||
|
fft_overlap = int(window_overlap * fft_len)
|
||||||
|
hop_len = fft_len - fft_overlap
|
||||||
|
return width * hop_len + fft_len - hop_len
|
||||||
|
|
||||||
|
|
||||||
|
def get_spectrogram_resolution(
|
||||||
|
config: SpectrogramConfig,
|
||||||
|
) -> tuple[float, float]:
|
||||||
|
max_freq = config.frequencies.max_freq
|
||||||
|
min_freq = config.frequencies.min_freq
|
||||||
|
assert config.size is not None
|
||||||
|
|
||||||
|
spec_height = config.size.height
|
||||||
|
resize_factor = config.size.resize_factor or 1
|
||||||
|
freq_bin_width = (max_freq - min_freq) / (spec_height * resize_factor)
|
||||||
|
hop_duration = config.fft.window_duration * (1 - config.fft.window_overlap)
|
||||||
|
return freq_bin_width, hop_duration / resize_factor
|
||||||
|
@ -22,9 +22,9 @@ class TermInfo(BaseModel):
|
|||||||
|
|
||||||
class TagInfo(BaseModel):
|
class TagInfo(BaseModel):
|
||||||
value: str
|
value: str
|
||||||
label: Optional[str] = None
|
|
||||||
term: Optional[TermInfo] = None
|
term: Optional[TermInfo] = None
|
||||||
key: Optional[str] = None
|
key: Optional[str] = None
|
||||||
|
label: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
call_type = data.Term(
|
call_type = data.Term(
|
||||||
|
@ -7,23 +7,29 @@ from soundevent import arrays, data, geometry
|
|||||||
from soundevent.geometry.operations import Positions
|
from soundevent.geometry.operations import Positions
|
||||||
from soundevent.types import ClassMapper
|
from soundevent.types import ClassMapper
|
||||||
|
|
||||||
|
from batdetect2.configs import BaseConfig
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"ClassMapper",
|
"ClassMapper",
|
||||||
"generate_heatmaps",
|
"generate_heatmaps",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
TARGET_SIGMA = 3.0
|
class HeatmapsConfig(BaseConfig):
|
||||||
|
position: Positions = "bottom-left"
|
||||||
|
sigma: float = 3.0
|
||||||
|
time_scale: float = 1000.0
|
||||||
|
frequency_scale: float = 1 / 859.375
|
||||||
|
|
||||||
|
|
||||||
def generate_heatmaps(
|
def generate_heatmaps(
|
||||||
sound_events: Sequence[data.SoundEventAnnotation],
|
sound_events: Sequence[data.SoundEventAnnotation],
|
||||||
spec: xr.DataArray,
|
spec: xr.DataArray,
|
||||||
class_mapper: ClassMapper,
|
class_mapper: ClassMapper,
|
||||||
target_sigma: float = TARGET_SIGMA,
|
target_sigma: float = 3.0,
|
||||||
position: Positions = "bottom-left",
|
position: Positions = "bottom-left",
|
||||||
time_scale: float = 1.0,
|
time_scale: float = 1000.0,
|
||||||
frequency_scale: float = 1.0,
|
frequency_scale: float = 1 / 859.375,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
) -> Tuple[xr.DataArray, xr.DataArray, xr.DataArray]:
|
||||||
shape = dict(zip(spec.dims, spec.shape))
|
shape = dict(zip(spec.dims, spec.shape))
|
||||||
@ -39,7 +45,7 @@ def generate_heatmaps(
|
|||||||
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
data=np.zeros((class_mapper.num_classes, *spec.shape), dtype=dtype),
|
||||||
dims=["category", *spec.dims],
|
dims=["category", *spec.dims],
|
||||||
coords={
|
coords={
|
||||||
"category": class_mapper.class_labels,
|
"category": [*class_mapper.class_labels],
|
||||||
**spec.coords,
|
**spec.coords,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
@ -14,12 +14,10 @@ from tqdm.auto import tqdm
|
|||||||
from batdetect2.configs import BaseConfig
|
from batdetect2.configs import BaseConfig
|
||||||
from batdetect2.preprocess import (
|
from batdetect2.preprocess import (
|
||||||
PreprocessingConfig,
|
PreprocessingConfig,
|
||||||
preprocess_audio_clip,
|
compute_spectrogram,
|
||||||
)
|
load_clip_audio,
|
||||||
from batdetect2.train.labels import (
|
|
||||||
TARGET_SIGMA,
|
|
||||||
generate_heatmaps,
|
|
||||||
)
|
)
|
||||||
|
from batdetect2.train.labels import HeatmapsConfig, generate_heatmaps
|
||||||
from batdetect2.train.targets import (
|
from batdetect2.train.targets import (
|
||||||
TargetConfig,
|
TargetConfig,
|
||||||
build_class_mapper,
|
build_class_mapper,
|
||||||
@ -34,16 +32,12 @@ __all__ = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class MasksConfig(BaseConfig):
|
|
||||||
sigma: float = TARGET_SIGMA
|
|
||||||
|
|
||||||
|
|
||||||
class TrainPreprocessingConfig(BaseConfig):
|
class TrainPreprocessingConfig(BaseConfig):
|
||||||
preprocessing: PreprocessingConfig = Field(
|
preprocessing: PreprocessingConfig = Field(
|
||||||
default_factory=PreprocessingConfig
|
default_factory=PreprocessingConfig
|
||||||
)
|
)
|
||||||
target: TargetConfig = Field(default_factory=TargetConfig)
|
target: TargetConfig = Field(default_factory=TargetConfig)
|
||||||
masks: MasksConfig = Field(default_factory=MasksConfig)
|
heatmaps: HeatmapsConfig = Field(default_factory=HeatmapsConfig)
|
||||||
|
|
||||||
|
|
||||||
def generate_train_example(
|
def generate_train_example(
|
||||||
@ -53,9 +47,14 @@ def generate_train_example(
|
|||||||
"""Generate a training example."""
|
"""Generate a training example."""
|
||||||
config = config or TrainPreprocessingConfig()
|
config = config or TrainPreprocessingConfig()
|
||||||
|
|
||||||
spectrogram = preprocess_audio_clip(
|
wave = load_clip_audio(
|
||||||
clip_annotation.clip,
|
clip_annotation.clip,
|
||||||
config=config.preprocessing,
|
config=config.preprocessing.audio,
|
||||||
|
)
|
||||||
|
|
||||||
|
spectrogram = compute_spectrogram(
|
||||||
|
wave,
|
||||||
|
config=config.preprocessing.spectrogram,
|
||||||
)
|
)
|
||||||
|
|
||||||
filter_fn = build_sound_event_filter(
|
filter_fn = build_sound_event_filter(
|
||||||
@ -65,17 +64,24 @@ def generate_train_example(
|
|||||||
selected_events = [
|
selected_events = [
|
||||||
event for event in clip_annotation.sound_events if filter_fn(event)
|
event for event in clip_annotation.sound_events if filter_fn(event)
|
||||||
]
|
]
|
||||||
|
|
||||||
class_mapper = build_class_mapper(config.target.classes)
|
class_mapper = build_class_mapper(config.target.classes)
|
||||||
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
detection_heatmap, class_heatmap, size_heatmap = generate_heatmaps(
|
||||||
selected_events,
|
selected_events,
|
||||||
spectrogram,
|
spectrogram,
|
||||||
class_mapper,
|
class_mapper,
|
||||||
target_sigma=config.masks.sigma,
|
target_sigma=config.heatmaps.sigma,
|
||||||
|
position=config.heatmaps.position,
|
||||||
|
time_scale=config.heatmaps.time_scale,
|
||||||
|
frequency_scale=config.heatmaps.frequency_scale,
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = xr.Dataset(
|
dataset = xr.Dataset(
|
||||||
{
|
{
|
||||||
|
# NOTE: Need to rename the time dimension to avoid conflicts with
|
||||||
|
# the spectrogram time dimension, otherwise xarray will interpolate
|
||||||
|
# the spectrogram and the heatmaps to the same temporal resolution
|
||||||
|
# as the waveform.
|
||||||
|
"audio": wave.rename({"time": "audio_time"}),
|
||||||
"spectrogram": spectrogram,
|
"spectrogram": spectrogram,
|
||||||
"detection": detection_heatmap,
|
"detection": detection_heatmap,
|
||||||
"class": class_heatmap,
|
"class": class_heatmap,
|
||||||
|
@ -13,9 +13,9 @@ class TargetConfig(BaseConfig):
|
|||||||
"""Configuration for target generation."""
|
"""Configuration for target generation."""
|
||||||
|
|
||||||
classes: List[TagInfo] = Field(default_factory=list)
|
classes: List[TagInfo] = Field(default_factory=list)
|
||||||
|
generic_class: Optional[TagInfo] = None
|
||||||
|
|
||||||
include: Optional[List[TagInfo]] = None
|
include: Optional[List[TagInfo]] = None
|
||||||
|
|
||||||
exclude: Optional[List[TagInfo]] = None
|
exclude: Optional[List[TagInfo]] = None
|
||||||
|
|
||||||
|
|
||||||
@ -73,7 +73,7 @@ class GenericMapper(ClassMapper):
|
|||||||
raise ValueError("Number of targets and class labels must match.")
|
raise ValueError("Number of targets and class labels must match.")
|
||||||
|
|
||||||
self.targets = set(classes)
|
self.targets = set(classes)
|
||||||
self.class_labels = labels
|
self.class_labels = list(dict.fromkeys(labels))
|
||||||
|
|
||||||
self._mapping = {tag: label for tag, label in zip(classes, labels)}
|
self._mapping = {tag: label for tag, label in zip(classes, labels)}
|
||||||
self._inverse_mapping = {
|
self._inverse_mapping = {
|
||||||
|
@ -17,7 +17,7 @@ dependencies = [
|
|||||||
"torch>=1.13.1,<2.5.0",
|
"torch>=1.13.1,<2.5.0",
|
||||||
"torchaudio>=1.13.1,<2.5.0",
|
"torchaudio>=1.13.1,<2.5.0",
|
||||||
"torchvision>=0.14.0",
|
"torchvision>=0.14.0",
|
||||||
"soundevent[audio,geometry,plot]>=2.2",
|
"soundevent[audio,geometry,plot]>=2.3",
|
||||||
"click>=8.1.7",
|
"click>=8.1.7",
|
||||||
"netcdf4>=1.6.5",
|
"netcdf4>=1.6.5",
|
||||||
"tqdm>=4.66.2",
|
"tqdm>=4.66.2",
|
||||||
|
@ -1,7 +1,11 @@
|
|||||||
|
import uuid
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
import soundfile as sf
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@ -19,6 +23,13 @@ def example_audio_dir(example_data_dir: Path) -> Path:
|
|||||||
return example_audio_dir
|
return example_audio_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def example_anns_dir(example_data_dir: Path) -> Path:
|
||||||
|
example_anns_dir = example_data_dir / "anns"
|
||||||
|
assert example_anns_dir.exists()
|
||||||
|
return example_anns_dir
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
def example_audio_files(example_audio_dir: Path) -> List[Path]:
|
||||||
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]"))
|
||||||
@ -38,3 +49,61 @@ def contrib_dir(data_dir) -> Path:
|
|||||||
dir = data_dir / "contrib"
|
dir = data_dir / "contrib"
|
||||||
assert dir.exists()
|
assert dir.exists()
|
||||||
return dir
|
return dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def wav_factory(tmp_path: Path):
|
||||||
|
def _wav_factory(
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
duration: float = 0.3,
|
||||||
|
channels: int = 1,
|
||||||
|
samplerate: int = 441_000,
|
||||||
|
bit_depth: int = 16,
|
||||||
|
) -> Path:
|
||||||
|
path = path or tmp_path / f"{uuid.uuid4()}.wav"
|
||||||
|
frames = int(samplerate * duration)
|
||||||
|
shape = (frames, channels)
|
||||||
|
subtype = f"PCM_{bit_depth}"
|
||||||
|
|
||||||
|
if bit_depth == 16:
|
||||||
|
dtype = np.int16
|
||||||
|
elif bit_depth == 32:
|
||||||
|
dtype = np.int32
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported bit depth: {bit_depth}")
|
||||||
|
|
||||||
|
wav = np.random.uniform(
|
||||||
|
low=np.iinfo(dtype).min,
|
||||||
|
high=np.iinfo(dtype).max,
|
||||||
|
size=shape,
|
||||||
|
).astype(dtype)
|
||||||
|
sf.write(str(path), wav, samplerate, subtype=subtype)
|
||||||
|
return path
|
||||||
|
|
||||||
|
return _wav_factory
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def recording_factory(wav_factory: Callable[..., Path]):
|
||||||
|
def _recording_factory(
|
||||||
|
tags: Optional[list[data.Tag]] = None,
|
||||||
|
path: Optional[Path] = None,
|
||||||
|
recording_id: Optional[uuid.UUID] = None,
|
||||||
|
duration: float = 1,
|
||||||
|
channels: int = 1,
|
||||||
|
samplerate: int = 44100,
|
||||||
|
time_expansion: float = 1,
|
||||||
|
) -> data.Recording:
|
||||||
|
path = path or wav_factory(
|
||||||
|
duration=duration,
|
||||||
|
channels=channels,
|
||||||
|
samplerate=samplerate,
|
||||||
|
)
|
||||||
|
return data.Recording.from_file(
|
||||||
|
path=path,
|
||||||
|
uuid=recording_id or uuid.uuid4(),
|
||||||
|
time_expansion=time_expansion,
|
||||||
|
tags=tags or [],
|
||||||
|
)
|
||||||
|
|
||||||
|
return _recording_factory
|
||||||
|
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20170701_213954-MYOMYS-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180530_213516-EPTSER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
BIN
tests/data/regression/20180627_215323-RHIFER-LR_0_0.5.wav.npz
Normal file
Binary file not shown.
@ -4,7 +4,7 @@ import numpy as np
|
|||||||
import pytest
|
import pytest
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.data import preprocessing
|
from batdetect2 import preprocess
|
||||||
from batdetect2.utils import audio_utils
|
from batdetect2.utils import audio_utils
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
@ -44,10 +44,10 @@ def test_audio_loading_hasnt_changed(
|
|||||||
target_samp_rate=target_sampling_rate,
|
target_samp_rate=target_sampling_rate,
|
||||||
scale=scale,
|
scale=scale,
|
||||||
)
|
)
|
||||||
audio_new = preprocessing.load_clip_audio(
|
audio_new = preprocess.load_clip_audio(
|
||||||
clip,
|
clip,
|
||||||
config=preprocessing.AudioConfig(
|
config=preprocess.AudioConfig(
|
||||||
resample=preprocessing.ResampleConfig(
|
resample=preprocess.ResampleConfig(
|
||||||
samplerate=target_sampling_rate,
|
samplerate=target_sampling_rate,
|
||||||
),
|
),
|
||||||
center=scale,
|
center=scale,
|
||||||
@ -84,20 +84,20 @@ def test_spectrogram_generation_hasnt_changed(
|
|||||||
if spec_scale == "log":
|
if spec_scale == "log":
|
||||||
scale = "log"
|
scale = "log"
|
||||||
elif spec_scale == "pcen":
|
elif spec_scale == "pcen":
|
||||||
scale = preprocessing.PcenConfig()
|
scale = preprocess.PcenConfig()
|
||||||
|
|
||||||
config = preprocessing.SpectrogramConfig(
|
config = preprocess.SpectrogramConfig(
|
||||||
fft=preprocessing.FFTConfig(
|
fft=preprocess.FFTConfig(
|
||||||
window_overlap=fft_overlap,
|
window_overlap=fft_overlap,
|
||||||
window_duration=fft_win_length,
|
window_duration=fft_win_length,
|
||||||
),
|
),
|
||||||
frequencies=preprocessing.FrequencyConfig(
|
frequencies=preprocess.FrequencyConfig(
|
||||||
min_freq=min_freq,
|
min_freq=min_freq,
|
||||||
max_freq=max_freq,
|
max_freq=max_freq,
|
||||||
),
|
),
|
||||||
scale=scale,
|
scale=scale,
|
||||||
denoise=denoise_spec_avg,
|
denoise=denoise_spec_avg,
|
||||||
resize=None,
|
size=None,
|
||||||
max_scale=max_scale_spec,
|
max_scale=max_scale_spec,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -112,10 +112,10 @@ def test_spectrogram_generation_hasnt_changed(
|
|||||||
end_time=recording.duration,
|
end_time=recording.duration,
|
||||||
)
|
)
|
||||||
|
|
||||||
audio = preprocessing.load_clip_audio(
|
audio = preprocess.load_clip_audio(
|
||||||
clip,
|
clip,
|
||||||
config=preprocessing.AudioConfig(
|
config=preprocess.AudioConfig(
|
||||||
resample=preprocessing.ResampleConfig(
|
resample=preprocess.ResampleConfig(
|
||||||
samplerate=target_sampling_rate,
|
samplerate=target_sampling_rate,
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
@ -135,7 +135,7 @@ def test_spectrogram_generation_hasnt_changed(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
new_spec = preprocessing.compute_spectrogram(
|
new_spec = preprocess.compute_spectrogram(
|
||||||
audio,
|
audio,
|
||||||
config=config,
|
config=config,
|
||||||
dtype=np.float32,
|
dtype=np.float32,
|
||||||
|
75
tests/test_migration/test_training.py
Normal file
75
tests/test_migration/test_training.py
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from batdetect2.compat.data import load_annotation_project
|
||||||
|
from batdetect2.compat.params import get_training_preprocessing_config
|
||||||
|
from batdetect2.train.preprocess import generate_train_example
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def regression_dir(data_dir: Path) -> Path:
|
||||||
|
dir = data_dir / "regression"
|
||||||
|
assert dir.exists()
|
||||||
|
return dir
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_generate_similar_training_inputs(
|
||||||
|
example_audio_dir: Path,
|
||||||
|
example_audio_files: List[Path],
|
||||||
|
example_anns_dir: Path,
|
||||||
|
regression_dir: Path,
|
||||||
|
):
|
||||||
|
old_parameters = json.loads((regression_dir / "params.json").read_text())
|
||||||
|
config = get_training_preprocessing_config(old_parameters)
|
||||||
|
|
||||||
|
for audio_file in example_audio_files:
|
||||||
|
example_file = regression_dir / f"{audio_file.name}.npz"
|
||||||
|
|
||||||
|
dataset = np.load(example_file)
|
||||||
|
|
||||||
|
spec = dataset["spec"][0]
|
||||||
|
detection_mask = dataset["detection_mask"][0]
|
||||||
|
size_mask = dataset["size_mask"]
|
||||||
|
class_mask = dataset["class_mask"]
|
||||||
|
|
||||||
|
project = load_annotation_project(
|
||||||
|
example_anns_dir,
|
||||||
|
audio_dir=example_audio_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
clip_annotation = next(
|
||||||
|
ann
|
||||||
|
for ann in project.clip_annotations
|
||||||
|
if ann.clip.recording.path == audio_file
|
||||||
|
)
|
||||||
|
|
||||||
|
new_dataset = generate_train_example(clip_annotation, config)
|
||||||
|
new_spec = new_dataset["spectrogram"].values
|
||||||
|
new_detection_mask = new_dataset["detection"].values
|
||||||
|
new_size_mask = new_dataset["size"].values
|
||||||
|
new_class_mask = new_dataset["class"].values
|
||||||
|
|
||||||
|
assert spec.shape == new_spec.shape
|
||||||
|
assert detection_mask.shape == new_detection_mask.shape
|
||||||
|
assert size_mask.shape == new_size_mask.shape
|
||||||
|
assert class_mask.shape[1:] == new_class_mask.shape[1:]
|
||||||
|
assert class_mask.shape[0] == new_class_mask.shape[0] + 1
|
||||||
|
|
||||||
|
x_new, y_new = np.nonzero(new_size_mask.max(axis=0))
|
||||||
|
x_orig, y_orig = np.nonzero(np.flipud(size_mask.max(axis=0)))
|
||||||
|
|
||||||
|
assert (x_new == x_orig).all()
|
||||||
|
|
||||||
|
# NOTE: a difference of 1 pixel is due to discrepancies on how
|
||||||
|
# frequency bins are interpreted. Shouldn't be an issue
|
||||||
|
assert (y_new == y_orig + 1).all()
|
||||||
|
|
||||||
|
width_new, height_new = new_size_mask[:, x_new, y_new]
|
||||||
|
width_orig, height_orig = np.flip(size_mask, axis=1)[:, x_orig, y_orig]
|
||||||
|
|
||||||
|
assert (np.floor(width_new) == width_orig).all()
|
||||||
|
assert (np.ceil(height_new) == height_orig).all()
|
0
tests/test_preprocessing/__init__.py
Normal file
0
tests/test_preprocessing/__init__.py
Normal file
0
tests/test_preprocessing/test_audio.py
Normal file
0
tests/test_preprocessing/test_audio.py
Normal file
141
tests/test_preprocessing/test_spectrogram.py
Normal file
141
tests/test_preprocessing/test_spectrogram.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import math
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
|
from hypothesis import HealthCheck, given, settings
|
||||||
|
from hypothesis import strategies as st
|
||||||
|
from soundevent import arrays
|
||||||
|
|
||||||
|
from batdetect2.preprocess.audio import AudioConfig, load_file_audio
|
||||||
|
from batdetect2.preprocess.spectrogram import (
|
||||||
|
FFTConfig,
|
||||||
|
FrequencyConfig,
|
||||||
|
SpecSizeConfig,
|
||||||
|
SpectrogramConfig,
|
||||||
|
compute_spectrogram,
|
||||||
|
duration_to_spec_width,
|
||||||
|
get_spectrogram_resolution,
|
||||||
|
pad_audio,
|
||||||
|
spec_width_to_samples,
|
||||||
|
stft,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@settings(suppress_health_check=[HealthCheck.function_scoped_fixture])
|
||||||
|
@given(
|
||||||
|
duration=st.floats(min_value=0.1, max_value=1.0),
|
||||||
|
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
||||||
|
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
||||||
|
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
||||||
|
)
|
||||||
|
def test_can_estimate_correctly_spectrogram_width_from_duration(
|
||||||
|
duration: float,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
samplerate: int,
|
||||||
|
wav_factory: Callable[..., Path],
|
||||||
|
):
|
||||||
|
path = wav_factory(duration=duration, samplerate=samplerate)
|
||||||
|
audio = load_file_audio(
|
||||||
|
path,
|
||||||
|
# NOTE: Dont resample nor adjust duration to test if the width
|
||||||
|
# estimation works on all scenarios
|
||||||
|
config=AudioConfig(resample=None, duration=None),
|
||||||
|
)
|
||||||
|
spectrogram = stft(audio, window_duration, window_overlap)
|
||||||
|
|
||||||
|
spec_width = duration_to_spec_width(
|
||||||
|
duration,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
)
|
||||||
|
assert spectrogram.sizes["time"] == spec_width
|
||||||
|
|
||||||
|
rebuilt_duration = (
|
||||||
|
spec_width_to_samples(
|
||||||
|
spec_width,
|
||||||
|
samplerate=samplerate,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
)
|
||||||
|
/ samplerate
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
abs(duration - rebuilt_duration)
|
||||||
|
< (1 - window_overlap) * window_duration
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@settings(
|
||||||
|
suppress_health_check=[HealthCheck.function_scoped_fixture],
|
||||||
|
deadline=400,
|
||||||
|
)
|
||||||
|
@given(
|
||||||
|
duration=st.floats(min_value=0.1, max_value=1.0),
|
||||||
|
window_duration=st.floats(min_value=0.001, max_value=0.01),
|
||||||
|
window_overlap=st.floats(min_value=0.2, max_value=0.9),
|
||||||
|
samplerate=st.integers(min_value=256_000, max_value=512_000),
|
||||||
|
divide_factor=st.integers(min_value=16, max_value=64),
|
||||||
|
)
|
||||||
|
def test_can_pad_audio_to_adjust_spectrogram_width(
|
||||||
|
duration: float,
|
||||||
|
window_duration: float,
|
||||||
|
window_overlap: float,
|
||||||
|
samplerate: int,
|
||||||
|
divide_factor: int,
|
||||||
|
wav_factory: Callable[..., Path],
|
||||||
|
):
|
||||||
|
path = wav_factory(duration=duration, samplerate=samplerate)
|
||||||
|
|
||||||
|
audio = load_file_audio(
|
||||||
|
path,
|
||||||
|
# NOTE: Dont resample nor adjust duration to test if the width
|
||||||
|
# estimation works on all scenarios
|
||||||
|
config=AudioConfig(resample=None, duration=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
audio = pad_audio(
|
||||||
|
audio,
|
||||||
|
window_duration=window_duration,
|
||||||
|
window_overlap=window_overlap,
|
||||||
|
divide_factor=divide_factor,
|
||||||
|
)
|
||||||
|
|
||||||
|
spectrogram = stft(audio, window_duration, window_overlap)
|
||||||
|
assert spectrogram.sizes["time"] % divide_factor == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_can_estimate_spectrogram_resolution(
|
||||||
|
wav_factory: Callable[..., Path],
|
||||||
|
):
|
||||||
|
path = wav_factory(duration=0.2, samplerate=256_000)
|
||||||
|
|
||||||
|
audio = load_file_audio(
|
||||||
|
path,
|
||||||
|
# NOTE: Dont resample nor adjust duration to test if the width
|
||||||
|
# estimation works on all scenarios
|
||||||
|
config=AudioConfig(resample=None, duration=None),
|
||||||
|
)
|
||||||
|
|
||||||
|
config = SpectrogramConfig(
|
||||||
|
fft=FFTConfig(),
|
||||||
|
size=SpecSizeConfig(height=256, resize_factor=0.5),
|
||||||
|
frequencies=FrequencyConfig(min_freq=10_000, max_freq=120_000),
|
||||||
|
)
|
||||||
|
|
||||||
|
spec = compute_spectrogram(audio, config=config)
|
||||||
|
|
||||||
|
freq_res, time_res = get_spectrogram_resolution(config)
|
||||||
|
|
||||||
|
assert math.isclose(
|
||||||
|
arrays.get_dim_step(spec, dim="frequency"),
|
||||||
|
freq_res,
|
||||||
|
rel_tol=0.1,
|
||||||
|
)
|
||||||
|
assert math.isclose(
|
||||||
|
arrays.get_dim_step(spec, dim="time"),
|
||||||
|
time_res,
|
||||||
|
rel_tol=0.1,
|
||||||
|
)
|
8
uv.lock
generated
8
uv.lock
generated
@ -236,7 +236,7 @@ requires-dist = [
|
|||||||
{ name = "pytorch-lightning", specifier = ">=2.2.2" },
|
{ name = "pytorch-lightning", specifier = ">=2.2.2" },
|
||||||
{ name = "scikit-learn", specifier = ">=1.2.2" },
|
{ name = "scikit-learn", specifier = ">=1.2.2" },
|
||||||
{ name = "scipy", specifier = ">=1.10.1" },
|
{ name = "scipy", specifier = ">=1.10.1" },
|
||||||
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.2" },
|
{ name = "soundevent", extras = ["audio", "geometry", "plot"], specifier = ">=2.3" },
|
||||||
{ name = "tensorboard", specifier = ">=2.16.2" },
|
{ name = "tensorboard", specifier = ">=2.16.2" },
|
||||||
{ name = "torch", specifier = ">=1.13.1,<2.5.0" },
|
{ name = "torch", specifier = ">=1.13.1,<2.5.0" },
|
||||||
{ name = "torchaudio", specifier = ">=1.13.1,<2.5.0" },
|
{ name = "torchaudio", specifier = ">=1.13.1,<2.5.0" },
|
||||||
@ -2679,15 +2679,15 @@ wheels = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "soundevent"
|
name = "soundevent"
|
||||||
version = "2.2.0"
|
version = "2.3.0"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { registry = "https://pypi.org/simple" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "email-validator" },
|
{ name = "email-validator" },
|
||||||
{ name = "pydantic" },
|
{ name = "pydantic" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/5a/f9/77d723df4d8d3a390d32c07325a08bfc669d5fb55e88b98181b5793e7333/soundevent-2.2.0.tar.gz", hash = "sha256:a87a97c8e4bfdadec4b6edc128470919ef9240a344cf924d2ac21f8c6b50acf1", size = 8715229 }
|
sdist = { url = "https://files.pythonhosted.org/packages/ff/51/83093cabe9ada781a0f7a78f82cc04162d005755b2f0ca3fdcb4ecd47a01/soundevent-2.3.0.tar.gz", hash = "sha256:b75d7674578a52bf196619f8b4b3d9170f2ca321d165ceb45916579a549c3e76", size = 8716539 }
|
||||||
wheels = [
|
wheels = [
|
||||||
{ url = "https://files.pythonhosted.org/packages/09/d8/ce6e2830d47fc3d24db264b94384fa7cdcb8069fd7547d55b7a83857e730/soundevent-2.2.0-py3-none-any.whl", hash = "sha256:c40913c15fc697a82a02df5f62d18ad3b77bfae80a9a5d54c47bc1377d3b4d7c", size = 144188 },
|
{ url = "https://files.pythonhosted.org/packages/ee/67/4c2d881f9b4a0b453dbee91e119e1e48df0fc92de2cc3062fcd8ad0a7e6b/soundevent-2.3.0-py3-none-any.whl", hash = "sha256:f7c74b1d73a347ebe843187c93130dc8af3214add95e6bc485f64944bea0d690", size = 145513 },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
|
Loading…
Reference in New Issue
Block a user