import uuid from pathlib import Path from typing import Callable, List, Optional from uuid import uuid4 import lightning as L import numpy as np import pytest import soundfile as sf from scipy import signal from soundevent import data, terms from batdetect2.audio import build_audio_loader from batdetect2.audio.clips import build_clipper from batdetect2.audio.types import AudioLoader, ClipperProtocol from batdetect2.config import BatDetect2Config from batdetect2.data import DatasetConfig, load_dataset from batdetect2.data.annotations.batdetect2 import BatDetect2FilesAnnotations from batdetect2.preprocess import build_preprocessor from batdetect2.preprocess.types import PreprocessorProtocol from batdetect2.targets import ( TargetConfig, build_targets, call_type, ) from batdetect2.targets.classes import TargetClassConfig from batdetect2.targets.types import TargetProtocol from batdetect2.train.labels import build_clip_labeler from batdetect2.train.lightning import build_training_module from batdetect2.train.types import ClipLabeller @pytest.fixture def example_data_dir() -> Path: pkg_dir = Path(__file__).parent.parent example_data_dir = pkg_dir / "example_data" assert example_data_dir.exists() return example_data_dir @pytest.fixture def example_audio_dir(example_data_dir: Path) -> Path: example_audio_dir = example_data_dir / "audio" assert example_audio_dir.exists() return example_audio_dir @pytest.fixture def example_anns_dir(example_data_dir: Path) -> Path: example_anns_dir = example_data_dir / "anns" assert example_anns_dir.exists() return example_anns_dir @pytest.fixture def example_audio_files(example_audio_dir: Path) -> List[Path]: audio_files = list(example_audio_dir.glob("*.[wW][aA][vV]")) assert len(audio_files) == 3 return audio_files @pytest.fixture def data_dir() -> Path: dir = Path(__file__).parent / "data" assert dir.exists() return dir @pytest.fixture def contrib_dir(data_dir) -> Path: dir = data_dir / "contrib" assert dir.exists() return dir @pytest.fixture def wav_factory(tmp_path: Path): def _wav_factory( path: Optional[Path] = None, duration: float = 0.3, channels: int = 1, samplerate: int = 441_000, bit_depth: int = 16, ) -> Path: path = path or tmp_path / f"{uuid.uuid4()}.wav" frames = int(samplerate * duration) shape = (frames, channels) subtype = f"PCM_{bit_depth}" if bit_depth == 16: dtype = np.int16 elif bit_depth == 32: dtype = np.int32 else: raise ValueError(f"Unsupported bit depth: {bit_depth}") wav = np.random.uniform( low=np.iinfo(dtype).min, high=np.iinfo(dtype).max, size=shape, ).astype(dtype) sf.write(str(path), wav, samplerate, subtype=subtype) return path return _wav_factory @pytest.fixture def create_recording(wav_factory: Callable[..., Path]): def factory( tags: Optional[list[data.Tag]] = None, path: Optional[Path] = None, recording_id: Optional[uuid.UUID] = None, duration: float = 1, channels: int = 1, samplerate: int = 256_000, time_expansion: float = 1, ) -> data.Recording: path = wav_factory( path=path, duration=duration, channels=channels, samplerate=samplerate, ) return data.Recording.from_file( path=path, uuid=recording_id or uuid.uuid4(), time_expansion=time_expansion, tags=tags or [], ) return factory @pytest.fixture def generate_whistle(tmp_path: Path): """ Pytest fixture that provides a factory for generating WAV audio files. The factory creates a recording containing a "whistle" (a short, frequency-specific pulse) positioned at a precise time, suitable for testing audio analysis functions. """ def factory( time: float, frequency: int, path: Optional[Path] = None, duration: float = 0.3, samplerate: int = 441_000, whistle_duration: float = 0.1, ) -> Path: path = path or tmp_path / f"{uuid.uuid4()}.wav" frames = int(samplerate * duration) offset = int((time - duration / 2) * samplerate) t = np.linspace(-duration / 2, duration / 2, frames, endpoint=False) data = signal.gausspulse( t, fc=frequency, bw=2 / (frequency * whistle_duration), ) wave = (np.roll(data, offset) * np.iinfo(np.int16).max).astype( np.int16 ) sf.write(str(path), wave, samplerate, subtype="PCM_16") return path return factory @pytest.fixture def recording( create_recording: Callable[..., data.Recording], ) -> data.Recording: return create_recording() @pytest.fixture def create_clip(): def factory( recording: data.Recording, start_time: float = 0, end_time: float = 0.5, ) -> data.Clip: return data.Clip( recording=recording, start_time=start_time, end_time=end_time, ) return factory @pytest.fixture def clip(recording: data.Recording) -> data.Clip: return data.Clip(recording=recording, start_time=0, end_time=0.5) @pytest.fixture def create_sound_event(): def factory( recording: data.Recording, coords: Optional[List[float]] = None, ) -> data.SoundEvent: coords = coords or [0.2, 60_000, 0.3, 70_000] return data.SoundEvent( geometry=data.BoundingBox(coordinates=coords), recording=recording, ) return factory @pytest.fixture def sound_event(recording: data.Recording) -> data.SoundEvent: return data.SoundEvent( geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]), recording=recording, ) @pytest.fixture def create_sound_event_annotation(): def factory( sound_event: data.SoundEvent, tags: Optional[List[data.Tag]] = None, ) -> data.SoundEventAnnotation: return data.SoundEventAnnotation( sound_event=sound_event, tags=tags or [], ) return factory @pytest.fixture def echolocation_call(recording: data.Recording) -> data.SoundEventAnnotation: return data.SoundEventAnnotation( sound_event=data.SoundEvent( geometry=data.BoundingBox(coordinates=[0.1, 67_000, 0.11, 73_000]), recording=recording, ), tags=[ data.Tag(term=terms.scientific_name, value="Myotis myotis"), data.Tag(term=call_type, value="Echolocation"), ], ) @pytest.fixture def generic_call(recording: data.Recording) -> data.SoundEventAnnotation: return data.SoundEventAnnotation( sound_event=data.SoundEvent( geometry=data.BoundingBox( coordinates=[0.34, 35_000, 0.348, 62_000] ), recording=recording, ), tags=[ data.Tag(term=terms.order, value="Chiroptera"), data.Tag(term=call_type, value="Echolocation"), ], ) @pytest.fixture def non_relevant_sound_event( recording: data.Recording, ) -> data.SoundEventAnnotation: return data.SoundEventAnnotation( sound_event=data.SoundEvent( geometry=data.BoundingBox( coordinates=[0.22, 50_000, 0.24, 58_000] ), recording=recording, ), tags=[ data.Tag( term=terms.scientific_name, value="Muscardinus avellanarius", ), ], ) @pytest.fixture def create_clip_annotation(): def factory( clip: data.Clip, clip_tags: Optional[List[data.Tag]] = None, sound_events: Optional[List[data.SoundEventAnnotation]] = None, ) -> data.ClipAnnotation: return data.ClipAnnotation( clip=clip, tags=clip_tags or [], sound_events=sound_events or [], ) return factory @pytest.fixture def clip_annotation( clip: data.Clip, echolocation_call: data.SoundEventAnnotation, generic_call: data.SoundEventAnnotation, non_relevant_sound_event: data.SoundEventAnnotation, ) -> data.ClipAnnotation: return data.ClipAnnotation( clip=clip, sound_events=[ echolocation_call, generic_call, non_relevant_sound_event, ], ) @pytest.fixture def create_annotation_set(): def factory( name: str = "test", description: str = "Test annotation set", annotations: Optional[List[data.ClipAnnotation]] = None, ) -> data.AnnotationSet: return data.AnnotationSet( name=name, description=description, clip_annotations=annotations or [], ) return factory @pytest.fixture def create_annotation_project(): def factory( name: str = "test_project", description: str = "Test Annotation Project", tasks: Optional[List[data.AnnotationTask]] = None, annotations: Optional[List[data.ClipAnnotation]] = None, ) -> data.AnnotationProject: return data.AnnotationProject( name=name, description=description, tasks=tasks or [], clip_annotations=annotations or [], ) return factory @pytest.fixture def sample_preprocessor() -> PreprocessorProtocol: return build_preprocessor() @pytest.fixture def sample_audio_loader() -> AudioLoader: return build_audio_loader() @pytest.fixture def bat_tag() -> data.Tag: return data.Tag(key="class", value="bat") @pytest.fixture def noise_tag() -> data.Tag: return data.Tag(key="class", value="noise") @pytest.fixture def myomyo_tag() -> data.Tag: return data.Tag(key="species", value="Myotis myotis") @pytest.fixture def pippip_tag() -> data.Tag: return data.Tag(key="species", value="Pipistrellus pipistrellus") @pytest.fixture def sample_target_config( bat_tag: data.Tag, myomyo_tag: data.Tag, pippip_tag: data.Tag, ) -> TargetConfig: return TargetConfig( detection_target=TargetClassConfig(name="bat", tags=[bat_tag]), classification_targets=[ TargetClassConfig(name="pippip", tags=[pippip_tag]), TargetClassConfig(name="myomyo", tags=[myomyo_tag]), ], ) @pytest.fixture def sample_targets( sample_target_config: TargetConfig, ) -> TargetProtocol: return build_targets(sample_target_config) @pytest.fixture def sample_labeller( sample_targets: TargetProtocol, sample_preprocessor: PreprocessorProtocol, ) -> ClipLabeller: return build_clip_labeler( sample_targets, min_freq=sample_preprocessor.min_freq, max_freq=sample_preprocessor.max_freq, ) @pytest.fixture def sample_clipper() -> ClipperProtocol: return build_clipper() @pytest.fixture def example_dataset(example_data_dir: Path) -> DatasetConfig: return DatasetConfig( name="test dataset", description="test dataset", sources=[ BatDetect2FilesAnnotations( name="example annotations", audio_dir=example_data_dir / "audio", annotations_dir=example_data_dir / "anns", ) ], ) @pytest.fixture def example_annotations( example_dataset: DatasetConfig, ) -> List[data.ClipAnnotation]: annotations = load_dataset(example_dataset) assert len(annotations) == 3 return list(annotations) @pytest.fixture def create_temp_yaml(tmp_path: Path) -> Callable[[str], Path]: """Create a temporary YAML file with the given content.""" def factory(content: str) -> Path: temp_file = tmp_path / f"{uuid4()}.yaml" temp_file.write_text(content) return temp_file return factory @pytest.fixture def tiny_checkpoint_path(tmp_path: Path) -> Path: module = build_training_module(model_config=BatDetect2Config().model) trainer = L.Trainer(enable_checkpointing=False, logger=False) checkpoint_path = tmp_path / "model.ckpt" trainer.strategy.connect(module) trainer.save_checkpoint(checkpoint_path) return checkpoint_path @pytest.fixture def single_audio_dir(tmp_path: Path, example_audio_files: List[Path]) -> Path: audio_dir = tmp_path / "audio" audio_dir.mkdir() source = example_audio_files[0] target = audio_dir / source.name target.write_bytes(source.read_bytes()) return audio_dir