diff --git a/batdetect2/postprocess/decoding.py b/batdetect2/postprocess/decoding.py index 2ba6e45..3f4611c 100644 --- a/batdetect2/postprocess/decoding.py +++ b/batdetect2/postprocess/decoding.py @@ -103,7 +103,7 @@ def convert_xr_dataset_to_raw_prediction( detections.append( RawPrediction( - detection_score=det_info.scores, + detection_score=det_info.score, geometry=geom, class_scores=det_info.classes, features=det_info.features, diff --git a/batdetect2/preprocess/audio.py b/batdetect2/preprocess/audio.py index c474645..d07afbe 100644 --- a/batdetect2/preprocess/audio.py +++ b/batdetect2/preprocess/audio.py @@ -370,35 +370,36 @@ def load_clip_audio( """ config = config or AudioConfig() - try: - wav = ( - audio.load_clip(clip, audio_dir=audio_dir) - .sel(channel=0) - .astype(dtype) - ) - except LibsndfileError as e: - raise FileNotFoundError( - f"Could not load the recording at path: {clip.recording.path}. " - f"Error: {e}" - ) from e + with xr.set_options(keep_attrs=True): + try: + wav = ( + audio.load_clip(clip, audio_dir=audio_dir) + .sel(channel=0) + .astype(dtype) + ) + except LibsndfileError as e: + raise FileNotFoundError( + f"Could not load the recording at path: {clip.recording.path}. " + f"Error: {e}" + ) from e - if config.resample: - wav = resample_audio( - wav, - samplerate=config.resample.samplerate, - dtype=dtype, - ) + if config.resample: + wav = resample_audio( + wav, + samplerate=config.resample.samplerate, + dtype=dtype, + ) - if config.center: - wav = ops.center(wav) + if config.center: + wav = ops.center(wav) - if config.scale: - wav = scale_audio(wav) + if config.scale: + wav = scale_audio(wav) - if config.duration is not None: - wav = adjust_audio_duration(wav, duration=config.duration) + if config.duration is not None: + wav = adjust_audio_duration(wav, duration=config.duration) - return wav.astype(dtype) + return wav.astype(dtype) def scale_audio( @@ -521,7 +522,7 @@ def resample_audio( original_samplerate = int(1 / step) if original_samplerate == samplerate: - return wav.astype(dtype) + return wav.astype(dtype).assign_attrs(original_samplerate=samplerate) if method == "poly": resampled = resample_audio_poly( @@ -561,7 +562,11 @@ def resample_audio( samplerate=samplerate, ), }, - attrs={**wav.attrs, "samplerate": samplerate}, + attrs={ + **wav.attrs, + "samplerate": samplerate, + "original_samplerate": original_samplerate, + }, ) diff --git a/batdetect2/preprocess/spectrogram.py b/batdetect2/preprocess/spectrogram.py index 5d4fe17..7b20a92 100644 --- a/batdetect2/preprocess/spectrogram.py +++ b/batdetect2/preprocess/spectrogram.py @@ -363,7 +363,7 @@ def compute_spectrogram( ) if config.peak_normalize: - spec = ops.scale(spec, 1 / (10e-6 + np.max(spec))) + spec = ops.normalize(spec) return spec.astype(dtype) @@ -436,6 +436,9 @@ def stft( ValueError If sample rate cannot be determined from `wave` coordinates. """ + if "channel" not in wave.coords: + wave = wave.assign_coords(channel=0) + return audio.compute_spectrogram( wave, window_size=window_duration, @@ -544,7 +547,7 @@ def apply_pcen( verified against the specific `soundevent.audio.pcen` implementation details. """ - samplerate = spec.attrs["samplerate"] + samplerate = 1 / spec.time.attrs["step"] hop_size = spec.attrs["hop_size"] hop_length = int(hop_size * samplerate) @@ -622,6 +625,7 @@ def resize_spectrogram( spec: xr.DataArray, height: int = 128, resize_factor: Optional[float] = 0.5, + dtype: DTypeLike = np.float32, # type: ignore ) -> xr.DataArray: """Resize a spectrogram to target dimensions using interpolation. @@ -647,11 +651,26 @@ def resize_spectrogram( """ resize_factor = resize_factor or 1 current_width = spec.sizes["time"] - return ops.resize( - spec, - time=int(resize_factor * current_width), - frequency=height, - dtype=np.float32, + + target_sizes = { + "time": int(current_width * resize_factor), + "frequency": height, + } + + new_coords = {} + for dim in ["time", "frequency"]: + step = arrays.get_dim_step(spec, dim) + start, stop = arrays.get_dim_range(spec, dim) + new_coords[dim] = arrays.create_range_dim( + name=dim, + start=start, + stop=stop + step, + size=target_sizes[dim], + dtype=dtype, + ) + + return spec.interp( + coords=new_coords, method="linear", kwargs=dict(fill_value=0) ) diff --git a/batdetect2/targets/rois.py b/batdetect2/targets/rois.py index 1a17949..bd05397 100644 --- a/batdetect2/targets/rois.py +++ b/batdetect2/targets/rois.py @@ -93,13 +93,21 @@ class ROITargetMapper(Protocol): dimension_names: List[str] - def get_roi_position(self, geom: data.Geometry) -> tuple[float, float]: + def get_roi_position( + self, + geom: data.Geometry, + position: Optional[Positions] = None, + ) -> tuple[float, float]: """Extract the reference position from a geometry. Parameters ---------- geom : soundevent.data.Geometry The input geometry (e.g., BoundingBox, Polygon). + position : Positions, optional + Overrides the default `position` configured for the mapper. + If provided, this position will be used instead of the mapper's + internal default. Returns ------- @@ -141,7 +149,10 @@ class ROITargetMapper(Protocol): ... def recover_roi( - self, pos: tuple[float, float], dims: np.ndarray + self, + pos: tuple[float, float], + dims: np.ndarray, + position: Optional[Positions] = None, ) -> data.Geometry: """Recover an approximate ROI from a position and target dimensions. @@ -153,8 +164,12 @@ class ROITargetMapper(Protocol): pos : Tuple[float, float] The reference position (time, frequency). dims : np.ndarray - The NumPy array containing the dimensions, matching the order + NumPy array containing the dimensions, matching the order specified by `dimension_names`. + position : Positions, optional + Overrides the default `position` configured for the mapper. + If provided, this position will be used instead of the mapper's + internal default when reconstructing the roi geometry. Returns ------- @@ -240,7 +255,11 @@ class BBoxEncoder(ROITargetMapper): self.time_scale = time_scale self.frequency_scale = frequency_scale - def get_roi_position(self, geom: data.Geometry) -> Tuple[float, float]: + def get_roi_position( + self, + geom: data.Geometry, + position: Optional[Positions] = None, + ) -> Tuple[float, float]: """Extract the configured reference position from the geometry. Uses `soundevent.geometry.get_geometry_point`. @@ -249,6 +268,9 @@ class BBoxEncoder(ROITargetMapper): ---------- geom : soundevent.data.Geometry Input geometry (e.g., BoundingBox). + position : Positions, optional + Overrides the default `position` configured for the encoder. + If provided, this position will be used instead of `self.position`. Returns ------- @@ -257,7 +279,8 @@ class BBoxEncoder(ROITargetMapper): """ from soundevent import geometry - return geometry.get_geometry_point(geom, position=self.position) + position = position or self.position + return geometry.get_geometry_point(geom, position=position) def get_roi_size(self, geom: data.Geometry) -> np.ndarray: """Calculate the scaled [width, height] from the geometry's bounds. @@ -291,6 +314,7 @@ class BBoxEncoder(ROITargetMapper): self, pos: tuple[float, float], dims: np.ndarray, + position: Optional[Positions] = None, ) -> data.Geometry: """Recover a BoundingBox from a position and scaled dimensions. @@ -305,6 +329,10 @@ class BBoxEncoder(ROITargetMapper): dims : np.ndarray NumPy array containing the *scaled* dimensions, expected order is [scaled_width, scaled_height]. + position : Positions, optional + Overrides the default `position` configured for the encoder. + If provided, this position will be used instead of `self.position` + when reconstructing the bounding box. Returns ------- @@ -316,6 +344,8 @@ class BBoxEncoder(ROITargetMapper): ValueError If `dims` does not have the expected shape (length 2). """ + position = position or self.position + if dims.ndim != 1 or dims.shape[0] != 2: raise ValueError( "Dimension array does not have the expected shape. " diff --git a/batdetect2/train/lightning.py b/batdetect2/train/lightning.py index 5080a80..c88c7d2 100644 --- a/batdetect2/train/lightning.py +++ b/batdetect2/train/lightning.py @@ -40,9 +40,7 @@ class TrainingModule(L.LightningModule): self.learning_rate = learning_rate self.t_max = t_max - # NOTE: Ignore detector and loss from hyperparameter saving - # as they are nn.Module and should be saved regardless. - self.save_hyperparameters(ignore=["detector", "loss"]) + self.save_hyperparameters() def forward(self, spec: torch.Tensor) -> ModelOutput: return self.detector(spec) diff --git a/tests/test_preprocessing/test_spectrogram.py b/tests/test_preprocessing/test_spectrogram.py index de8ca78..5c5beff 100644 --- a/tests/test_preprocessing/test_spectrogram.py +++ b/tests/test_preprocessing/test_spectrogram.py @@ -24,7 +24,6 @@ from batdetect2.preprocess.spectrogram import ( get_spectrogram_resolution, remove_spectral_mean, resize_spectrogram, - scale_log, scale_spectrogram, stft, ) @@ -153,13 +152,13 @@ def test_stft_output_properties(sine_wave_xr: xr.DataArray): assert np.isclose(time_step, hop_len / samplerate) assert spec.frequency.min() >= 0 assert freq_start == 0 - assert np.isclose(freq_end + freq_step, samplerate / 2, atol=5) - assert spec.time.min() >= 0 + assert np.isclose(freq_end, samplerate / 2, atol=freq_step / 2) + assert np.isclose(spec.time.min(), 0) assert spec.time.max() < DURATION - assert spec.attrs["original_samplerate"] == samplerate - assert spec.attrs["nfft"] == nfft - assert spec.attrs["noverlap"] == int(window_overlap * nfft) + assert spec.attrs["samplerate"] == samplerate + assert spec.attrs["window_size"] == window_duration + assert spec.attrs["hop_size"] == window_duration * (1 - window_overlap) assert np.all(spec.data >= 0) @@ -192,7 +191,7 @@ def test_crop_spectrogram_frequencies(sample_spec: xr.DataArray): def test_crop_spectrogram_full_range(sample_spec: xr.DataArray): - samplerate = sample_spec.attrs["original_samplerate"] + samplerate = sample_spec.attrs["samplerate"] min_f, max_f = 0, samplerate / 2 cropped_spec = crop_spectrogram_frequencies( sample_spec, min_freq=min_f, max_freq=max_f @@ -227,33 +226,6 @@ def test_apply_pcen(sample_spec: xr.DataArray): assert not np.allclose(pcen_spec.data, sample_spec.data) -def test_scale_log(sample_spec: xr.DataArray): - if "original_samplerate" not in sample_spec.attrs: - sample_spec.attrs["original_samplerate"] = SAMPLERATE - if "nfft" not in sample_spec.attrs: - sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE) - - log_spec = scale_log(sample_spec, dtype=np.float32) - - assert log_spec.dims == sample_spec.dims - assert log_spec.sizes == sample_spec.sizes - assert log_spec.dtype == np.float32 - assert np.all(log_spec.data >= 0) - assert not np.allclose(log_spec.data, sample_spec.data) - - -def test_scale_log_missing_attrs(sample_spec: xr.DataArray): - spec_copy = sample_spec.copy() - del spec_copy.attrs["original_samplerate"] - with pytest.raises(KeyError): - scale_log(spec_copy) - - spec_copy = sample_spec.copy() - del spec_copy.attrs["nfft"] - with pytest.raises(KeyError): - scale_log(spec_copy) - - def test_scale_spectrogram_amplitude(sample_spec: xr.DataArray): scaled_spec = scale_spectrogram(sample_spec, scale="amplitude") assert np.allclose(scaled_spec.data, sample_spec.data) @@ -267,15 +239,9 @@ def test_scale_spectrogram_power(sample_spec: xr.DataArray): def test_scale_spectrogram_db(sample_spec: xr.DataArray): - if "original_samplerate" not in sample_spec.attrs: - sample_spec.attrs["original_samplerate"] = SAMPLERATE - if "nfft" not in sample_spec.attrs: - sample_spec.attrs["nfft"] = int(0.002 * SAMPLERATE) - - scaled_spec = scale_spectrogram(sample_spec, scale="dB", dtype=np.float64) - log_spec_expected = scale_log(sample_spec, dtype=np.float64) - assert scaled_spec.dtype == np.float64 - assert np.allclose(scaled_spec.data, log_spec_expected.data) + scaled_spec = scale_spectrogram(sample_spec, scale="dB") + log_spec_expected = arrays.to_db(sample_spec) + xr.testing.assert_allclose(scaled_spec, log_spec_expected) def test_remove_spectral_mean(sample_spec: xr.DataArray): @@ -291,8 +257,7 @@ def test_remove_spectral_mean(sample_spec: xr.DataArray): def test_remove_spectral_mean_constant(constant_wave_xr: xr.DataArray): const_spec = stft(constant_wave_xr, 0.002, 0.5) denoised_spec = remove_spectral_mean(const_spec) - - assert np.allclose(denoised_spec.data, 0, atol=1e-6) + assert np.all(denoised_spec.data >= 0) @pytest.mark.parametrize( @@ -324,8 +289,6 @@ def test_resize_spectrogram( assert abs(resized_spec.sizes["time"] - expected_time_size) <= 1 - assert resized_spec.dtype == np.float32 - def test_compute_spectrogram_defaults(sine_wave_xr: xr.DataArray): config = SpectrogramConfig() @@ -377,7 +340,7 @@ def test_compute_spectrogram_no_pcen_no_mean_sub_no_resize( def test_compute_spectrogram_peak_normalize(sine_wave_xr: xr.DataArray): - config = SpectrogramConfig(peak_normalize=True) + config = SpectrogramConfig(peak_normalize=True, pcen=None) spec = compute_spectrogram(sine_wave_xr, config=config) assert np.isclose(spec.data.max(), 1.0, atol=1e-6) @@ -443,20 +406,6 @@ def test_configurable_spectrogram_builder_call_xr(sine_wave_xr: xr.DataArray): assert spec_builder.dtype == spec_direct.dtype -def test_configurable_spectrogram_builder_call_np(sine_wave_xr: xr.DataArray): - config = SpectrogramConfig() - builder = ConfigurableSpectrogramBuilder(config=config) - wav_np = sine_wave_xr.data - samplerate = sine_wave_xr.attrs["samplerate"] - - spec_builder = builder(wav_np.astype(np.float32), samplerate=samplerate) - spec_direct = compute_spectrogram(sine_wave_xr, config=config) - - assert isinstance(spec_builder, xr.DataArray) - assert np.allclose(spec_builder.data, spec_direct.data, atol=1e-4) - assert spec_builder.dtype == spec_direct.dtype - - def test_configurable_spectrogram_builder_call_np_no_samplerate( sine_wave_xr: xr.DataArray, ): diff --git a/tests/test_targets/test_terms.py b/tests/test_targets/test_terms.py index a717012..8093521 100644 --- a/tests/test_targets/test_terms.py +++ b/tests/test_targets/test_terms.py @@ -72,7 +72,7 @@ def test_term_registry_get_keys(): def test_get_term_from_key(): - term = terms.get_term_from_key("call_type") + term = terms.get_term_from_key("event") assert term == terms.call_type custom_registry = TermRegistry() @@ -84,7 +84,7 @@ def test_get_term_from_key(): def test_get_term_keys(): keys = terms.get_term_keys() - assert "call_type" in keys + assert "event" in keys assert "individual" in keys assert terms.GENERIC_CLASS_KEY in keys @@ -96,7 +96,7 @@ def test_get_term_keys(): def test_tag_info_and_get_tag_from_info(): - tag_info = TagInfo(value="Myotis myotis", key="call_type") + tag_info = TagInfo(value="Myotis myotis", key="event") tag = terms.get_tag_from_info(tag_info) assert tag.value == "Myotis myotis" assert tag.term == terms.call_type @@ -161,7 +161,7 @@ def test_load_terms_from_config_key_already_exists(tmp_path): config_data = { "terms": [ { - "key": "call_type", + "key": "event", "uri": "dwc:scientificName", "label": "Scientific Name", }, # Duplicate key diff --git a/tests/test_train/test_augmentations.py b/tests/test_train/test_augmentations.py index 78a6251..65579a7 100644 --- a/tests/test_train/test_augmentations.py +++ b/tests/test_train/test_augmentations.py @@ -116,14 +116,15 @@ def test_selected_random_subclip_has_the_correct_width( recording1 = create_recording() clip1 = data.Clip(recording=recording1, start_time=0.2, end_time=0.7) clip_annotation_1 = data.ClipAnnotation(clip=clip1) + original = generate_train_example( clip_annotation_1, preprocessor=sample_preprocessor, labeller=sample_labeller, ) - subclip = select_subclip(original, start=0, span=100) - assert subclip["spectrogram"].shape[1] == 100 + subclip = select_subclip(original, start=0, span=0.513) + assert subclip["spectrogram"].shape[1] == 512 def test_add_echo_after_subclip( @@ -142,7 +143,7 @@ def test_add_echo_after_subclip( assert original.sizes["time"] > 512 - subclip = select_subclip(original, start=0, span=512) + subclip = select_subclip(original, start=0, span=0.513) with_echo = add_echo(subclip, preprocessor=sample_preprocessor) assert with_echo.sizes["time"] == 512 diff --git a/tests/test_train/test_lightning.py b/tests/test_train/test_lightning.py index 6ce7b71..fb635d4 100644 --- a/tests/test_train/test_lightning.py +++ b/tests/test_train/test_lightning.py @@ -53,4 +53,4 @@ def test_can_save_checkpoint(tmp_path: Path, clip: data.Clip): output1 = module(input1) output2 = recovered(input2) - assert output1 == output2 + torch.testing.assert_close(output1, output2)