Fix testing issues

This commit is contained in:
mbsantiago 2025-06-20 15:57:11 +01:00
parent 84a13c65a7
commit 3c9e5aca2f
9 changed files with 114 additions and 112 deletions

View File

@ -103,7 +103,7 @@ def convert_xr_dataset_to_raw_prediction(
detections.append(
RawPrediction(
detection_score=det_info.scores,
detection_score=det_info.score,
geometry=geom,
class_scores=det_info.classes,
features=det_info.features,

View File

@ -370,6 +370,7 @@ def load_clip_audio(
"""
config = config or AudioConfig()
with xr.set_options(keep_attrs=True):
try:
wav = (
audio.load_clip(clip, audio_dir=audio_dir)
@ -521,7 +522,7 @@ def resample_audio(
original_samplerate = int(1 / step)
if original_samplerate == samplerate:
return wav.astype(dtype)
return wav.astype(dtype).assign_attrs(original_samplerate=samplerate)
if method == "poly":
resampled = resample_audio_poly(
@ -561,7 +562,11 @@ def resample_audio(
samplerate=samplerate,
),
},
attrs={**wav.attrs, "samplerate": samplerate},
attrs={
**wav.attrs,
"samplerate": samplerate,
"original_samplerate": original_samplerate,
},
)

View File

@ -363,7 +363,7 @@ def compute_spectrogram(
)
if config.peak_normalize:
spec = ops.scale(spec, 1 / (10e-6 + np.max(spec)))
spec = ops.normalize(spec)
return spec.astype(dtype)
@ -436,6 +436,9 @@ def stft(
ValueError
If sample rate cannot be determined from `wave` coordinates.
"""
if "channel" not in wave.coords:
wave = wave.assign_coords(channel=0)
return audio.compute_spectrogram(
wave,
window_size=window_duration,
@ -544,7 +547,7 @@ def apply_pcen(
verified against the specific `soundevent.audio.pcen` implementation
details.
"""
samplerate = spec.attrs["samplerate"]
samplerate = 1 / spec.time.attrs["step"]
hop_size = spec.attrs["hop_size"]
hop_length = int(hop_size * samplerate)
@ -622,6 +625,7 @@ def resize_spectrogram(
spec: xr.DataArray,
height: int = 128,
resize_factor: Optional[float] = 0.5,
dtype: DTypeLike = np.float32, # type: ignore
) -> xr.DataArray:
"""Resize a spectrogram to target dimensions using interpolation.
@ -647,11 +651,26 @@ def resize_spectrogram(
"""
resize_factor = resize_factor or 1
current_width = spec.sizes["time"]
return ops.resize(
spec,
time=int(resize_factor * current_width),
frequency=height,
dtype=np.float32,
target_sizes = {
"time": int(current_width * resize_factor),
"frequency": height,
}
new_coords = {}
for dim in ["time", "frequency"]:
step = arrays.get_dim_step(spec, dim)
start, stop = arrays.get_dim_range(spec, dim)
new_coords[dim] = arrays.create_range_dim(
name=dim,
start=start,
stop=stop + step,
size=target_sizes[dim],
dtype=dtype,
)
return spec.interp(
coords=new_coords, method="linear", kwargs=dict(fill_value=0)
)

View File

@ -93,13 +93,21 @@ class ROITargetMapper(Protocol):
dimension_names: List[str]
def get_roi_position(self, geom: data.Geometry) -> tuple[float, float]:
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> tuple[float, float]:
"""Extract the reference position from a geometry.
Parameters
----------
geom : soundevent.data.Geometry
The input geometry (e.g., BoundingBox, Polygon).
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default.
Returns
-------
@ -141,7 +149,10 @@ class ROITargetMapper(Protocol):
...
def recover_roi(
self, pos: tuple[float, float], dims: np.ndarray
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
) -> data.Geometry:
"""Recover an approximate ROI from a position and target dimensions.
@ -153,8 +164,12 @@ class ROITargetMapper(Protocol):
pos : Tuple[float, float]
The reference position (time, frequency).
dims : np.ndarray
The NumPy array containing the dimensions, matching the order
NumPy array containing the dimensions, matching the order
specified by `dimension_names`.
position : Positions, optional
Overrides the default `position` configured for the mapper.
If provided, this position will be used instead of the mapper's
internal default when reconstructing the roi geometry.
Returns
-------
@ -240,7 +255,11 @@ class BBoxEncoder(ROITargetMapper):
self.time_scale = time_scale
self.frequency_scale = frequency_scale
def get_roi_position(self, geom: data.Geometry) -> Tuple[float, float]:
def get_roi_position(
self,
geom: data.Geometry,
position: Optional[Positions] = None,
) -> Tuple[float, float]:
"""Extract the configured reference position from the geometry.
Uses `soundevent.geometry.get_geometry_point`.
@ -249,6 +268,9 @@ class BBoxEncoder(ROITargetMapper):
----------
geom : soundevent.data.Geometry
Input geometry (e.g., BoundingBox).
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`.
Returns
-------
@ -257,7 +279,8 @@ class BBoxEncoder(ROITargetMapper):
"""
from soundevent import geometry
return geometry.get_geometry_point(geom, position=self.position)
position = position or self.position
return geometry.get_geometry_point(geom, position=position)
def get_roi_size(self, geom: data.Geometry) -> np.ndarray:
"""Calculate the scaled [width, height] from the geometry's bounds.
@ -291,6 +314,7 @@ class BBoxEncoder(ROITargetMapper):
self,
pos: tuple[float, float],
dims: np.ndarray,
position: Optional[Positions] = None,
) -> data.Geometry:
"""Recover a BoundingBox from a position and scaled dimensions.
@ -305,6 +329,10 @@ class BBoxEncoder(ROITargetMapper):
dims : np.ndarray
NumPy array containing the *scaled* dimensions, expected order is
[scaled_width, scaled_height].
position : Positions, optional
Overrides the default `position` configured for the encoder.
If provided, this position will be used instead of `self.position`
when reconstructing the bounding box.
Returns
-------
@ -316,6 +344,8 @@ class BBoxEncoder(ROITargetMapper):
ValueError
If `dims` does not have the expected shape (length 2).
"""
position = position or self.position
if dims.ndim != 1 or dims.shape[0] != 2:
raise ValueError(
"Dimension array does not have the expected shape. "

View File

@ -40,9 +40,7 @@ class TrainingModule(L.LightningModule):
self.learning_rate = learning_rate
self.t_max = t_max
# NOTE: Ignore detector and loss from hyperparameter saving
# as they are nn.Module and should be saved regardless.
self.save_hyperparameters(ignore=["detector", "loss"])
self.save_hyperparameters()
def forward(self, spec: torch.Tensor) -> ModelOutput:
return self.detector(spec)

View File

@ -24,7 +24,6 @@ from batdetect2.preprocess.spectrogram import (
get_spectrogram_resolution,
remove_spectral_mean,
resize_spectrogram,
scale_log,
scale_spectrogram,
stft,
)
@ -153,13 +152,13 @@ def test_stft_output_properties(sine_wave_xr: xr.DataArray):
assert np.isclose(time_step, hop_len / samplerate)
assert spec.frequency.min() >= 0
assert freq_start == 0
assert np.isclose(freq_end + freq_step, samplerate / 2, atol=5)
assert spec.time.min() >= 0
assert np.isclose(freq_end, samplerate / 2, atol=freq_step / 2)
assert np.isclose(spec.time.min(), 0)
assert spec.time.max() < DURATION
assert spec.attrs["original_samplerate"] == samplerate
assert spec.attrs["nfft"] == nfft
assert spec.attrs["noverlap"] == int(window_overlap * nfft)
assert spec.attrs["samplerate"] == samplerate
assert spec.attrs["window_size"] == window_duration
assert spec.attrs["hop_size"] == window_duration * (1 - window_overlap)
assert np.all(spec.data >= 0)
@ -192,7 +191,7 @@ def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray):
def test_crop_spectrogram_full_range(sample_spec: xr.DataArray):
samplerate = sample_spec.attrs["original_samplerate"]
samplerate = sample_spec.attrs["samplerate"]
min_f, max_f = 0, samplerate / 2
cropped_spec = crop_spectrogram_frequencies(
sample_spec, min_freq=min_f, max_freq=max_f
@ -227,33 +226,6 @@ def test_apply_pcen(sample_spec: xr.DataArray):
assert not np.allclose(pcen_spec.data, sample_spec.data)
def test_scale_log(sample_spec: xr.DataArray):
if "original_samplerate" not in sample_spec.attrs:
sample_spec.attrs["original_samplerate"] = SAMPLERATE
if "nfft" not in sample_spec.attrs:
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
log_spec = scale_log(sample_spec, dtype=np.float32)
assert log_spec.dims == sample_spec.dims
assert log_spec.sizes == sample_spec.sizes
assert log_spec.dtype == np.float32
assert np.all(log_spec.data >= 0)
assert not np.allclose(log_spec.data, sample_spec.data)
def test_scale_log_missing_attrs(sample_spec: xr.DataArray):
spec_copy = sample_spec.copy()
del spec_copy.attrs["original_samplerate"]
with pytest.raises(KeyError):
scale_log(spec_copy)
spec_copy = sample_spec.copy()
del spec_copy.attrs["nfft"]
with pytest.raises(KeyError):
scale_log(spec_copy)
def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray):
scaled_spec = scale_spectrogram(sample_spec, scale="amplitude")
assert np.allclose(scaled_spec.data, sample_spec.data)
@ -267,15 +239,9 @@ def test_scale_spectrogram_power(sample_spec: xr.DataArray):
def test_scale_spectrogram_db(sample_spec: xr.DataArray):
if "original_samplerate" not in sample_spec.attrs:
sample_spec.attrs["original_samplerate"] = SAMPLERATE
if "nfft" not in sample_spec.attrs:
sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE)
scaled_spec = scale_spectrogram(sample_spec, scale="dB", dtype=np.float64)
log_spec_expected = scale_log(sample_spec, dtype=np.float64)
assert scaled_spec.dtype == np.float64
assert np.allclose(scaled_spec.data, log_spec_expected.data)
scaled_spec = scale_spectrogram(sample_spec, scale="dB")
log_spec_expected = arrays.to_db(sample_spec)
xr.testing.assert_allclose(scaled_spec, log_spec_expected)
def test_remove_spectral_mean(sample_spec: xr.DataArray):
@ -291,8 +257,7 @@ def test_remove_spectral_mean(sample_spec: xr.DataArray):
def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray):
const_spec = stft(constant_wave_xr, 0.002, 0.5)
denoised_spec = remove_spectral_mean(const_spec)
assert np.allclose(denoised_spec.data, 0, atol=1e-6)
assert np.all(denoised_spec.data >= 0)
@pytest.mark.parametrize(
@ -324,8 +289,6 @@ def test_resize_spectrogram(
assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1
assert resized_spec.dtype == np.float32
def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig()
@ -377,7 +340,7 @@ def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize(
def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig(peak_normalize=True)
config = SpectrogramConfig(peak_normalize=True, pcen=None)
spec = compute_spectrogram(sine_wave_xr, config=config)
assert np.isclose(spec.data.max(), 1.0, atol=1e-6)
@ -443,20 +406,6 @@ def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray):
assert spec_builder.dtype == spec_direct.dtype
def test_configurable_spectrogram_builder_call_np(sine_wave_xr: xr.DataArray):
config = SpectrogramConfig()
builder = ConfigurableSpectrogramBuilder(config=config)
wav_np = sine_wave_xr.data
samplerate = sine_wave_xr.attrs["samplerate"]
spec_builder = builder(wav_np.astype(np.float32), samplerate=samplerate)
spec_direct = compute_spectrogram(sine_wave_xr, config=config)
assert isinstance(spec_builder, xr.DataArray)
assert np.allclose(spec_builder.data, spec_direct.data, atol=1e-4)
assert spec_builder.dtype == spec_direct.dtype
def test_configurable_spectrogram_builder_call_np_no_samplerate(
sine_wave_xr: xr.DataArray,
):

View File

@ -72,7 +72,7 @@ def test_term_registry_get_keys():
def test_get_term_from_key():
term = terms.get_term_from_key("call_type")
term = terms.get_term_from_key("event")
assert term == terms.call_type
custom_registry = TermRegistry()
@ -84,7 +84,7 @@ def test_get_term_from_key():
def test_get_term_keys():
keys = terms.get_term_keys()
assert "call_type" in keys
assert "event" in keys
assert "individual" in keys
assert terms.GENERIC_CLASS_KEY in keys
@ -96,7 +96,7 @@ def test_get_term_keys():
def test_tag_info_and_get_tag_from_info():
tag_info = TagInfo(value="Myotis myotis", key="call_type")
tag_info = TagInfo(value="Myotis myotis", key="event")
tag = terms.get_tag_from_info(tag_info)
assert tag.value == "Myotis myotis"
assert tag.term == terms.call_type
@ -161,7 +161,7 @@ def test_load_terms_from_config_key_already_exists(tmp_path):
config_data = {
"terms": [
{
"key": "call_type",
"key": "event",
"uri": "dwc:scientificName",
"label": "Scientific Name",
}, # Duplicate key

View File

@ -116,14 +116,15 @@ def test_selected_random_subclip_has_the_correct_width(
recording1 = create_recording()
clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7)
clip_annotation_1 = data.ClipAnnotation(clip=clip1)
original = generate_train_example(
clip_annotation_1,
preprocessor=sample_preprocessor,
labeller=sample_labeller,
)
subclip = select_subclip(original, start=0, span=100)
assert subclip["spectrogram"].shape[1] == 100
subclip = select_subclip(original, start=0, span=0.513)
assert subclip["spectrogram"].shape[1] == 512
def test_add_echo_after_subclip(
@ -142,7 +143,7 @@ def test_add_echo_after_subclip(
assert original.sizes["time"] > 512
subclip = select_subclip(original, start=0, span=512)
subclip = select_subclip(original, start=0, span=0.513)
with_echo = add_echo(subclip, preprocessor=sample_preprocessor)
assert with_echo.sizes["time"] == 512

View File

@ -53,4 +53,4 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip):
output1 = module(input1)
output2 = recovered(input2)
assert output1 == output2
torch.testing.assert_close(output1, output2)