Compare commits

...

10 Commits

Author SHA1 Message Date
mbsantiago
4303d4e42d Fix logging imports 2026-04-23 19:26:45 +01:00
mbsantiago
da113eaea8 Add path_in_list condition 2026-04-04 10:44:35 +01:00
mbsantiago
1579bbc6c5 Add csv list format 2026-04-04 10:23:14 +01:00
mbsantiago
c67d9cbba0 Sort summary by class name 2026-04-03 17:13:12 +01:00
mbsantiago
00961132a9 Improve test suite for conditions 2026-04-03 17:07:26 +01:00
mbsantiago
e04d86808d Add clip annotation filtering to data loading 2026-04-03 16:40:23 +01:00
mbsantiago
c8dd4155bf Add conditions for clips and recordings 2026-04-03 16:40:11 +01:00
mbsantiago
e80fe8675d Mark some tests as slow for quicker feedback 2026-03-29 15:55:24 +01:00
mbsantiago
c24056214c Make sure api_v2 loads fast 2026-03-29 15:10:18 +01:00
mbsantiago
6d09133dca Minor fixes 2026-03-29 14:30:14 +01:00
33 changed files with 2742 additions and 971 deletions

View File

@ -20,7 +20,15 @@ install:
# Testing & Coverage
# Run tests using pytest.
test:
uv run pytest -n auto {{TESTS_DIR}}
uv run pytest {{TESTS_DIR}}
# Run the fast subset of tests (excludes @pytest.mark.slow).
test-quick:
uv run pytest --durations=10 -m "not slow" {{TESTS_DIR}}
# Run only long-running tests marked with @pytest.mark.slow.
test-slow:
uv run pytest -m "slow" {{TESTS_DIR}}
# Run tests and generate coverage data.
coverage:

View File

@ -95,6 +95,9 @@ mlflow = ["mlflow>=3.1.1"]
gradio = [
"gradio>=6.9.0",
]
dvc = [
"dvclive>=3.49.0",
]
[tool.ruff]
line-length = 79
@ -126,3 +129,8 @@ exclude = [
"src/batdetect2/finetune",
"src/batdetect2/utils",
]
[tool.pytest.ini_options]
markers = [
"slow: marks long-running tests that are skipped in quick runs",
]

View File

@ -1,68 +1,46 @@
from __future__ import annotations
from pathlib import Path
from typing import Literal, Sequence, cast
from typing import TYPE_CHECKING, Literal
import numpy as np
import torch
from soundevent import data
from soundevent.audio.files import get_audio_files
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
from batdetect2.config import BatDetect2Config
from batdetect2.data import Dataset, load_dataset_from_config
from batdetect2.evaluate import (
DEFAULT_EVAL_DIR,
EvaluationConfig,
EvaluatorProtocol,
build_evaluator,
run_evaluate,
save_evaluation_results,
)
from batdetect2.inference import (
InferenceConfig,
process_file_list,
run_batch_inference,
)
from batdetect2.logging import (
DEFAULT_LOGS_DIR,
AppLoggingConfig,
LoggerConfig,
)
from batdetect2.models import (
Model,
ModelConfig,
build_model,
build_model_with_new_targets,
)
from batdetect2.models.detectors import Detector
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
OutputsConfig,
OutputTransformProtocol,
build_output_formatter,
build_output_transform,
get_output_formatter,
)
from batdetect2.postprocess import (
ClipDetections,
Detection,
PostprocessorProtocol,
build_postprocessor,
)
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
TargetProtocol,
build_roi_mapping,
build_targets,
)
from batdetect2.train import (
DEFAULT_CHECKPOINT_DIR,
TrainingConfig,
load_model_from_checkpoint,
run_train,
)
if TYPE_CHECKING:
from collections.abc import Sequence
import torch
from batdetect2.audio import AudioConfig, AudioLoader
from batdetect2.config import BatDetect2Config
from batdetect2.data import Dataset
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig, LoggerConfig
from batdetect2.models import Model, ModelConfig
from batdetect2.outputs import (
OutputFormatConfig,
OutputFormatterProtocol,
OutputsConfig,
OutputTransformProtocol,
)
from batdetect2.postprocess import (
ClipDetections,
Detection,
PostprocessorProtocol,
)
from batdetect2.preprocess import PreprocessorProtocol
from batdetect2.targets import (
ROIMapperProtocol,
TargetConfig,
TargetProtocol,
)
from batdetect2.train import TrainingConfig
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
class BatDetect2API:
@ -109,6 +87,8 @@ class BatDetect2API:
path: data.PathLike,
base_dir: data.PathLike | None = None,
) -> Dataset:
from batdetect2.data import load_dataset_from_config
return load_dataset_from_config(path, base_dir=base_dir)
def train(
@ -128,6 +108,8 @@ class BatDetect2API:
train_config: TrainingConfig | None = None,
logger_config: LoggerConfig | None = None,
):
from batdetect2.train import run_train
run_train(
train_annotations=train_annotations,
val_annotations=val_annotations,
@ -172,6 +154,7 @@ class BatDetect2API:
logger_config: LoggerConfig | None = None,
) -> "BatDetect2API":
"""Fine-tune the model with trainable-parameter selection."""
from batdetect2.train import run_train
self._set_trainable_parameters(trainable)
@ -211,6 +194,8 @@ class BatDetect2API:
outputs_config: OutputsConfig | None = None,
logger_config: LoggerConfig | None = None,
) -> tuple[dict[str, float], list[ClipDetections]]:
from batdetect2.evaluate import run_evaluate
return run_evaluate(
self.model,
test_annotations,
@ -235,6 +220,8 @@ class BatDetect2API:
predictions: Sequence[ClipDetections],
output_dir: data.PathLike | None = None,
):
from batdetect2.evaluate import save_evaluation_results
clip_evals = self.evaluator.evaluate(
annotations,
predictions,
@ -307,6 +294,8 @@ class BatDetect2API:
self,
audio: np.ndarray,
) -> torch.Tensor:
import torch
tensor = torch.tensor(audio).unsqueeze(0)
return self.preprocessor(tensor)
@ -316,6 +305,8 @@ class BatDetect2API:
batch_size: int | None = None,
detection_threshold: float | None = None,
) -> ClipDetections:
from batdetect2.postprocess import ClipDetections
recording = data.Recording.from_file(audio_file, compute_hash=False)
predictions = self.process_files(
@ -382,6 +373,8 @@ class BatDetect2API:
audio_dir: data.PathLike,
detection_threshold: float | None = None,
) -> list[ClipDetections]:
from soundevent.audio.files import get_audio_files
files = list(get_audio_files(audio_dir))
return self.process_files(
files,
@ -398,6 +391,8 @@ class BatDetect2API:
output_config: OutputsConfig | None = None,
detection_threshold: float | None = None,
) -> list[ClipDetections]:
from batdetect2.inference import process_file_list
return process_file_list(
self.model,
audio_files,
@ -424,6 +419,8 @@ class BatDetect2API:
output_config: OutputsConfig | None = None,
detection_threshold: float | None = None,
) -> list[ClipDetections]:
from batdetect2.inference import run_batch_inference
return run_batch_inference(
self.model,
clips,
@ -448,6 +445,8 @@ class BatDetect2API:
format: str | None = None,
config: OutputFormatConfig | None = None,
):
from batdetect2.outputs import get_output_formatter
formatter = self.formatter
if format is not None or config is not None:
@ -467,6 +466,8 @@ class BatDetect2API:
format: str | None = None,
config: OutputFormatConfig | None = None,
) -> list[object]:
from batdetect2.outputs import get_output_formatter
formatter = self.formatter
if format is not None or config is not None:
@ -484,6 +485,17 @@ class BatDetect2API:
cls,
config: BatDetect2Config,
) -> "BatDetect2API":
from batdetect2.audio import build_audio_loader
from batdetect2.evaluate import build_evaluator
from batdetect2.models import build_model
from batdetect2.outputs import (
build_output_formatter,
build_output_transform,
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
targets = build_targets(config=config.model.targets)
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
@ -563,6 +575,21 @@ class BatDetect2API:
outputs_config: OutputsConfig | None = None,
logging_config: AppLoggingConfig | None = None,
) -> "BatDetect2API":
from batdetect2.audio import AudioConfig, build_audio_loader
from batdetect2.evaluate import EvaluationConfig, build_evaluator
from batdetect2.inference import InferenceConfig
from batdetect2.logging import AppLoggingConfig
from batdetect2.models import build_model_with_new_targets
from batdetect2.outputs import (
OutputsConfig,
build_output_formatter,
build_output_transform,
)
from batdetect2.postprocess import build_postprocessor
from batdetect2.preprocess import build_preprocessor
from batdetect2.targets import build_roi_mapping, build_targets
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
model, model_config = load_model_from_checkpoint(path)
audio_config = audio_config or AudioConfig(
@ -645,7 +672,7 @@ class BatDetect2API:
self,
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
) -> None:
detector = cast(Detector, self.model.detector)
detector = self.model.detector
for parameter in detector.parameters():
parameter.requires_grad = False

View File

@ -2,10 +2,6 @@
import click
from batdetect2.logging import enable_logging
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
__all__ = [
"cli",
]
@ -34,5 +30,7 @@ def cli(verbose: int = 0):
"""
click.echo(INFO_STR)
from batdetect2.logging import enable_logging
enable_logging(verbose)
# click.echo(BATDETECT_ASCII_ART)

View File

@ -73,7 +73,7 @@ def summary(
summary = compute_class_summary(dataset, targets)
print(summary.to_markdown())
print(summary.sort_values("class_name").to_markdown())
@data.command(short_help="Convert dataset config to annotation set.")

View File

@ -4,8 +4,6 @@ from typing import TYPE_CHECKING
import click
from loguru import logger
from soundevent import io
from soundevent.audio.files import get_audio_files
from batdetect2.cli.base import cli
@ -219,6 +217,8 @@ def predict_directory_command(
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
inference, and saves predictions to `output_path`.
"""
from soundevent.audio.files import get_audio_files
audio_files = list(get_audio_files(audio_dir))
_run_prediction(
model_path=model_path,
@ -309,6 +309,8 @@ def predict_dataset_command(
The dataset is read as a soundevent annotation set and unique recording
paths are extracted before inference.
"""
from soundevent import io
dataset_path = Path(dataset_path)
dataset = io.load(dataset_path, type="annotation_set")
audio_files = sorted(

View File

@ -1,312 +0,0 @@
from collections.abc import Callable
from typing import Annotated, List, Literal, Sequence
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
conditions: Registry[SoundEventCondition, []] = Registry("condition")
@add_import_config(conditions)
class SoundEventConditionImportConfig(ImportConfig):
"""Use any callable as a sound event condition.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
class HasTag:
def __init__(self, tag: data.Tag):
self.tag = tag
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return any(
self.tag.term.name == tag.term.name and self.tag.value == tag.value
for tag in sound_event_annotation.tags
)
@conditions.register(HasTagConfig)
@staticmethod
def from_config(config: HasTagConfig):
return HasTag(tag=config.tag)
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: List[data.Tag]
class HasAllTags:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = {(tag.term.name, tag.value) for tag in tags}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return self.tags.issubset(
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
)
@conditions.register(HasAllTagsConfig)
@staticmethod
def from_config(config: HasAllTagsConfig):
return HasAllTags(tags=config.tags)
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: List[data.Tag]
class HasAnyTag:
def __init__(self, tags: List[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.tags = {(tag.term.name, tag.value) for tag in tags}
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return bool(
self.tags.intersection(
{
(tag.term.name, tag.value)
for tag in sound_event_annotation.tags
}
)
)
@conditions.register(HasAnyTagConfig)
@staticmethod
def from_config(config: HasAnyTagConfig):
return HasAnyTag(tags=config.tags)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
class DurationConfig(BaseConfig):
name: Literal["duration"] = "duration"
operator: Operator
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
if operator == "gte":
return lambda x: x >= value
if operator == "lt":
return lambda x: x < value
if operator == "lte":
return lambda x: x <= value
if operator == "eq":
return lambda x: x == value
raise ValueError(f"Invalid operator {operator}")
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
self.seconds = seconds
self._comparator = _build_comparator(self.operator, self.seconds)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
start_time, _, end_time, _ = compute_bounds(geometry)
duration = end_time - start_time
return self._comparator(duration)
@conditions.register(DurationConfig)
@staticmethod
def from_config(config: DurationConfig):
return Duration(operator=config.operator, seconds=config.seconds)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
operator: Operator
hertz: float
class Frequency:
def __init__(
self,
operator: Operator,
boundary: Literal["low", "high"],
hertz: float,
):
self.operator = operator
self.hertz = hertz
self.boundary = boundary
self._comparator = _build_comparator(self.operator, self.hertz)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
# Automatically false if geometry does not have a frequency range
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
return False
_, low_freq, _, high_freq = compute_bounds(geometry)
if self.boundary == "low":
return self._comparator(low_freq)
return self._comparator(high_freq)
@conditions.register(FrequencyConfig)
@staticmethod
def from_config(config: FrequencyConfig):
return Frequency(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
class AllOfConfig(BaseConfig):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
class AllOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return all(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AllOfConfig)
@staticmethod
def from_config(config: AllOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AllOf(conditions)
class AnyOfConfig(BaseConfig):
name: Literal["any_of"] = "any_of"
conditions: List["SoundEventConditionConfig"]
class AnyOf:
def __init__(self, conditions: List[SoundEventCondition]):
self.conditions = conditions
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return any(c(sound_event_annotation) for c in self.conditions)
@conditions.register(AnyOfConfig)
@staticmethod
def from_config(config: AnyOfConfig):
conditions = [
build_sound_event_condition(cond) for cond in config.conditions
]
return AnyOf(conditions)
class NotConfig(BaseConfig):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
class Not:
def __init__(self, condition: SoundEventCondition):
self.condition = condition
def __call__(
self, sound_event_annotation: data.SoundEventAnnotation
) -> bool:
return not self.condition(sound_event_annotation)
@conditions.register(NotConfig)
@staticmethod
def from_config(config: NotConfig):
condition = build_sound_event_condition(config.condition)
return Not(condition)
SoundEventConditionConfig = Annotated[
HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| DurationConfig
| FrequencyConfig
| AllOfConfig
| AnyOfConfig
| NotConfig,
Field(discriminator="name"),
]
def build_sound_event_condition(
config: SoundEventConditionConfig,
) -> SoundEventCondition:
return conditions.build(config)
def filter_clip_annotation(
clip_annotation: data.ClipAnnotation,
condition: SoundEventCondition,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
sound_event
for sound_event in clip_annotation.sound_events
if condition(sound_event)
]
)
)

View File

@ -0,0 +1,81 @@
from batdetect2.data.conditions.clips import (
ClipAllOfConfig,
ClipAnnotationCondition,
ClipAnnotationConditionConfig,
ClipAnnotationConditionImportConfig,
ClipAnyOfConfig,
ClipNotConfig,
RecordingSatisfiesConfig,
build_clip_annotation_condition,
)
from batdetect2.data.conditions.common import (
CsvList,
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
JsonList,
ListFormatConfig,
TxtList,
)
from batdetect2.data.conditions.recordings import (
PathInListConfig,
RecordingAllOfConfig,
RecordingAnyOfConfig,
RecordingCondition,
RecordingConditionConfig,
RecordingConditionImportConfig,
RecordingNotConfig,
build_recording_condition,
)
from batdetect2.data.conditions.sound_events import (
AllOfConfig,
AnyOfConfig,
DurationConfig,
FrequencyConfig,
NotConfig,
Operator,
SoundEventCondition,
SoundEventConditionConfig,
SoundEventConditionImportConfig,
build_sound_event_condition,
filter_clip_annotation,
)
__all__ = [
"AllOfConfig",
"AnyOfConfig",
"ClipAllOfConfig",
"ClipAnnotationCondition",
"ClipAnnotationConditionConfig",
"ClipAnnotationConditionImportConfig",
"ClipAnyOfConfig",
"ClipNotConfig",
"CsvList",
"DurationConfig",
"FrequencyConfig",
"HasAllTagsConfig",
"HasAnyTagConfig",
"HasTagConfig",
"IdInListConfig",
"JsonList",
"ListFormatConfig",
"NotConfig",
"Operator",
"PathInListConfig",
"RecordingCondition",
"RecordingConditionConfig",
"RecordingConditionImportConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingNotConfig",
"RecordingSatisfiesConfig",
"SoundEventCondition",
"SoundEventConditionConfig",
"SoundEventConditionImportConfig",
"TxtList",
"build_clip_annotation_condition",
"build_recording_condition",
"build_sound_event_condition",
"filter_clip_annotation",
]

View File

@ -0,0 +1,138 @@
from collections.abc import Callable, Sequence
from typing import Annotated, Literal
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions.common import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
MultiConditionConfigBase,
NotConditionConfigBase,
register_all_of_condition,
register_any_of_condition,
register_has_all_tags_condition,
register_has_any_tag_condition,
register_has_tag_condition,
register_id_in_list_condition,
register_not_condition,
)
from batdetect2.data.conditions.recordings import (
RecordingCondition,
RecordingConditionConfig,
build_recording_condition,
)
__all__ = [
"ClipAllOfConfig",
"ClipAnnotationCondition",
"ClipAnnotationConditionConfig",
"ClipAnnotationConditionImportConfig",
"ClipAnyOfConfig",
"ClipNotConfig",
"RecordingSatisfiesConfig",
"build_clip_annotation_condition",
]
ClipAnnotationCondition = Callable[[data.ClipAnnotation], bool]
clip_annotation_conditions: Registry[
ClipAnnotationCondition,
[data.PathLike | None],
] = Registry("clip_condition")
@add_import_config(clip_annotation_conditions, arg_names=["base_dir"])
class ClipAnnotationConditionImportConfig(ImportConfig):
"""Use any callable as a clip annotation condition.
Set ``name="import"`` and provide a ``target`` pointing to any callable
to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
class RecordingSatisfiesConfig(BaseConfig):
name: Literal["recording_satisfies"] = "recording_satisfies"
condition: RecordingConditionConfig
class RecordingSatisfies:
def __init__(self, condition: RecordingCondition):
self.condition = condition
def __call__(self, clip_annotation: data.ClipAnnotation) -> bool:
recording = clip_annotation.clip.recording
return self.condition(recording)
@clip_annotation_conditions.register(RecordingSatisfiesConfig)
@staticmethod
def from_config(
config: RecordingSatisfiesConfig,
base_dir: data.PathLike | None = None,
) -> "RecordingSatisfies":
condition = build_recording_condition(
config.condition,
base_dir=base_dir,
)
return RecordingSatisfies(condition)
register_has_tag_condition(clip_annotation_conditions, HasTagConfig)
register_has_all_tags_condition(
clip_annotation_conditions,
HasAllTagsConfig,
)
register_has_any_tag_condition(
clip_annotation_conditions,
HasAnyTagConfig,
)
register_id_in_list_condition(clip_annotation_conditions, IdInListConfig)
@register_all_of_condition(clip_annotation_conditions)
class ClipAllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of"
conditions: Sequence["ClipAnnotationConditionConfig"]
@register_any_of_condition(clip_annotation_conditions)
class ClipAnyOfConfig(MultiConditionConfigBase):
name: Literal["any_of"] = "any_of"
conditions: Sequence["ClipAnnotationConditionConfig"]
@register_not_condition(clip_annotation_conditions)
class ClipNotConfig(NotConditionConfigBase):
name: Literal["not"] = "not"
condition: "ClipAnnotationConditionConfig"
ClipAnnotationConditionConfig = Annotated[
RecordingSatisfiesConfig
| IdInListConfig
| HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| ClipAllOfConfig
| ClipAnyOfConfig
| ClipNotConfig
| ClipAnnotationConditionImportConfig,
Field(discriminator="name"),
]
def build_clip_annotation_condition(
config: ClipAnnotationConditionConfig,
base_dir: data.PathLike | None = None,
) -> ClipAnnotationCondition:
return clip_annotation_conditions.build(config, base_dir)

View File

@ -0,0 +1,417 @@
import csv
import json
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
from uuid import UUID
from pydantic import BaseModel, Field, model_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import Registry
__all__ = [
"AllOf",
"AnyOf",
"Condition",
"CsvList",
"HasAllTags",
"HasAllTagsConfig",
"HasAnyTag",
"HasAnyTagConfig",
"HasTag",
"HasTagConfig",
"IdInList",
"IdInListConfig",
"JsonList",
"ListLoader",
"ListFormatConfig",
"MultiConditionConfigBase",
"Not",
"NotConditionConfigBase",
"ObjectWithTags",
"ObjectWithUUID",
"TxtList",
"build_list_loader",
"register_all_of_condition",
"register_any_of_condition",
"register_has_all_tags_condition",
"register_has_any_tag_condition",
"register_has_tag_condition",
"register_id_in_list_condition",
"register_not_condition",
]
class ObjectWithTags(Protocol):
tags: list[data.Tag]
class ObjectWithUUID(Protocol):
uuid: UUID
ConditionObject = TypeVar("ConditionObject")
TaggedObject = TypeVar("TaggedObject", bound="ObjectWithTags")
UUIDObject = TypeVar("UUIDObject", bound="ObjectWithUUID")
P = ParamSpec("P")
NotConfigType = TypeVar("NotConfigType", bound="NotConditionConfigBase")
MultiConfigType = TypeVar(
"MultiConfigType",
bound="MultiConditionConfigBase",
)
Condition = Callable[[ConditionObject], bool]
class NotConditionConfigBase(BaseConfig):
condition: BaseModel
class MultiConditionConfigBase(BaseConfig):
conditions: Sequence[BaseModel]
class Not(Generic[ConditionObject]):
def __init__(self, condition: Condition[ConditionObject]):
self.condition = condition
def __call__(self, obj: ConditionObject) -> bool:
return not self.condition(obj)
class AllOf(Generic[ConditionObject]):
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
self.conditions = list(conditions)
def __call__(self, obj: ConditionObject) -> bool:
return all(condition(obj) for condition in self.conditions)
class AnyOf(Generic[ConditionObject]):
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
self.conditions = list(conditions)
def __call__(self, obj: ConditionObject) -> bool:
return any(condition(obj) for condition in self.conditions)
class HasTag(Generic[TaggedObject]):
def __init__(self, tag: data.Tag):
self.tag_key = (tag.term.name, tag.value)
def __call__(self, obj: TaggedObject) -> bool:
return any(
(tag.term.name, tag.value) == self.tag_key for tag in obj.tags
)
class HasAllTags(Generic[TaggedObject]):
def __init__(self, tags: list[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
def __call__(self, obj: TaggedObject) -> bool:
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
return self.required_keys.issubset(tag_keys)
class HasAnyTag(Generic[TaggedObject]):
def __init__(self, tags: list[data.Tag]):
if not tags:
raise ValueError("Need to specify at least one tag")
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
def __call__(self, obj: TaggedObject) -> bool:
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
return bool(self.required_keys.intersection(tag_keys))
class IdInList(Generic[UUIDObject]):
def __init__(self, ids: set[UUID]):
self.ids = ids
def __call__(self, obj: UUIDObject) -> bool:
return obj.uuid in self.ids
class HasTagConfig(BaseConfig):
name: Literal["has_tag"] = "has_tag"
tag: data.Tag
class HasAllTagsConfig(BaseConfig):
name: Literal["has_all_tags"] = "has_all_tags"
tags: list[data.Tag]
class HasAnyTagConfig(BaseConfig):
name: Literal["has_any_tag"] = "has_any_tag"
tags: list[data.Tag]
class JsonList(BaseConfig):
name: Literal["json"] = "json"
field: str | None = None
class TxtList(BaseConfig):
name: Literal["txt"] = "txt"
class CsvList(BaseConfig):
name: Literal["csv"] = "csv"
column: str
ListFormatConfig = Annotated[
JsonList | TxtList | CsvList,
Field(discriminator="name"),
]
ListLoader = Callable[[Path], list[str]]
list_loaders: Registry[ListLoader, []] = Registry("list_loader")
class IdInListConfig(BaseConfig):
name: Literal["id_in_list"] = "id_in_list"
path: Path
format: ListFormatConfig = JsonList()
@model_validator(mode="before")
@classmethod
def _normalize_format(cls, values):
if not isinstance(values, dict):
return values
format_config = values.get("format")
if isinstance(format_config, str):
values = values.copy()
config_class = list_loaders.get_config_type(format_config)
values["format"] = config_class().model_dump()
return values
class JsonListLoader:
def __init__(self, field: str | None):
self.field = field
def __call__(self, path: Path) -> list[str]:
content = json.loads(path.read_text())
if self.field is not None:
if not isinstance(content, dict):
raise TypeError(
"Expected JSON object with field for 'id_in_list'."
)
if self.field not in content:
raise KeyError(f"Field '{self.field}' not found in '{path}'.")
content = content[self.field]
if not isinstance(content, list):
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
return [str(value) for value in content]
@list_loaders.register(JsonList)
@staticmethod
def from_config(config: JsonList) -> ListLoader:
return JsonListLoader(field=config.field)
class TxtListLoader:
def __call__(self, path: Path) -> list[str]:
return [
line.strip()
for line in path.read_text().splitlines()
if line.strip()
]
@list_loaders.register(TxtList)
@staticmethod
def from_config(config: TxtList) -> ListLoader:
return TxtListLoader()
class CsvListLoader:
def __init__(self, column: str):
self.column = column
def __call__(self, path: Path) -> list[str]:
with path.open("r", newline="") as csv_file:
reader = csv.DictReader(csv_file)
if reader.fieldnames is None:
raise ValueError(
f"Expected CSV header row for 'id_in_list' in '{path}'."
)
if self.column not in reader.fieldnames:
raise ValueError(
f"Column '{self.column}' not found in '{path}'."
)
values = []
for row in reader:
value = row.get(self.column)
if value is None:
continue
value = value.strip()
if not value:
continue
values.append(value)
return values
@list_loaders.register(CsvList)
@staticmethod
def from_config(config: CsvList) -> ListLoader:
return CsvListLoader(column=config.column)
def build_list_loader(config: ListFormatConfig) -> ListLoader:
return list_loaders.build(config)
def register_id_in_list_condition(
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
config_cls: type[IdInListConfig],
) -> None:
def builder(
config: IdInListConfig,
base_dir: data.PathLike | None = None,
) -> Condition[UUIDObject]:
path = config.path
if base_dir is not None and not path.is_absolute():
path = Path(base_dir) / path
ids = set()
loader = build_list_loader(config.format)
values = loader(path)
for index, value in enumerate(values):
try:
ids.add(UUID(value))
except ValueError as err:
raise ValueError(
f"Invalid ID at index {index} in '{path}': {value!r}."
) from err
return IdInList(ids)
registry.register(config_cls)(builder)
def register_has_tag_condition(
registry: Registry[Condition[TaggedObject], P],
config_cls: type[HasTagConfig],
) -> None:
def builder(
config: HasTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasTag(config.tag)
registry.register(config_cls)(builder)
def register_has_all_tags_condition(
registry: Registry[Condition[TaggedObject], P],
config_cls: type[HasAllTagsConfig],
) -> None:
def builder(
config: HasAllTagsConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAllTags(config.tags)
registry.register(config_cls)(builder)
def register_has_any_tag_condition(
registry: Registry[Condition[TaggedObject], P],
config_cls: type[HasAnyTagConfig],
) -> None:
def builder(
config: HasAnyTagConfig,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[TaggedObject]:
return HasAnyTag(config.tags)
registry.register(config_cls)(builder)
def register_not_condition(
registry: Registry[Condition[ConditionObject], P],
) -> Callable[[type[NotConfigType]], type[NotConfigType]]:
def decorator(config_cls: type[NotConfigType]) -> type[NotConfigType]:
@registry.register(config_cls)
def builder(
config: NotConfigType,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[ConditionObject]:
condition = registry.build(config.condition, *args, **kwargs)
return Not(condition)
return config_cls
return decorator
def register_all_of_condition(
registry: Registry[Condition[ConditionObject], P],
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
@registry.register(config_cls)
def builder(
config: MultiConfigType,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[ConditionObject]:
conditions = [
registry.build(condition, *args, **kwargs)
for condition in config.conditions
]
return AllOf(conditions)
return config_cls
return decorator
def register_any_of_condition(
registry: Registry[Condition[ConditionObject], P],
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
@registry.register(config_cls)
def builder(
config: MultiConfigType,
*args: P.args,
**kwargs: P.kwargs,
) -> Condition[ConditionObject]:
conditions = [
registry.build(condition, *args, **kwargs)
for condition in config.conditions
]
return AnyOf(conditions)
return config_cls
return decorator

View File

@ -0,0 +1,217 @@
from collections.abc import Callable, Sequence
from pathlib import Path
from typing import Annotated, Literal
from loguru import logger
from pydantic import Field, model_validator
from soundevent import data
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions.common import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
JsonList,
ListFormatConfig,
MultiConditionConfigBase,
NotConditionConfigBase,
build_list_loader,
list_loaders,
register_all_of_condition,
register_any_of_condition,
register_has_all_tags_condition,
register_has_any_tag_condition,
register_has_tag_condition,
register_id_in_list_condition,
register_not_condition,
)
__all__ = [
"IdInListConfig",
"PathInListConfig",
"RecordingAllOfConfig",
"RecordingAnyOfConfig",
"RecordingCondition",
"RecordingConditionConfig",
"RecordingConditionImportConfig",
"RecordingNotConfig",
"build_recording_condition",
]
RecordingCondition = Callable[[data.Recording], bool]
recording_conditions: Registry[RecordingCondition, [data.PathLike | None]] = (
Registry("recording_condition")
)
@add_import_config(recording_conditions, arg_names=["base_dir"])
class RecordingConditionImportConfig(ImportConfig):
"""Use any callable as a recording condition.
Set ``name="import"`` and provide a ``target`` pointing to any callable
to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
register_id_in_list_condition(recording_conditions, IdInListConfig)
register_has_tag_condition(recording_conditions, HasTagConfig)
register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
class PathInListConfig(BaseConfig):
name: Literal["path_in_list"] = "path_in_list"
path: Path
format: ListFormatConfig = JsonList()
base_dir: Path | None = None
on_outside: Literal["allow", "warn", "error"] = "allow"
@model_validator(mode="before")
@classmethod
def _normalize_format(cls, values):
if not isinstance(values, dict):
return values
format_config = values.get("format")
if isinstance(format_config, str):
values = values.copy()
config_class = list_loaders.get_config_type(format_config)
values["format"] = config_class().model_dump()
return values
class PathInList:
def __init__(
self,
paths: set[Path],
base_dir: Path | None,
on_outside: Literal["allow", "warn", "error"],
):
self.paths = paths
self.base_dir = base_dir
self.on_outside = on_outside
def __call__(self, recording: data.Recording) -> bool:
normalized_path = self._normalize_recording_path(recording.path)
if normalized_path is None:
return True
return normalized_path in self.paths
def _normalize_recording_path(self, path: data.PathLike) -> Path | None:
recording_path = Path(path)
if self.base_dir is None:
return recording_path
if not recording_path.is_absolute():
return recording_path
try:
return recording_path.relative_to(self.base_dir)
except ValueError as err:
if self.on_outside == "allow":
return None
if self.on_outside == "warn":
logger.warning(
"Recording path '{}' is outside '{}' in path_in_list; "
"allowing.",
recording_path,
self.base_dir,
)
return None
raise ValueError(
f"Recording path '{recording_path}' is outside "
f"'{self.base_dir}' for 'path_in_list'."
) from err
@recording_conditions.register(PathInListConfig)
@staticmethod
def from_config(
config: PathInListConfig,
base_dir: data.PathLike | None = None,
) -> "PathInList":
list_path = config.path
if base_dir is not None and not list_path.is_absolute():
list_path = Path(base_dir) / list_path
match_base_dir = config.base_dir
if (
match_base_dir is not None
and base_dir is not None
and not match_base_dir.is_absolute()
):
match_base_dir = Path(base_dir) / match_base_dir
loader = build_list_loader(config.format)
paths = {
Path(value).relative_to(match_base_dir)
if (
match_base_dir is not None
and Path(value).is_absolute()
and Path(value).is_relative_to(match_base_dir)
)
else Path(value)
for value in loader(list_path)
}
return PathInList(
paths=paths,
base_dir=match_base_dir,
on_outside=config.on_outside,
)
@register_all_of_condition(recording_conditions)
class RecordingAllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of"
conditions: Sequence["RecordingConditionConfig"]
@register_any_of_condition(recording_conditions)
class RecordingAnyOfConfig(MultiConditionConfigBase):
name: Literal["any_of"] = "any_of"
conditions: Sequence["RecordingConditionConfig"]
@register_not_condition(recording_conditions)
class RecordingNotConfig(NotConditionConfigBase):
name: Literal["not"] = "not"
condition: "RecordingConditionConfig"
RecordingConditionConfig = Annotated[
IdInListConfig
| PathInListConfig
| HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| RecordingAllOfConfig
| RecordingAnyOfConfig
| RecordingNotConfig
| RecordingConditionImportConfig,
Field(discriminator="name"),
]
def build_recording_condition(
config: RecordingConditionConfig,
base_dir: data.PathLike | None = None,
) -> RecordingCondition:
return recording_conditions.build(config, base_dir)

View File

@ -0,0 +1,236 @@
from collections.abc import Callable, Sequence
from typing import Annotated, Literal
from pydantic import Field
from soundevent import data
from soundevent.geometry import compute_bounds
from batdetect2.core.configs import BaseConfig
from batdetect2.core.registries import (
ImportConfig,
Registry,
add_import_config,
)
from batdetect2.data.conditions.common import (
HasAllTagsConfig,
HasAnyTagConfig,
HasTagConfig,
IdInListConfig,
MultiConditionConfigBase,
NotConditionConfigBase,
register_all_of_condition,
register_any_of_condition,
register_has_all_tags_condition,
register_has_any_tag_condition,
register_has_tag_condition,
register_id_in_list_condition,
register_not_condition,
)
__all__ = [
"AllOfConfig",
"AnyOfConfig",
"DurationConfig",
"FrequencyConfig",
"HasAllTagsConfig",
"HasAnyTagConfig",
"HasTagConfig",
"NotConfig",
"Operator",
"SoundEventCondition",
"SoundEventConditionConfig",
"SoundEventConditionImportConfig",
"build_sound_event_condition",
"filter_clip_annotation",
]
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
sound_event_conditions: Registry[
SoundEventCondition,
[data.PathLike | None],
] = Registry("sound_event_condition")
@add_import_config(sound_event_conditions, arg_names=["base_dir"])
class SoundEventConditionImportConfig(ImportConfig):
"""Use any callable as a sound event condition.
Set ``name="import"`` and provide a ``target`` pointing to any
callable to use it instead of a built-in option.
"""
name: Literal["import"] = "import"
register_has_tag_condition(sound_event_conditions, HasTagConfig)
register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig)
register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig)
register_id_in_list_condition(sound_event_conditions, IdInListConfig)
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
class DurationConfig(BaseConfig):
name: Literal["duration"] = "duration"
operator: Operator
seconds: float
def _build_comparator(
operator: Operator, value: float
) -> Callable[[float], bool]:
if operator == "gt":
return lambda x: x > value
if operator == "gte":
return lambda x: x >= value
if operator == "lt":
return lambda x: x < value
if operator == "lte":
return lambda x: x <= value
if operator == "eq":
return lambda x: x == value
raise ValueError(f"Invalid operator {operator}")
class Duration:
def __init__(self, operator: Operator, seconds: float):
self.operator = operator
self.seconds = seconds
self._comparator = _build_comparator(self.operator, self.seconds)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
start_time, _, end_time, _ = compute_bounds(geometry)
duration = end_time - start_time
return self._comparator(duration)
@sound_event_conditions.register(DurationConfig)
@staticmethod
def from_config(
config: DurationConfig,
base_dir: data.PathLike | None = None,
):
_ = base_dir
return Duration(operator=config.operator, seconds=config.seconds)
class FrequencyConfig(BaseConfig):
name: Literal["frequency"] = "frequency"
boundary: Literal["low", "high"]
operator: Operator
hertz: float
class Frequency:
def __init__(
self,
operator: Operator,
boundary: Literal["low", "high"],
hertz: float,
):
self.operator = operator
self.hertz = hertz
self.boundary = boundary
self._comparator = _build_comparator(self.operator, self.hertz)
def __call__(
self,
sound_event_annotation: data.SoundEventAnnotation,
) -> bool:
geometry = sound_event_annotation.sound_event.geometry
if geometry is None:
return False
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
return False
_, low_freq, _, high_freq = compute_bounds(geometry)
if self.boundary == "low":
return self._comparator(low_freq)
return self._comparator(high_freq)
@sound_event_conditions.register(FrequencyConfig)
@staticmethod
def from_config(
config: FrequencyConfig,
base_dir: data.PathLike | None = None,
):
_ = base_dir
return Frequency(
operator=config.operator,
boundary=config.boundary,
hertz=config.hertz,
)
@register_all_of_condition(sound_event_conditions)
class AllOfConfig(MultiConditionConfigBase):
name: Literal["all_of"] = "all_of"
conditions: Sequence["SoundEventConditionConfig"]
@register_any_of_condition(sound_event_conditions)
class AnyOfConfig(MultiConditionConfigBase):
name: Literal["any_of"] = "any_of"
conditions: list["SoundEventConditionConfig"]
@register_not_condition(sound_event_conditions)
class NotConfig(NotConditionConfigBase):
name: Literal["not"] = "not"
condition: "SoundEventConditionConfig"
SoundEventConditionConfig = Annotated[
IdInListConfig
| HasTagConfig
| HasAllTagsConfig
| HasAnyTagConfig
| DurationConfig
| FrequencyConfig
| AllOfConfig
| AnyOfConfig
| NotConfig
| SoundEventConditionImportConfig,
Field(discriminator="name"),
]
def build_sound_event_condition(
config: SoundEventConditionConfig,
base_dir: data.PathLike | None = None,
) -> SoundEventCondition:
return sound_event_conditions.build(config, base_dir)
def filter_clip_annotation(
clip_annotation: data.ClipAnnotation,
condition: SoundEventCondition,
) -> data.ClipAnnotation:
return clip_annotation.model_copy(
update=dict(
sound_events=[
sound_event
for sound_event in clip_annotation.sound_events
if condition(sound_event)
]
)
)

View File

@ -32,7 +32,9 @@ from batdetect2.data.annotations import (
load_annotated_dataset,
)
from batdetect2.data.conditions import (
ClipAnnotationConditionConfig,
SoundEventConditionConfig,
build_clip_annotation_condition,
build_sound_event_condition,
filter_clip_annotation,
)
@ -69,6 +71,7 @@ class DatasetConfig(BaseConfig):
description: str
sources: list[AnnotationFormats]
clip_filter: ClipAnnotationConditionConfig | None = None
sound_event_filter: SoundEventConditionConfig | None = None
sound_event_transforms: list[SoundEventTransformConfig] = Field(
default_factory=list
@ -84,11 +87,58 @@ def load_dataset(
apply_transforms: bool = True,
apply_filters: bool = True,
) -> Dataset:
"""Load all clip annotations from the sources defined in a DatasetConfig."""
"""Load and merge clip annotations from configured dataset sources.
Loads each source listed in ``config.sources`` and returns a flat
collection of ``soundevent.data.ClipAnnotation`` objects. Source tags,
dataset-level filters, and dataset-level transforms can be enabled or
disabled with flags.
Parameters
----------
config : DatasetConfig
Dataset definition containing source configurations, optional
clip-level filter, sound-event filter, and optional sound-event
transform pipeline.
base_dir : data.PathLike, optional
Base directory used to resolve relative paths in source
configurations.
add_source_tag : bool, default=True
If True, append a ``data_source`` tag to each clip annotation with
the source name.
include_sources : list[str], optional
Source names to include. If None, all sources are eligible.
exclude_sources : list[str], optional
Source names to skip after include filtering. If a source appears in
both include and exclude lists, it is skipped.
apply_transforms : bool, default=True
If True, apply transforms defined in
``config.sound_event_transforms``.
apply_filters : bool, default=True
If True, apply filters defined in ``config.clip_filter`` and
``config.sound_event_filter``.
Returns
-------
Dataset
Flat collection of clip annotations loaded from the selected sources.
"""
clip_annotations = []
condition = (
build_sound_event_condition(config.sound_event_filter)
clip_condition = (
build_clip_annotation_condition(
config.clip_filter,
base_dir=base_dir,
)
if config.clip_filter is not None
else None
)
sound_event_condition = (
build_sound_event_condition(
config.sound_event_filter,
base_dir=base_dir,
)
if config.sound_event_filter is not None
else None
)
@ -123,10 +173,17 @@ def load_dataset(
if add_source_tag:
clip_annotation = insert_source_tag(clip_annotation, source)
if condition is not None and apply_filters:
if (
clip_condition is not None
and apply_filters
and not clip_condition(clip_annotation)
):
continue
if sound_event_condition is not None and apply_filters:
clip_annotation = filter_clip_annotation(
clip_annotation,
condition,
sound_event_condition,
)
if transform is not None and apply_transforms:
@ -181,47 +238,58 @@ def load_dataset_from_config(
path: data.PathLike,
field: str | None = None,
base_dir: data.PathLike | None = None,
add_source_tag: bool = True,
include_sources: list[str] | None = None,
exclude_sources: list[str] | None = None,
apply_transforms: bool = True,
apply_filters: bool = True,
) -> Dataset:
"""Load dataset annotation metadata from a configuration file.
"""Load a dataset by reading a ``DatasetConfig`` from disk.
This is a convenience function that first loads the `DatasetConfig` from
the specified file path and optional nested field, and then calls
`load_dataset` to load all corresponding `ClipAnnotation` objects.
This convenience wrapper first loads a ``DatasetConfig`` from ``path``
and optional ``field``, then delegates to :func:`load_dataset`.
Parameters
----------
path : data.PathLike
Path to the configuration file (e.g., YAML).
Path to a configuration file containing a ``DatasetConfig``.
field : str, optional
Dot-separated path to a nested section within the file containing the
dataset configuration (e.g., "data.training_set"). If None, the
entire file content is assumed to be the `DatasetConfig`.
base_dir : Path, optional
An optional base directory path to resolve relative paths within the
configuration sources. Passed to `load_dataset`. Defaults to None.
Dot-separated field path to a nested config section. If None, the
full file is parsed as ``DatasetConfig``.
base_dir : data.PathLike, optional
Base directory used to resolve relative paths in source
configurations.
add_source_tag : bool, default=True
If True, append a ``data_source`` tag to each clip annotation.
include_sources : list[str], optional
Source names to include. If None, all sources are eligible.
exclude_sources : list[str], optional
Source names to skip after include filtering.
apply_transforms : bool, default=True
If True, apply transforms defined in the loaded config.
apply_filters : bool, default=True
If True, apply clip and sound-event filters defined in the loaded
config.
Returns
-------
Dataset (List[data.ClipAnnotation])
A flat list containing all loaded `ClipAnnotation` metadata objects.
Raises
------
FileNotFoundError
If the config file `path` does not exist.
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
If the configuration file is invalid, cannot be parsed, or does not
match the `DatasetConfig` schema.
Exception
Can raise exceptions from `load_dataset` if loading data from sources
fails.
Dataset
Flat collection of clip annotations loaded from the selected sources.
"""
config = load_config(
path=path,
schema=DatasetConfig,
field=field,
)
return load_dataset(config, base_dir=base_dir)
return load_dataset(
config,
base_dir=base_dir,
add_source_tag=add_source_tag,
include_sources=include_sources,
exclude_sources=exclude_sources,
apply_transforms=apply_transforms,
apply_filters=apply_filters,
)
def save_dataset(

View File

@ -1,9 +1,12 @@
from __future__ import annotations
import io
import sys
from collections.abc import Callable
from functools import partial
from pathlib import Path
from typing import (
TYPE_CHECKING,
Annotated,
Any,
Dict,
@ -13,21 +16,23 @@ from typing import (
TypeVar,
)
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from loguru import logger
from matplotlib.figure import Figure
from pydantic import Field
from soundevent import data
from batdetect2.core.configs import BaseConfig
if TYPE_CHECKING:
import numpy as np
import pandas as pd
from lightning.pytorch.loggers import (
CSVLogger,
Logger,
MLFlowLogger,
TensorBoardLogger,
)
from matplotlib.figure import Figure
from soundevent import data
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
__all__ = [
@ -271,10 +276,16 @@ def build_logger(
)
PlotLogger = Callable[[str, Figure, int], None]
PlotLogger = Callable[[str, "Figure", int], None]
def get_image_logger(logger: Logger) -> PlotLogger | None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(logger, TensorBoardLogger):
return logger.experiment.add_figure
@ -296,10 +307,16 @@ def get_image_logger(logger: Logger) -> PlotLogger | None:
return partial(save_figure, dir=Path(logger.log_dir))
TableLogger = Callable[[str, pd.DataFrame, int], None]
TableLogger = Callable[[str, "pd.DataFrame", int], None]
def get_table_logger(logger: Logger) -> TableLogger | None:
from lightning.pytorch.loggers import (
CSVLogger,
MLFlowLogger,
TensorBoardLogger,
)
if isinstance(logger, TensorBoardLogger):
return partial(save_table, dir=Path(logger.log_dir))
@ -337,6 +354,8 @@ def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
import numpy as np
with io.BytesIO() as buff:
figure.savefig(buff, format="raw")
buff.seek(0)

View File

@ -15,11 +15,11 @@ GENERIC_CLASS_KEY = "class"
data_source = data.Term(
name="soundevent:data_source",
label="Data Source",
name="dcterms:source",
label="Source",
uri="http://purl.org/dc/terms/source",
definition=(
"A unique identifier for the source of the data, typically "
"representing the project, site, or deployment context."
"A related resource from which the described resource is derived."
),
)
@ -45,6 +45,17 @@ individual = data.Term(
)
"""Term used for tags identifying a specific individual animal."""
dataset_split = data.Term(
name="batdetect2:split",
label="Dataset Split",
definition=(
"Identifies the specific data partition (e.g., 'train', 'test') "
"that the item belongs to within an experimental setup. "
"The expected value is a literal text string."
),
)
"""Custom metadata term defining the machine learning partition of an item."""
generic_class = data.Term(
name="soundevent:class",
label="Class",

View File

@ -8,12 +8,10 @@ import torch
from soundevent.geometry import compute_bounds
from batdetect2.api_v2 import BatDetect2API
from batdetect2.audio import AudioConfig
from batdetect2.config import BatDetect2Config
from batdetect2.inference import InferenceConfig
from batdetect2.models.detectors import Detector
from batdetect2.models.heads import ClassifierHead
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
from batdetect2.train import load_model_from_checkpoint
from batdetect2.train.lightning import build_training_module
@ -48,6 +46,7 @@ def test_process_file_returns_recording_level_predictions(
)
@pytest.mark.slow
def test_process_files_is_batch_size_invariant(
api_v2: BatDetect2API,
example_audio_files: list[Path],
@ -182,6 +181,7 @@ def test_user_can_read_extracted_features_per_detection(
assert all(vec.size > 0 for vec in feature_vectors)
@pytest.mark.slow
def test_user_can_load_checkpoint_and_finetune(
tmp_path: Path,
example_annotations,
@ -295,6 +295,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
)
@pytest.mark.slow
def test_user_can_finetune_only_heads(
tmp_path: Path,
example_annotations,
@ -330,6 +331,7 @@ def test_user_can_finetune_only_heads(
assert list(finetune_dir.rglob("*.ckpt"))
@pytest.mark.slow
def test_user_can_evaluate_small_dataset_and_get_metrics(
api_v2: BatDetect2API,
example_annotations,
@ -416,6 +418,7 @@ def test_detection_threshold_override_changes_process_file_results(
)
@pytest.mark.slow
def test_detection_threshold_override_is_ephemeral_in_process_file(
api_v2: BatDetect2API,
example_audio_files: list[Path],
@ -452,51 +455,3 @@ def test_detection_threshold_override_changes_spectrogram_results(
)
assert len(strict_detections) <= len(default_detections)
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
"""User story: call-level overrides do not mutate resolved defaults."""
api = BatDetect2API.from_config(BatDetect2Config())
override_inference = InferenceConfig.model_validate(
{"loader": {"batch_size": 7}}
)
override_audio = AudioConfig.model_validate({"samplerate": 384000})
override_train = TrainingConfig.model_validate(
{"trainer": {"max_epochs": 2}}
)
captured_process: dict[str, object] = {}
captured_train: dict[str, object] = {}
def fake_process_file_list(*args, **kwargs):
captured_process["inference_config"] = kwargs["inference_config"]
captured_process["audio_config"] = kwargs["audio_config"]
return []
def fake_run_train(*args, **kwargs):
captured_train["train_config"] = kwargs["train_config"]
captured_train["audio_config"] = kwargs["audio_config"]
captured_train["model_config"] = kwargs["model_config"]
return None
monkeypatch.setattr(
"batdetect2.api_v2.process_file_list", fake_process_file_list
)
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
api.process_files(
[], inference_config=override_inference, audio_config=override_audio
)
api.train([], train_config=override_train, audio_config=override_audio)
assert captured_process["inference_config"] is override_inference
assert captured_process["audio_config"] is override_audio
assert captured_train["train_config"] is override_train
assert captured_train["audio_config"] is override_audio
assert captured_train["model_config"] is api.model_config
assert api.inference_config.loader.batch_size != 7
assert api.audio_config.samplerate != 384000
assert api.train_config.trainer.max_epochs != 2

View File

@ -1,4 +1,5 @@
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from hypothesis import given, settings
@ -10,6 +11,7 @@ from batdetect2.utils import audio_utils, detector_utils
@given(duration=st.floats(min_value=0.1, max_value=1))
@settings(deadline=None)
@pytest.mark.slow
def test_can_compute_correct_spectrogram_width(duration: float):
samplerate = parameters.TARGET_SAMPLERATE_HZ
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
@ -89,6 +91,7 @@ def test_pad_audio_without_fixed_size(duration: float):
@given(duration=st.floats(min_value=0.1, max_value=2))
@settings(deadline=None)
@pytest.mark.slow
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
duration: float,
):

View File

@ -3,11 +3,13 @@
from pathlib import Path
import pandas as pd
import pytest
from click.testing import CliRunner
from batdetect2.cli import cli
@pytest.mark.slow
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
"""User story: run legacy detect on example audio directory."""
@ -29,6 +31,7 @@ def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
assert len(list(results_dir.glob("*.json"))) == 3
@pytest.mark.slow
def test_cli_detect_command_with_non_trivial_time_expansion(
tmp_path: Path,
) -> None:
@ -52,6 +55,7 @@ def test_cli_detect_command_with_non_trivial_time_expansion(
assert "Time Expansion Factor: 10" in result.stdout
@pytest.mark.slow
def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None:
"""User story: request extra spectral features in output CSV."""

View File

@ -20,6 +20,7 @@ def test_cli_predict_help() -> None:
assert "dataset" in result.output
@pytest.mark.slow
def test_cli_predict_directory_runs_on_real_audio(
tmp_path: Path,
tiny_checkpoint_path: Path,

View File

@ -2,6 +2,7 @@
from pathlib import Path
import pytest
from click.testing import CliRunner
from batdetect2.cli import cli
@ -19,6 +20,7 @@ def test_cli_train_help() -> None:
assert "--model" in result.output
@pytest.mark.slow
def test_cli_train_from_checkpoint_runs_on_small_dataset(
tmp_path: Path,
tiny_checkpoint_path: Path,

View File

@ -2,11 +2,13 @@
from pathlib import Path
import pytest
from click.testing import CliRunner
from batdetect2.cli import cli
runner = CliRunner()
pytestmark = pytest.mark.slow
def test_can_process_jeff37_files(

View File

@ -0,0 +1,303 @@
import json
import textwrap
import uuid
from pathlib import Path
from pydantic import TypeAdapter
from soundevent import data
from batdetect2.core import load_config
from batdetect2.data.conditions import (
ClipAnnotationConditionConfig,
build_clip_annotation_condition,
)
def load_clip_condition_config(
tmp_path: Path,
yaml_string: str,
) -> ClipAnnotationConditionConfig:
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
config_path.write_text(textwrap.dedent(yaml_string).strip())
return load_config(
config_path, schema=TypeAdapter(ClipAnnotationConditionConfig)
)
def build_clip_condition_from_yaml(
tmp_path: Path,
yaml_string: str,
base_dir: Path | None = None,
):
config = load_clip_condition_config(tmp_path, yaml_string)
return build_clip_annotation_condition(config, base_dir=base_dir)
def test_recording_satisfies_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
clip_a = create_clip(recording_a)
clip_b = create_clip(recording_b)
clip_annotation_a = create_clip_annotation(clip_a)
clip_annotation_b = create_clip_annotation(clip_b)
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
condition = build_clip_condition_from_yaml(
tmp_path,
f"""
name: recording_satisfies
condition:
name: id_in_list
path: {ids_path}
""",
)
assert condition(clip_annotation_a)
assert not condition(clip_annotation_b)
def test_clip_id_in_list_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
clip_annotation_a = create_clip_annotation(create_clip(recording_a))
clip_annotation_b = create_clip_annotation(create_clip(recording_b))
ids_path = tmp_path / "clip_annotation_ids.json"
ids_path.write_text(json.dumps([str(clip_annotation_a.uuid)]))
condition = build_clip_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
""",
)
assert condition(clip_annotation_a)
assert not condition(clip_annotation_b)
def test_clip_has_tag_conditions(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
reviewed = data.Tag(key="status", value="reviewed")
train = data.Tag(key="split", value="train")
recording = create_recording(path=tmp_path / "rec.wav")
clip = create_clip(recording)
clip_annotation = create_clip_annotation(
clip,
clip_tags=[reviewed, train],
)
clip_annotation_missing = create_clip_annotation(
create_clip(recording),
clip_tags=[train],
)
condition = build_clip_condition_from_yaml(
tmp_path,
"""
name: has_tag
tag:
key: status
value: reviewed
""",
)
assert condition(clip_annotation)
assert not condition(clip_annotation_missing)
def test_clip_has_all_tags_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
reviewed = data.Tag(key="status", value="reviewed")
train = data.Tag(key="split", value="train")
recording = create_recording(path=tmp_path / "rec.wav")
clip_annotation = create_clip_annotation(
create_clip(recording),
clip_tags=[reviewed, train],
)
clip_annotation_missing = create_clip_annotation(
create_clip(recording),
clip_tags=[reviewed],
)
condition = build_clip_condition_from_yaml(
tmp_path,
"""
name: has_all_tags
tags:
- key: status
value: reviewed
- key: split
value: train
""",
)
assert condition(clip_annotation)
assert not condition(clip_annotation_missing)
def test_clip_has_any_tag_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
reviewed = data.Tag(key="status", value="reviewed")
train = data.Tag(key="split", value="train")
recording = create_recording(path=tmp_path / "rec.wav")
clip_annotation = create_clip_annotation(
create_clip(recording),
clip_tags=[reviewed, train],
)
clip_annotation_missing = create_clip_annotation(
create_clip(recording),
clip_tags=[data.Tag(key="split", value="test")],
)
condition = build_clip_condition_from_yaml(
tmp_path,
"""
name: has_any_tag
tags:
- key: split
value: val
- key: split
value: train
""",
)
assert condition(clip_annotation)
assert not condition(clip_annotation_missing)
def test_clip_all_of_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
reviewed = data.Tag(key="status", value="reviewed")
train = data.Tag(key="split", value="train")
recording = create_recording(path=tmp_path / "rec.wav")
clip = create_clip(recording)
clip_annotation = create_clip_annotation(
clip,
clip_tags=[reviewed, train],
)
clip_annotation_missing = create_clip_annotation(
create_clip(recording),
clip_tags=[reviewed],
)
condition = build_clip_condition_from_yaml(
tmp_path,
"""
name: all_of
conditions:
- name: has_tag
tag:
key: status
value: reviewed
- name: has_any_tag
tags:
- key: split
value: train
- key: split
value: val
""",
)
assert condition(clip_annotation)
assert not condition(clip_annotation_missing)
def test_clip_any_of_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
reviewed = data.Tag(key="status", value="reviewed")
recording = create_recording(path=tmp_path / "rec.wav")
clip_annotation = create_clip_annotation(
create_clip(recording),
clip_tags=[reviewed],
)
clip_annotation_missing = create_clip_annotation(
create_clip(recording),
clip_tags=[data.Tag(key="status", value="unchecked")],
)
condition = build_clip_condition_from_yaml(
tmp_path,
"""
name: any_of
conditions:
- name: has_tag
tag:
key: split
value: val
- name: has_tag
tag:
key: status
value: reviewed
""",
)
assert condition(clip_annotation)
assert not condition(clip_annotation_missing)
def test_clip_not_condition(
tmp_path: Path,
create_recording,
create_clip,
create_clip_annotation,
) -> None:
recording = create_recording(path=tmp_path / "rec.wav")
clip_annotation = create_clip_annotation(
create_clip(recording),
clip_tags=[data.Tag(key="split", value="train")],
)
clip_annotation_missing = create_clip_annotation(
create_clip(recording),
clip_tags=[data.Tag(key="split", value="val")],
)
condition = build_clip_condition_from_yaml(
tmp_path,
"""
name: not
condition:
name: has_tag
tag:
key: split
value: val
""",
)
assert condition(clip_annotation)
assert not condition(clip_annotation_missing)

View File

@ -0,0 +1,564 @@
import json
import textwrap
import uuid
from pathlib import Path
import pytest
from pydantic import TypeAdapter
from soundevent import data
from batdetect2.core import load_config
from batdetect2.data.conditions import (
RecordingConditionConfig,
build_recording_condition,
)
def load_recording_condition_config(
tmp_path: Path,
yaml_string: str,
) -> RecordingConditionConfig:
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
config_path.write_text(textwrap.dedent(yaml_string).strip())
return load_config(
config_path,
schema=TypeAdapter(RecordingConditionConfig),
)
def build_recording_condition_from_yaml(
tmp_path: Path,
yaml_string: str,
base_dir: Path | None = None,
):
config = load_recording_condition_config(tmp_path, yaml_string)
return build_recording_condition(config, base_dir=base_dir)
def test_id_in_list_condition(tmp_path: Path, create_recording) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_uses_base_dir(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
split_dir = tmp_path / "splits"
split_dir.mkdir()
ids_path = split_dir / "train_ids.json"
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
condition = build_recording_condition_from_yaml(
tmp_path,
"""
name: id_in_list
path: splits/train_ids.json
""",
base_dir=tmp_path,
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_raises_for_non_list_json(
tmp_path: Path,
) -> None:
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps({"id": "foo"}))
with pytest.raises(TypeError, match="Expected JSON list"):
build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
""",
)
def test_id_in_list_condition_raises_for_invalid_id(tmp_path: Path) -> None:
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(json.dumps(["not-a-uuid"]))
with pytest.raises(ValueError, match="Invalid ID"):
build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
""",
)
def test_id_in_list_condition_supports_txt_format(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.txt"
ids_path.write_text(f"{recording_a.uuid}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
format: txt
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_supports_json_field(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.json"
ids_path.write_text(
json.dumps(
{
"train": [str(recording_a.uuid)],
"val": [str(recording_b.uuid)],
}
)
)
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
format:
name: json
field: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_id_in_list_condition_supports_csv_column(
tmp_path: Path,
create_recording,
) -> None:
recording_a = create_recording(path=tmp_path / "a.wav")
recording_b = create_recording(path=tmp_path / "b.wav")
ids_path = tmp_path / "recording_ids.csv"
ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
format:
name: csv
column: recording_uuid
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_supports_txt_format(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_supports_json_field(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.json"
paths_file.write_text(
json.dumps(
{
"train": [str(recording_a.path)],
"val": [str(recording_b.path)],
}
)
)
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format:
name: json
field: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_supports_csv_column(
tmp_path: Path,
create_recording,
) -> None:
audio_dir = tmp_path / "audio"
audio_dir.mkdir()
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.csv"
paths_file.write_text(f"recording_path\n{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format:
name: csv
column: recording_path
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_uses_base_dir(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
audio_dir = data_dir / "audio"
audio_dir.mkdir(parents=True)
recording_a = create_recording(path=audio_dir / "a.wav")
recording_b = create_recording(path=audio_dir / "b.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_a.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_path_in_list_condition_outside_allow(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text("dataset/audio/unknown.wav\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: allow
""",
)
assert condition(recording_outside)
assert not condition(recording_inside)
def test_path_in_list_condition_outside_warn(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text("dataset/audio/unknown.wav\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: warn
""",
)
assert condition(recording_outside)
assert not condition(recording_inside)
def test_path_in_list_condition_outside_error(
tmp_path: Path,
create_recording,
) -> None:
data_dir = tmp_path / "dataset"
inside_dir = data_dir / "audio"
inside_dir.mkdir(parents=True)
outside_dir = tmp_path / "other"
outside_dir.mkdir()
recording_inside = create_recording(path=inside_dir / "a.wav")
recording_outside = create_recording(path=outside_dir / "x.wav")
paths_file = tmp_path / "recording_paths.txt"
paths_file.write_text(f"{recording_inside.path}\n")
condition = build_recording_condition_from_yaml(
tmp_path,
f"""
name: path_in_list
path: {paths_file}
format: txt
base_dir: {data_dir}
on_outside: error
""",
)
assert condition(recording_inside)
with pytest.raises(ValueError, match="outside"):
condition(recording_outside)
def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
train = data.Tag(key="split", value="train")
val = data.Tag(key="split", value="val")
recording_a = create_recording(
path=tmp_path / "a.wav",
tags=[train],
)
recording_b = create_recording(
path=tmp_path / "b.wav",
tags=[val],
)
condition = build_recording_condition_from_yaml(
tmp_path,
"""
name: has_tag
tag:
key: split
value: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_has_all_tags_condition(tmp_path: Path, create_recording) -> None:
train = data.Tag(key="split", value="train")
uk = data.Tag(key="region", value="uk")
recording_a = create_recording(
path=tmp_path / "a.wav",
tags=[train, uk],
)
recording_b = create_recording(
path=tmp_path / "b.wav",
tags=[train],
)
condition = build_recording_condition_from_yaml(
tmp_path,
"""
name: has_all_tags
tags:
- key: split
value: train
- key: region
value: uk
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_has_any_tag_condition(tmp_path: Path, create_recording) -> None:
uk = data.Tag(key="region", value="uk")
us = data.Tag(key="region", value="us")
recording_a = create_recording(
path=tmp_path / "a.wav",
tags=[uk],
)
recording_b = create_recording(
path=tmp_path / "b.wav",
tags=[us],
)
condition = build_recording_condition_from_yaml(
tmp_path,
"""
name: has_any_tag
tags:
- key: region
value: eu
- key: region
value: uk
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_all_of_condition(tmp_path: Path, create_recording) -> None:
train = data.Tag(key="split", value="train")
uk = data.Tag(key="region", value="uk")
us = data.Tag(key="region", value="us")
recording_a = create_recording(
path=tmp_path / "a.wav",
tags=[train, uk],
)
recording_b = create_recording(
path=tmp_path / "b.wav",
tags=[train, us],
)
condition = build_recording_condition_from_yaml(
tmp_path,
"""
name: all_of
conditions:
- name: has_tag
tag:
key: split
value: train
- name: has_any_tag
tags:
- key: region
value: eu
- key: region
value: uk
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_any_of_condition(tmp_path: Path, create_recording) -> None:
train = data.Tag(key="split", value="train")
us = data.Tag(key="region", value="us")
recording_a = create_recording(
path=tmp_path / "a.wav",
tags=[train],
)
recording_b = create_recording(
path=tmp_path / "b.wav",
tags=[us],
)
condition = build_recording_condition_from_yaml(
tmp_path,
"""
name: any_of
conditions:
- name: has_tag
tag:
key: region
value: eu
- name: has_tag
tag:
key: split
value: train
""",
)
assert condition(recording_a)
assert not condition(recording_b)
def test_not_condition(tmp_path: Path, create_recording) -> None:
uk = data.Tag(key="region", value="uk")
us = data.Tag(key="region", value="us")
recording_a = create_recording(
path=tmp_path / "a.wav",
tags=[uk],
)
recording_b = create_recording(
path=tmp_path / "b.wav",
tags=[us],
)
condition = build_recording_condition_from_yaml(
tmp_path,
"""
name: not
condition:
name: has_tag
tag:
key: region
value: us
""",
)
assert condition(recording_a)
assert not condition(recording_b)

View File

@ -0,0 +1,400 @@
import json
import textwrap
import uuid
from pathlib import Path
import pytest
from pydantic import TypeAdapter
from soundevent import data
from batdetect2.core import load_config
from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
)
def load_sound_event_condition_config(
tmp_path: Path,
yaml_string: str,
) -> SoundEventConditionConfig:
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
config_path.write_text(textwrap.dedent(yaml_string).strip())
return load_config(
config_path,
schema=TypeAdapter(SoundEventConditionConfig),
)
def build_condition_from_str(
tmp_path: Path,
yaml_string: str,
base_dir: Path | None = None,
):
config = load_sound_event_condition_config(tmp_path, yaml_string)
return build_sound_event_condition(config, base_dir=base_dir)
def create_sound_event_annotation(
recording: data.Recording,
geometry: data.Geometry,
tags: list[data.Tag] | None = None,
) -> data.SoundEventAnnotation:
return data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=geometry,
),
tags=tags or [],
)
def test_has_tag_condition(
sound_event: data.SoundEvent, tmp_path: Path
) -> None:
condition = build_condition_from_str(
tmp_path,
"""
name: has_tag
tag:
key: species
value: Myotis myotis
""",
)
passing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")],
)
failing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert condition(passing)
assert not condition(failing)
def test_has_all_tags_condition(
sound_event: data.SoundEvent,
tmp_path: Path,
) -> None:
condition = build_condition_from_str(
tmp_path,
"""
name: has_all_tags
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""",
)
passing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
],
)
failing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(passing)
assert not condition(failing)
def test_has_any_tag_condition(
sound_event: data.SoundEvent,
tmp_path: Path,
) -> None:
condition = build_condition_from_str(
tmp_path,
"""
name: has_any_tag
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""",
)
passing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="event", value="Echolocation")],
)
failing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Social"),
],
)
assert condition(passing)
assert not condition(failing)
def test_not_condition(sound_event: data.SoundEvent, tmp_path: Path) -> None:
condition = build_condition_from_str(
tmp_path,
"""
name: not
condition:
name: has_tag
tag:
key: species
value: Myotis myotis
""",
)
passing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
failing = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(passing)
assert not condition(failing)
def test_id_in_list_condition(
sound_event: data.SoundEvent, tmp_path: Path
) -> None:
passing = data.SoundEventAnnotation(sound_event=sound_event)
failing = data.SoundEventAnnotation(sound_event=sound_event)
ids_path = tmp_path / "sound_event_ids.json"
ids_path.write_text(json.dumps([str(passing.uuid)]))
condition = build_condition_from_str(
tmp_path,
f"""
name: id_in_list
path: {ids_path}
""",
)
assert condition(passing)
assert not condition(failing)
def test_id_in_list_condition_uses_base_dir(
sound_event: data.SoundEvent,
tmp_path: Path,
) -> None:
passing = data.SoundEventAnnotation(sound_event=sound_event)
failing = data.SoundEventAnnotation(sound_event=sound_event)
split_dir = tmp_path / "splits"
split_dir.mkdir()
ids_path = split_dir / "sound_event_ids.json"
ids_path.write_text(json.dumps([str(passing.uuid)]))
condition = build_condition_from_str(
tmp_path,
"""
name: id_in_list
path: splits/sound_event_ids.json
""",
base_dir=tmp_path,
)
assert condition(passing)
assert not condition(failing)
@pytest.mark.parametrize(
"operator,seconds,passing_duration,failing_duration",
[
("lt", 2, 1, 2),
("lte", 2, 2, 3),
("gt", 2, 3, 2),
("gte", 2, 2, 1),
("eq", 2, 2, 3),
],
)
def test_duration_condition(
tmp_path: Path,
recording: data.Recording,
operator: str,
seconds: int,
passing_duration: int,
failing_duration: int,
) -> None:
condition = build_condition_from_str(
tmp_path,
f"""
name: duration
operator: {operator}
seconds: {seconds}
""",
)
passing = create_sound_event_annotation(
recording=recording,
geometry=data.TimeInterval(coordinates=[0, passing_duration]),
)
failing = create_sound_event_annotation(
recording=recording,
geometry=data.TimeInterval(coordinates=[0, failing_duration]),
)
assert condition(passing)
assert not condition(failing)
@pytest.mark.parametrize(
"boundary,operator,hertz,passing_bbox,failing_bbox",
[
("high", "lt", 300, [0, 100, 1, 200], [0, 100, 1, 300]),
("high", "lte", 300, [0, 100, 1, 300], [0, 100, 1, 400]),
("high", "gt", 300, [0, 100, 1, 400], [0, 100, 1, 300]),
("high", "gte", 300, [0, 100, 1, 300], [0, 100, 1, 200]),
("high", "eq", 300, [0, 100, 1, 300], [0, 100, 1, 400]),
("low", "lt", 200, [0, 100, 1, 400], [0, 200, 1, 400]),
("low", "lte", 200, [0, 200, 1, 400], [0, 300, 1, 400]),
("low", "gt", 200, [0, 300, 1, 400], [0, 200, 1, 400]),
("low", "gte", 200, [0, 200, 1, 400], [0, 100, 1, 400]),
("low", "eq", 200, [0, 200, 1, 400], [0, 300, 1, 400]),
],
)
def test_frequency_condition(
tmp_path: Path,
recording: data.Recording,
boundary: str,
operator: str,
hertz: int,
passing_bbox: list[int],
failing_bbox: list[int],
) -> None:
condition = build_condition_from_str(
tmp_path,
f"""
name: frequency
boundary: {boundary}
operator: {operator}
hertz: {hertz}
""",
)
passing = create_sound_event_annotation(
recording=recording,
geometry=data.BoundingBox(
coordinates=[float(value) for value in passing_bbox]
),
)
failing = create_sound_event_annotation(
recording=recording,
geometry=data.BoundingBox(
coordinates=[float(value) for value in failing_bbox]
),
)
assert condition(passing)
assert not condition(failing)
def test_frequency_condition_is_false_for_temporal_geometries(
tmp_path: Path,
recording: data.Recording,
) -> None:
condition = build_condition_from_str(
tmp_path,
"""
name: frequency
boundary: low
operator: eq
hertz: 200
""",
)
passing = create_sound_event_annotation(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 200, 1, 400]),
)
failing = create_sound_event_annotation(
recording=recording,
geometry=data.TimeInterval(coordinates=[0, 3]),
)
assert condition(passing)
assert not condition(failing)
def test_has_all_tags_fails_if_empty(tmp_path: Path) -> None:
with pytest.raises(ValueError, match="at least one tag"):
build_condition_from_str(
tmp_path,
"""
name: has_all_tags
tags: []
""",
)
def test_all_of_condition(tmp_path: Path, recording: data.Recording) -> None:
condition = build_condition_from_str(
tmp_path,
"""
name: all_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""",
)
passing = create_sound_event_annotation(
recording=recording,
geometry=data.TimeInterval(coordinates=[0, 0.5]),
tags=[data.Tag(key="species", value="Myotis myotis")],
)
failing = create_sound_event_annotation(
recording=recording,
geometry=data.TimeInterval(coordinates=[0, 2]),
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(passing)
assert not condition(failing)
def test_any_of_condition(tmp_path: Path, recording: data.Recording) -> None:
condition = build_condition_from_str(
tmp_path,
"""
name: any_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""",
)
passing = create_sound_event_annotation(
recording=recording,
geometry=data.TimeInterval(coordinates=[0, 2]),
tags=[data.Tag(key="species", value="Myotis myotis")],
)
failing = create_sound_event_annotation(
recording=recording,
geometry=data.TimeInterval(coordinates=[0, 2]),
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert condition(passing)
assert not condition(failing)

View File

@ -0,0 +1,100 @@
import json
from pathlib import Path
from soundevent import data
from batdetect2.data import DatasetConfig, load_dataset
from batdetect2.data.conditions import (
HasTagConfig,
IdInListConfig,
RecordingSatisfiesConfig,
)
def test_load_dataset_applies_clip_filter(
example_dataset: DatasetConfig,
tmp_path: Path,
) -> None:
baseline = list(load_dataset(example_dataset))
keep_recording_id = str(baseline[0].clip.recording.uuid)
ids_path = tmp_path / "train_ids.json"
ids_path.write_text(json.dumps([keep_recording_id]))
config = example_dataset.model_copy(
update={
"clip_filter": RecordingSatisfiesConfig(
condition=IdInListConfig(path=ids_path)
)
}
)
filtered = list(load_dataset(config))
assert len(filtered) == 1
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
def test_load_dataset_clip_filter_is_skipped_when_filters_disabled(
example_dataset: DatasetConfig,
tmp_path: Path,
) -> None:
baseline = list(load_dataset(example_dataset))
keep_recording_id = str(baseline[0].clip.recording.uuid)
ids_path = tmp_path / "train_ids.json"
ids_path.write_text(json.dumps([keep_recording_id]))
config = example_dataset.model_copy(
update={
"clip_filter": RecordingSatisfiesConfig(
condition=IdInListConfig(path=ids_path)
)
}
)
filtered = list(load_dataset(config, apply_filters=False))
assert len(filtered) == len(baseline)
def test_load_dataset_resolves_clip_filter_paths_from_base_dir(
example_dataset: DatasetConfig,
tmp_path: Path,
) -> None:
baseline = list(load_dataset(example_dataset))
keep_recording_id = str(baseline[0].clip.recording.uuid)
split_dir = tmp_path / "splits"
split_dir.mkdir()
ids_path = split_dir / "train_ids.json"
ids_path.write_text(json.dumps([keep_recording_id]))
config = example_dataset.model_copy(
update={
"clip_filter": RecordingSatisfiesConfig(
condition=IdInListConfig(path=Path("splits/train_ids.json"))
)
}
)
filtered = list(load_dataset(config, base_dir=tmp_path))
assert len(filtered) == 1
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
def test_sound_event_filter_keeps_empty_clips(
example_dataset: DatasetConfig,
) -> None:
config = example_dataset.model_copy(
update={
"sound_event_filter": HasTagConfig(
tag=data.Tag(key="species", value="__missing_species__")
)
}
)
filtered = list(load_dataset(config))
assert len(filtered) == 3
assert all(
len(clip_annotation.sound_events) == 0 for clip_annotation in filtered
)

View File

@ -1,491 +0,0 @@
import textwrap
import pytest
import yaml
from pydantic import TypeAdapter
from soundevent import data
from batdetect2.data.conditions import (
SoundEventConditionConfig,
build_sound_event_condition,
)
def build_condition_from_str(content):
content = textwrap.dedent(content)
content = yaml.safe_load(content)
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
return build_sound_event_condition(config)
def test_has_tag(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_tag
tag:
key: species
value: Myotis myotis
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert not condition(sound_event_annotation)
def test_has_all_tags(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_all_tags
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Echolocation"),
],
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
data.Tag(key="sex", value="Female"),
],
)
assert condition(sound_event_annotation)
def test_has_any_tags(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: has_any_tag
tags:
- key: species
value: Myotis myotis
- key: event
value: Echolocation
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Echolocation"),
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Eptesicus fuscus"),
data.Tag(key="event", value="Social"),
],
)
assert not condition(sound_event_annotation)
def test_not(sound_event: data.SoundEvent):
condition = build_condition_from_str("""
name: not
condition:
name: has_tag
tag:
key: species
value: Myotis myotis
""")
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert not condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert condition(sound_event_annotation)
sound_event_annotation = data.SoundEventAnnotation(
sound_event=sound_event,
tags=[
data.Tag(key="species", value="Myotis myotis"),
data.Tag(key="event", value="Echolocation"),
],
)
assert not condition(sound_event_annotation)
def test_duration(recording: data.Recording):
se1 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
),
)
se2 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
),
)
se3 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
),
)
condition = build_condition_from_str("""
name: duration
operator: lt
seconds: 2
""")
assert condition(se1)
assert not condition(se2)
assert not condition(se3)
condition = build_condition_from_str("""
name: duration
operator: lte
seconds: 2
""")
assert condition(se1)
assert condition(se2)
assert not condition(se3)
condition = build_condition_from_str("""
name: duration
operator: gt
seconds: 2
""")
assert not condition(se1)
assert not condition(se2)
assert condition(se3)
condition = build_condition_from_str("""
name: duration
operator: gte
seconds: 2
""")
assert not condition(se1)
assert condition(se2)
assert condition(se3)
condition = build_condition_from_str("""
name: duration
operator: eq
seconds: 2
""")
assert not condition(se1)
assert condition(se2)
assert not condition(se3)
def test_frequency(recording: data.Recording):
se12 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
),
)
se13 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
),
)
se14 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
),
)
se24 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
),
)
se34 = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
recording=recording,
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
),
)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: lt
hertz: 300
""")
assert condition(se12)
assert not condition(se13)
assert not condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: lte
hertz: 300
""")
assert condition(se12)
assert condition(se13)
assert not condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: gt
hertz: 300
""")
assert not condition(se12)
assert not condition(se13)
assert condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: gte
hertz: 300
""")
assert not condition(se12)
assert condition(se13)
assert condition(se14)
condition = build_condition_from_str("""
name: frequency
boundary: high
operator: eq
hertz: 300
""")
assert not condition(se12)
assert condition(se13)
assert not condition(se14)
# LOW
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: lt
hertz: 200
""")
assert condition(se14)
assert not condition(se24)
assert not condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: lte
hertz: 200
""")
assert condition(se14)
assert condition(se24)
assert not condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: gt
hertz: 200
""")
assert not condition(se14)
assert not condition(se24)
assert condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: gte
hertz: 200
""")
assert not condition(se14)
assert condition(se24)
assert condition(se34)
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
assert not condition(se14)
assert condition(se24)
assert not condition(se34)
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
condition = build_condition_from_str("""
name: frequency
boundary: low
operator: eq
hertz: 200
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 3]),
recording=recording,
)
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeStamp(coordinates=3),
recording=recording,
)
)
assert not condition(se)
def test_has_tags_fails_if_empty():
with pytest.raises(ValueError):
build_condition_from_str("""
name: has_tags
tags: []
""")
def test_all_of(recording: data.Recording):
condition = build_condition_from_str("""
name: all_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert not condition(se)
def test_any_of(recording: data.Recording):
condition = build_condition_from_str("""
name: any_of
conditions:
- name: has_tag
tag:
key: species
value: Myotis myotis
- name: duration
operator: lt
seconds: 1
""")
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert not condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 2]),
recording=recording,
),
tags=[data.Tag(key="species", value="Myotis myotis")],
)
assert condition(se)
se = data.SoundEventAnnotation(
sound_event=data.SoundEvent(
geometry=data.TimeInterval(coordinates=[0, 0.5]),
recording=recording,
),
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
)
assert condition(se)

View File

@ -2,11 +2,14 @@
import os
import pytest
from batdetect2 import api
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
@pytest.mark.slow
def test_no_detections_above_nyquist():
"""Test that no detections are made above the nyquist frequency."""
# Recording donated by @@kdarras

View File

@ -4,6 +4,7 @@ from pathlib import Path
from typing import List
import numpy as np
import pytest
from hypothesis import given, settings
from hypothesis import strategies as st
@ -13,6 +14,7 @@ from batdetect2.detector import parameters
@settings(deadline=None, max_examples=5)
@given(duration=st.floats(min_value=0.1, max_value=2))
@pytest.mark.slow
def test_can_import_model_without_pickle(duration: float):
# NOTE: remove this test once no other issues are found This is a temporary
# test to check that change in model loading did not impact model behaviour
@ -42,6 +44,7 @@ def test_can_import_model_without_pickle(duration: float):
assert predictions_without_pickle == predictions_with_pickle
@pytest.mark.slow
def test_can_import_model_without_pickle_on_test_data(
example_audio_files: List[Path],
):

View File

@ -96,7 +96,7 @@ def test_registry_build_unknown_name_raises():
name = "NonExistentBackbone"
with pytest.raises(NotImplementedError):
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
backbone_registry.build(FakeConfig()) # ty: ignore[invalid-argument-type]
def test_backbone_config_validates_unet_from_dict():

View File

@ -1,7 +1,6 @@
from collections.abc import Callable
from pathlib import Path
import pytest
from soundevent import data, terms
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets

View File

@ -1,10 +1,13 @@
from pathlib import Path
import pytest
from soundevent import data
from batdetect2.config import BatDetect2Config
from batdetect2.train import run_train
pytestmark = pytest.mark.slow
def _build_fast_train_config() -> BatDetect2Config:
config = BatDetect2Config()

View File

@ -37,6 +37,7 @@ def test_can_initialize_default_module():
assert isinstance(module, L.LightningModule)
@pytest.mark.slow
def test_can_save_checkpoint(
tmp_path: Path,
clip: data.Clip,
@ -182,6 +183,7 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
assert api.audio_config.samplerate == module.model_config.samplerate
@pytest.mark.slow
def test_train_smoke_produces_loadable_checkpoint(
tmp_path: Path,
example_annotations: list[data.ClipAnnotation],