mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
10 Commits
591d4f4ae8
...
4303d4e42d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4303d4e42d | ||
|
|
da113eaea8 | ||
|
|
1579bbc6c5 | ||
|
|
c67d9cbba0 | ||
|
|
00961132a9 | ||
|
|
e04d86808d | ||
|
|
c8dd4155bf | ||
|
|
e80fe8675d | ||
|
|
c24056214c | ||
|
|
6d09133dca |
10
justfile
10
justfile
@ -20,7 +20,15 @@ install:
|
||||
# Testing & Coverage
|
||||
# Run tests using pytest.
|
||||
test:
|
||||
uv run pytest -n auto {{TESTS_DIR}}
|
||||
uv run pytest {{TESTS_DIR}}
|
||||
|
||||
# Run the fast subset of tests (excludes @pytest.mark.slow).
|
||||
test-quick:
|
||||
uv run pytest --durations=10 -m "not slow" {{TESTS_DIR}}
|
||||
|
||||
# Run only long-running tests marked with @pytest.mark.slow.
|
||||
test-slow:
|
||||
uv run pytest -m "slow" {{TESTS_DIR}}
|
||||
|
||||
# Run tests and generate coverage data.
|
||||
coverage:
|
||||
|
||||
@ -95,6 +95,9 @@ mlflow = ["mlflow>=3.1.1"]
|
||||
gradio = [
|
||||
"gradio>=6.9.0",
|
||||
]
|
||||
dvc = [
|
||||
"dvclive>=3.49.0",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 79
|
||||
@ -126,3 +129,8 @@ exclude = [
|
||||
"src/batdetect2/finetune",
|
||||
"src/batdetect2/utils",
|
||||
]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
markers = [
|
||||
"slow: marks long-running tests that are skipped in quick runs",
|
||||
]
|
||||
|
||||
@ -1,68 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal, Sequence, cast
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from soundevent import data
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import Dataset, load_dataset_from_config
|
||||
from batdetect2.evaluate import (
|
||||
DEFAULT_EVAL_DIR,
|
||||
EvaluationConfig,
|
||||
EvaluatorProtocol,
|
||||
build_evaluator,
|
||||
run_evaluate,
|
||||
save_evaluation_results,
|
||||
)
|
||||
from batdetect2.inference import (
|
||||
InferenceConfig,
|
||||
process_file_list,
|
||||
run_batch_inference,
|
||||
)
|
||||
from batdetect2.logging import (
|
||||
DEFAULT_LOGS_DIR,
|
||||
AppLoggingConfig,
|
||||
LoggerConfig,
|
||||
)
|
||||
from batdetect2.models import (
|
||||
Model,
|
||||
ModelConfig,
|
||||
build_model,
|
||||
build_model_with_new_targets,
|
||||
)
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
get_output_formatter,
|
||||
)
|
||||
from batdetect2.postprocess import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
PostprocessorProtocol,
|
||||
build_postprocessor,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetConfig,
|
||||
TargetProtocol,
|
||||
build_roi_mapping,
|
||||
build_targets,
|
||||
)
|
||||
from batdetect2.train import (
|
||||
DEFAULT_CHECKPOINT_DIR,
|
||||
TrainingConfig,
|
||||
load_model_from_checkpoint,
|
||||
run_train,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import torch
|
||||
|
||||
from batdetect2.audio import AudioConfig, AudioLoader
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.data import Dataset
|
||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
||||
from batdetect2.models import Model, ModelConfig
|
||||
from batdetect2.outputs import (
|
||||
OutputFormatConfig,
|
||||
OutputFormatterProtocol,
|
||||
OutputsConfig,
|
||||
OutputTransformProtocol,
|
||||
)
|
||||
from batdetect2.postprocess import (
|
||||
ClipDetections,
|
||||
Detection,
|
||||
PostprocessorProtocol,
|
||||
)
|
||||
from batdetect2.preprocess import PreprocessorProtocol
|
||||
from batdetect2.targets import (
|
||||
ROIMapperProtocol,
|
||||
TargetConfig,
|
||||
TargetProtocol,
|
||||
)
|
||||
from batdetect2.train import TrainingConfig
|
||||
|
||||
|
||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
||||
|
||||
|
||||
class BatDetect2API:
|
||||
@ -109,6 +87,8 @@ class BatDetect2API:
|
||||
path: data.PathLike,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Dataset:
|
||||
from batdetect2.data import load_dataset_from_config
|
||||
|
||||
return load_dataset_from_config(path, base_dir=base_dir)
|
||||
|
||||
def train(
|
||||
@ -128,6 +108,8 @@ class BatDetect2API:
|
||||
train_config: TrainingConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
):
|
||||
from batdetect2.train import run_train
|
||||
|
||||
run_train(
|
||||
train_annotations=train_annotations,
|
||||
val_annotations=val_annotations,
|
||||
@ -172,6 +154,7 @@ class BatDetect2API:
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
"""Fine-tune the model with trainable-parameter selection."""
|
||||
from batdetect2.train import run_train
|
||||
|
||||
self._set_trainable_parameters(trainable)
|
||||
|
||||
@ -211,6 +194,8 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logger_config: LoggerConfig | None = None,
|
||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||
from batdetect2.evaluate import run_evaluate
|
||||
|
||||
return run_evaluate(
|
||||
self.model,
|
||||
test_annotations,
|
||||
@ -235,6 +220,8 @@ class BatDetect2API:
|
||||
predictions: Sequence[ClipDetections],
|
||||
output_dir: data.PathLike | None = None,
|
||||
):
|
||||
from batdetect2.evaluate import save_evaluation_results
|
||||
|
||||
clip_evals = self.evaluator.evaluate(
|
||||
annotations,
|
||||
predictions,
|
||||
@ -307,6 +294,8 @@ class BatDetect2API:
|
||||
self,
|
||||
audio: np.ndarray,
|
||||
) -> torch.Tensor:
|
||||
import torch
|
||||
|
||||
tensor = torch.tensor(audio).unsqueeze(0)
|
||||
return self.preprocessor(tensor)
|
||||
|
||||
@ -316,6 +305,8 @@ class BatDetect2API:
|
||||
batch_size: int | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> ClipDetections:
|
||||
from batdetect2.postprocess import ClipDetections
|
||||
|
||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||
|
||||
predictions = self.process_files(
|
||||
@ -382,6 +373,8 @@ class BatDetect2API:
|
||||
audio_dir: data.PathLike,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
files = list(get_audio_files(audio_dir))
|
||||
return self.process_files(
|
||||
files,
|
||||
@ -398,6 +391,8 @@ class BatDetect2API:
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
from batdetect2.inference import process_file_list
|
||||
|
||||
return process_file_list(
|
||||
self.model,
|
||||
audio_files,
|
||||
@ -424,6 +419,8 @@ class BatDetect2API:
|
||||
output_config: OutputsConfig | None = None,
|
||||
detection_threshold: float | None = None,
|
||||
) -> list[ClipDetections]:
|
||||
from batdetect2.inference import run_batch_inference
|
||||
|
||||
return run_batch_inference(
|
||||
self.model,
|
||||
clips,
|
||||
@ -448,6 +445,8 @@ class BatDetect2API:
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
):
|
||||
from batdetect2.outputs import get_output_formatter
|
||||
|
||||
formatter = self.formatter
|
||||
|
||||
if format is not None or config is not None:
|
||||
@ -467,6 +466,8 @@ class BatDetect2API:
|
||||
format: str | None = None,
|
||||
config: OutputFormatConfig | None = None,
|
||||
) -> list[object]:
|
||||
from batdetect2.outputs import get_output_formatter
|
||||
|
||||
formatter = self.formatter
|
||||
|
||||
if format is not None or config is not None:
|
||||
@ -484,6 +485,17 @@ class BatDetect2API:
|
||||
cls,
|
||||
config: BatDetect2Config,
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import build_audio_loader
|
||||
from batdetect2.evaluate import build_evaluator
|
||||
from batdetect2.models import build_model
|
||||
from batdetect2.outputs import (
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
|
||||
targets = build_targets(config=config.model.targets)
|
||||
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||
|
||||
@ -563,6 +575,21 @@ class BatDetect2API:
|
||||
outputs_config: OutputsConfig | None = None,
|
||||
logging_config: AppLoggingConfig | None = None,
|
||||
) -> "BatDetect2API":
|
||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.logging import AppLoggingConfig
|
||||
from batdetect2.models import build_model_with_new_targets
|
||||
from batdetect2.outputs import (
|
||||
OutputsConfig,
|
||||
build_output_formatter,
|
||||
build_output_transform,
|
||||
)
|
||||
from batdetect2.postprocess import build_postprocessor
|
||||
from batdetect2.preprocess import build_preprocessor
|
||||
from batdetect2.targets import build_roi_mapping, build_targets
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
|
||||
model, model_config = load_model_from_checkpoint(path)
|
||||
|
||||
audio_config = audio_config or AudioConfig(
|
||||
@ -645,7 +672,7 @@ class BatDetect2API:
|
||||
self,
|
||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||
) -> None:
|
||||
detector = cast(Detector, self.model.detector)
|
||||
detector = self.model.detector
|
||||
|
||||
for parameter in detector.parameters():
|
||||
parameter.requires_grad = False
|
||||
|
||||
@ -2,10 +2,6 @@
|
||||
|
||||
import click
|
||||
|
||||
from batdetect2.logging import enable_logging
|
||||
|
||||
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||
|
||||
__all__ = [
|
||||
"cli",
|
||||
]
|
||||
@ -34,5 +30,7 @@ def cli(verbose: int = 0):
|
||||
"""
|
||||
click.echo(INFO_STR)
|
||||
|
||||
from batdetect2.logging import enable_logging
|
||||
|
||||
enable_logging(verbose)
|
||||
# click.echo(BATDETECT_ASCII_ART)
|
||||
|
||||
@ -73,7 +73,7 @@ def summary(
|
||||
|
||||
summary = compute_class_summary(dataset, targets)
|
||||
|
||||
print(summary.to_markdown())
|
||||
print(summary.sort_values("class_name").to_markdown())
|
||||
|
||||
|
||||
@data.command(short_help="Convert dataset config to annotation set.")
|
||||
|
||||
@ -4,8 +4,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import click
|
||||
from loguru import logger
|
||||
from soundevent import io
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
from batdetect2.cli.base import cli
|
||||
|
||||
@ -219,6 +217,8 @@ def predict_directory_command(
|
||||
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
|
||||
inference, and saves predictions to `output_path`.
|
||||
"""
|
||||
from soundevent.audio.files import get_audio_files
|
||||
|
||||
audio_files = list(get_audio_files(audio_dir))
|
||||
_run_prediction(
|
||||
model_path=model_path,
|
||||
@ -309,6 +309,8 @@ def predict_dataset_command(
|
||||
The dataset is read as a soundevent annotation set and unique recording
|
||||
paths are extracted before inference.
|
||||
"""
|
||||
from soundevent import io
|
||||
|
||||
dataset_path = Path(dataset_path)
|
||||
dataset = io.load(dataset_path, type="annotation_set")
|
||||
audio_files = sorted(
|
||||
|
||||
@ -1,312 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Annotated, List, Literal, Sequence
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||
|
||||
|
||||
@add_import_config(conditions)
|
||||
class SoundEventConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
|
||||
|
||||
class HasTag:
|
||||
def __init__(self, tag: data.Tag):
|
||||
self.tag = tag
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return any(
|
||||
self.tag.term.name == tag.term.name and self.tag.value == tag.value
|
||||
for tag in sound_event_annotation.tags
|
||||
)
|
||||
|
||||
@conditions.register(HasTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasTagConfig):
|
||||
return HasTag(tag=config.tag)
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
class HasAllTags:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return self.tags.issubset(
|
||||
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
|
||||
)
|
||||
|
||||
@conditions.register(HasAllTagsConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAllTagsConfig):
|
||||
return HasAllTags(tags=config.tags)
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: List[data.Tag]
|
||||
|
||||
|
||||
class HasAnyTag:
|
||||
def __init__(self, tags: List[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return bool(
|
||||
self.tags.intersection(
|
||||
{
|
||||
(tag.term.name, tag.value)
|
||||
for tag in sound_event_annotation.tags
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@conditions.register(HasAnyTagConfig)
|
||||
@staticmethod
|
||||
def from_config(config: HasAnyTagConfig):
|
||||
return HasAnyTag(tags=config.tags)
|
||||
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
class DurationConfig(BaseConfig):
|
||||
name: Literal["duration"] = "duration"
|
||||
operator: Operator
|
||||
seconds: float
|
||||
|
||||
|
||||
def _build_comparator(
|
||||
operator: Operator, value: float
|
||||
) -> Callable[[float], bool]:
|
||||
if operator == "gt":
|
||||
return lambda x: x > value
|
||||
|
||||
if operator == "gte":
|
||||
return lambda x: x >= value
|
||||
|
||||
if operator == "lt":
|
||||
return lambda x: x < value
|
||||
|
||||
if operator == "lte":
|
||||
return lambda x: x <= value
|
||||
|
||||
if operator == "eq":
|
||||
return lambda x: x == value
|
||||
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
|
||||
|
||||
class Duration:
|
||||
def __init__(self, operator: Operator, seconds: float):
|
||||
self.operator = operator
|
||||
self.seconds = seconds
|
||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
duration = end_time - start_time
|
||||
|
||||
return self._comparator(duration)
|
||||
|
||||
@conditions.register(DurationConfig)
|
||||
@staticmethod
|
||||
def from_config(config: DurationConfig):
|
||||
return Duration(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
name: Literal["frequency"] = "frequency"
|
||||
boundary: Literal["low", "high"]
|
||||
operator: Operator
|
||||
hertz: float
|
||||
|
||||
|
||||
class Frequency:
|
||||
def __init__(
|
||||
self,
|
||||
operator: Operator,
|
||||
boundary: Literal["low", "high"],
|
||||
hertz: float,
|
||||
):
|
||||
self.operator = operator
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
# Automatically false if geometry does not have a frequency range
|
||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||
return False
|
||||
|
||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
if self.boundary == "low":
|
||||
return self._comparator(low_freq)
|
||||
|
||||
return self._comparator(high_freq)
|
||||
|
||||
@conditions.register(FrequencyConfig)
|
||||
@staticmethod
|
||||
def from_config(config: FrequencyConfig):
|
||||
return Frequency(
|
||||
operator=config.operator,
|
||||
boundary=config.boundary,
|
||||
hertz=config.hertz,
|
||||
)
|
||||
|
||||
|
||||
class AllOfConfig(BaseConfig):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
class AllOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return all(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@conditions.register(AllOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AllOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return AllOf(conditions)
|
||||
|
||||
|
||||
class AnyOfConfig(BaseConfig):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: List["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
class AnyOf:
|
||||
def __init__(self, conditions: List[SoundEventCondition]):
|
||||
self.conditions = conditions
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return any(c(sound_event_annotation) for c in self.conditions)
|
||||
|
||||
@conditions.register(AnyOfConfig)
|
||||
@staticmethod
|
||||
def from_config(config: AnyOfConfig):
|
||||
conditions = [
|
||||
build_sound_event_condition(cond) for cond in config.conditions
|
||||
]
|
||||
return AnyOf(conditions)
|
||||
|
||||
|
||||
class NotConfig(BaseConfig):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "SoundEventConditionConfig"
|
||||
|
||||
|
||||
class Not:
|
||||
def __init__(self, condition: SoundEventCondition):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(
|
||||
self, sound_event_annotation: data.SoundEventAnnotation
|
||||
) -> bool:
|
||||
return not self.condition(sound_event_annotation)
|
||||
|
||||
@conditions.register(NotConfig)
|
||||
@staticmethod
|
||||
def from_config(config: NotConfig):
|
||||
condition = build_sound_event_condition(config.condition)
|
||||
return Not(condition)
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| DurationConfig
|
||||
| FrequencyConfig
|
||||
| AllOfConfig
|
||||
| AnyOfConfig
|
||||
| NotConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
) -> SoundEventCondition:
|
||||
return conditions.build(config)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
condition: SoundEventCondition,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if condition(sound_event)
|
||||
]
|
||||
)
|
||||
)
|
||||
81
src/batdetect2/data/conditions/__init__.py
Normal file
81
src/batdetect2/data/conditions/__init__.py
Normal file
@ -0,0 +1,81 @@
|
||||
from batdetect2.data.conditions.clips import (
|
||||
ClipAllOfConfig,
|
||||
ClipAnnotationCondition,
|
||||
ClipAnnotationConditionConfig,
|
||||
ClipAnnotationConditionImportConfig,
|
||||
ClipAnyOfConfig,
|
||||
ClipNotConfig,
|
||||
RecordingSatisfiesConfig,
|
||||
build_clip_annotation_condition,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
CsvList,
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
JsonList,
|
||||
ListFormatConfig,
|
||||
TxtList,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
PathInListConfig,
|
||||
RecordingAllOfConfig,
|
||||
RecordingAnyOfConfig,
|
||||
RecordingCondition,
|
||||
RecordingConditionConfig,
|
||||
RecordingConditionImportConfig,
|
||||
RecordingNotConfig,
|
||||
build_recording_condition,
|
||||
)
|
||||
from batdetect2.data.conditions.sound_events import (
|
||||
AllOfConfig,
|
||||
AnyOfConfig,
|
||||
DurationConfig,
|
||||
FrequencyConfig,
|
||||
NotConfig,
|
||||
Operator,
|
||||
SoundEventCondition,
|
||||
SoundEventConditionConfig,
|
||||
SoundEventConditionImportConfig,
|
||||
build_sound_event_condition,
|
||||
filter_clip_annotation,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AllOfConfig",
|
||||
"AnyOfConfig",
|
||||
"ClipAllOfConfig",
|
||||
"ClipAnnotationCondition",
|
||||
"ClipAnnotationConditionConfig",
|
||||
"ClipAnnotationConditionImportConfig",
|
||||
"ClipAnyOfConfig",
|
||||
"ClipNotConfig",
|
||||
"CsvList",
|
||||
"DurationConfig",
|
||||
"FrequencyConfig",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTagConfig",
|
||||
"HasTagConfig",
|
||||
"IdInListConfig",
|
||||
"JsonList",
|
||||
"ListFormatConfig",
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"PathInListConfig",
|
||||
"RecordingCondition",
|
||||
"RecordingConditionConfig",
|
||||
"RecordingConditionImportConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingNotConfig",
|
||||
"RecordingSatisfiesConfig",
|
||||
"SoundEventCondition",
|
||||
"SoundEventConditionConfig",
|
||||
"SoundEventConditionImportConfig",
|
||||
"TxtList",
|
||||
"build_clip_annotation_condition",
|
||||
"build_recording_condition",
|
||||
"build_sound_event_condition",
|
||||
"filter_clip_annotation",
|
||||
]
|
||||
138
src/batdetect2/data/conditions/clips.py
Normal file
138
src/batdetect2/data/conditions/clips.py
Normal file
@ -0,0 +1,138 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
MultiConditionConfigBase,
|
||||
NotConditionConfigBase,
|
||||
register_all_of_condition,
|
||||
register_any_of_condition,
|
||||
register_has_all_tags_condition,
|
||||
register_has_any_tag_condition,
|
||||
register_has_tag_condition,
|
||||
register_id_in_list_condition,
|
||||
register_not_condition,
|
||||
)
|
||||
from batdetect2.data.conditions.recordings import (
|
||||
RecordingCondition,
|
||||
RecordingConditionConfig,
|
||||
build_recording_condition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ClipAllOfConfig",
|
||||
"ClipAnnotationCondition",
|
||||
"ClipAnnotationConditionConfig",
|
||||
"ClipAnnotationConditionImportConfig",
|
||||
"ClipAnyOfConfig",
|
||||
"ClipNotConfig",
|
||||
"RecordingSatisfiesConfig",
|
||||
"build_clip_annotation_condition",
|
||||
]
|
||||
|
||||
ClipAnnotationCondition = Callable[[data.ClipAnnotation], bool]
|
||||
|
||||
clip_annotation_conditions: Registry[
|
||||
ClipAnnotationCondition,
|
||||
[data.PathLike | None],
|
||||
] = Registry("clip_condition")
|
||||
|
||||
|
||||
@add_import_config(clip_annotation_conditions, arg_names=["base_dir"])
|
||||
class ClipAnnotationConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a clip annotation condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||
to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
class RecordingSatisfiesConfig(BaseConfig):
|
||||
name: Literal["recording_satisfies"] = "recording_satisfies"
|
||||
condition: RecordingConditionConfig
|
||||
|
||||
|
||||
class RecordingSatisfies:
|
||||
def __init__(self, condition: RecordingCondition):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(self, clip_annotation: data.ClipAnnotation) -> bool:
|
||||
recording = clip_annotation.clip.recording
|
||||
return self.condition(recording)
|
||||
|
||||
@clip_annotation_conditions.register(RecordingSatisfiesConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: RecordingSatisfiesConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> "RecordingSatisfies":
|
||||
condition = build_recording_condition(
|
||||
config.condition,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
return RecordingSatisfies(condition)
|
||||
|
||||
|
||||
register_has_tag_condition(clip_annotation_conditions, HasTagConfig)
|
||||
register_has_all_tags_condition(
|
||||
clip_annotation_conditions,
|
||||
HasAllTagsConfig,
|
||||
)
|
||||
register_has_any_tag_condition(
|
||||
clip_annotation_conditions,
|
||||
HasAnyTagConfig,
|
||||
)
|
||||
register_id_in_list_condition(clip_annotation_conditions, IdInListConfig)
|
||||
|
||||
|
||||
@register_all_of_condition(clip_annotation_conditions)
|
||||
class ClipAllOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["ClipAnnotationConditionConfig"]
|
||||
|
||||
|
||||
@register_any_of_condition(clip_annotation_conditions)
|
||||
class ClipAnyOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: Sequence["ClipAnnotationConditionConfig"]
|
||||
|
||||
|
||||
@register_not_condition(clip_annotation_conditions)
|
||||
class ClipNotConfig(NotConditionConfigBase):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "ClipAnnotationConditionConfig"
|
||||
|
||||
|
||||
ClipAnnotationConditionConfig = Annotated[
|
||||
RecordingSatisfiesConfig
|
||||
| IdInListConfig
|
||||
| HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| ClipAllOfConfig
|
||||
| ClipAnyOfConfig
|
||||
| ClipNotConfig
|
||||
| ClipAnnotationConditionImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_clip_annotation_condition(
|
||||
config: ClipAnnotationConditionConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> ClipAnnotationCondition:
|
||||
return clip_annotation_conditions.build(config, base_dir)
|
||||
417
src/batdetect2/data/conditions/common.py
Normal file
417
src/batdetect2/data/conditions/common.py
Normal file
@ -0,0 +1,417 @@
|
||||
import csv
|
||||
import json
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import Registry
|
||||
|
||||
__all__ = [
|
||||
"AllOf",
|
||||
"AnyOf",
|
||||
"Condition",
|
||||
"CsvList",
|
||||
"HasAllTags",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTag",
|
||||
"HasAnyTagConfig",
|
||||
"HasTag",
|
||||
"HasTagConfig",
|
||||
"IdInList",
|
||||
"IdInListConfig",
|
||||
"JsonList",
|
||||
"ListLoader",
|
||||
"ListFormatConfig",
|
||||
"MultiConditionConfigBase",
|
||||
"Not",
|
||||
"NotConditionConfigBase",
|
||||
"ObjectWithTags",
|
||||
"ObjectWithUUID",
|
||||
"TxtList",
|
||||
"build_list_loader",
|
||||
"register_all_of_condition",
|
||||
"register_any_of_condition",
|
||||
"register_has_all_tags_condition",
|
||||
"register_has_any_tag_condition",
|
||||
"register_has_tag_condition",
|
||||
"register_id_in_list_condition",
|
||||
"register_not_condition",
|
||||
]
|
||||
|
||||
|
||||
class ObjectWithTags(Protocol):
|
||||
tags: list[data.Tag]
|
||||
|
||||
|
||||
class ObjectWithUUID(Protocol):
|
||||
uuid: UUID
|
||||
|
||||
|
||||
ConditionObject = TypeVar("ConditionObject")
|
||||
TaggedObject = TypeVar("TaggedObject", bound="ObjectWithTags")
|
||||
UUIDObject = TypeVar("UUIDObject", bound="ObjectWithUUID")
|
||||
P = ParamSpec("P")
|
||||
NotConfigType = TypeVar("NotConfigType", bound="NotConditionConfigBase")
|
||||
MultiConfigType = TypeVar(
|
||||
"MultiConfigType",
|
||||
bound="MultiConditionConfigBase",
|
||||
)
|
||||
Condition = Callable[[ConditionObject], bool]
|
||||
|
||||
|
||||
class NotConditionConfigBase(BaseConfig):
|
||||
condition: BaseModel
|
||||
|
||||
|
||||
class MultiConditionConfigBase(BaseConfig):
|
||||
conditions: Sequence[BaseModel]
|
||||
|
||||
|
||||
class Not(Generic[ConditionObject]):
|
||||
def __init__(self, condition: Condition[ConditionObject]):
|
||||
self.condition = condition
|
||||
|
||||
def __call__(self, obj: ConditionObject) -> bool:
|
||||
return not self.condition(obj)
|
||||
|
||||
|
||||
class AllOf(Generic[ConditionObject]):
|
||||
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
||||
self.conditions = list(conditions)
|
||||
|
||||
def __call__(self, obj: ConditionObject) -> bool:
|
||||
return all(condition(obj) for condition in self.conditions)
|
||||
|
||||
|
||||
class AnyOf(Generic[ConditionObject]):
|
||||
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
||||
self.conditions = list(conditions)
|
||||
|
||||
def __call__(self, obj: ConditionObject) -> bool:
|
||||
return any(condition(obj) for condition in self.conditions)
|
||||
|
||||
|
||||
class HasTag(Generic[TaggedObject]):
|
||||
def __init__(self, tag: data.Tag):
|
||||
self.tag_key = (tag.term.name, tag.value)
|
||||
|
||||
def __call__(self, obj: TaggedObject) -> bool:
|
||||
return any(
|
||||
(tag.term.name, tag.value) == self.tag_key for tag in obj.tags
|
||||
)
|
||||
|
||||
|
||||
class HasAllTags(Generic[TaggedObject]):
|
||||
def __init__(self, tags: list[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(self, obj: TaggedObject) -> bool:
|
||||
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
||||
return self.required_keys.issubset(tag_keys)
|
||||
|
||||
|
||||
class HasAnyTag(Generic[TaggedObject]):
|
||||
def __init__(self, tags: list[data.Tag]):
|
||||
if not tags:
|
||||
raise ValueError("Need to specify at least one tag")
|
||||
|
||||
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
||||
|
||||
def __call__(self, obj: TaggedObject) -> bool:
|
||||
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
||||
return bool(self.required_keys.intersection(tag_keys))
|
||||
|
||||
|
||||
class IdInList(Generic[UUIDObject]):
|
||||
def __init__(self, ids: set[UUID]):
|
||||
self.ids = ids
|
||||
|
||||
def __call__(self, obj: UUIDObject) -> bool:
|
||||
return obj.uuid in self.ids
|
||||
|
||||
|
||||
class HasTagConfig(BaseConfig):
|
||||
name: Literal["has_tag"] = "has_tag"
|
||||
tag: data.Tag
|
||||
|
||||
|
||||
class HasAllTagsConfig(BaseConfig):
|
||||
name: Literal["has_all_tags"] = "has_all_tags"
|
||||
tags: list[data.Tag]
|
||||
|
||||
|
||||
class HasAnyTagConfig(BaseConfig):
|
||||
name: Literal["has_any_tag"] = "has_any_tag"
|
||||
tags: list[data.Tag]
|
||||
|
||||
|
||||
class JsonList(BaseConfig):
|
||||
name: Literal["json"] = "json"
|
||||
field: str | None = None
|
||||
|
||||
|
||||
class TxtList(BaseConfig):
|
||||
name: Literal["txt"] = "txt"
|
||||
|
||||
|
||||
class CsvList(BaseConfig):
|
||||
name: Literal["csv"] = "csv"
|
||||
column: str
|
||||
|
||||
|
||||
ListFormatConfig = Annotated[
|
||||
JsonList | TxtList | CsvList,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
ListLoader = Callable[[Path], list[str]]
|
||||
|
||||
list_loaders: Registry[ListLoader, []] = Registry("list_loader")
|
||||
|
||||
|
||||
class IdInListConfig(BaseConfig):
|
||||
name: Literal["id_in_list"] = "id_in_list"
|
||||
path: Path
|
||||
format: ListFormatConfig = JsonList()
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_format(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
|
||||
format_config = values.get("format")
|
||||
|
||||
if isinstance(format_config, str):
|
||||
values = values.copy()
|
||||
config_class = list_loaders.get_config_type(format_config)
|
||||
values["format"] = config_class().model_dump()
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class JsonListLoader:
|
||||
def __init__(self, field: str | None):
|
||||
self.field = field
|
||||
|
||||
def __call__(self, path: Path) -> list[str]:
|
||||
content = json.loads(path.read_text())
|
||||
|
||||
if self.field is not None:
|
||||
if not isinstance(content, dict):
|
||||
raise TypeError(
|
||||
"Expected JSON object with field for 'id_in_list'."
|
||||
)
|
||||
|
||||
if self.field not in content:
|
||||
raise KeyError(f"Field '{self.field}' not found in '{path}'.")
|
||||
|
||||
content = content[self.field]
|
||||
|
||||
if not isinstance(content, list):
|
||||
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
|
||||
|
||||
return [str(value) for value in content]
|
||||
|
||||
@list_loaders.register(JsonList)
|
||||
@staticmethod
|
||||
def from_config(config: JsonList) -> ListLoader:
|
||||
return JsonListLoader(field=config.field)
|
||||
|
||||
|
||||
class TxtListLoader:
|
||||
def __call__(self, path: Path) -> list[str]:
|
||||
return [
|
||||
line.strip()
|
||||
for line in path.read_text().splitlines()
|
||||
if line.strip()
|
||||
]
|
||||
|
||||
@list_loaders.register(TxtList)
|
||||
@staticmethod
|
||||
def from_config(config: TxtList) -> ListLoader:
|
||||
return TxtListLoader()
|
||||
|
||||
|
||||
class CsvListLoader:
|
||||
def __init__(self, column: str):
|
||||
self.column = column
|
||||
|
||||
def __call__(self, path: Path) -> list[str]:
|
||||
with path.open("r", newline="") as csv_file:
|
||||
reader = csv.DictReader(csv_file)
|
||||
|
||||
if reader.fieldnames is None:
|
||||
raise ValueError(
|
||||
f"Expected CSV header row for 'id_in_list' in '{path}'."
|
||||
)
|
||||
|
||||
if self.column not in reader.fieldnames:
|
||||
raise ValueError(
|
||||
f"Column '{self.column}' not found in '{path}'."
|
||||
)
|
||||
|
||||
values = []
|
||||
for row in reader:
|
||||
value = row.get(self.column)
|
||||
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
value = value.strip()
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
||||
values.append(value)
|
||||
|
||||
return values
|
||||
|
||||
@list_loaders.register(CsvList)
|
||||
@staticmethod
|
||||
def from_config(config: CsvList) -> ListLoader:
|
||||
return CsvListLoader(column=config.column)
|
||||
|
||||
|
||||
def build_list_loader(config: ListFormatConfig) -> ListLoader:
|
||||
return list_loaders.build(config)
|
||||
|
||||
|
||||
def register_id_in_list_condition(
|
||||
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
|
||||
config_cls: type[IdInListConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: IdInListConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> Condition[UUIDObject]:
|
||||
path = config.path
|
||||
|
||||
if base_dir is not None and not path.is_absolute():
|
||||
path = Path(base_dir) / path
|
||||
|
||||
ids = set()
|
||||
loader = build_list_loader(config.format)
|
||||
values = loader(path)
|
||||
for index, value in enumerate(values):
|
||||
try:
|
||||
ids.add(UUID(value))
|
||||
except ValueError as err:
|
||||
raise ValueError(
|
||||
f"Invalid ID at index {index} in '{path}': {value!r}."
|
||||
) from err
|
||||
|
||||
return IdInList(ids)
|
||||
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_has_tag_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
config_cls: type[HasTagConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: HasTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasTag(config.tag)
|
||||
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_has_all_tags_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
config_cls: type[HasAllTagsConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: HasAllTagsConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAllTags(config.tags)
|
||||
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_has_any_tag_condition(
|
||||
registry: Registry[Condition[TaggedObject], P],
|
||||
config_cls: type[HasAnyTagConfig],
|
||||
) -> None:
|
||||
def builder(
|
||||
config: HasAnyTagConfig,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[TaggedObject]:
|
||||
return HasAnyTag(config.tags)
|
||||
|
||||
registry.register(config_cls)(builder)
|
||||
|
||||
|
||||
def register_not_condition(
|
||||
registry: Registry[Condition[ConditionObject], P],
|
||||
) -> Callable[[type[NotConfigType]], type[NotConfigType]]:
|
||||
def decorator(config_cls: type[NotConfigType]) -> type[NotConfigType]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: NotConfigType,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[ConditionObject]:
|
||||
condition = registry.build(config.condition, *args, **kwargs)
|
||||
return Not(condition)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_all_of_condition(
|
||||
registry: Registry[Condition[ConditionObject], P],
|
||||
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
||||
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: MultiConfigType,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[ConditionObject]:
|
||||
conditions = [
|
||||
registry.build(condition, *args, **kwargs)
|
||||
for condition in config.conditions
|
||||
]
|
||||
return AllOf(conditions)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def register_any_of_condition(
|
||||
registry: Registry[Condition[ConditionObject], P],
|
||||
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
||||
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
||||
@registry.register(config_cls)
|
||||
def builder(
|
||||
config: MultiConfigType,
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> Condition[ConditionObject]:
|
||||
conditions = [
|
||||
registry.build(condition, *args, **kwargs)
|
||||
for condition in config.conditions
|
||||
]
|
||||
return AnyOf(conditions)
|
||||
|
||||
return config_cls
|
||||
|
||||
return decorator
|
||||
217
src/batdetect2/data/conditions/recordings.py
Normal file
217
src/batdetect2/data/conditions/recordings.py
Normal file
@ -0,0 +1,217 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from loguru import logger
|
||||
from pydantic import Field, model_validator
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
JsonList,
|
||||
ListFormatConfig,
|
||||
MultiConditionConfigBase,
|
||||
NotConditionConfigBase,
|
||||
build_list_loader,
|
||||
list_loaders,
|
||||
register_all_of_condition,
|
||||
register_any_of_condition,
|
||||
register_has_all_tags_condition,
|
||||
register_has_any_tag_condition,
|
||||
register_has_tag_condition,
|
||||
register_id_in_list_condition,
|
||||
register_not_condition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"IdInListConfig",
|
||||
"PathInListConfig",
|
||||
"RecordingAllOfConfig",
|
||||
"RecordingAnyOfConfig",
|
||||
"RecordingCondition",
|
||||
"RecordingConditionConfig",
|
||||
"RecordingConditionImportConfig",
|
||||
"RecordingNotConfig",
|
||||
"build_recording_condition",
|
||||
]
|
||||
|
||||
RecordingCondition = Callable[[data.Recording], bool]
|
||||
|
||||
recording_conditions: Registry[RecordingCondition, [data.PathLike | None]] = (
|
||||
Registry("recording_condition")
|
||||
)
|
||||
|
||||
|
||||
@add_import_config(recording_conditions, arg_names=["base_dir"])
|
||||
class RecordingConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a recording condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
||||
to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
register_id_in_list_condition(recording_conditions, IdInListConfig)
|
||||
register_has_tag_condition(recording_conditions, HasTagConfig)
|
||||
register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
|
||||
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
|
||||
|
||||
|
||||
class PathInListConfig(BaseConfig):
|
||||
name: Literal["path_in_list"] = "path_in_list"
|
||||
path: Path
|
||||
format: ListFormatConfig = JsonList()
|
||||
base_dir: Path | None = None
|
||||
on_outside: Literal["allow", "warn", "error"] = "allow"
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_format(cls, values):
|
||||
if not isinstance(values, dict):
|
||||
return values
|
||||
|
||||
format_config = values.get("format")
|
||||
|
||||
if isinstance(format_config, str):
|
||||
values = values.copy()
|
||||
config_class = list_loaders.get_config_type(format_config)
|
||||
values["format"] = config_class().model_dump()
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class PathInList:
|
||||
def __init__(
|
||||
self,
|
||||
paths: set[Path],
|
||||
base_dir: Path | None,
|
||||
on_outside: Literal["allow", "warn", "error"],
|
||||
):
|
||||
self.paths = paths
|
||||
self.base_dir = base_dir
|
||||
self.on_outside = on_outside
|
||||
|
||||
def __call__(self, recording: data.Recording) -> bool:
|
||||
normalized_path = self._normalize_recording_path(recording.path)
|
||||
|
||||
if normalized_path is None:
|
||||
return True
|
||||
|
||||
return normalized_path in self.paths
|
||||
|
||||
def _normalize_recording_path(self, path: data.PathLike) -> Path | None:
|
||||
recording_path = Path(path)
|
||||
|
||||
if self.base_dir is None:
|
||||
return recording_path
|
||||
|
||||
if not recording_path.is_absolute():
|
||||
return recording_path
|
||||
|
||||
try:
|
||||
return recording_path.relative_to(self.base_dir)
|
||||
except ValueError as err:
|
||||
if self.on_outside == "allow":
|
||||
return None
|
||||
|
||||
if self.on_outside == "warn":
|
||||
logger.warning(
|
||||
"Recording path '{}' is outside '{}' in path_in_list; "
|
||||
"allowing.",
|
||||
recording_path,
|
||||
self.base_dir,
|
||||
)
|
||||
return None
|
||||
|
||||
raise ValueError(
|
||||
f"Recording path '{recording_path}' is outside "
|
||||
f"'{self.base_dir}' for 'path_in_list'."
|
||||
) from err
|
||||
|
||||
@recording_conditions.register(PathInListConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: PathInListConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> "PathInList":
|
||||
list_path = config.path
|
||||
|
||||
if base_dir is not None and not list_path.is_absolute():
|
||||
list_path = Path(base_dir) / list_path
|
||||
|
||||
match_base_dir = config.base_dir
|
||||
if (
|
||||
match_base_dir is not None
|
||||
and base_dir is not None
|
||||
and not match_base_dir.is_absolute()
|
||||
):
|
||||
match_base_dir = Path(base_dir) / match_base_dir
|
||||
|
||||
loader = build_list_loader(config.format)
|
||||
|
||||
paths = {
|
||||
Path(value).relative_to(match_base_dir)
|
||||
if (
|
||||
match_base_dir is not None
|
||||
and Path(value).is_absolute()
|
||||
and Path(value).is_relative_to(match_base_dir)
|
||||
)
|
||||
else Path(value)
|
||||
for value in loader(list_path)
|
||||
}
|
||||
|
||||
return PathInList(
|
||||
paths=paths,
|
||||
base_dir=match_base_dir,
|
||||
on_outside=config.on_outside,
|
||||
)
|
||||
|
||||
|
||||
@register_all_of_condition(recording_conditions)
|
||||
class RecordingAllOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["RecordingConditionConfig"]
|
||||
|
||||
|
||||
@register_any_of_condition(recording_conditions)
|
||||
class RecordingAnyOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: Sequence["RecordingConditionConfig"]
|
||||
|
||||
|
||||
@register_not_condition(recording_conditions)
|
||||
class RecordingNotConfig(NotConditionConfigBase):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "RecordingConditionConfig"
|
||||
|
||||
|
||||
RecordingConditionConfig = Annotated[
|
||||
IdInListConfig
|
||||
| PathInListConfig
|
||||
| HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| RecordingAllOfConfig
|
||||
| RecordingAnyOfConfig
|
||||
| RecordingNotConfig
|
||||
| RecordingConditionImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_recording_condition(
|
||||
config: RecordingConditionConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> RecordingCondition:
|
||||
return recording_conditions.build(config, base_dir)
|
||||
236
src/batdetect2/data/conditions/sound_events.py
Normal file
236
src/batdetect2/data/conditions/sound_events.py
Normal file
@ -0,0 +1,236 @@
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
from batdetect2.core.registries import (
|
||||
ImportConfig,
|
||||
Registry,
|
||||
add_import_config,
|
||||
)
|
||||
from batdetect2.data.conditions.common import (
|
||||
HasAllTagsConfig,
|
||||
HasAnyTagConfig,
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
MultiConditionConfigBase,
|
||||
NotConditionConfigBase,
|
||||
register_all_of_condition,
|
||||
register_any_of_condition,
|
||||
register_has_all_tags_condition,
|
||||
register_has_any_tag_condition,
|
||||
register_has_tag_condition,
|
||||
register_id_in_list_condition,
|
||||
register_not_condition,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AllOfConfig",
|
||||
"AnyOfConfig",
|
||||
"DurationConfig",
|
||||
"FrequencyConfig",
|
||||
"HasAllTagsConfig",
|
||||
"HasAnyTagConfig",
|
||||
"HasTagConfig",
|
||||
"NotConfig",
|
||||
"Operator",
|
||||
"SoundEventCondition",
|
||||
"SoundEventConditionConfig",
|
||||
"SoundEventConditionImportConfig",
|
||||
"build_sound_event_condition",
|
||||
"filter_clip_annotation",
|
||||
]
|
||||
|
||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||
|
||||
sound_event_conditions: Registry[
|
||||
SoundEventCondition,
|
||||
[data.PathLike | None],
|
||||
] = Registry("sound_event_condition")
|
||||
|
||||
|
||||
@add_import_config(sound_event_conditions, arg_names=["base_dir"])
|
||||
class SoundEventConditionImportConfig(ImportConfig):
|
||||
"""Use any callable as a sound event condition.
|
||||
|
||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||
callable to use it instead of a built-in option.
|
||||
"""
|
||||
|
||||
name: Literal["import"] = "import"
|
||||
|
||||
|
||||
register_has_tag_condition(sound_event_conditions, HasTagConfig)
|
||||
register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig)
|
||||
register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig)
|
||||
register_id_in_list_condition(sound_event_conditions, IdInListConfig)
|
||||
|
||||
|
||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||
|
||||
|
||||
class DurationConfig(BaseConfig):
|
||||
name: Literal["duration"] = "duration"
|
||||
operator: Operator
|
||||
seconds: float
|
||||
|
||||
|
||||
def _build_comparator(
|
||||
operator: Operator, value: float
|
||||
) -> Callable[[float], bool]:
|
||||
if operator == "gt":
|
||||
return lambda x: x > value
|
||||
|
||||
if operator == "gte":
|
||||
return lambda x: x >= value
|
||||
|
||||
if operator == "lt":
|
||||
return lambda x: x < value
|
||||
|
||||
if operator == "lte":
|
||||
return lambda x: x <= value
|
||||
|
||||
if operator == "eq":
|
||||
return lambda x: x == value
|
||||
|
||||
raise ValueError(f"Invalid operator {operator}")
|
||||
|
||||
|
||||
class Duration:
|
||||
def __init__(self, operator: Operator, seconds: float):
|
||||
self.operator = operator
|
||||
self.seconds = seconds
|
||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||
duration = end_time - start_time
|
||||
|
||||
return self._comparator(duration)
|
||||
|
||||
@sound_event_conditions.register(DurationConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: DurationConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
):
|
||||
_ = base_dir
|
||||
return Duration(operator=config.operator, seconds=config.seconds)
|
||||
|
||||
|
||||
class FrequencyConfig(BaseConfig):
|
||||
name: Literal["frequency"] = "frequency"
|
||||
boundary: Literal["low", "high"]
|
||||
operator: Operator
|
||||
hertz: float
|
||||
|
||||
|
||||
class Frequency:
|
||||
def __init__(
|
||||
self,
|
||||
operator: Operator,
|
||||
boundary: Literal["low", "high"],
|
||||
hertz: float,
|
||||
):
|
||||
self.operator = operator
|
||||
self.hertz = hertz
|
||||
self.boundary = boundary
|
||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sound_event_annotation: data.SoundEventAnnotation,
|
||||
) -> bool:
|
||||
geometry = sound_event_annotation.sound_event.geometry
|
||||
|
||||
if geometry is None:
|
||||
return False
|
||||
|
||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||
return False
|
||||
|
||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||
|
||||
if self.boundary == "low":
|
||||
return self._comparator(low_freq)
|
||||
|
||||
return self._comparator(high_freq)
|
||||
|
||||
@sound_event_conditions.register(FrequencyConfig)
|
||||
@staticmethod
|
||||
def from_config(
|
||||
config: FrequencyConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
):
|
||||
_ = base_dir
|
||||
return Frequency(
|
||||
operator=config.operator,
|
||||
boundary=config.boundary,
|
||||
hertz=config.hertz,
|
||||
)
|
||||
|
||||
|
||||
@register_all_of_condition(sound_event_conditions)
|
||||
class AllOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["all_of"] = "all_of"
|
||||
conditions: Sequence["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@register_any_of_condition(sound_event_conditions)
|
||||
class AnyOfConfig(MultiConditionConfigBase):
|
||||
name: Literal["any_of"] = "any_of"
|
||||
conditions: list["SoundEventConditionConfig"]
|
||||
|
||||
|
||||
@register_not_condition(sound_event_conditions)
|
||||
class NotConfig(NotConditionConfigBase):
|
||||
name: Literal["not"] = "not"
|
||||
condition: "SoundEventConditionConfig"
|
||||
|
||||
|
||||
SoundEventConditionConfig = Annotated[
|
||||
IdInListConfig
|
||||
| HasTagConfig
|
||||
| HasAllTagsConfig
|
||||
| HasAnyTagConfig
|
||||
| DurationConfig
|
||||
| FrequencyConfig
|
||||
| AllOfConfig
|
||||
| AnyOfConfig
|
||||
| NotConfig
|
||||
| SoundEventConditionImportConfig,
|
||||
Field(discriminator="name"),
|
||||
]
|
||||
|
||||
|
||||
def build_sound_event_condition(
|
||||
config: SoundEventConditionConfig,
|
||||
base_dir: data.PathLike | None = None,
|
||||
) -> SoundEventCondition:
|
||||
return sound_event_conditions.build(config, base_dir)
|
||||
|
||||
|
||||
def filter_clip_annotation(
|
||||
clip_annotation: data.ClipAnnotation,
|
||||
condition: SoundEventCondition,
|
||||
) -> data.ClipAnnotation:
|
||||
return clip_annotation.model_copy(
|
||||
update=dict(
|
||||
sound_events=[
|
||||
sound_event
|
||||
for sound_event in clip_annotation.sound_events
|
||||
if condition(sound_event)
|
||||
]
|
||||
)
|
||||
)
|
||||
@ -32,7 +32,9 @@ from batdetect2.data.annotations import (
|
||||
load_annotated_dataset,
|
||||
)
|
||||
from batdetect2.data.conditions import (
|
||||
ClipAnnotationConditionConfig,
|
||||
SoundEventConditionConfig,
|
||||
build_clip_annotation_condition,
|
||||
build_sound_event_condition,
|
||||
filter_clip_annotation,
|
||||
)
|
||||
@ -69,6 +71,7 @@ class DatasetConfig(BaseConfig):
|
||||
description: str
|
||||
sources: list[AnnotationFormats]
|
||||
|
||||
clip_filter: ClipAnnotationConditionConfig | None = None
|
||||
sound_event_filter: SoundEventConditionConfig | None = None
|
||||
sound_event_transforms: list[SoundEventTransformConfig] = Field(
|
||||
default_factory=list
|
||||
@ -84,11 +87,58 @@ def load_dataset(
|
||||
apply_transforms: bool = True,
|
||||
apply_filters: bool = True,
|
||||
) -> Dataset:
|
||||
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||
"""Load and merge clip annotations from configured dataset sources.
|
||||
|
||||
Loads each source listed in ``config.sources`` and returns a flat
|
||||
collection of ``soundevent.data.ClipAnnotation`` objects. Source tags,
|
||||
dataset-level filters, and dataset-level transforms can be enabled or
|
||||
disabled with flags.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
config : DatasetConfig
|
||||
Dataset definition containing source configurations, optional
|
||||
clip-level filter, sound-event filter, and optional sound-event
|
||||
transform pipeline.
|
||||
base_dir : data.PathLike, optional
|
||||
Base directory used to resolve relative paths in source
|
||||
configurations.
|
||||
add_source_tag : bool, default=True
|
||||
If True, append a ``data_source`` tag to each clip annotation with
|
||||
the source name.
|
||||
include_sources : list[str], optional
|
||||
Source names to include. If None, all sources are eligible.
|
||||
exclude_sources : list[str], optional
|
||||
Source names to skip after include filtering. If a source appears in
|
||||
both include and exclude lists, it is skipped.
|
||||
apply_transforms : bool, default=True
|
||||
If True, apply transforms defined in
|
||||
``config.sound_event_transforms``.
|
||||
apply_filters : bool, default=True
|
||||
If True, apply filters defined in ``config.clip_filter`` and
|
||||
``config.sound_event_filter``.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset
|
||||
Flat collection of clip annotations loaded from the selected sources.
|
||||
"""
|
||||
clip_annotations = []
|
||||
|
||||
condition = (
|
||||
build_sound_event_condition(config.sound_event_filter)
|
||||
clip_condition = (
|
||||
build_clip_annotation_condition(
|
||||
config.clip_filter,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
if config.clip_filter is not None
|
||||
else None
|
||||
)
|
||||
|
||||
sound_event_condition = (
|
||||
build_sound_event_condition(
|
||||
config.sound_event_filter,
|
||||
base_dir=base_dir,
|
||||
)
|
||||
if config.sound_event_filter is not None
|
||||
else None
|
||||
)
|
||||
@ -123,10 +173,17 @@ def load_dataset(
|
||||
if add_source_tag:
|
||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||
|
||||
if condition is not None and apply_filters:
|
||||
if (
|
||||
clip_condition is not None
|
||||
and apply_filters
|
||||
and not clip_condition(clip_annotation)
|
||||
):
|
||||
continue
|
||||
|
||||
if sound_event_condition is not None and apply_filters:
|
||||
clip_annotation = filter_clip_annotation(
|
||||
clip_annotation,
|
||||
condition,
|
||||
sound_event_condition,
|
||||
)
|
||||
|
||||
if transform is not None and apply_transforms:
|
||||
@ -181,47 +238,58 @@ def load_dataset_from_config(
|
||||
path: data.PathLike,
|
||||
field: str | None = None,
|
||||
base_dir: data.PathLike | None = None,
|
||||
add_source_tag: bool = True,
|
||||
include_sources: list[str] | None = None,
|
||||
exclude_sources: list[str] | None = None,
|
||||
apply_transforms: bool = True,
|
||||
apply_filters: bool = True,
|
||||
) -> Dataset:
|
||||
"""Load dataset annotation metadata from a configuration file.
|
||||
"""Load a dataset by reading a ``DatasetConfig`` from disk.
|
||||
|
||||
This is a convenience function that first loads the `DatasetConfig` from
|
||||
the specified file path and optional nested field, and then calls
|
||||
`load_dataset` to load all corresponding `ClipAnnotation` objects.
|
||||
This convenience wrapper first loads a ``DatasetConfig`` from ``path``
|
||||
and optional ``field``, then delegates to :func:`load_dataset`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
path : data.PathLike
|
||||
Path to the configuration file (e.g., YAML).
|
||||
Path to a configuration file containing a ``DatasetConfig``.
|
||||
field : str, optional
|
||||
Dot-separated path to a nested section within the file containing the
|
||||
dataset configuration (e.g., "data.training_set"). If None, the
|
||||
entire file content is assumed to be the `DatasetConfig`.
|
||||
base_dir : Path, optional
|
||||
An optional base directory path to resolve relative paths within the
|
||||
configuration sources. Passed to `load_dataset`. Defaults to None.
|
||||
Dot-separated field path to a nested config section. If None, the
|
||||
full file is parsed as ``DatasetConfig``.
|
||||
base_dir : data.PathLike, optional
|
||||
Base directory used to resolve relative paths in source
|
||||
configurations.
|
||||
add_source_tag : bool, default=True
|
||||
If True, append a ``data_source`` tag to each clip annotation.
|
||||
include_sources : list[str], optional
|
||||
Source names to include. If None, all sources are eligible.
|
||||
exclude_sources : list[str], optional
|
||||
Source names to skip after include filtering.
|
||||
apply_transforms : bool, default=True
|
||||
If True, apply transforms defined in the loaded config.
|
||||
apply_filters : bool, default=True
|
||||
If True, apply clip and sound-event filters defined in the loaded
|
||||
config.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dataset (List[data.ClipAnnotation])
|
||||
A flat list containing all loaded `ClipAnnotation` metadata objects.
|
||||
|
||||
Raises
|
||||
------
|
||||
FileNotFoundError
|
||||
If the config file `path` does not exist.
|
||||
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
|
||||
If the configuration file is invalid, cannot be parsed, or does not
|
||||
match the `DatasetConfig` schema.
|
||||
Exception
|
||||
Can raise exceptions from `load_dataset` if loading data from sources
|
||||
fails.
|
||||
Dataset
|
||||
Flat collection of clip annotations loaded from the selected sources.
|
||||
"""
|
||||
config = load_config(
|
||||
path=path,
|
||||
schema=DatasetConfig,
|
||||
field=field,
|
||||
)
|
||||
return load_dataset(config, base_dir=base_dir)
|
||||
return load_dataset(
|
||||
config,
|
||||
base_dir=base_dir,
|
||||
add_source_tag=add_source_tag,
|
||||
include_sources=include_sources,
|
||||
exclude_sources=exclude_sources,
|
||||
apply_transforms=apply_transforms,
|
||||
apply_filters=apply_filters,
|
||||
)
|
||||
|
||||
|
||||
def save_dataset(
|
||||
|
||||
@ -1,9 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import sys
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Annotated,
|
||||
Any,
|
||||
Dict,
|
||||
@ -13,21 +16,23 @@ from typing import (
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
Logger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
from loguru import logger
|
||||
from matplotlib.figure import Figure
|
||||
from pydantic import Field
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core.configs import BaseConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
Logger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
from matplotlib.figure import Figure
|
||||
from soundevent import data
|
||||
|
||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||
|
||||
__all__ = [
|
||||
@ -271,10 +276,16 @@ def build_logger(
|
||||
)
|
||||
|
||||
|
||||
PlotLogger = Callable[[str, Figure, int], None]
|
||||
PlotLogger = Callable[[str, "Figure", int], None]
|
||||
|
||||
|
||||
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return logger.experiment.add_figure
|
||||
|
||||
@ -296,10 +307,16 @@ def get_image_logger(logger: Logger) -> PlotLogger | None:
|
||||
return partial(save_figure, dir=Path(logger.log_dir))
|
||||
|
||||
|
||||
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||
TableLogger = Callable[[str, "pd.DataFrame", int], None]
|
||||
|
||||
|
||||
def get_table_logger(logger: Logger) -> TableLogger | None:
|
||||
from lightning.pytorch.loggers import (
|
||||
CSVLogger,
|
||||
MLFlowLogger,
|
||||
TensorBoardLogger,
|
||||
)
|
||||
|
||||
if isinstance(logger, TensorBoardLogger):
|
||||
return partial(save_table, dir=Path(logger.log_dir))
|
||||
|
||||
@ -337,6 +354,8 @@ def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
||||
|
||||
|
||||
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||
import numpy as np
|
||||
|
||||
with io.BytesIO() as buff:
|
||||
figure.savefig(buff, format="raw")
|
||||
buff.seek(0)
|
||||
|
||||
@ -15,11 +15,11 @@ GENERIC_CLASS_KEY = "class"
|
||||
|
||||
|
||||
data_source = data.Term(
|
||||
name="soundevent:data_source",
|
||||
label="Data Source",
|
||||
name="dcterms:source",
|
||||
label="Source",
|
||||
uri="http://purl.org/dc/terms/source",
|
||||
definition=(
|
||||
"A unique identifier for the source of the data, typically "
|
||||
"representing the project, site, or deployment context."
|
||||
"A related resource from which the described resource is derived."
|
||||
),
|
||||
)
|
||||
|
||||
@ -45,6 +45,17 @@ individual = data.Term(
|
||||
)
|
||||
"""Term used for tags identifying a specific individual animal."""
|
||||
|
||||
dataset_split = data.Term(
|
||||
name="batdetect2:split",
|
||||
label="Dataset Split",
|
||||
definition=(
|
||||
"Identifies the specific data partition (e.g., 'train', 'test') "
|
||||
"that the item belongs to within an experimental setup. "
|
||||
"The expected value is a literal text string."
|
||||
),
|
||||
)
|
||||
"""Custom metadata term defining the machine learning partition of an item."""
|
||||
|
||||
generic_class = data.Term(
|
||||
name="soundevent:class",
|
||||
label="Class",
|
||||
|
||||
@ -8,12 +8,10 @@ import torch
|
||||
from soundevent.geometry import compute_bounds
|
||||
|
||||
from batdetect2.api_v2 import BatDetect2API
|
||||
from batdetect2.audio import AudioConfig
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.inference import InferenceConfig
|
||||
from batdetect2.models.detectors import Detector
|
||||
from batdetect2.models.heads import ClassifierHead
|
||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||
from batdetect2.train import load_model_from_checkpoint
|
||||
from batdetect2.train.lightning import build_training_module
|
||||
|
||||
|
||||
@ -48,6 +46,7 @@ def test_process_file_returns_recording_level_predictions(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_process_files_is_batch_size_invariant(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
@ -182,6 +181,7 @@ def test_user_can_read_extracted_features_per_detection(
|
||||
assert all(vec.size > 0 for vec in feature_vectors)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_load_checkpoint_and_finetune(
|
||||
tmp_path: Path,
|
||||
example_annotations,
|
||||
@ -295,6 +295,7 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_finetune_only_heads(
|
||||
tmp_path: Path,
|
||||
example_annotations,
|
||||
@ -330,6 +331,7 @@ def test_user_can_finetune_only_heads(
|
||||
assert list(finetune_dir.rglob("*.ckpt"))
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||
api_v2: BatDetect2API,
|
||||
example_annotations,
|
||||
@ -416,6 +418,7 @@ def test_detection_threshold_override_changes_process_file_results(
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_detection_threshold_override_is_ephemeral_in_process_file(
|
||||
api_v2: BatDetect2API,
|
||||
example_audio_files: list[Path],
|
||||
@ -452,51 +455,3 @@ def test_detection_threshold_override_changes_spectrogram_results(
|
||||
)
|
||||
|
||||
assert len(strict_detections) <= len(default_detections)
|
||||
|
||||
|
||||
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
||||
"""User story: call-level overrides do not mutate resolved defaults."""
|
||||
|
||||
api = BatDetect2API.from_config(BatDetect2Config())
|
||||
|
||||
override_inference = InferenceConfig.model_validate(
|
||||
{"loader": {"batch_size": 7}}
|
||||
)
|
||||
override_audio = AudioConfig.model_validate({"samplerate": 384000})
|
||||
override_train = TrainingConfig.model_validate(
|
||||
{"trainer": {"max_epochs": 2}}
|
||||
)
|
||||
|
||||
captured_process: dict[str, object] = {}
|
||||
captured_train: dict[str, object] = {}
|
||||
|
||||
def fake_process_file_list(*args, **kwargs):
|
||||
captured_process["inference_config"] = kwargs["inference_config"]
|
||||
captured_process["audio_config"] = kwargs["audio_config"]
|
||||
return []
|
||||
|
||||
def fake_run_train(*args, **kwargs):
|
||||
captured_train["train_config"] = kwargs["train_config"]
|
||||
captured_train["audio_config"] = kwargs["audio_config"]
|
||||
captured_train["model_config"] = kwargs["model_config"]
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"batdetect2.api_v2.process_file_list", fake_process_file_list
|
||||
)
|
||||
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
|
||||
|
||||
api.process_files(
|
||||
[], inference_config=override_inference, audio_config=override_audio
|
||||
)
|
||||
api.train([], train_config=override_train, audio_config=override_audio)
|
||||
|
||||
assert captured_process["inference_config"] is override_inference
|
||||
assert captured_process["audio_config"] is override_audio
|
||||
assert captured_train["train_config"] is override_train
|
||||
assert captured_train["audio_config"] is override_audio
|
||||
assert captured_train["model_config"] is api.model_config
|
||||
|
||||
assert api.inference_config.loader.batch_size != 7
|
||||
assert api.audio_config.samplerate != 384000
|
||||
assert api.train_config.trainer.max_epochs != 2
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from hypothesis import given, settings
|
||||
@ -10,6 +11,7 @@ from batdetect2.utils import audio_utils, detector_utils
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=1))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.slow
|
||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||
@ -89,6 +91,7 @@ def test_pad_audio_without_fixed_size(duration: float):
|
||||
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
@settings(deadline=None)
|
||||
@pytest.mark.slow
|
||||
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||
duration: float,
|
||||
):
|
||||
|
||||
@ -3,11 +3,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
||||
"""User story: run legacy detect on example audio directory."""
|
||||
|
||||
@ -29,6 +31,7 @@ def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
||||
assert len(list(results_dir.glob("*.json"))) == 3
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_detect_command_with_non_trivial_time_expansion(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
@ -52,6 +55,7 @@ def test_cli_detect_command_with_non_trivial_time_expansion(
|
||||
assert "Time Expansion Factor: 10" in result.stdout
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None:
|
||||
"""User story: request extra spectral features in output CSV."""
|
||||
|
||||
|
||||
@ -20,6 +20,7 @@ def test_cli_predict_help() -> None:
|
||||
assert "dataset" in result.output
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_predict_directory_runs_on_real_audio(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
|
||||
@ -2,6 +2,7 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
@ -19,6 +20,7 @@ def test_cli_train_help() -> None:
|
||||
assert "--model" in result.output
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cli_train_from_checkpoint_runs_on_small_dataset(
|
||||
tmp_path: Path,
|
||||
tiny_checkpoint_path: Path,
|
||||
|
||||
@ -2,11 +2,13 @@
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from batdetect2.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
def test_can_process_jeff37_files(
|
||||
|
||||
303
tests/test_data/test_conditions/test_clip.py
Normal file
303
tests/test_data/test_conditions/test_clip.py
Normal file
@ -0,0 +1,303 @@
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import load_config
|
||||
from batdetect2.data.conditions import (
|
||||
ClipAnnotationConditionConfig,
|
||||
build_clip_annotation_condition,
|
||||
)
|
||||
|
||||
|
||||
def load_clip_condition_config(
|
||||
tmp_path: Path,
|
||||
yaml_string: str,
|
||||
) -> ClipAnnotationConditionConfig:
|
||||
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
|
||||
config_path.write_text(textwrap.dedent(yaml_string).strip())
|
||||
return load_config(
|
||||
config_path, schema=TypeAdapter(ClipAnnotationConditionConfig)
|
||||
)
|
||||
|
||||
|
||||
def build_clip_condition_from_yaml(
|
||||
tmp_path: Path,
|
||||
yaml_string: str,
|
||||
base_dir: Path | None = None,
|
||||
):
|
||||
config = load_clip_condition_config(tmp_path, yaml_string)
|
||||
return build_clip_annotation_condition(config, base_dir=base_dir)
|
||||
|
||||
|
||||
def test_recording_satisfies_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
clip_a = create_clip(recording_a)
|
||||
clip_b = create_clip(recording_b)
|
||||
clip_annotation_a = create_clip_annotation(clip_a)
|
||||
clip_annotation_b = create_clip_annotation(clip_b)
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: recording_satisfies
|
||||
condition:
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation_a)
|
||||
assert not condition(clip_annotation_b)
|
||||
|
||||
|
||||
def test_clip_id_in_list_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
clip_annotation_a = create_clip_annotation(create_clip(recording_a))
|
||||
clip_annotation_b = create_clip_annotation(create_clip(recording_b))
|
||||
ids_path = tmp_path / "clip_annotation_ids.json"
|
||||
ids_path.write_text(json.dumps([str(clip_annotation_a.uuid)]))
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation_a)
|
||||
assert not condition(clip_annotation_b)
|
||||
|
||||
|
||||
def test_clip_has_tag_conditions(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
reviewed = data.Tag(key="status", value="reviewed")
|
||||
train = data.Tag(key="split", value="train")
|
||||
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip = create_clip(recording)
|
||||
clip_annotation = create_clip_annotation(
|
||||
clip,
|
||||
clip_tags=[reviewed, train],
|
||||
)
|
||||
clip_annotation_missing = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[train],
|
||||
)
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_tag
|
||||
tag:
|
||||
key: status
|
||||
value: reviewed
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation)
|
||||
assert not condition(clip_annotation_missing)
|
||||
|
||||
|
||||
def test_clip_has_all_tags_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
reviewed = data.Tag(key="status", value="reviewed")
|
||||
train = data.Tag(key="split", value="train")
|
||||
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip_annotation = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[reviewed, train],
|
||||
)
|
||||
clip_annotation_missing = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[reviewed],
|
||||
)
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_all_tags
|
||||
tags:
|
||||
- key: status
|
||||
value: reviewed
|
||||
- key: split
|
||||
value: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation)
|
||||
assert not condition(clip_annotation_missing)
|
||||
|
||||
|
||||
def test_clip_has_any_tag_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
reviewed = data.Tag(key="status", value="reviewed")
|
||||
train = data.Tag(key="split", value="train")
|
||||
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip_annotation = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[reviewed, train],
|
||||
)
|
||||
clip_annotation_missing = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[data.Tag(key="split", value="test")],
|
||||
)
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_any_tag
|
||||
tags:
|
||||
- key: split
|
||||
value: val
|
||||
- key: split
|
||||
value: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation)
|
||||
assert not condition(clip_annotation_missing)
|
||||
|
||||
|
||||
def test_clip_all_of_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
reviewed = data.Tag(key="status", value="reviewed")
|
||||
train = data.Tag(key="split", value="train")
|
||||
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip = create_clip(recording)
|
||||
clip_annotation = create_clip_annotation(
|
||||
clip,
|
||||
clip_tags=[reviewed, train],
|
||||
)
|
||||
clip_annotation_missing = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[reviewed],
|
||||
)
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: status
|
||||
value: reviewed
|
||||
- name: has_any_tag
|
||||
tags:
|
||||
- key: split
|
||||
value: train
|
||||
- key: split
|
||||
value: val
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation)
|
||||
assert not condition(clip_annotation_missing)
|
||||
|
||||
|
||||
def test_clip_any_of_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
reviewed = data.Tag(key="status", value="reviewed")
|
||||
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip_annotation = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[reviewed],
|
||||
)
|
||||
clip_annotation_missing = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[data.Tag(key="status", value="unchecked")],
|
||||
)
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: any_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: split
|
||||
value: val
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: status
|
||||
value: reviewed
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation)
|
||||
assert not condition(clip_annotation_missing)
|
||||
|
||||
|
||||
def test_clip_not_condition(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
create_clip,
|
||||
create_clip_annotation,
|
||||
) -> None:
|
||||
recording = create_recording(path=tmp_path / "rec.wav")
|
||||
clip_annotation = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[data.Tag(key="split", value="train")],
|
||||
)
|
||||
clip_annotation_missing = create_clip_annotation(
|
||||
create_clip(recording),
|
||||
clip_tags=[data.Tag(key="split", value="val")],
|
||||
)
|
||||
|
||||
condition = build_clip_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: split
|
||||
value: val
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(clip_annotation)
|
||||
assert not condition(clip_annotation_missing)
|
||||
564
tests/test_data/test_conditions/test_recording.py
Normal file
564
tests/test_data/test_conditions/test_recording.py
Normal file
@ -0,0 +1,564 @@
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import load_config
|
||||
from batdetect2.data.conditions import (
|
||||
RecordingConditionConfig,
|
||||
build_recording_condition,
|
||||
)
|
||||
|
||||
|
||||
def load_recording_condition_config(
|
||||
tmp_path: Path,
|
||||
yaml_string: str,
|
||||
) -> RecordingConditionConfig:
|
||||
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
|
||||
config_path.write_text(textwrap.dedent(yaml_string).strip())
|
||||
return load_config(
|
||||
config_path,
|
||||
schema=TypeAdapter(RecordingConditionConfig),
|
||||
)
|
||||
|
||||
|
||||
def build_recording_condition_from_yaml(
|
||||
tmp_path: Path,
|
||||
yaml_string: str,
|
||||
base_dir: Path | None = None,
|
||||
):
|
||||
config = load_recording_condition_config(tmp_path, yaml_string)
|
||||
return build_recording_condition(config, base_dir=base_dir)
|
||||
|
||||
|
||||
def test_id_in_list_condition(tmp_path: Path, create_recording) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_id_in_list_condition_uses_base_dir(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
split_dir = tmp_path / "splits"
|
||||
split_dir.mkdir()
|
||||
ids_path = split_dir / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: id_in_list
|
||||
path: splits/train_ids.json
|
||||
""",
|
||||
base_dir=tmp_path,
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_id_in_list_condition_raises_for_non_list_json(
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps({"id": "foo"}))
|
||||
|
||||
with pytest.raises(TypeError, match="Expected JSON list"):
|
||||
build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def test_id_in_list_condition_raises_for_invalid_id(tmp_path: Path) -> None:
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(json.dumps(["not-a-uuid"]))
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid ID"):
|
||||
build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def test_id_in_list_condition_supports_txt_format(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.txt"
|
||||
ids_path.write_text(f"{recording_a.uuid}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
format: txt
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_id_in_list_condition_supports_json_field(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.json"
|
||||
ids_path.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"train": [str(recording_a.uuid)],
|
||||
"val": [str(recording_b.uuid)],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
format:
|
||||
name: json
|
||||
field: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_id_in_list_condition_supports_csv_column(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
||||
ids_path = tmp_path / "recording_ids.csv"
|
||||
ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
format:
|
||||
name: csv
|
||||
column: recording_uuid
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_supports_txt_format(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text(f"{recording_a.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_supports_json_field(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.json"
|
||||
paths_file.write_text(
|
||||
json.dumps(
|
||||
{
|
||||
"train": [str(recording_a.path)],
|
||||
"val": [str(recording_b.path)],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format:
|
||||
name: json
|
||||
field: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_supports_csv_column(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
audio_dir = tmp_path / "audio"
|
||||
audio_dir.mkdir()
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.csv"
|
||||
paths_file.write_text(f"recording_path\n{recording_a.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format:
|
||||
name: csv
|
||||
column: recording_path
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_uses_base_dir(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
audio_dir = data_dir / "audio"
|
||||
audio_dir.mkdir(parents=True)
|
||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text(f"{recording_a.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_path_in_list_condition_outside_allow(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
inside_dir = data_dir / "audio"
|
||||
inside_dir.mkdir(parents=True)
|
||||
outside_dir = tmp_path / "other"
|
||||
outside_dir.mkdir()
|
||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text("dataset/audio/unknown.wav\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
on_outside: allow
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_outside)
|
||||
assert not condition(recording_inside)
|
||||
|
||||
|
||||
def test_path_in_list_condition_outside_warn(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
inside_dir = data_dir / "audio"
|
||||
inside_dir.mkdir(parents=True)
|
||||
outside_dir = tmp_path / "other"
|
||||
outside_dir.mkdir()
|
||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text("dataset/audio/unknown.wav\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
on_outside: warn
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_outside)
|
||||
assert not condition(recording_inside)
|
||||
|
||||
|
||||
def test_path_in_list_condition_outside_error(
|
||||
tmp_path: Path,
|
||||
create_recording,
|
||||
) -> None:
|
||||
data_dir = tmp_path / "dataset"
|
||||
inside_dir = data_dir / "audio"
|
||||
inside_dir.mkdir(parents=True)
|
||||
outside_dir = tmp_path / "other"
|
||||
outside_dir.mkdir()
|
||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
||||
paths_file = tmp_path / "recording_paths.txt"
|
||||
paths_file.write_text(f"{recording_inside.path}\n")
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: path_in_list
|
||||
path: {paths_file}
|
||||
format: txt
|
||||
base_dir: {data_dir}
|
||||
on_outside: error
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_inside)
|
||||
with pytest.raises(ValueError, match="outside"):
|
||||
condition(recording_outside)
|
||||
|
||||
|
||||
def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
|
||||
train = data.Tag(key="split", value="train")
|
||||
val = data.Tag(key="split", value="val")
|
||||
|
||||
recording_a = create_recording(
|
||||
path=tmp_path / "a.wav",
|
||||
tags=[train],
|
||||
)
|
||||
recording_b = create_recording(
|
||||
path=tmp_path / "b.wav",
|
||||
tags=[val],
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_tag
|
||||
tag:
|
||||
key: split
|
||||
value: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_has_all_tags_condition(tmp_path: Path, create_recording) -> None:
|
||||
train = data.Tag(key="split", value="train")
|
||||
uk = data.Tag(key="region", value="uk")
|
||||
|
||||
recording_a = create_recording(
|
||||
path=tmp_path / "a.wav",
|
||||
tags=[train, uk],
|
||||
)
|
||||
recording_b = create_recording(
|
||||
path=tmp_path / "b.wav",
|
||||
tags=[train],
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_all_tags
|
||||
tags:
|
||||
- key: split
|
||||
value: train
|
||||
- key: region
|
||||
value: uk
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_has_any_tag_condition(tmp_path: Path, create_recording) -> None:
|
||||
uk = data.Tag(key="region", value="uk")
|
||||
us = data.Tag(key="region", value="us")
|
||||
|
||||
recording_a = create_recording(
|
||||
path=tmp_path / "a.wav",
|
||||
tags=[uk],
|
||||
)
|
||||
recording_b = create_recording(
|
||||
path=tmp_path / "b.wav",
|
||||
tags=[us],
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_any_tag
|
||||
tags:
|
||||
- key: region
|
||||
value: eu
|
||||
- key: region
|
||||
value: uk
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_all_of_condition(tmp_path: Path, create_recording) -> None:
|
||||
train = data.Tag(key="split", value="train")
|
||||
uk = data.Tag(key="region", value="uk")
|
||||
us = data.Tag(key="region", value="us")
|
||||
|
||||
recording_a = create_recording(
|
||||
path=tmp_path / "a.wav",
|
||||
tags=[train, uk],
|
||||
)
|
||||
recording_b = create_recording(
|
||||
path=tmp_path / "b.wav",
|
||||
tags=[train, us],
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: split
|
||||
value: train
|
||||
- name: has_any_tag
|
||||
tags:
|
||||
- key: region
|
||||
value: eu
|
||||
- key: region
|
||||
value: uk
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_any_of_condition(tmp_path: Path, create_recording) -> None:
|
||||
train = data.Tag(key="split", value="train")
|
||||
us = data.Tag(key="region", value="us")
|
||||
|
||||
recording_a = create_recording(
|
||||
path=tmp_path / "a.wav",
|
||||
tags=[train],
|
||||
)
|
||||
recording_b = create_recording(
|
||||
path=tmp_path / "b.wav",
|
||||
tags=[us],
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: any_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: region
|
||||
value: eu
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: split
|
||||
value: train
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
|
||||
|
||||
def test_not_condition(tmp_path: Path, create_recording) -> None:
|
||||
uk = data.Tag(key="region", value="uk")
|
||||
us = data.Tag(key="region", value="us")
|
||||
|
||||
recording_a = create_recording(
|
||||
path=tmp_path / "a.wav",
|
||||
tags=[uk],
|
||||
)
|
||||
recording_b = create_recording(
|
||||
path=tmp_path / "b.wav",
|
||||
tags=[us],
|
||||
)
|
||||
|
||||
condition = build_recording_condition_from_yaml(
|
||||
tmp_path,
|
||||
"""
|
||||
name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: region
|
||||
value: us
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(recording_a)
|
||||
assert not condition(recording_b)
|
||||
400
tests/test_data/test_conditions/test_sound_events.py
Normal file
400
tests/test_data/test_conditions/test_sound_events.py
Normal file
@ -0,0 +1,400 @@
|
||||
import json
|
||||
import textwrap
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.core import load_config
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
|
||||
|
||||
def load_sound_event_condition_config(
|
||||
tmp_path: Path,
|
||||
yaml_string: str,
|
||||
) -> SoundEventConditionConfig:
|
||||
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
|
||||
config_path.write_text(textwrap.dedent(yaml_string).strip())
|
||||
return load_config(
|
||||
config_path,
|
||||
schema=TypeAdapter(SoundEventConditionConfig),
|
||||
)
|
||||
|
||||
|
||||
def build_condition_from_str(
|
||||
tmp_path: Path,
|
||||
yaml_string: str,
|
||||
base_dir: Path | None = None,
|
||||
):
|
||||
config = load_sound_event_condition_config(tmp_path, yaml_string)
|
||||
return build_sound_event_condition(config, base_dir=base_dir)
|
||||
|
||||
|
||||
def create_sound_event_annotation(
|
||||
recording: data.Recording,
|
||||
geometry: data.Geometry,
|
||||
tags: list[data.Tag] | None = None,
|
||||
) -> data.SoundEventAnnotation:
|
||||
return data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=geometry,
|
||||
),
|
||||
tags=tags or [],
|
||||
)
|
||||
|
||||
|
||||
def test_has_tag_condition(
|
||||
sound_event: data.SoundEvent, tmp_path: Path
|
||||
) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
""",
|
||||
)
|
||||
|
||||
passing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
failing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_has_all_tags_condition(
|
||||
sound_event: data.SoundEvent,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_all_tags
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
- key: event
|
||||
value: Echolocation
|
||||
""",
|
||||
)
|
||||
|
||||
passing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
failing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_has_any_tag_condition(
|
||||
sound_event: data.SoundEvent,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_any_tag
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
- key: event
|
||||
value: Echolocation
|
||||
""",
|
||||
)
|
||||
|
||||
passing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="event", value="Echolocation")],
|
||||
)
|
||||
failing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||
data.Tag(key="event", value="Social"),
|
||||
],
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_not_condition(sound_event: data.SoundEvent, tmp_path: Path) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
""",
|
||||
)
|
||||
|
||||
passing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
failing = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_id_in_list_condition(
|
||||
sound_event: data.SoundEvent, tmp_path: Path
|
||||
) -> None:
|
||||
passing = data.SoundEventAnnotation(sound_event=sound_event)
|
||||
failing = data.SoundEventAnnotation(sound_event=sound_event)
|
||||
ids_path = tmp_path / "sound_event_ids.json"
|
||||
ids_path.write_text(json.dumps([str(passing.uuid)]))
|
||||
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: id_in_list
|
||||
path: {ids_path}
|
||||
""",
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_id_in_list_condition_uses_base_dir(
|
||||
sound_event: data.SoundEvent,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
passing = data.SoundEventAnnotation(sound_event=sound_event)
|
||||
failing = data.SoundEventAnnotation(sound_event=sound_event)
|
||||
split_dir = tmp_path / "splits"
|
||||
split_dir.mkdir()
|
||||
ids_path = split_dir / "sound_event_ids.json"
|
||||
ids_path.write_text(json.dumps([str(passing.uuid)]))
|
||||
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: id_in_list
|
||||
path: splits/sound_event_ids.json
|
||||
""",
|
||||
base_dir=tmp_path,
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"operator,seconds,passing_duration,failing_duration",
|
||||
[
|
||||
("lt", 2, 1, 2),
|
||||
("lte", 2, 2, 3),
|
||||
("gt", 2, 3, 2),
|
||||
("gte", 2, 2, 1),
|
||||
("eq", 2, 2, 3),
|
||||
],
|
||||
)
|
||||
def test_duration_condition(
|
||||
tmp_path: Path,
|
||||
recording: data.Recording,
|
||||
operator: str,
|
||||
seconds: int,
|
||||
passing_duration: int,
|
||||
failing_duration: int,
|
||||
) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: duration
|
||||
operator: {operator}
|
||||
seconds: {seconds}
|
||||
""",
|
||||
)
|
||||
|
||||
passing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.TimeInterval(coordinates=[0, passing_duration]),
|
||||
)
|
||||
failing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.TimeInterval(coordinates=[0, failing_duration]),
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"boundary,operator,hertz,passing_bbox,failing_bbox",
|
||||
[
|
||||
("high", "lt", 300, [0, 100, 1, 200], [0, 100, 1, 300]),
|
||||
("high", "lte", 300, [0, 100, 1, 300], [0, 100, 1, 400]),
|
||||
("high", "gt", 300, [0, 100, 1, 400], [0, 100, 1, 300]),
|
||||
("high", "gte", 300, [0, 100, 1, 300], [0, 100, 1, 200]),
|
||||
("high", "eq", 300, [0, 100, 1, 300], [0, 100, 1, 400]),
|
||||
("low", "lt", 200, [0, 100, 1, 400], [0, 200, 1, 400]),
|
||||
("low", "lte", 200, [0, 200, 1, 400], [0, 300, 1, 400]),
|
||||
("low", "gt", 200, [0, 300, 1, 400], [0, 200, 1, 400]),
|
||||
("low", "gte", 200, [0, 200, 1, 400], [0, 100, 1, 400]),
|
||||
("low", "eq", 200, [0, 200, 1, 400], [0, 300, 1, 400]),
|
||||
],
|
||||
)
|
||||
def test_frequency_condition(
|
||||
tmp_path: Path,
|
||||
recording: data.Recording,
|
||||
boundary: str,
|
||||
operator: str,
|
||||
hertz: int,
|
||||
passing_bbox: list[int],
|
||||
failing_bbox: list[int],
|
||||
) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
f"""
|
||||
name: frequency
|
||||
boundary: {boundary}
|
||||
operator: {operator}
|
||||
hertz: {hertz}
|
||||
""",
|
||||
)
|
||||
|
||||
passing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[float(value) for value in passing_bbox]
|
||||
),
|
||||
)
|
||||
failing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(
|
||||
coordinates=[float(value) for value in failing_bbox]
|
||||
),
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_frequency_condition_is_false_for_temporal_geometries(
|
||||
tmp_path: Path,
|
||||
recording: data.Recording,
|
||||
) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: eq
|
||||
hertz: 200
|
||||
""",
|
||||
)
|
||||
|
||||
passing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 200, 1, 400]),
|
||||
)
|
||||
failing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.TimeInterval(coordinates=[0, 3]),
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_has_all_tags_fails_if_empty(tmp_path: Path) -> None:
|
||||
with pytest.raises(ValueError, match="at least one tag"):
|
||||
build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: has_all_tags
|
||||
tags: []
|
||||
""",
|
||||
)
|
||||
|
||||
|
||||
def test_all_of_condition(tmp_path: Path, recording: data.Recording) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
- name: duration
|
||||
operator: lt
|
||||
seconds: 1
|
||||
""",
|
||||
)
|
||||
|
||||
passing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
failing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
|
||||
|
||||
def test_any_of_condition(tmp_path: Path, recording: data.Recording) -> None:
|
||||
condition = build_condition_from_str(
|
||||
tmp_path,
|
||||
"""
|
||||
name: any_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
- name: duration
|
||||
operator: lt
|
||||
seconds: 1
|
||||
""",
|
||||
)
|
||||
|
||||
passing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
failing = create_sound_event_annotation(
|
||||
recording=recording,
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
|
||||
assert condition(passing)
|
||||
assert not condition(failing)
|
||||
100
tests/test_data/test_datasets.py
Normal file
100
tests/test_data/test_datasets.py
Normal file
@ -0,0 +1,100 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data import DatasetConfig, load_dataset
|
||||
from batdetect2.data.conditions import (
|
||||
HasTagConfig,
|
||||
IdInListConfig,
|
||||
RecordingSatisfiesConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_load_dataset_applies_clip_filter(
|
||||
example_dataset: DatasetConfig,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
baseline = list(load_dataset(example_dataset))
|
||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
||||
ids_path = tmp_path / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
||||
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"clip_filter": RecordingSatisfiesConfig(
|
||||
condition=IdInListConfig(path=ids_path)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config))
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
||||
|
||||
|
||||
def test_load_dataset_clip_filter_is_skipped_when_filters_disabled(
|
||||
example_dataset: DatasetConfig,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
baseline = list(load_dataset(example_dataset))
|
||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
||||
ids_path = tmp_path / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
||||
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"clip_filter": RecordingSatisfiesConfig(
|
||||
condition=IdInListConfig(path=ids_path)
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config, apply_filters=False))
|
||||
|
||||
assert len(filtered) == len(baseline)
|
||||
|
||||
|
||||
def test_load_dataset_resolves_clip_filter_paths_from_base_dir(
|
||||
example_dataset: DatasetConfig,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
baseline = list(load_dataset(example_dataset))
|
||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
||||
split_dir = tmp_path / "splits"
|
||||
split_dir.mkdir()
|
||||
ids_path = split_dir / "train_ids.json"
|
||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
||||
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"clip_filter": RecordingSatisfiesConfig(
|
||||
condition=IdInListConfig(path=Path("splits/train_ids.json"))
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config, base_dir=tmp_path))
|
||||
|
||||
assert len(filtered) == 1
|
||||
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
||||
|
||||
|
||||
def test_sound_event_filter_keeps_empty_clips(
|
||||
example_dataset: DatasetConfig,
|
||||
) -> None:
|
||||
config = example_dataset.model_copy(
|
||||
update={
|
||||
"sound_event_filter": HasTagConfig(
|
||||
tag=data.Tag(key="species", value="__missing_species__")
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
filtered = list(load_dataset(config))
|
||||
|
||||
assert len(filtered) == 3
|
||||
assert all(
|
||||
len(clip_annotation.sound_events) == 0 for clip_annotation in filtered
|
||||
)
|
||||
@ -1,491 +0,0 @@
|
||||
import textwrap
|
||||
|
||||
import pytest
|
||||
import yaml
|
||||
from pydantic import TypeAdapter
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.data.conditions import (
|
||||
SoundEventConditionConfig,
|
||||
build_sound_event_condition,
|
||||
)
|
||||
|
||||
|
||||
def build_condition_from_str(content):
|
||||
content = textwrap.dedent(content)
|
||||
content = yaml.safe_load(content)
|
||||
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
||||
return build_sound_event_condition(config)
|
||||
|
||||
|
||||
def test_has_tag(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_has_all_tags(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: has_all_tags
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
- key: event
|
||||
value: Echolocation
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
data.Tag(key="sex", value="Female"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_has_any_tags(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: has_any_tag
|
||||
tags:
|
||||
- key: species
|
||||
value: Myotis myotis
|
||||
- key: event
|
||||
value: Echolocation
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||
data.Tag(key="event", value="Social"),
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_not(sound_event: data.SoundEvent):
|
||||
condition = build_condition_from_str("""
|
||||
name: not
|
||||
condition:
|
||||
name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
""")
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert condition(sound_event_annotation)
|
||||
|
||||
sound_event_annotation = data.SoundEventAnnotation(
|
||||
sound_event=sound_event,
|
||||
tags=[
|
||||
data.Tag(key="species", value="Myotis myotis"),
|
||||
data.Tag(key="event", value="Echolocation"),
|
||||
],
|
||||
)
|
||||
assert not condition(sound_event_annotation)
|
||||
|
||||
|
||||
def test_duration(recording: data.Recording):
|
||||
se1 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
|
||||
),
|
||||
)
|
||||
se2 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
|
||||
),
|
||||
)
|
||||
se3 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
|
||||
),
|
||||
)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: lt
|
||||
seconds: 2
|
||||
""")
|
||||
assert condition(se1)
|
||||
assert not condition(se2)
|
||||
assert not condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: lte
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert condition(se1)
|
||||
assert condition(se2)
|
||||
assert not condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: gt
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert not condition(se1)
|
||||
assert not condition(se2)
|
||||
assert condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: gte
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert not condition(se1)
|
||||
assert condition(se2)
|
||||
assert condition(se3)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: duration
|
||||
operator: eq
|
||||
seconds: 2
|
||||
""")
|
||||
|
||||
assert not condition(se1)
|
||||
assert condition(se2)
|
||||
assert not condition(se3)
|
||||
|
||||
|
||||
def test_frequency(recording: data.Recording):
|
||||
se12 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
|
||||
),
|
||||
)
|
||||
se13 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
|
||||
),
|
||||
)
|
||||
se14 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
|
||||
),
|
||||
)
|
||||
se24 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
|
||||
),
|
||||
)
|
||||
se34 = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
recording=recording,
|
||||
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
|
||||
),
|
||||
)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: lt
|
||||
hertz: 300
|
||||
""")
|
||||
assert condition(se12)
|
||||
assert not condition(se13)
|
||||
assert not condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: lte
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert condition(se12)
|
||||
assert condition(se13)
|
||||
assert not condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: gt
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert not condition(se12)
|
||||
assert not condition(se13)
|
||||
assert condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: gte
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert not condition(se12)
|
||||
assert condition(se13)
|
||||
assert condition(se14)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: high
|
||||
operator: eq
|
||||
hertz: 300
|
||||
""")
|
||||
|
||||
assert not condition(se12)
|
||||
assert condition(se13)
|
||||
assert not condition(se14)
|
||||
|
||||
# LOW
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: lt
|
||||
hertz: 200
|
||||
""")
|
||||
assert condition(se14)
|
||||
assert not condition(se24)
|
||||
assert not condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: lte
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert condition(se14)
|
||||
assert condition(se24)
|
||||
assert not condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: gt
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert not condition(se14)
|
||||
assert not condition(se24)
|
||||
assert condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: gte
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert not condition(se14)
|
||||
assert condition(se24)
|
||||
assert condition(se34)
|
||||
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: eq
|
||||
hertz: 200
|
||||
""")
|
||||
|
||||
assert not condition(se14)
|
||||
assert condition(se24)
|
||||
assert not condition(se34)
|
||||
|
||||
|
||||
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: frequency
|
||||
boundary: low
|
||||
operator: eq
|
||||
hertz: 200
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 3]),
|
||||
recording=recording,
|
||||
)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeStamp(coordinates=3),
|
||||
recording=recording,
|
||||
)
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_has_tags_fails_if_empty():
|
||||
with pytest.raises(ValueError):
|
||||
build_condition_from_str("""
|
||||
name: has_tags
|
||||
tags: []
|
||||
""")
|
||||
|
||||
|
||||
def test_all_of(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: all_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
- name: duration
|
||||
operator: lt
|
||||
seconds: 1
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
|
||||
def test_any_of(recording: data.Recording):
|
||||
condition = build_condition_from_str("""
|
||||
name: any_of
|
||||
conditions:
|
||||
- name: has_tag
|
||||
tag:
|
||||
key: species
|
||||
value: Myotis myotis
|
||||
- name: duration
|
||||
operator: lt
|
||||
seconds: 1
|
||||
""")
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert not condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||
)
|
||||
assert condition(se)
|
||||
|
||||
se = data.SoundEventAnnotation(
|
||||
sound_event=data.SoundEvent(
|
||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||
recording=recording,
|
||||
),
|
||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||
)
|
||||
assert condition(se)
|
||||
@ -2,11 +2,14 @@
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from batdetect2 import api
|
||||
|
||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_no_detections_above_nyquist():
|
||||
"""Test that no detections are made above the nyquist frequency."""
|
||||
# Recording donated by @@kdarras
|
||||
|
||||
@ -4,6 +4,7 @@ from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from hypothesis import given, settings
|
||||
from hypothesis import strategies as st
|
||||
|
||||
@ -13,6 +14,7 @@ from batdetect2.detector import parameters
|
||||
|
||||
@settings(deadline=None, max_examples=5)
|
||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||
@pytest.mark.slow
|
||||
def test_can_import_model_without_pickle(duration: float):
|
||||
# NOTE: remove this test once no other issues are found This is a temporary
|
||||
# test to check that change in model loading did not impact model behaviour
|
||||
@ -42,6 +44,7 @@ def test_can_import_model_without_pickle(duration: float):
|
||||
assert predictions_without_pickle == predictions_with_pickle
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_can_import_model_without_pickle_on_test_data(
|
||||
example_audio_files: List[Path],
|
||||
):
|
||||
|
||||
@ -96,7 +96,7 @@ def test_registry_build_unknown_name_raises():
|
||||
name = "NonExistentBackbone"
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
|
||||
backbone_registry.build(FakeConfig()) # ty: ignore[invalid-argument-type]
|
||||
|
||||
|
||||
def test_backbone_config_validates_unet_from_dict():
|
||||
|
||||
@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from soundevent import data, terms
|
||||
|
||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||
|
||||
@ -1,10 +1,13 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from soundevent import data
|
||||
|
||||
from batdetect2.config import BatDetect2Config
|
||||
from batdetect2.train import run_train
|
||||
|
||||
pytestmark = pytest.mark.slow
|
||||
|
||||
|
||||
def _build_fast_train_config() -> BatDetect2Config:
|
||||
config = BatDetect2Config()
|
||||
|
||||
@ -37,6 +37,7 @@ def test_can_initialize_default_module():
|
||||
assert isinstance(module, L.LightningModule)
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_can_save_checkpoint(
|
||||
tmp_path: Path,
|
||||
clip: data.Clip,
|
||||
@ -182,6 +183,7 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
||||
assert api.audio_config.samplerate == module.model_config.samplerate
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_train_smoke_produces_loadable_checkpoint(
|
||||
tmp_path: Path,
|
||||
example_annotations: list[data.ClipAnnotation],
|
||||
|
||||
Loading…
Reference in New Issue
Block a user