mirror of
https://github.com/macaodha/batdetect2.git
synced 2026-05-22 22:32:18 +02:00
Compare commits
No commits in common. "4303d4e42d45d67c6a24132f5f5f4605105b8de1" and "591d4f4ae8ef8570f6f9eb5f8808ad3e79b6101c" have entirely different histories.
4303d4e42d
...
591d4f4ae8
10
justfile
10
justfile
@ -20,15 +20,7 @@ install:
|
|||||||
# Testing & Coverage
|
# Testing & Coverage
|
||||||
# Run tests using pytest.
|
# Run tests using pytest.
|
||||||
test:
|
test:
|
||||||
uv run pytest {{TESTS_DIR}}
|
uv run pytest -n auto {{TESTS_DIR}}
|
||||||
|
|
||||||
# Run the fast subset of tests (excludes @pytest.mark.slow).
|
|
||||||
test-quick:
|
|
||||||
uv run pytest --durations=10 -m "not slow" {{TESTS_DIR}}
|
|
||||||
|
|
||||||
# Run only long-running tests marked with @pytest.mark.slow.
|
|
||||||
test-slow:
|
|
||||||
uv run pytest -m "slow" {{TESTS_DIR}}
|
|
||||||
|
|
||||||
# Run tests and generate coverage data.
|
# Run tests and generate coverage data.
|
||||||
coverage:
|
coverage:
|
||||||
|
|||||||
@ -95,9 +95,6 @@ mlflow = ["mlflow>=3.1.1"]
|
|||||||
gradio = [
|
gradio = [
|
||||||
"gradio>=6.9.0",
|
"gradio>=6.9.0",
|
||||||
]
|
]
|
||||||
dvc = [
|
|
||||||
"dvclive>=3.49.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 79
|
line-length = 79
|
||||||
@ -129,8 +126,3 @@ exclude = [
|
|||||||
"src/batdetect2/finetune",
|
"src/batdetect2/finetune",
|
||||||
"src/batdetect2/utils",
|
"src/batdetect2/utils",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
markers = [
|
|
||||||
"slow: marks long-running tests that are skipped in quick runs",
|
|
||||||
]
|
|
||||||
|
|||||||
@ -1,46 +1,68 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import Literal, Sequence, cast
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
from batdetect2.audio import AudioConfig, AudioLoader, build_audio_loader
|
||||||
from collections.abc import Sequence
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.data import Dataset, load_dataset_from_config
|
||||||
import torch
|
from batdetect2.evaluate import (
|
||||||
|
DEFAULT_EVAL_DIR,
|
||||||
from batdetect2.audio import AudioConfig, AudioLoader
|
EvaluationConfig,
|
||||||
from batdetect2.config import BatDetect2Config
|
EvaluatorProtocol,
|
||||||
from batdetect2.data import Dataset
|
build_evaluator,
|
||||||
from batdetect2.evaluate import EvaluationConfig, EvaluatorProtocol
|
run_evaluate,
|
||||||
from batdetect2.inference import InferenceConfig
|
save_evaluation_results,
|
||||||
from batdetect2.logging import AppLoggingConfig, LoggerConfig
|
)
|
||||||
from batdetect2.models import Model, ModelConfig
|
from batdetect2.inference import (
|
||||||
from batdetect2.outputs import (
|
InferenceConfig,
|
||||||
OutputFormatConfig,
|
process_file_list,
|
||||||
OutputFormatterProtocol,
|
run_batch_inference,
|
||||||
OutputsConfig,
|
)
|
||||||
OutputTransformProtocol,
|
from batdetect2.logging import (
|
||||||
)
|
DEFAULT_LOGS_DIR,
|
||||||
from batdetect2.postprocess import (
|
AppLoggingConfig,
|
||||||
ClipDetections,
|
LoggerConfig,
|
||||||
Detection,
|
)
|
||||||
PostprocessorProtocol,
|
from batdetect2.models import (
|
||||||
)
|
Model,
|
||||||
from batdetect2.preprocess import PreprocessorProtocol
|
ModelConfig,
|
||||||
from batdetect2.targets import (
|
build_model,
|
||||||
ROIMapperProtocol,
|
build_model_with_new_targets,
|
||||||
TargetConfig,
|
)
|
||||||
TargetProtocol,
|
from batdetect2.models.detectors import Detector
|
||||||
)
|
from batdetect2.outputs import (
|
||||||
from batdetect2.train import TrainingConfig
|
OutputFormatConfig,
|
||||||
|
OutputFormatterProtocol,
|
||||||
|
OutputsConfig,
|
||||||
DEFAULT_CHECKPOINT_DIR: Path = Path("outputs") / "checkpoints"
|
OutputTransformProtocol,
|
||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
build_output_formatter,
|
||||||
DEFAULT_EVAL_DIR: Path = Path("outputs") / "evaluations"
|
build_output_transform,
|
||||||
|
get_output_formatter,
|
||||||
|
)
|
||||||
|
from batdetect2.postprocess import (
|
||||||
|
ClipDetections,
|
||||||
|
Detection,
|
||||||
|
PostprocessorProtocol,
|
||||||
|
build_postprocessor,
|
||||||
|
)
|
||||||
|
from batdetect2.preprocess import PreprocessorProtocol, build_preprocessor
|
||||||
|
from batdetect2.targets import (
|
||||||
|
ROIMapperProtocol,
|
||||||
|
TargetConfig,
|
||||||
|
TargetProtocol,
|
||||||
|
build_roi_mapping,
|
||||||
|
build_targets,
|
||||||
|
)
|
||||||
|
from batdetect2.train import (
|
||||||
|
DEFAULT_CHECKPOINT_DIR,
|
||||||
|
TrainingConfig,
|
||||||
|
load_model_from_checkpoint,
|
||||||
|
run_train,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BatDetect2API:
|
class BatDetect2API:
|
||||||
@ -87,8 +109,6 @@ class BatDetect2API:
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: data.PathLike | None = None,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
from batdetect2.data import load_dataset_from_config
|
|
||||||
|
|
||||||
return load_dataset_from_config(path, base_dir=base_dir)
|
return load_dataset_from_config(path, base_dir=base_dir)
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
@ -108,8 +128,6 @@ class BatDetect2API:
|
|||||||
train_config: TrainingConfig | None = None,
|
train_config: TrainingConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.train import run_train
|
|
||||||
|
|
||||||
run_train(
|
run_train(
|
||||||
train_annotations=train_annotations,
|
train_annotations=train_annotations,
|
||||||
val_annotations=val_annotations,
|
val_annotations=val_annotations,
|
||||||
@ -154,7 +172,6 @@ class BatDetect2API:
|
|||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
"""Fine-tune the model with trainable-parameter selection."""
|
"""Fine-tune the model with trainable-parameter selection."""
|
||||||
from batdetect2.train import run_train
|
|
||||||
|
|
||||||
self._set_trainable_parameters(trainable)
|
self._set_trainable_parameters(trainable)
|
||||||
|
|
||||||
@ -194,8 +211,6 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig | None = None,
|
outputs_config: OutputsConfig | None = None,
|
||||||
logger_config: LoggerConfig | None = None,
|
logger_config: LoggerConfig | None = None,
|
||||||
) -> tuple[dict[str, float], list[ClipDetections]]:
|
) -> tuple[dict[str, float], list[ClipDetections]]:
|
||||||
from batdetect2.evaluate import run_evaluate
|
|
||||||
|
|
||||||
return run_evaluate(
|
return run_evaluate(
|
||||||
self.model,
|
self.model,
|
||||||
test_annotations,
|
test_annotations,
|
||||||
@ -220,8 +235,6 @@ class BatDetect2API:
|
|||||||
predictions: Sequence[ClipDetections],
|
predictions: Sequence[ClipDetections],
|
||||||
output_dir: data.PathLike | None = None,
|
output_dir: data.PathLike | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.evaluate import save_evaluation_results
|
|
||||||
|
|
||||||
clip_evals = self.evaluator.evaluate(
|
clip_evals = self.evaluator.evaluate(
|
||||||
annotations,
|
annotations,
|
||||||
predictions,
|
predictions,
|
||||||
@ -294,8 +307,6 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
audio: np.ndarray,
|
audio: np.ndarray,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
import torch
|
|
||||||
|
|
||||||
tensor = torch.tensor(audio).unsqueeze(0)
|
tensor = torch.tensor(audio).unsqueeze(0)
|
||||||
return self.preprocessor(tensor)
|
return self.preprocessor(tensor)
|
||||||
|
|
||||||
@ -305,8 +316,6 @@ class BatDetect2API:
|
|||||||
batch_size: int | None = None,
|
batch_size: int | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> ClipDetections:
|
) -> ClipDetections:
|
||||||
from batdetect2.postprocess import ClipDetections
|
|
||||||
|
|
||||||
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
recording = data.Recording.from_file(audio_file, compute_hash=False)
|
||||||
|
|
||||||
predictions = self.process_files(
|
predictions = self.process_files(
|
||||||
@ -373,8 +382,6 @@ class BatDetect2API:
|
|||||||
audio_dir: data.PathLike,
|
audio_dir: data.PathLike,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
from soundevent.audio.files import get_audio_files
|
|
||||||
|
|
||||||
files = list(get_audio_files(audio_dir))
|
files = list(get_audio_files(audio_dir))
|
||||||
return self.process_files(
|
return self.process_files(
|
||||||
files,
|
files,
|
||||||
@ -391,8 +398,6 @@ class BatDetect2API:
|
|||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
from batdetect2.inference import process_file_list
|
|
||||||
|
|
||||||
return process_file_list(
|
return process_file_list(
|
||||||
self.model,
|
self.model,
|
||||||
audio_files,
|
audio_files,
|
||||||
@ -419,8 +424,6 @@ class BatDetect2API:
|
|||||||
output_config: OutputsConfig | None = None,
|
output_config: OutputsConfig | None = None,
|
||||||
detection_threshold: float | None = None,
|
detection_threshold: float | None = None,
|
||||||
) -> list[ClipDetections]:
|
) -> list[ClipDetections]:
|
||||||
from batdetect2.inference import run_batch_inference
|
|
||||||
|
|
||||||
return run_batch_inference(
|
return run_batch_inference(
|
||||||
self.model,
|
self.model,
|
||||||
clips,
|
clips,
|
||||||
@ -445,8 +448,6 @@ class BatDetect2API:
|
|||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: OutputFormatConfig | None = None,
|
||||||
):
|
):
|
||||||
from batdetect2.outputs import get_output_formatter
|
|
||||||
|
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
|
|
||||||
if format is not None or config is not None:
|
if format is not None or config is not None:
|
||||||
@ -466,8 +467,6 @@ class BatDetect2API:
|
|||||||
format: str | None = None,
|
format: str | None = None,
|
||||||
config: OutputFormatConfig | None = None,
|
config: OutputFormatConfig | None = None,
|
||||||
) -> list[object]:
|
) -> list[object]:
|
||||||
from batdetect2.outputs import get_output_formatter
|
|
||||||
|
|
||||||
formatter = self.formatter
|
formatter = self.formatter
|
||||||
|
|
||||||
if format is not None or config is not None:
|
if format is not None or config is not None:
|
||||||
@ -485,17 +484,6 @@ class BatDetect2API:
|
|||||||
cls,
|
cls,
|
||||||
config: BatDetect2Config,
|
config: BatDetect2Config,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
from batdetect2.audio import build_audio_loader
|
|
||||||
from batdetect2.evaluate import build_evaluator
|
|
||||||
from batdetect2.models import build_model
|
|
||||||
from batdetect2.outputs import (
|
|
||||||
build_output_formatter,
|
|
||||||
build_output_transform,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess import build_postprocessor
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_roi_mapping, build_targets
|
|
||||||
|
|
||||||
targets = build_targets(config=config.model.targets)
|
targets = build_targets(config=config.model.targets)
|
||||||
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
roi_mapper = build_roi_mapping(config=config.model.targets.roi)
|
||||||
|
|
||||||
@ -575,21 +563,6 @@ class BatDetect2API:
|
|||||||
outputs_config: OutputsConfig | None = None,
|
outputs_config: OutputsConfig | None = None,
|
||||||
logging_config: AppLoggingConfig | None = None,
|
logging_config: AppLoggingConfig | None = None,
|
||||||
) -> "BatDetect2API":
|
) -> "BatDetect2API":
|
||||||
from batdetect2.audio import AudioConfig, build_audio_loader
|
|
||||||
from batdetect2.evaluate import EvaluationConfig, build_evaluator
|
|
||||||
from batdetect2.inference import InferenceConfig
|
|
||||||
from batdetect2.logging import AppLoggingConfig
|
|
||||||
from batdetect2.models import build_model_with_new_targets
|
|
||||||
from batdetect2.outputs import (
|
|
||||||
OutputsConfig,
|
|
||||||
build_output_formatter,
|
|
||||||
build_output_transform,
|
|
||||||
)
|
|
||||||
from batdetect2.postprocess import build_postprocessor
|
|
||||||
from batdetect2.preprocess import build_preprocessor
|
|
||||||
from batdetect2.targets import build_roi_mapping, build_targets
|
|
||||||
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
|
||||||
|
|
||||||
model, model_config = load_model_from_checkpoint(path)
|
model, model_config = load_model_from_checkpoint(path)
|
||||||
|
|
||||||
audio_config = audio_config or AudioConfig(
|
audio_config = audio_config or AudioConfig(
|
||||||
@ -672,7 +645,7 @@ class BatDetect2API:
|
|||||||
self,
|
self,
|
||||||
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
trainable: Literal["all", "heads", "classifier_head", "bbox_head"],
|
||||||
) -> None:
|
) -> None:
|
||||||
detector = self.model.detector
|
detector = cast(Detector, self.model.detector)
|
||||||
|
|
||||||
for parameter in detector.parameters():
|
for parameter in detector.parameters():
|
||||||
parameter.requires_grad = False
|
parameter.requires_grad = False
|
||||||
|
|||||||
@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
|
|
||||||
|
from batdetect2.logging import enable_logging
|
||||||
|
|
||||||
|
# from batdetect2.cli.ascii import BATDETECT_ASCII_ART
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"cli",
|
"cli",
|
||||||
]
|
]
|
||||||
@ -30,7 +34,5 @@ def cli(verbose: int = 0):
|
|||||||
"""
|
"""
|
||||||
click.echo(INFO_STR)
|
click.echo(INFO_STR)
|
||||||
|
|
||||||
from batdetect2.logging import enable_logging
|
|
||||||
|
|
||||||
enable_logging(verbose)
|
enable_logging(verbose)
|
||||||
# click.echo(BATDETECT_ASCII_ART)
|
# click.echo(BATDETECT_ASCII_ART)
|
||||||
|
|||||||
@ -73,7 +73,7 @@ def summary(
|
|||||||
|
|
||||||
summary = compute_class_summary(dataset, targets)
|
summary = compute_class_summary(dataset, targets)
|
||||||
|
|
||||||
print(summary.sort_values("class_name").to_markdown())
|
print(summary.to_markdown())
|
||||||
|
|
||||||
|
|
||||||
@data.command(short_help="Convert dataset config to annotation set.")
|
@data.command(short_help="Convert dataset config to annotation set.")
|
||||||
|
|||||||
@ -4,6 +4,8 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from soundevent import io
|
||||||
|
from soundevent.audio.files import get_audio_files
|
||||||
|
|
||||||
from batdetect2.cli.base import cli
|
from batdetect2.cli.base import cli
|
||||||
|
|
||||||
@ -217,8 +219,6 @@ def predict_directory_command(
|
|||||||
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
|
Loads a checkpoint, scans `audio_dir` for supported audio files, runs
|
||||||
inference, and saves predictions to `output_path`.
|
inference, and saves predictions to `output_path`.
|
||||||
"""
|
"""
|
||||||
from soundevent.audio.files import get_audio_files
|
|
||||||
|
|
||||||
audio_files = list(get_audio_files(audio_dir))
|
audio_files = list(get_audio_files(audio_dir))
|
||||||
_run_prediction(
|
_run_prediction(
|
||||||
model_path=model_path,
|
model_path=model_path,
|
||||||
@ -309,8 +309,6 @@ def predict_dataset_command(
|
|||||||
The dataset is read as a soundevent annotation set and unique recording
|
The dataset is read as a soundevent annotation set and unique recording
|
||||||
paths are extracted before inference.
|
paths are extracted before inference.
|
||||||
"""
|
"""
|
||||||
from soundevent import io
|
|
||||||
|
|
||||||
dataset_path = Path(dataset_path)
|
dataset_path = Path(dataset_path)
|
||||||
dataset = io.load(dataset_path, type="annotation_set")
|
dataset = io.load(dataset_path, type="annotation_set")
|
||||||
audio_files = sorted(
|
audio_files = sorted(
|
||||||
|
|||||||
312
src/batdetect2/data/conditions.py
Normal file
312
src/batdetect2/data/conditions.py
Normal file
@ -0,0 +1,312 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
from typing import Annotated, List, Literal, Sequence
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
from batdetect2.core.registries import (
|
||||||
|
ImportConfig,
|
||||||
|
Registry,
|
||||||
|
add_import_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
||||||
|
|
||||||
|
conditions: Registry[SoundEventCondition, []] = Registry("condition")
|
||||||
|
|
||||||
|
|
||||||
|
@add_import_config(conditions)
|
||||||
|
class SoundEventConditionImportConfig(ImportConfig):
|
||||||
|
"""Use any callable as a sound event condition.
|
||||||
|
|
||||||
|
Set ``name="import"`` and provide a ``target`` pointing to any
|
||||||
|
callable to use it instead of a built-in option.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: Literal["import"] = "import"
|
||||||
|
|
||||||
|
|
||||||
|
class HasTagConfig(BaseConfig):
|
||||||
|
name: Literal["has_tag"] = "has_tag"
|
||||||
|
tag: data.Tag
|
||||||
|
|
||||||
|
|
||||||
|
class HasTag:
|
||||||
|
def __init__(self, tag: data.Tag):
|
||||||
|
self.tag = tag
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return any(
|
||||||
|
self.tag.term.name == tag.term.name and self.tag.value == tag.value
|
||||||
|
for tag in sound_event_annotation.tags
|
||||||
|
)
|
||||||
|
|
||||||
|
@conditions.register(HasTagConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: HasTagConfig):
|
||||||
|
return HasTag(tag=config.tag)
|
||||||
|
|
||||||
|
|
||||||
|
class HasAllTagsConfig(BaseConfig):
|
||||||
|
name: Literal["has_all_tags"] = "has_all_tags"
|
||||||
|
tags: List[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
class HasAllTags:
|
||||||
|
def __init__(self, tags: List[data.Tag]):
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
|
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return self.tags.issubset(
|
||||||
|
{(tag.term.name, tag.value) for tag in sound_event_annotation.tags}
|
||||||
|
)
|
||||||
|
|
||||||
|
@conditions.register(HasAllTagsConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: HasAllTagsConfig):
|
||||||
|
return HasAllTags(tags=config.tags)
|
||||||
|
|
||||||
|
|
||||||
|
class HasAnyTagConfig(BaseConfig):
|
||||||
|
name: Literal["has_any_tag"] = "has_any_tag"
|
||||||
|
tags: List[data.Tag]
|
||||||
|
|
||||||
|
|
||||||
|
class HasAnyTag:
|
||||||
|
def __init__(self, tags: List[data.Tag]):
|
||||||
|
if not tags:
|
||||||
|
raise ValueError("Need to specify at least one tag")
|
||||||
|
|
||||||
|
self.tags = {(tag.term.name, tag.value) for tag in tags}
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return bool(
|
||||||
|
self.tags.intersection(
|
||||||
|
{
|
||||||
|
(tag.term.name, tag.value)
|
||||||
|
for tag in sound_event_annotation.tags
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@conditions.register(HasAnyTagConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: HasAnyTagConfig):
|
||||||
|
return HasAnyTag(tags=config.tags)
|
||||||
|
|
||||||
|
|
||||||
|
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
||||||
|
|
||||||
|
|
||||||
|
class DurationConfig(BaseConfig):
|
||||||
|
name: Literal["duration"] = "duration"
|
||||||
|
operator: Operator
|
||||||
|
seconds: float
|
||||||
|
|
||||||
|
|
||||||
|
def _build_comparator(
|
||||||
|
operator: Operator, value: float
|
||||||
|
) -> Callable[[float], bool]:
|
||||||
|
if operator == "gt":
|
||||||
|
return lambda x: x > value
|
||||||
|
|
||||||
|
if operator == "gte":
|
||||||
|
return lambda x: x >= value
|
||||||
|
|
||||||
|
if operator == "lt":
|
||||||
|
return lambda x: x < value
|
||||||
|
|
||||||
|
if operator == "lte":
|
||||||
|
return lambda x: x <= value
|
||||||
|
|
||||||
|
if operator == "eq":
|
||||||
|
return lambda x: x == value
|
||||||
|
|
||||||
|
raise ValueError(f"Invalid operator {operator}")
|
||||||
|
|
||||||
|
|
||||||
|
class Duration:
|
||||||
|
def __init__(self, operator: Operator, seconds: float):
|
||||||
|
self.operator = operator
|
||||||
|
self.seconds = seconds
|
||||||
|
self._comparator = _build_comparator(self.operator, self.seconds)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> bool:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
start_time, _, end_time, _ = compute_bounds(geometry)
|
||||||
|
duration = end_time - start_time
|
||||||
|
|
||||||
|
return self._comparator(duration)
|
||||||
|
|
||||||
|
@conditions.register(DurationConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: DurationConfig):
|
||||||
|
return Duration(operator=config.operator, seconds=config.seconds)
|
||||||
|
|
||||||
|
|
||||||
|
class FrequencyConfig(BaseConfig):
|
||||||
|
name: Literal["frequency"] = "frequency"
|
||||||
|
boundary: Literal["low", "high"]
|
||||||
|
operator: Operator
|
||||||
|
hertz: float
|
||||||
|
|
||||||
|
|
||||||
|
class Frequency:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
operator: Operator,
|
||||||
|
boundary: Literal["low", "high"],
|
||||||
|
hertz: float,
|
||||||
|
):
|
||||||
|
self.operator = operator
|
||||||
|
self.hertz = hertz
|
||||||
|
self.boundary = boundary
|
||||||
|
self._comparator = _build_comparator(self.operator, self.hertz)
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
sound_event_annotation: data.SoundEventAnnotation,
|
||||||
|
) -> bool:
|
||||||
|
geometry = sound_event_annotation.sound_event.geometry
|
||||||
|
|
||||||
|
if geometry is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Automatically false if geometry does not have a frequency range
|
||||||
|
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
||||||
|
return False
|
||||||
|
|
||||||
|
_, low_freq, _, high_freq = compute_bounds(geometry)
|
||||||
|
|
||||||
|
if self.boundary == "low":
|
||||||
|
return self._comparator(low_freq)
|
||||||
|
|
||||||
|
return self._comparator(high_freq)
|
||||||
|
|
||||||
|
@conditions.register(FrequencyConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: FrequencyConfig):
|
||||||
|
return Frequency(
|
||||||
|
operator=config.operator,
|
||||||
|
boundary=config.boundary,
|
||||||
|
hertz=config.hertz,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AllOfConfig(BaseConfig):
|
||||||
|
name: Literal["all_of"] = "all_of"
|
||||||
|
conditions: Sequence["SoundEventConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
class AllOf:
|
||||||
|
def __init__(self, conditions: List[SoundEventCondition]):
|
||||||
|
self.conditions = conditions
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return all(c(sound_event_annotation) for c in self.conditions)
|
||||||
|
|
||||||
|
@conditions.register(AllOfConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: AllOfConfig):
|
||||||
|
conditions = [
|
||||||
|
build_sound_event_condition(cond) for cond in config.conditions
|
||||||
|
]
|
||||||
|
return AllOf(conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class AnyOfConfig(BaseConfig):
|
||||||
|
name: Literal["any_of"] = "any_of"
|
||||||
|
conditions: List["SoundEventConditionConfig"]
|
||||||
|
|
||||||
|
|
||||||
|
class AnyOf:
|
||||||
|
def __init__(self, conditions: List[SoundEventCondition]):
|
||||||
|
self.conditions = conditions
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return any(c(sound_event_annotation) for c in self.conditions)
|
||||||
|
|
||||||
|
@conditions.register(AnyOfConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: AnyOfConfig):
|
||||||
|
conditions = [
|
||||||
|
build_sound_event_condition(cond) for cond in config.conditions
|
||||||
|
]
|
||||||
|
return AnyOf(conditions)
|
||||||
|
|
||||||
|
|
||||||
|
class NotConfig(BaseConfig):
|
||||||
|
name: Literal["not"] = "not"
|
||||||
|
condition: "SoundEventConditionConfig"
|
||||||
|
|
||||||
|
|
||||||
|
class Not:
|
||||||
|
def __init__(self, condition: SoundEventCondition):
|
||||||
|
self.condition = condition
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self, sound_event_annotation: data.SoundEventAnnotation
|
||||||
|
) -> bool:
|
||||||
|
return not self.condition(sound_event_annotation)
|
||||||
|
|
||||||
|
@conditions.register(NotConfig)
|
||||||
|
@staticmethod
|
||||||
|
def from_config(config: NotConfig):
|
||||||
|
condition = build_sound_event_condition(config.condition)
|
||||||
|
return Not(condition)
|
||||||
|
|
||||||
|
|
||||||
|
SoundEventConditionConfig = Annotated[
|
||||||
|
HasTagConfig
|
||||||
|
| HasAllTagsConfig
|
||||||
|
| HasAnyTagConfig
|
||||||
|
| DurationConfig
|
||||||
|
| FrequencyConfig
|
||||||
|
| AllOfConfig
|
||||||
|
| AnyOfConfig
|
||||||
|
| NotConfig,
|
||||||
|
Field(discriminator="name"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def build_sound_event_condition(
|
||||||
|
config: SoundEventConditionConfig,
|
||||||
|
) -> SoundEventCondition:
|
||||||
|
return conditions.build(config)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_clip_annotation(
|
||||||
|
clip_annotation: data.ClipAnnotation,
|
||||||
|
condition: SoundEventCondition,
|
||||||
|
) -> data.ClipAnnotation:
|
||||||
|
return clip_annotation.model_copy(
|
||||||
|
update=dict(
|
||||||
|
sound_events=[
|
||||||
|
sound_event
|
||||||
|
for sound_event in clip_annotation.sound_events
|
||||||
|
if condition(sound_event)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
)
|
||||||
@ -1,81 +0,0 @@
|
|||||||
from batdetect2.data.conditions.clips import (
|
|
||||||
ClipAllOfConfig,
|
|
||||||
ClipAnnotationCondition,
|
|
||||||
ClipAnnotationConditionConfig,
|
|
||||||
ClipAnnotationConditionImportConfig,
|
|
||||||
ClipAnyOfConfig,
|
|
||||||
ClipNotConfig,
|
|
||||||
RecordingSatisfiesConfig,
|
|
||||||
build_clip_annotation_condition,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions.common import (
|
|
||||||
CsvList,
|
|
||||||
HasAllTagsConfig,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
HasTagConfig,
|
|
||||||
IdInListConfig,
|
|
||||||
JsonList,
|
|
||||||
ListFormatConfig,
|
|
||||||
TxtList,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions.recordings import (
|
|
||||||
PathInListConfig,
|
|
||||||
RecordingAllOfConfig,
|
|
||||||
RecordingAnyOfConfig,
|
|
||||||
RecordingCondition,
|
|
||||||
RecordingConditionConfig,
|
|
||||||
RecordingConditionImportConfig,
|
|
||||||
RecordingNotConfig,
|
|
||||||
build_recording_condition,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions.sound_events import (
|
|
||||||
AllOfConfig,
|
|
||||||
AnyOfConfig,
|
|
||||||
DurationConfig,
|
|
||||||
FrequencyConfig,
|
|
||||||
NotConfig,
|
|
||||||
Operator,
|
|
||||||
SoundEventCondition,
|
|
||||||
SoundEventConditionConfig,
|
|
||||||
SoundEventConditionImportConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
filter_clip_annotation,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AllOfConfig",
|
|
||||||
"AnyOfConfig",
|
|
||||||
"ClipAllOfConfig",
|
|
||||||
"ClipAnnotationCondition",
|
|
||||||
"ClipAnnotationConditionConfig",
|
|
||||||
"ClipAnnotationConditionImportConfig",
|
|
||||||
"ClipAnyOfConfig",
|
|
||||||
"ClipNotConfig",
|
|
||||||
"CsvList",
|
|
||||||
"DurationConfig",
|
|
||||||
"FrequencyConfig",
|
|
||||||
"HasAllTagsConfig",
|
|
||||||
"HasAnyTagConfig",
|
|
||||||
"HasTagConfig",
|
|
||||||
"IdInListConfig",
|
|
||||||
"JsonList",
|
|
||||||
"ListFormatConfig",
|
|
||||||
"NotConfig",
|
|
||||||
"Operator",
|
|
||||||
"PathInListConfig",
|
|
||||||
"RecordingCondition",
|
|
||||||
"RecordingConditionConfig",
|
|
||||||
"RecordingConditionImportConfig",
|
|
||||||
"RecordingAllOfConfig",
|
|
||||||
"RecordingAnyOfConfig",
|
|
||||||
"RecordingNotConfig",
|
|
||||||
"RecordingSatisfiesConfig",
|
|
||||||
"SoundEventCondition",
|
|
||||||
"SoundEventConditionConfig",
|
|
||||||
"SoundEventConditionImportConfig",
|
|
||||||
"TxtList",
|
|
||||||
"build_clip_annotation_condition",
|
|
||||||
"build_recording_condition",
|
|
||||||
"build_sound_event_condition",
|
|
||||||
"filter_clip_annotation",
|
|
||||||
]
|
|
||||||
@ -1,138 +0,0 @@
|
|||||||
from collections.abc import Callable, Sequence
|
|
||||||
from typing import Annotated, Literal
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import (
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions.common import (
|
|
||||||
HasAllTagsConfig,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
HasTagConfig,
|
|
||||||
IdInListConfig,
|
|
||||||
MultiConditionConfigBase,
|
|
||||||
NotConditionConfigBase,
|
|
||||||
register_all_of_condition,
|
|
||||||
register_any_of_condition,
|
|
||||||
register_has_all_tags_condition,
|
|
||||||
register_has_any_tag_condition,
|
|
||||||
register_has_tag_condition,
|
|
||||||
register_id_in_list_condition,
|
|
||||||
register_not_condition,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions.recordings import (
|
|
||||||
RecordingCondition,
|
|
||||||
RecordingConditionConfig,
|
|
||||||
build_recording_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ClipAllOfConfig",
|
|
||||||
"ClipAnnotationCondition",
|
|
||||||
"ClipAnnotationConditionConfig",
|
|
||||||
"ClipAnnotationConditionImportConfig",
|
|
||||||
"ClipAnyOfConfig",
|
|
||||||
"ClipNotConfig",
|
|
||||||
"RecordingSatisfiesConfig",
|
|
||||||
"build_clip_annotation_condition",
|
|
||||||
]
|
|
||||||
|
|
||||||
ClipAnnotationCondition = Callable[[data.ClipAnnotation], bool]
|
|
||||||
|
|
||||||
clip_annotation_conditions: Registry[
|
|
||||||
ClipAnnotationCondition,
|
|
||||||
[data.PathLike | None],
|
|
||||||
] = Registry("clip_condition")
|
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(clip_annotation_conditions, arg_names=["base_dir"])
|
|
||||||
class ClipAnnotationConditionImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a clip annotation condition.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
|
||||||
to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
class RecordingSatisfiesConfig(BaseConfig):
|
|
||||||
name: Literal["recording_satisfies"] = "recording_satisfies"
|
|
||||||
condition: RecordingConditionConfig
|
|
||||||
|
|
||||||
|
|
||||||
class RecordingSatisfies:
|
|
||||||
def __init__(self, condition: RecordingCondition):
|
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
def __call__(self, clip_annotation: data.ClipAnnotation) -> bool:
|
|
||||||
recording = clip_annotation.clip.recording
|
|
||||||
return self.condition(recording)
|
|
||||||
|
|
||||||
@clip_annotation_conditions.register(RecordingSatisfiesConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: RecordingSatisfiesConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> "RecordingSatisfies":
|
|
||||||
condition = build_recording_condition(
|
|
||||||
config.condition,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
return RecordingSatisfies(condition)
|
|
||||||
|
|
||||||
|
|
||||||
register_has_tag_condition(clip_annotation_conditions, HasTagConfig)
|
|
||||||
register_has_all_tags_condition(
|
|
||||||
clip_annotation_conditions,
|
|
||||||
HasAllTagsConfig,
|
|
||||||
)
|
|
||||||
register_has_any_tag_condition(
|
|
||||||
clip_annotation_conditions,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
)
|
|
||||||
register_id_in_list_condition(clip_annotation_conditions, IdInListConfig)
|
|
||||||
|
|
||||||
|
|
||||||
@register_all_of_condition(clip_annotation_conditions)
|
|
||||||
class ClipAllOfConfig(MultiConditionConfigBase):
|
|
||||||
name: Literal["all_of"] = "all_of"
|
|
||||||
conditions: Sequence["ClipAnnotationConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@register_any_of_condition(clip_annotation_conditions)
|
|
||||||
class ClipAnyOfConfig(MultiConditionConfigBase):
|
|
||||||
name: Literal["any_of"] = "any_of"
|
|
||||||
conditions: Sequence["ClipAnnotationConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@register_not_condition(clip_annotation_conditions)
|
|
||||||
class ClipNotConfig(NotConditionConfigBase):
|
|
||||||
name: Literal["not"] = "not"
|
|
||||||
condition: "ClipAnnotationConditionConfig"
|
|
||||||
|
|
||||||
|
|
||||||
ClipAnnotationConditionConfig = Annotated[
|
|
||||||
RecordingSatisfiesConfig
|
|
||||||
| IdInListConfig
|
|
||||||
| HasTagConfig
|
|
||||||
| HasAllTagsConfig
|
|
||||||
| HasAnyTagConfig
|
|
||||||
| ClipAllOfConfig
|
|
||||||
| ClipAnyOfConfig
|
|
||||||
| ClipNotConfig
|
|
||||||
| ClipAnnotationConditionImportConfig,
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_clip_annotation_condition(
|
|
||||||
config: ClipAnnotationConditionConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> ClipAnnotationCondition:
|
|
||||||
return clip_annotation_conditions.build(config, base_dir)
|
|
||||||
@ -1,417 +0,0 @@
|
|||||||
import csv
|
|
||||||
import json
|
|
||||||
from collections.abc import Callable, Sequence
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, Generic, Literal, ParamSpec, Protocol, TypeVar
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import Registry
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AllOf",
|
|
||||||
"AnyOf",
|
|
||||||
"Condition",
|
|
||||||
"CsvList",
|
|
||||||
"HasAllTags",
|
|
||||||
"HasAllTagsConfig",
|
|
||||||
"HasAnyTag",
|
|
||||||
"HasAnyTagConfig",
|
|
||||||
"HasTag",
|
|
||||||
"HasTagConfig",
|
|
||||||
"IdInList",
|
|
||||||
"IdInListConfig",
|
|
||||||
"JsonList",
|
|
||||||
"ListLoader",
|
|
||||||
"ListFormatConfig",
|
|
||||||
"MultiConditionConfigBase",
|
|
||||||
"Not",
|
|
||||||
"NotConditionConfigBase",
|
|
||||||
"ObjectWithTags",
|
|
||||||
"ObjectWithUUID",
|
|
||||||
"TxtList",
|
|
||||||
"build_list_loader",
|
|
||||||
"register_all_of_condition",
|
|
||||||
"register_any_of_condition",
|
|
||||||
"register_has_all_tags_condition",
|
|
||||||
"register_has_any_tag_condition",
|
|
||||||
"register_has_tag_condition",
|
|
||||||
"register_id_in_list_condition",
|
|
||||||
"register_not_condition",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectWithTags(Protocol):
|
|
||||||
tags: list[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectWithUUID(Protocol):
|
|
||||||
uuid: UUID
|
|
||||||
|
|
||||||
|
|
||||||
ConditionObject = TypeVar("ConditionObject")
|
|
||||||
TaggedObject = TypeVar("TaggedObject", bound="ObjectWithTags")
|
|
||||||
UUIDObject = TypeVar("UUIDObject", bound="ObjectWithUUID")
|
|
||||||
P = ParamSpec("P")
|
|
||||||
NotConfigType = TypeVar("NotConfigType", bound="NotConditionConfigBase")
|
|
||||||
MultiConfigType = TypeVar(
|
|
||||||
"MultiConfigType",
|
|
||||||
bound="MultiConditionConfigBase",
|
|
||||||
)
|
|
||||||
Condition = Callable[[ConditionObject], bool]
|
|
||||||
|
|
||||||
|
|
||||||
class NotConditionConfigBase(BaseConfig):
|
|
||||||
condition: BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class MultiConditionConfigBase(BaseConfig):
|
|
||||||
conditions: Sequence[BaseModel]
|
|
||||||
|
|
||||||
|
|
||||||
class Not(Generic[ConditionObject]):
|
|
||||||
def __init__(self, condition: Condition[ConditionObject]):
|
|
||||||
self.condition = condition
|
|
||||||
|
|
||||||
def __call__(self, obj: ConditionObject) -> bool:
|
|
||||||
return not self.condition(obj)
|
|
||||||
|
|
||||||
|
|
||||||
class AllOf(Generic[ConditionObject]):
|
|
||||||
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
|
||||||
self.conditions = list(conditions)
|
|
||||||
|
|
||||||
def __call__(self, obj: ConditionObject) -> bool:
|
|
||||||
return all(condition(obj) for condition in self.conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class AnyOf(Generic[ConditionObject]):
|
|
||||||
def __init__(self, conditions: Sequence[Condition[ConditionObject]]):
|
|
||||||
self.conditions = list(conditions)
|
|
||||||
|
|
||||||
def __call__(self, obj: ConditionObject) -> bool:
|
|
||||||
return any(condition(obj) for condition in self.conditions)
|
|
||||||
|
|
||||||
|
|
||||||
class HasTag(Generic[TaggedObject]):
|
|
||||||
def __init__(self, tag: data.Tag):
|
|
||||||
self.tag_key = (tag.term.name, tag.value)
|
|
||||||
|
|
||||||
def __call__(self, obj: TaggedObject) -> bool:
|
|
||||||
return any(
|
|
||||||
(tag.term.name, tag.value) == self.tag_key for tag in obj.tags
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTags(Generic[TaggedObject]):
|
|
||||||
def __init__(self, tags: list[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
|
||||||
|
|
||||||
def __call__(self, obj: TaggedObject) -> bool:
|
|
||||||
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
|
||||||
return self.required_keys.issubset(tag_keys)
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTag(Generic[TaggedObject]):
|
|
||||||
def __init__(self, tags: list[data.Tag]):
|
|
||||||
if not tags:
|
|
||||||
raise ValueError("Need to specify at least one tag")
|
|
||||||
|
|
||||||
self.required_keys = {(tag.term.name, tag.value) for tag in tags}
|
|
||||||
|
|
||||||
def __call__(self, obj: TaggedObject) -> bool:
|
|
||||||
tag_keys = {(tag.term.name, tag.value) for tag in obj.tags}
|
|
||||||
return bool(self.required_keys.intersection(tag_keys))
|
|
||||||
|
|
||||||
|
|
||||||
class IdInList(Generic[UUIDObject]):
|
|
||||||
def __init__(self, ids: set[UUID]):
|
|
||||||
self.ids = ids
|
|
||||||
|
|
||||||
def __call__(self, obj: UUIDObject) -> bool:
|
|
||||||
return obj.uuid in self.ids
|
|
||||||
|
|
||||||
|
|
||||||
class HasTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_tag"] = "has_tag"
|
|
||||||
tag: data.Tag
|
|
||||||
|
|
||||||
|
|
||||||
class HasAllTagsConfig(BaseConfig):
|
|
||||||
name: Literal["has_all_tags"] = "has_all_tags"
|
|
||||||
tags: list[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
class HasAnyTagConfig(BaseConfig):
|
|
||||||
name: Literal["has_any_tag"] = "has_any_tag"
|
|
||||||
tags: list[data.Tag]
|
|
||||||
|
|
||||||
|
|
||||||
class JsonList(BaseConfig):
|
|
||||||
name: Literal["json"] = "json"
|
|
||||||
field: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class TxtList(BaseConfig):
|
|
||||||
name: Literal["txt"] = "txt"
|
|
||||||
|
|
||||||
|
|
||||||
class CsvList(BaseConfig):
|
|
||||||
name: Literal["csv"] = "csv"
|
|
||||||
column: str
|
|
||||||
|
|
||||||
|
|
||||||
ListFormatConfig = Annotated[
|
|
||||||
JsonList | TxtList | CsvList,
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
ListLoader = Callable[[Path], list[str]]
|
|
||||||
|
|
||||||
list_loaders: Registry[ListLoader, []] = Registry("list_loader")
|
|
||||||
|
|
||||||
|
|
||||||
class IdInListConfig(BaseConfig):
|
|
||||||
name: Literal["id_in_list"] = "id_in_list"
|
|
||||||
path: Path
|
|
||||||
format: ListFormatConfig = JsonList()
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_format(cls, values):
|
|
||||||
if not isinstance(values, dict):
|
|
||||||
return values
|
|
||||||
|
|
||||||
format_config = values.get("format")
|
|
||||||
|
|
||||||
if isinstance(format_config, str):
|
|
||||||
values = values.copy()
|
|
||||||
config_class = list_loaders.get_config_type(format_config)
|
|
||||||
values["format"] = config_class().model_dump()
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
class JsonListLoader:
|
|
||||||
def __init__(self, field: str | None):
|
|
||||||
self.field = field
|
|
||||||
|
|
||||||
def __call__(self, path: Path) -> list[str]:
|
|
||||||
content = json.loads(path.read_text())
|
|
||||||
|
|
||||||
if self.field is not None:
|
|
||||||
if not isinstance(content, dict):
|
|
||||||
raise TypeError(
|
|
||||||
"Expected JSON object with field for 'id_in_list'."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.field not in content:
|
|
||||||
raise KeyError(f"Field '{self.field}' not found in '{path}'.")
|
|
||||||
|
|
||||||
content = content[self.field]
|
|
||||||
|
|
||||||
if not isinstance(content, list):
|
|
||||||
raise TypeError("Expected JSON list with IDs for 'id_in_list'.")
|
|
||||||
|
|
||||||
return [str(value) for value in content]
|
|
||||||
|
|
||||||
@list_loaders.register(JsonList)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: JsonList) -> ListLoader:
|
|
||||||
return JsonListLoader(field=config.field)
|
|
||||||
|
|
||||||
|
|
||||||
class TxtListLoader:
|
|
||||||
def __call__(self, path: Path) -> list[str]:
|
|
||||||
return [
|
|
||||||
line.strip()
|
|
||||||
for line in path.read_text().splitlines()
|
|
||||||
if line.strip()
|
|
||||||
]
|
|
||||||
|
|
||||||
@list_loaders.register(TxtList)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: TxtList) -> ListLoader:
|
|
||||||
return TxtListLoader()
|
|
||||||
|
|
||||||
|
|
||||||
class CsvListLoader:
|
|
||||||
def __init__(self, column: str):
|
|
||||||
self.column = column
|
|
||||||
|
|
||||||
def __call__(self, path: Path) -> list[str]:
|
|
||||||
with path.open("r", newline="") as csv_file:
|
|
||||||
reader = csv.DictReader(csv_file)
|
|
||||||
|
|
||||||
if reader.fieldnames is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"Expected CSV header row for 'id_in_list' in '{path}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.column not in reader.fieldnames:
|
|
||||||
raise ValueError(
|
|
||||||
f"Column '{self.column}' not found in '{path}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
values = []
|
|
||||||
for row in reader:
|
|
||||||
value = row.get(self.column)
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = value.strip()
|
|
||||||
|
|
||||||
if not value:
|
|
||||||
continue
|
|
||||||
|
|
||||||
values.append(value)
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
@list_loaders.register(CsvList)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(config: CsvList) -> ListLoader:
|
|
||||||
return CsvListLoader(column=config.column)
|
|
||||||
|
|
||||||
|
|
||||||
def build_list_loader(config: ListFormatConfig) -> ListLoader:
|
|
||||||
return list_loaders.build(config)
|
|
||||||
|
|
||||||
|
|
||||||
def register_id_in_list_condition(
|
|
||||||
registry: Registry[Condition[UUIDObject], [data.PathLike | None]],
|
|
||||||
config_cls: type[IdInListConfig],
|
|
||||||
) -> None:
|
|
||||||
def builder(
|
|
||||||
config: IdInListConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> Condition[UUIDObject]:
|
|
||||||
path = config.path
|
|
||||||
|
|
||||||
if base_dir is not None and not path.is_absolute():
|
|
||||||
path = Path(base_dir) / path
|
|
||||||
|
|
||||||
ids = set()
|
|
||||||
loader = build_list_loader(config.format)
|
|
||||||
values = loader(path)
|
|
||||||
for index, value in enumerate(values):
|
|
||||||
try:
|
|
||||||
ids.add(UUID(value))
|
|
||||||
except ValueError as err:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid ID at index {index} in '{path}': {value!r}."
|
|
||||||
) from err
|
|
||||||
|
|
||||||
return IdInList(ids)
|
|
||||||
|
|
||||||
registry.register(config_cls)(builder)
|
|
||||||
|
|
||||||
|
|
||||||
def register_has_tag_condition(
|
|
||||||
registry: Registry[Condition[TaggedObject], P],
|
|
||||||
config_cls: type[HasTagConfig],
|
|
||||||
) -> None:
|
|
||||||
def builder(
|
|
||||||
config: HasTagConfig,
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> Condition[TaggedObject]:
|
|
||||||
return HasTag(config.tag)
|
|
||||||
|
|
||||||
registry.register(config_cls)(builder)
|
|
||||||
|
|
||||||
|
|
||||||
def register_has_all_tags_condition(
|
|
||||||
registry: Registry[Condition[TaggedObject], P],
|
|
||||||
config_cls: type[HasAllTagsConfig],
|
|
||||||
) -> None:
|
|
||||||
def builder(
|
|
||||||
config: HasAllTagsConfig,
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> Condition[TaggedObject]:
|
|
||||||
return HasAllTags(config.tags)
|
|
||||||
|
|
||||||
registry.register(config_cls)(builder)
|
|
||||||
|
|
||||||
|
|
||||||
def register_has_any_tag_condition(
|
|
||||||
registry: Registry[Condition[TaggedObject], P],
|
|
||||||
config_cls: type[HasAnyTagConfig],
|
|
||||||
) -> None:
|
|
||||||
def builder(
|
|
||||||
config: HasAnyTagConfig,
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> Condition[TaggedObject]:
|
|
||||||
return HasAnyTag(config.tags)
|
|
||||||
|
|
||||||
registry.register(config_cls)(builder)
|
|
||||||
|
|
||||||
|
|
||||||
def register_not_condition(
|
|
||||||
registry: Registry[Condition[ConditionObject], P],
|
|
||||||
) -> Callable[[type[NotConfigType]], type[NotConfigType]]:
|
|
||||||
def decorator(config_cls: type[NotConfigType]) -> type[NotConfigType]:
|
|
||||||
@registry.register(config_cls)
|
|
||||||
def builder(
|
|
||||||
config: NotConfigType,
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> Condition[ConditionObject]:
|
|
||||||
condition = registry.build(config.condition, *args, **kwargs)
|
|
||||||
return Not(condition)
|
|
||||||
|
|
||||||
return config_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_all_of_condition(
|
|
||||||
registry: Registry[Condition[ConditionObject], P],
|
|
||||||
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
|
||||||
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
|
||||||
@registry.register(config_cls)
|
|
||||||
def builder(
|
|
||||||
config: MultiConfigType,
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> Condition[ConditionObject]:
|
|
||||||
conditions = [
|
|
||||||
registry.build(condition, *args, **kwargs)
|
|
||||||
for condition in config.conditions
|
|
||||||
]
|
|
||||||
return AllOf(conditions)
|
|
||||||
|
|
||||||
return config_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
|
|
||||||
def register_any_of_condition(
|
|
||||||
registry: Registry[Condition[ConditionObject], P],
|
|
||||||
) -> Callable[[type[MultiConfigType]], type[MultiConfigType]]:
|
|
||||||
def decorator(config_cls: type[MultiConfigType]) -> type[MultiConfigType]:
|
|
||||||
@registry.register(config_cls)
|
|
||||||
def builder(
|
|
||||||
config: MultiConfigType,
|
|
||||||
*args: P.args,
|
|
||||||
**kwargs: P.kwargs,
|
|
||||||
) -> Condition[ConditionObject]:
|
|
||||||
conditions = [
|
|
||||||
registry.build(condition, *args, **kwargs)
|
|
||||||
for condition in config.conditions
|
|
||||||
]
|
|
||||||
return AnyOf(conditions)
|
|
||||||
|
|
||||||
return config_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
@ -1,217 +0,0 @@
|
|||||||
from collections.abc import Callable, Sequence
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Annotated, Literal
|
|
||||||
|
|
||||||
from loguru import logger
|
|
||||||
from pydantic import Field, model_validator
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import (
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions.common import (
|
|
||||||
HasAllTagsConfig,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
HasTagConfig,
|
|
||||||
IdInListConfig,
|
|
||||||
JsonList,
|
|
||||||
ListFormatConfig,
|
|
||||||
MultiConditionConfigBase,
|
|
||||||
NotConditionConfigBase,
|
|
||||||
build_list_loader,
|
|
||||||
list_loaders,
|
|
||||||
register_all_of_condition,
|
|
||||||
register_any_of_condition,
|
|
||||||
register_has_all_tags_condition,
|
|
||||||
register_has_any_tag_condition,
|
|
||||||
register_has_tag_condition,
|
|
||||||
register_id_in_list_condition,
|
|
||||||
register_not_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"IdInListConfig",
|
|
||||||
"PathInListConfig",
|
|
||||||
"RecordingAllOfConfig",
|
|
||||||
"RecordingAnyOfConfig",
|
|
||||||
"RecordingCondition",
|
|
||||||
"RecordingConditionConfig",
|
|
||||||
"RecordingConditionImportConfig",
|
|
||||||
"RecordingNotConfig",
|
|
||||||
"build_recording_condition",
|
|
||||||
]
|
|
||||||
|
|
||||||
RecordingCondition = Callable[[data.Recording], bool]
|
|
||||||
|
|
||||||
recording_conditions: Registry[RecordingCondition, [data.PathLike | None]] = (
|
|
||||||
Registry("recording_condition")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(recording_conditions, arg_names=["base_dir"])
|
|
||||||
class RecordingConditionImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a recording condition.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any callable
|
|
||||||
to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
register_id_in_list_condition(recording_conditions, IdInListConfig)
|
|
||||||
register_has_tag_condition(recording_conditions, HasTagConfig)
|
|
||||||
register_has_all_tags_condition(recording_conditions, HasAllTagsConfig)
|
|
||||||
register_has_any_tag_condition(recording_conditions, HasAnyTagConfig)
|
|
||||||
|
|
||||||
|
|
||||||
class PathInListConfig(BaseConfig):
|
|
||||||
name: Literal["path_in_list"] = "path_in_list"
|
|
||||||
path: Path
|
|
||||||
format: ListFormatConfig = JsonList()
|
|
||||||
base_dir: Path | None = None
|
|
||||||
on_outside: Literal["allow", "warn", "error"] = "allow"
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _normalize_format(cls, values):
|
|
||||||
if not isinstance(values, dict):
|
|
||||||
return values
|
|
||||||
|
|
||||||
format_config = values.get("format")
|
|
||||||
|
|
||||||
if isinstance(format_config, str):
|
|
||||||
values = values.copy()
|
|
||||||
config_class = list_loaders.get_config_type(format_config)
|
|
||||||
values["format"] = config_class().model_dump()
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
class PathInList:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
paths: set[Path],
|
|
||||||
base_dir: Path | None,
|
|
||||||
on_outside: Literal["allow", "warn", "error"],
|
|
||||||
):
|
|
||||||
self.paths = paths
|
|
||||||
self.base_dir = base_dir
|
|
||||||
self.on_outside = on_outside
|
|
||||||
|
|
||||||
def __call__(self, recording: data.Recording) -> bool:
|
|
||||||
normalized_path = self._normalize_recording_path(recording.path)
|
|
||||||
|
|
||||||
if normalized_path is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
return normalized_path in self.paths
|
|
||||||
|
|
||||||
def _normalize_recording_path(self, path: data.PathLike) -> Path | None:
|
|
||||||
recording_path = Path(path)
|
|
||||||
|
|
||||||
if self.base_dir is None:
|
|
||||||
return recording_path
|
|
||||||
|
|
||||||
if not recording_path.is_absolute():
|
|
||||||
return recording_path
|
|
||||||
|
|
||||||
try:
|
|
||||||
return recording_path.relative_to(self.base_dir)
|
|
||||||
except ValueError as err:
|
|
||||||
if self.on_outside == "allow":
|
|
||||||
return None
|
|
||||||
|
|
||||||
if self.on_outside == "warn":
|
|
||||||
logger.warning(
|
|
||||||
"Recording path '{}' is outside '{}' in path_in_list; "
|
|
||||||
"allowing.",
|
|
||||||
recording_path,
|
|
||||||
self.base_dir,
|
|
||||||
)
|
|
||||||
return None
|
|
||||||
|
|
||||||
raise ValueError(
|
|
||||||
f"Recording path '{recording_path}' is outside "
|
|
||||||
f"'{self.base_dir}' for 'path_in_list'."
|
|
||||||
) from err
|
|
||||||
|
|
||||||
@recording_conditions.register(PathInListConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: PathInListConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> "PathInList":
|
|
||||||
list_path = config.path
|
|
||||||
|
|
||||||
if base_dir is not None and not list_path.is_absolute():
|
|
||||||
list_path = Path(base_dir) / list_path
|
|
||||||
|
|
||||||
match_base_dir = config.base_dir
|
|
||||||
if (
|
|
||||||
match_base_dir is not None
|
|
||||||
and base_dir is not None
|
|
||||||
and not match_base_dir.is_absolute()
|
|
||||||
):
|
|
||||||
match_base_dir = Path(base_dir) / match_base_dir
|
|
||||||
|
|
||||||
loader = build_list_loader(config.format)
|
|
||||||
|
|
||||||
paths = {
|
|
||||||
Path(value).relative_to(match_base_dir)
|
|
||||||
if (
|
|
||||||
match_base_dir is not None
|
|
||||||
and Path(value).is_absolute()
|
|
||||||
and Path(value).is_relative_to(match_base_dir)
|
|
||||||
)
|
|
||||||
else Path(value)
|
|
||||||
for value in loader(list_path)
|
|
||||||
}
|
|
||||||
|
|
||||||
return PathInList(
|
|
||||||
paths=paths,
|
|
||||||
base_dir=match_base_dir,
|
|
||||||
on_outside=config.on_outside,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@register_all_of_condition(recording_conditions)
|
|
||||||
class RecordingAllOfConfig(MultiConditionConfigBase):
|
|
||||||
name: Literal["all_of"] = "all_of"
|
|
||||||
conditions: Sequence["RecordingConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@register_any_of_condition(recording_conditions)
|
|
||||||
class RecordingAnyOfConfig(MultiConditionConfigBase):
|
|
||||||
name: Literal["any_of"] = "any_of"
|
|
||||||
conditions: Sequence["RecordingConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@register_not_condition(recording_conditions)
|
|
||||||
class RecordingNotConfig(NotConditionConfigBase):
|
|
||||||
name: Literal["not"] = "not"
|
|
||||||
condition: "RecordingConditionConfig"
|
|
||||||
|
|
||||||
|
|
||||||
RecordingConditionConfig = Annotated[
|
|
||||||
IdInListConfig
|
|
||||||
| PathInListConfig
|
|
||||||
| HasTagConfig
|
|
||||||
| HasAllTagsConfig
|
|
||||||
| HasAnyTagConfig
|
|
||||||
| RecordingAllOfConfig
|
|
||||||
| RecordingAnyOfConfig
|
|
||||||
| RecordingNotConfig
|
|
||||||
| RecordingConditionImportConfig,
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_recording_condition(
|
|
||||||
config: RecordingConditionConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> RecordingCondition:
|
|
||||||
return recording_conditions.build(config, base_dir)
|
|
||||||
@ -1,236 +0,0 @@
|
|||||||
from collections.abc import Callable, Sequence
|
|
||||||
from typing import Annotated, Literal
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from soundevent import data
|
|
||||||
from soundevent.geometry import compute_bounds
|
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
|
||||||
from batdetect2.core.registries import (
|
|
||||||
ImportConfig,
|
|
||||||
Registry,
|
|
||||||
add_import_config,
|
|
||||||
)
|
|
||||||
from batdetect2.data.conditions.common import (
|
|
||||||
HasAllTagsConfig,
|
|
||||||
HasAnyTagConfig,
|
|
||||||
HasTagConfig,
|
|
||||||
IdInListConfig,
|
|
||||||
MultiConditionConfigBase,
|
|
||||||
NotConditionConfigBase,
|
|
||||||
register_all_of_condition,
|
|
||||||
register_any_of_condition,
|
|
||||||
register_has_all_tags_condition,
|
|
||||||
register_has_any_tag_condition,
|
|
||||||
register_has_tag_condition,
|
|
||||||
register_id_in_list_condition,
|
|
||||||
register_not_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AllOfConfig",
|
|
||||||
"AnyOfConfig",
|
|
||||||
"DurationConfig",
|
|
||||||
"FrequencyConfig",
|
|
||||||
"HasAllTagsConfig",
|
|
||||||
"HasAnyTagConfig",
|
|
||||||
"HasTagConfig",
|
|
||||||
"NotConfig",
|
|
||||||
"Operator",
|
|
||||||
"SoundEventCondition",
|
|
||||||
"SoundEventConditionConfig",
|
|
||||||
"SoundEventConditionImportConfig",
|
|
||||||
"build_sound_event_condition",
|
|
||||||
"filter_clip_annotation",
|
|
||||||
]
|
|
||||||
|
|
||||||
SoundEventCondition = Callable[[data.SoundEventAnnotation], bool]
|
|
||||||
|
|
||||||
sound_event_conditions: Registry[
|
|
||||||
SoundEventCondition,
|
|
||||||
[data.PathLike | None],
|
|
||||||
] = Registry("sound_event_condition")
|
|
||||||
|
|
||||||
|
|
||||||
@add_import_config(sound_event_conditions, arg_names=["base_dir"])
|
|
||||||
class SoundEventConditionImportConfig(ImportConfig):
|
|
||||||
"""Use any callable as a sound event condition.
|
|
||||||
|
|
||||||
Set ``name="import"`` and provide a ``target`` pointing to any
|
|
||||||
callable to use it instead of a built-in option.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: Literal["import"] = "import"
|
|
||||||
|
|
||||||
|
|
||||||
register_has_tag_condition(sound_event_conditions, HasTagConfig)
|
|
||||||
register_has_all_tags_condition(sound_event_conditions, HasAllTagsConfig)
|
|
||||||
register_has_any_tag_condition(sound_event_conditions, HasAnyTagConfig)
|
|
||||||
register_id_in_list_condition(sound_event_conditions, IdInListConfig)
|
|
||||||
|
|
||||||
|
|
||||||
Operator = Literal["gt", "gte", "lt", "lte", "eq"]
|
|
||||||
|
|
||||||
|
|
||||||
class DurationConfig(BaseConfig):
|
|
||||||
name: Literal["duration"] = "duration"
|
|
||||||
operator: Operator
|
|
||||||
seconds: float
|
|
||||||
|
|
||||||
|
|
||||||
def _build_comparator(
|
|
||||||
operator: Operator, value: float
|
|
||||||
) -> Callable[[float], bool]:
|
|
||||||
if operator == "gt":
|
|
||||||
return lambda x: x > value
|
|
||||||
|
|
||||||
if operator == "gte":
|
|
||||||
return lambda x: x >= value
|
|
||||||
|
|
||||||
if operator == "lt":
|
|
||||||
return lambda x: x < value
|
|
||||||
|
|
||||||
if operator == "lte":
|
|
||||||
return lambda x: x <= value
|
|
||||||
|
|
||||||
if operator == "eq":
|
|
||||||
return lambda x: x == value
|
|
||||||
|
|
||||||
raise ValueError(f"Invalid operator {operator}")
|
|
||||||
|
|
||||||
|
|
||||||
class Duration:
|
|
||||||
def __init__(self, operator: Operator, seconds: float):
|
|
||||||
self.operator = operator
|
|
||||||
self.seconds = seconds
|
|
||||||
self._comparator = _build_comparator(self.operator, self.seconds)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
start_time, _, end_time, _ = compute_bounds(geometry)
|
|
||||||
duration = end_time - start_time
|
|
||||||
|
|
||||||
return self._comparator(duration)
|
|
||||||
|
|
||||||
@sound_event_conditions.register(DurationConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: DurationConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
):
|
|
||||||
_ = base_dir
|
|
||||||
return Duration(operator=config.operator, seconds=config.seconds)
|
|
||||||
|
|
||||||
|
|
||||||
class FrequencyConfig(BaseConfig):
|
|
||||||
name: Literal["frequency"] = "frequency"
|
|
||||||
boundary: Literal["low", "high"]
|
|
||||||
operator: Operator
|
|
||||||
hertz: float
|
|
||||||
|
|
||||||
|
|
||||||
class Frequency:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
operator: Operator,
|
|
||||||
boundary: Literal["low", "high"],
|
|
||||||
hertz: float,
|
|
||||||
):
|
|
||||||
self.operator = operator
|
|
||||||
self.hertz = hertz
|
|
||||||
self.boundary = boundary
|
|
||||||
self._comparator = _build_comparator(self.operator, self.hertz)
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
sound_event_annotation: data.SoundEventAnnotation,
|
|
||||||
) -> bool:
|
|
||||||
geometry = sound_event_annotation.sound_event.geometry
|
|
||||||
|
|
||||||
if geometry is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if isinstance(geometry, (data.TimeInterval, data.TimeStamp)):
|
|
||||||
return False
|
|
||||||
|
|
||||||
_, low_freq, _, high_freq = compute_bounds(geometry)
|
|
||||||
|
|
||||||
if self.boundary == "low":
|
|
||||||
return self._comparator(low_freq)
|
|
||||||
|
|
||||||
return self._comparator(high_freq)
|
|
||||||
|
|
||||||
@sound_event_conditions.register(FrequencyConfig)
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: FrequencyConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
):
|
|
||||||
_ = base_dir
|
|
||||||
return Frequency(
|
|
||||||
operator=config.operator,
|
|
||||||
boundary=config.boundary,
|
|
||||||
hertz=config.hertz,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@register_all_of_condition(sound_event_conditions)
|
|
||||||
class AllOfConfig(MultiConditionConfigBase):
|
|
||||||
name: Literal["all_of"] = "all_of"
|
|
||||||
conditions: Sequence["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@register_any_of_condition(sound_event_conditions)
|
|
||||||
class AnyOfConfig(MultiConditionConfigBase):
|
|
||||||
name: Literal["any_of"] = "any_of"
|
|
||||||
conditions: list["SoundEventConditionConfig"]
|
|
||||||
|
|
||||||
|
|
||||||
@register_not_condition(sound_event_conditions)
|
|
||||||
class NotConfig(NotConditionConfigBase):
|
|
||||||
name: Literal["not"] = "not"
|
|
||||||
condition: "SoundEventConditionConfig"
|
|
||||||
|
|
||||||
|
|
||||||
SoundEventConditionConfig = Annotated[
|
|
||||||
IdInListConfig
|
|
||||||
| HasTagConfig
|
|
||||||
| HasAllTagsConfig
|
|
||||||
| HasAnyTagConfig
|
|
||||||
| DurationConfig
|
|
||||||
| FrequencyConfig
|
|
||||||
| AllOfConfig
|
|
||||||
| AnyOfConfig
|
|
||||||
| NotConfig
|
|
||||||
| SoundEventConditionImportConfig,
|
|
||||||
Field(discriminator="name"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def build_sound_event_condition(
|
|
||||||
config: SoundEventConditionConfig,
|
|
||||||
base_dir: data.PathLike | None = None,
|
|
||||||
) -> SoundEventCondition:
|
|
||||||
return sound_event_conditions.build(config, base_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def filter_clip_annotation(
|
|
||||||
clip_annotation: data.ClipAnnotation,
|
|
||||||
condition: SoundEventCondition,
|
|
||||||
) -> data.ClipAnnotation:
|
|
||||||
return clip_annotation.model_copy(
|
|
||||||
update=dict(
|
|
||||||
sound_events=[
|
|
||||||
sound_event
|
|
||||||
for sound_event in clip_annotation.sound_events
|
|
||||||
if condition(sound_event)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
@ -32,9 +32,7 @@ from batdetect2.data.annotations import (
|
|||||||
load_annotated_dataset,
|
load_annotated_dataset,
|
||||||
)
|
)
|
||||||
from batdetect2.data.conditions import (
|
from batdetect2.data.conditions import (
|
||||||
ClipAnnotationConditionConfig,
|
|
||||||
SoundEventConditionConfig,
|
SoundEventConditionConfig,
|
||||||
build_clip_annotation_condition,
|
|
||||||
build_sound_event_condition,
|
build_sound_event_condition,
|
||||||
filter_clip_annotation,
|
filter_clip_annotation,
|
||||||
)
|
)
|
||||||
@ -71,7 +69,6 @@ class DatasetConfig(BaseConfig):
|
|||||||
description: str
|
description: str
|
||||||
sources: list[AnnotationFormats]
|
sources: list[AnnotationFormats]
|
||||||
|
|
||||||
clip_filter: ClipAnnotationConditionConfig | None = None
|
|
||||||
sound_event_filter: SoundEventConditionConfig | None = None
|
sound_event_filter: SoundEventConditionConfig | None = None
|
||||||
sound_event_transforms: list[SoundEventTransformConfig] = Field(
|
sound_event_transforms: list[SoundEventTransformConfig] = Field(
|
||||||
default_factory=list
|
default_factory=list
|
||||||
@ -87,58 +84,11 @@ def load_dataset(
|
|||||||
apply_transforms: bool = True,
|
apply_transforms: bool = True,
|
||||||
apply_filters: bool = True,
|
apply_filters: bool = True,
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load and merge clip annotations from configured dataset sources.
|
"""Load all clip annotations from the sources defined in a DatasetConfig."""
|
||||||
|
|
||||||
Loads each source listed in ``config.sources`` and returns a flat
|
|
||||||
collection of ``soundevent.data.ClipAnnotation`` objects. Source tags,
|
|
||||||
dataset-level filters, and dataset-level transforms can be enabled or
|
|
||||||
disabled with flags.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
config : DatasetConfig
|
|
||||||
Dataset definition containing source configurations, optional
|
|
||||||
clip-level filter, sound-event filter, and optional sound-event
|
|
||||||
transform pipeline.
|
|
||||||
base_dir : data.PathLike, optional
|
|
||||||
Base directory used to resolve relative paths in source
|
|
||||||
configurations.
|
|
||||||
add_source_tag : bool, default=True
|
|
||||||
If True, append a ``data_source`` tag to each clip annotation with
|
|
||||||
the source name.
|
|
||||||
include_sources : list[str], optional
|
|
||||||
Source names to include. If None, all sources are eligible.
|
|
||||||
exclude_sources : list[str], optional
|
|
||||||
Source names to skip after include filtering. If a source appears in
|
|
||||||
both include and exclude lists, it is skipped.
|
|
||||||
apply_transforms : bool, default=True
|
|
||||||
If True, apply transforms defined in
|
|
||||||
``config.sound_event_transforms``.
|
|
||||||
apply_filters : bool, default=True
|
|
||||||
If True, apply filters defined in ``config.clip_filter`` and
|
|
||||||
``config.sound_event_filter``.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
Dataset
|
|
||||||
Flat collection of clip annotations loaded from the selected sources.
|
|
||||||
"""
|
|
||||||
clip_annotations = []
|
clip_annotations = []
|
||||||
|
|
||||||
clip_condition = (
|
condition = (
|
||||||
build_clip_annotation_condition(
|
build_sound_event_condition(config.sound_event_filter)
|
||||||
config.clip_filter,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
if config.clip_filter is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
sound_event_condition = (
|
|
||||||
build_sound_event_condition(
|
|
||||||
config.sound_event_filter,
|
|
||||||
base_dir=base_dir,
|
|
||||||
)
|
|
||||||
if config.sound_event_filter is not None
|
if config.sound_event_filter is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@ -173,17 +123,10 @@ def load_dataset(
|
|||||||
if add_source_tag:
|
if add_source_tag:
|
||||||
clip_annotation = insert_source_tag(clip_annotation, source)
|
clip_annotation = insert_source_tag(clip_annotation, source)
|
||||||
|
|
||||||
if (
|
if condition is not None and apply_filters:
|
||||||
clip_condition is not None
|
|
||||||
and apply_filters
|
|
||||||
and not clip_condition(clip_annotation)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if sound_event_condition is not None and apply_filters:
|
|
||||||
clip_annotation = filter_clip_annotation(
|
clip_annotation = filter_clip_annotation(
|
||||||
clip_annotation,
|
clip_annotation,
|
||||||
sound_event_condition,
|
condition,
|
||||||
)
|
)
|
||||||
|
|
||||||
if transform is not None and apply_transforms:
|
if transform is not None and apply_transforms:
|
||||||
@ -238,58 +181,47 @@ def load_dataset_from_config(
|
|||||||
path: data.PathLike,
|
path: data.PathLike,
|
||||||
field: str | None = None,
|
field: str | None = None,
|
||||||
base_dir: data.PathLike | None = None,
|
base_dir: data.PathLike | None = None,
|
||||||
add_source_tag: bool = True,
|
|
||||||
include_sources: list[str] | None = None,
|
|
||||||
exclude_sources: list[str] | None = None,
|
|
||||||
apply_transforms: bool = True,
|
|
||||||
apply_filters: bool = True,
|
|
||||||
) -> Dataset:
|
) -> Dataset:
|
||||||
"""Load a dataset by reading a ``DatasetConfig`` from disk.
|
"""Load dataset annotation metadata from a configuration file.
|
||||||
|
|
||||||
This convenience wrapper first loads a ``DatasetConfig`` from ``path``
|
This is a convenience function that first loads the `DatasetConfig` from
|
||||||
and optional ``field``, then delegates to :func:`load_dataset`.
|
the specified file path and optional nested field, and then calls
|
||||||
|
`load_dataset` to load all corresponding `ClipAnnotation` objects.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
path : data.PathLike
|
path : data.PathLike
|
||||||
Path to a configuration file containing a ``DatasetConfig``.
|
Path to the configuration file (e.g., YAML).
|
||||||
field : str, optional
|
field : str, optional
|
||||||
Dot-separated field path to a nested config section. If None, the
|
Dot-separated path to a nested section within the file containing the
|
||||||
full file is parsed as ``DatasetConfig``.
|
dataset configuration (e.g., "data.training_set"). If None, the
|
||||||
base_dir : data.PathLike, optional
|
entire file content is assumed to be the `DatasetConfig`.
|
||||||
Base directory used to resolve relative paths in source
|
base_dir : Path, optional
|
||||||
configurations.
|
An optional base directory path to resolve relative paths within the
|
||||||
add_source_tag : bool, default=True
|
configuration sources. Passed to `load_dataset`. Defaults to None.
|
||||||
If True, append a ``data_source`` tag to each clip annotation.
|
|
||||||
include_sources : list[str], optional
|
|
||||||
Source names to include. If None, all sources are eligible.
|
|
||||||
exclude_sources : list[str], optional
|
|
||||||
Source names to skip after include filtering.
|
|
||||||
apply_transforms : bool, default=True
|
|
||||||
If True, apply transforms defined in the loaded config.
|
|
||||||
apply_filters : bool, default=True
|
|
||||||
If True, apply clip and sound-event filters defined in the loaded
|
|
||||||
config.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
Dataset
|
Dataset (List[data.ClipAnnotation])
|
||||||
Flat collection of clip annotations loaded from the selected sources.
|
A flat list containing all loaded `ClipAnnotation` metadata objects.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
FileNotFoundError
|
||||||
|
If the config file `path` does not exist.
|
||||||
|
yaml.YAMLError, pydantic.ValidationError, KeyError, TypeError
|
||||||
|
If the configuration file is invalid, cannot be parsed, or does not
|
||||||
|
match the `DatasetConfig` schema.
|
||||||
|
Exception
|
||||||
|
Can raise exceptions from `load_dataset` if loading data from sources
|
||||||
|
fails.
|
||||||
"""
|
"""
|
||||||
config = load_config(
|
config = load_config(
|
||||||
path=path,
|
path=path,
|
||||||
schema=DatasetConfig,
|
schema=DatasetConfig,
|
||||||
field=field,
|
field=field,
|
||||||
)
|
)
|
||||||
return load_dataset(
|
return load_dataset(config, base_dir=base_dir)
|
||||||
config,
|
|
||||||
base_dir=base_dir,
|
|
||||||
add_source_tag=add_source_tag,
|
|
||||||
include_sources=include_sources,
|
|
||||||
exclude_sources=exclude_sources,
|
|
||||||
apply_transforms=apply_transforms,
|
|
||||||
apply_filters=apply_filters,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def save_dataset(
|
def save_dataset(
|
||||||
|
|||||||
@ -1,12 +1,9 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import sys
|
import sys
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
|
||||||
Annotated,
|
Annotated,
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
@ -16,23 +13,21 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
from lightning.pytorch.loggers import (
|
||||||
|
CSVLogger,
|
||||||
|
Logger,
|
||||||
|
MLFlowLogger,
|
||||||
|
TensorBoardLogger,
|
||||||
|
)
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from matplotlib.figure import Figure
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.core.configs import BaseConfig
|
from batdetect2.core.configs import BaseConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
import numpy as np
|
|
||||||
import pandas as pd
|
|
||||||
from lightning.pytorch.loggers import (
|
|
||||||
CSVLogger,
|
|
||||||
Logger,
|
|
||||||
MLFlowLogger,
|
|
||||||
TensorBoardLogger,
|
|
||||||
)
|
|
||||||
from matplotlib.figure import Figure
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
DEFAULT_LOGS_DIR: Path = Path("outputs") / "logs"
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
@ -276,16 +271,10 @@ def build_logger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
PlotLogger = Callable[[str, "Figure", int], None]
|
PlotLogger = Callable[[str, Figure, int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
def get_image_logger(logger: Logger) -> PlotLogger | None:
|
||||||
from lightning.pytorch.loggers import (
|
|
||||||
CSVLogger,
|
|
||||||
MLFlowLogger,
|
|
||||||
TensorBoardLogger,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return logger.experiment.add_figure
|
return logger.experiment.add_figure
|
||||||
|
|
||||||
@ -307,16 +296,10 @@ def get_image_logger(logger: Logger) -> PlotLogger | None:
|
|||||||
return partial(save_figure, dir=Path(logger.log_dir))
|
return partial(save_figure, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
|
|
||||||
TableLogger = Callable[[str, "pd.DataFrame", int], None]
|
TableLogger = Callable[[str, pd.DataFrame, int], None]
|
||||||
|
|
||||||
|
|
||||||
def get_table_logger(logger: Logger) -> TableLogger | None:
|
def get_table_logger(logger: Logger) -> TableLogger | None:
|
||||||
from lightning.pytorch.loggers import (
|
|
||||||
CSVLogger,
|
|
||||||
MLFlowLogger,
|
|
||||||
TensorBoardLogger,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(logger, TensorBoardLogger):
|
if isinstance(logger, TensorBoardLogger):
|
||||||
return partial(save_table, dir=Path(logger.log_dir))
|
return partial(save_table, dir=Path(logger.log_dir))
|
||||||
|
|
||||||
@ -354,8 +337,6 @@ def save_figure(name: str, fig: Figure, step: int, dir: Path) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
def _convert_figure_to_array(figure: Figure) -> np.ndarray:
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
with io.BytesIO() as buff:
|
with io.BytesIO() as buff:
|
||||||
figure.savefig(buff, format="raw")
|
figure.savefig(buff, format="raw")
|
||||||
buff.seek(0)
|
buff.seek(0)
|
||||||
|
|||||||
@ -15,11 +15,11 @@ GENERIC_CLASS_KEY = "class"
|
|||||||
|
|
||||||
|
|
||||||
data_source = data.Term(
|
data_source = data.Term(
|
||||||
name="dcterms:source",
|
name="soundevent:data_source",
|
||||||
label="Source",
|
label="Data Source",
|
||||||
uri="http://purl.org/dc/terms/source",
|
|
||||||
definition=(
|
definition=(
|
||||||
"A related resource from which the described resource is derived."
|
"A unique identifier for the source of the data, typically "
|
||||||
|
"representing the project, site, or deployment context."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -45,17 +45,6 @@ individual = data.Term(
|
|||||||
)
|
)
|
||||||
"""Term used for tags identifying a specific individual animal."""
|
"""Term used for tags identifying a specific individual animal."""
|
||||||
|
|
||||||
dataset_split = data.Term(
|
|
||||||
name="batdetect2:split",
|
|
||||||
label="Dataset Split",
|
|
||||||
definition=(
|
|
||||||
"Identifies the specific data partition (e.g., 'train', 'test') "
|
|
||||||
"that the item belongs to within an experimental setup. "
|
|
||||||
"The expected value is a literal text string."
|
|
||||||
),
|
|
||||||
)
|
|
||||||
"""Custom metadata term defining the machine learning partition of an item."""
|
|
||||||
|
|
||||||
generic_class = data.Term(
|
generic_class = data.Term(
|
||||||
name="soundevent:class",
|
name="soundevent:class",
|
||||||
label="Class",
|
label="Class",
|
||||||
|
|||||||
@ -8,10 +8,12 @@ import torch
|
|||||||
from soundevent.geometry import compute_bounds
|
from soundevent.geometry import compute_bounds
|
||||||
|
|
||||||
from batdetect2.api_v2 import BatDetect2API
|
from batdetect2.api_v2 import BatDetect2API
|
||||||
|
from batdetect2.audio import AudioConfig
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
|
from batdetect2.inference import InferenceConfig
|
||||||
from batdetect2.models.detectors import Detector
|
from batdetect2.models.detectors import Detector
|
||||||
from batdetect2.models.heads import ClassifierHead
|
from batdetect2.models.heads import ClassifierHead
|
||||||
from batdetect2.train import load_model_from_checkpoint
|
from batdetect2.train import TrainingConfig, load_model_from_checkpoint
|
||||||
from batdetect2.train.lightning import build_training_module
|
from batdetect2.train.lightning import build_training_module
|
||||||
|
|
||||||
|
|
||||||
@ -46,7 +48,6 @@ def test_process_file_returns_recording_level_predictions(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_process_files_is_batch_size_invariant(
|
def test_process_files_is_batch_size_invariant(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
example_audio_files: list[Path],
|
example_audio_files: list[Path],
|
||||||
@ -181,7 +182,6 @@ def test_user_can_read_extracted_features_per_detection(
|
|||||||
assert all(vec.size > 0 for vec in feature_vectors)
|
assert all(vec.size > 0 for vec in feature_vectors)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_user_can_load_checkpoint_and_finetune(
|
def test_user_can_load_checkpoint_and_finetune(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
example_annotations,
|
example_annotations,
|
||||||
@ -295,7 +295,6 @@ def test_checkpoint_with_same_targets_config_keeps_heads_unchanged(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_user_can_finetune_only_heads(
|
def test_user_can_finetune_only_heads(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
example_annotations,
|
example_annotations,
|
||||||
@ -331,7 +330,6 @@ def test_user_can_finetune_only_heads(
|
|||||||
assert list(finetune_dir.rglob("*.ckpt"))
|
assert list(finetune_dir.rglob("*.ckpt"))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
def test_user_can_evaluate_small_dataset_and_get_metrics(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
example_annotations,
|
example_annotations,
|
||||||
@ -418,7 +416,6 @@ def test_detection_threshold_override_changes_process_file_results(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_detection_threshold_override_is_ephemeral_in_process_file(
|
def test_detection_threshold_override_is_ephemeral_in_process_file(
|
||||||
api_v2: BatDetect2API,
|
api_v2: BatDetect2API,
|
||||||
example_audio_files: list[Path],
|
example_audio_files: list[Path],
|
||||||
@ -455,3 +452,51 @@ def test_detection_threshold_override_changes_spectrogram_results(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert len(strict_detections) <= len(default_detections)
|
assert len(strict_detections) <= len(default_detections)
|
||||||
|
|
||||||
|
|
||||||
|
def test_per_call_overrides_are_ephemeral(monkeypatch) -> None:
|
||||||
|
"""User story: call-level overrides do not mutate resolved defaults."""
|
||||||
|
|
||||||
|
api = BatDetect2API.from_config(BatDetect2Config())
|
||||||
|
|
||||||
|
override_inference = InferenceConfig.model_validate(
|
||||||
|
{"loader": {"batch_size": 7}}
|
||||||
|
)
|
||||||
|
override_audio = AudioConfig.model_validate({"samplerate": 384000})
|
||||||
|
override_train = TrainingConfig.model_validate(
|
||||||
|
{"trainer": {"max_epochs": 2}}
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_process: dict[str, object] = {}
|
||||||
|
captured_train: dict[str, object] = {}
|
||||||
|
|
||||||
|
def fake_process_file_list(*args, **kwargs):
|
||||||
|
captured_process["inference_config"] = kwargs["inference_config"]
|
||||||
|
captured_process["audio_config"] = kwargs["audio_config"]
|
||||||
|
return []
|
||||||
|
|
||||||
|
def fake_run_train(*args, **kwargs):
|
||||||
|
captured_train["train_config"] = kwargs["train_config"]
|
||||||
|
captured_train["audio_config"] = kwargs["audio_config"]
|
||||||
|
captured_train["model_config"] = kwargs["model_config"]
|
||||||
|
return None
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"batdetect2.api_v2.process_file_list", fake_process_file_list
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("batdetect2.api_v2.run_train", fake_run_train)
|
||||||
|
|
||||||
|
api.process_files(
|
||||||
|
[], inference_config=override_inference, audio_config=override_audio
|
||||||
|
)
|
||||||
|
api.train([], train_config=override_train, audio_config=override_audio)
|
||||||
|
|
||||||
|
assert captured_process["inference_config"] is override_inference
|
||||||
|
assert captured_process["audio_config"] is override_audio
|
||||||
|
assert captured_train["train_config"] is override_train
|
||||||
|
assert captured_train["audio_config"] is override_audio
|
||||||
|
assert captured_train["model_config"] is api.model_config
|
||||||
|
|
||||||
|
assert api.inference_config.loader.batch_size != 7
|
||||||
|
assert api.audio_config.samplerate != 384000
|
||||||
|
assert api.train_config.trainer.max_epochs != 2
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from hypothesis import given, settings
|
from hypothesis import given, settings
|
||||||
@ -11,7 +10,6 @@ from batdetect2.utils import audio_utils, detector_utils
|
|||||||
|
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=1))
|
@given(duration=st.floats(min_value=0.1, max_value=1))
|
||||||
@settings(deadline=None)
|
@settings(deadline=None)
|
||||||
@pytest.mark.slow
|
|
||||||
def test_can_compute_correct_spectrogram_width(duration: float):
|
def test_can_compute_correct_spectrogram_width(duration: float):
|
||||||
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
samplerate = parameters.TARGET_SAMPLERATE_HZ
|
||||||
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
params = parameters.DEFAULT_SPECTROGRAM_PARAMETERS
|
||||||
@ -91,7 +89,6 @@ def test_pad_audio_without_fixed_size(duration: float):
|
|||||||
|
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
@settings(deadline=None)
|
@settings(deadline=None)
|
||||||
@pytest.mark.slow
|
|
||||||
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
def test_computed_spectrograms_are_actually_divisible_by_the_spec_divide_factor(
|
||||||
duration: float,
|
duration: float,
|
||||||
):
|
):
|
||||||
|
|||||||
@ -3,13 +3,11 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import pytest
|
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
||||||
"""User story: run legacy detect on example audio directory."""
|
"""User story: run legacy detect on example audio directory."""
|
||||||
|
|
||||||
@ -31,7 +29,6 @@ def test_cli_detect_command_on_test_audio(tmp_path: Path) -> None:
|
|||||||
assert len(list(results_dir.glob("*.json"))) == 3
|
assert len(list(results_dir.glob("*.json"))) == 3
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_cli_detect_command_with_non_trivial_time_expansion(
|
def test_cli_detect_command_with_non_trivial_time_expansion(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -55,7 +52,6 @@ def test_cli_detect_command_with_non_trivial_time_expansion(
|
|||||||
assert "Time Expansion Factor: 10" in result.stdout
|
assert "Time Expansion Factor: 10" in result.stdout
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None:
|
def test_cli_detect_command_with_spec_feature_flag(tmp_path: Path) -> None:
|
||||||
"""User story: request extra spectral features in output CSV."""
|
"""User story: request extra spectral features in output CSV."""
|
||||||
|
|
||||||
|
|||||||
@ -20,7 +20,6 @@ def test_cli_predict_help() -> None:
|
|||||||
assert "dataset" in result.output
|
assert "dataset" in result.output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_cli_predict_directory_runs_on_real_audio(
|
def test_cli_predict_directory_runs_on_real_audio(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
|
|||||||
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
@ -20,7 +19,6 @@ def test_cli_train_help() -> None:
|
|||||||
assert "--model" in result.output
|
assert "--model" in result.output
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_cli_train_from_checkpoint_runs_on_small_dataset(
|
def test_cli_train_from_checkpoint_runs_on_small_dataset(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
tiny_checkpoint_path: Path,
|
tiny_checkpoint_path: Path,
|
||||||
|
|||||||
@ -2,13 +2,11 @@
|
|||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
from batdetect2.cli import cli
|
from batdetect2.cli import cli
|
||||||
|
|
||||||
runner = CliRunner()
|
runner = CliRunner()
|
||||||
pytestmark = pytest.mark.slow
|
|
||||||
|
|
||||||
|
|
||||||
def test_can_process_jeff37_files(
|
def test_can_process_jeff37_files(
|
||||||
|
|||||||
@ -1,303 +0,0 @@
|
|||||||
import json
|
|
||||||
import textwrap
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core import load_config
|
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
ClipAnnotationConditionConfig,
|
|
||||||
build_clip_annotation_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_clip_condition_config(
|
|
||||||
tmp_path: Path,
|
|
||||||
yaml_string: str,
|
|
||||||
) -> ClipAnnotationConditionConfig:
|
|
||||||
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
|
|
||||||
config_path.write_text(textwrap.dedent(yaml_string).strip())
|
|
||||||
return load_config(
|
|
||||||
config_path, schema=TypeAdapter(ClipAnnotationConditionConfig)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_clip_condition_from_yaml(
|
|
||||||
tmp_path: Path,
|
|
||||||
yaml_string: str,
|
|
||||||
base_dir: Path | None = None,
|
|
||||||
):
|
|
||||||
config = load_clip_condition_config(tmp_path, yaml_string)
|
|
||||||
return build_clip_annotation_condition(config, base_dir=base_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def test_recording_satisfies_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
|
||||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
|
||||||
clip_a = create_clip(recording_a)
|
|
||||||
clip_b = create_clip(recording_b)
|
|
||||||
clip_annotation_a = create_clip_annotation(clip_a)
|
|
||||||
clip_annotation_b = create_clip_annotation(clip_b)
|
|
||||||
ids_path = tmp_path / "recording_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: recording_satisfies
|
|
||||||
condition:
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation_a)
|
|
||||||
assert not condition(clip_annotation_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_id_in_list_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
|
||||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
|
||||||
clip_annotation_a = create_clip_annotation(create_clip(recording_a))
|
|
||||||
clip_annotation_b = create_clip_annotation(create_clip(recording_b))
|
|
||||||
ids_path = tmp_path / "clip_annotation_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([str(clip_annotation_a.uuid)]))
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation_a)
|
|
||||||
assert not condition(clip_annotation_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_has_tag_conditions(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
reviewed = data.Tag(key="status", value="reviewed")
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
|
|
||||||
recording = create_recording(path=tmp_path / "rec.wav")
|
|
||||||
clip = create_clip(recording)
|
|
||||||
clip_annotation = create_clip_annotation(
|
|
||||||
clip,
|
|
||||||
clip_tags=[reviewed, train],
|
|
||||||
)
|
|
||||||
clip_annotation_missing = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[train],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: status
|
|
||||||
value: reviewed
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation)
|
|
||||||
assert not condition(clip_annotation_missing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_has_all_tags_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
reviewed = data.Tag(key="status", value="reviewed")
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
|
|
||||||
recording = create_recording(path=tmp_path / "rec.wav")
|
|
||||||
clip_annotation = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[reviewed, train],
|
|
||||||
)
|
|
||||||
clip_annotation_missing = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[reviewed],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_all_tags
|
|
||||||
tags:
|
|
||||||
- key: status
|
|
||||||
value: reviewed
|
|
||||||
- key: split
|
|
||||||
value: train
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation)
|
|
||||||
assert not condition(clip_annotation_missing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_has_any_tag_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
reviewed = data.Tag(key="status", value="reviewed")
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
|
|
||||||
recording = create_recording(path=tmp_path / "rec.wav")
|
|
||||||
clip_annotation = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[reviewed, train],
|
|
||||||
)
|
|
||||||
clip_annotation_missing = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[data.Tag(key="split", value="test")],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_any_tag
|
|
||||||
tags:
|
|
||||||
- key: split
|
|
||||||
value: val
|
|
||||||
- key: split
|
|
||||||
value: train
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation)
|
|
||||||
assert not condition(clip_annotation_missing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_all_of_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
reviewed = data.Tag(key="status", value="reviewed")
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
|
|
||||||
recording = create_recording(path=tmp_path / "rec.wav")
|
|
||||||
clip = create_clip(recording)
|
|
||||||
clip_annotation = create_clip_annotation(
|
|
||||||
clip,
|
|
||||||
clip_tags=[reviewed, train],
|
|
||||||
)
|
|
||||||
clip_annotation_missing = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[reviewed],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: all_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: status
|
|
||||||
value: reviewed
|
|
||||||
- name: has_any_tag
|
|
||||||
tags:
|
|
||||||
- key: split
|
|
||||||
value: train
|
|
||||||
- key: split
|
|
||||||
value: val
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation)
|
|
||||||
assert not condition(clip_annotation_missing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_any_of_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
reviewed = data.Tag(key="status", value="reviewed")
|
|
||||||
|
|
||||||
recording = create_recording(path=tmp_path / "rec.wav")
|
|
||||||
clip_annotation = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[reviewed],
|
|
||||||
)
|
|
||||||
clip_annotation_missing = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[data.Tag(key="status", value="unchecked")],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: any_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: split
|
|
||||||
value: val
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: status
|
|
||||||
value: reviewed
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation)
|
|
||||||
assert not condition(clip_annotation_missing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_clip_not_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
create_clip,
|
|
||||||
create_clip_annotation,
|
|
||||||
) -> None:
|
|
||||||
recording = create_recording(path=tmp_path / "rec.wav")
|
|
||||||
clip_annotation = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[data.Tag(key="split", value="train")],
|
|
||||||
)
|
|
||||||
clip_annotation_missing = create_clip_annotation(
|
|
||||||
create_clip(recording),
|
|
||||||
clip_tags=[data.Tag(key="split", value="val")],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_clip_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: not
|
|
||||||
condition:
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: split
|
|
||||||
value: val
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(clip_annotation)
|
|
||||||
assert not condition(clip_annotation_missing)
|
|
||||||
@ -1,564 +0,0 @@
|
|||||||
import json
|
|
||||||
import textwrap
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core import load_config
|
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
RecordingConditionConfig,
|
|
||||||
build_recording_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_recording_condition_config(
|
|
||||||
tmp_path: Path,
|
|
||||||
yaml_string: str,
|
|
||||||
) -> RecordingConditionConfig:
|
|
||||||
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
|
|
||||||
config_path.write_text(textwrap.dedent(yaml_string).strip())
|
|
||||||
return load_config(
|
|
||||||
config_path,
|
|
||||||
schema=TypeAdapter(RecordingConditionConfig),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_recording_condition_from_yaml(
|
|
||||||
tmp_path: Path,
|
|
||||||
yaml_string: str,
|
|
||||||
base_dir: Path | None = None,
|
|
||||||
):
|
|
||||||
config = load_recording_condition_config(tmp_path, yaml_string)
|
|
||||||
return build_recording_condition(config, base_dir=base_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition(tmp_path: Path, create_recording) -> None:
|
|
||||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
|
||||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
|
||||||
ids_path = tmp_path / "recording_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition_uses_base_dir(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
|
||||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
|
||||||
split_dir = tmp_path / "splits"
|
|
||||||
split_dir.mkdir()
|
|
||||||
ids_path = split_dir / "train_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([str(recording_a.uuid)]))
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: id_in_list
|
|
||||||
path: splits/train_ids.json
|
|
||||||
""",
|
|
||||||
base_dir=tmp_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition_raises_for_non_list_json(
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
ids_path = tmp_path / "recording_ids.json"
|
|
||||||
ids_path.write_text(json.dumps({"id": "foo"}))
|
|
||||||
|
|
||||||
with pytest.raises(TypeError, match="Expected JSON list"):
|
|
||||||
build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition_raises_for_invalid_id(tmp_path: Path) -> None:
|
|
||||||
ids_path = tmp_path / "recording_ids.json"
|
|
||||||
ids_path.write_text(json.dumps(["not-a-uuid"]))
|
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Invalid ID"):
|
|
||||||
build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition_supports_txt_format(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
|
||||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
|
||||||
ids_path = tmp_path / "recording_ids.txt"
|
|
||||||
ids_path.write_text(f"{recording_a.uuid}\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
format: txt
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition_supports_json_field(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
|
||||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
|
||||||
ids_path = tmp_path / "recording_ids.json"
|
|
||||||
ids_path.write_text(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"train": [str(recording_a.uuid)],
|
|
||||||
"val": [str(recording_b.uuid)],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
format:
|
|
||||||
name: json
|
|
||||||
field: train
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition_supports_csv_column(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
recording_a = create_recording(path=tmp_path / "a.wav")
|
|
||||||
recording_b = create_recording(path=tmp_path / "b.wav")
|
|
||||||
ids_path = tmp_path / "recording_ids.csv"
|
|
||||||
ids_path.write_text(f"recording_uuid\n{recording_a.uuid}\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
format:
|
|
||||||
name: csv
|
|
||||||
column: recording_uuid
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_path_in_list_condition_supports_txt_format(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
audio_dir = tmp_path / "audio"
|
|
||||||
audio_dir.mkdir()
|
|
||||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
|
||||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
|
||||||
paths_file = tmp_path / "recording_paths.txt"
|
|
||||||
paths_file.write_text(f"{recording_a.path}\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: path_in_list
|
|
||||||
path: {paths_file}
|
|
||||||
format: txt
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_path_in_list_condition_supports_json_field(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
audio_dir = tmp_path / "audio"
|
|
||||||
audio_dir.mkdir()
|
|
||||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
|
||||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
|
||||||
paths_file = tmp_path / "recording_paths.json"
|
|
||||||
paths_file.write_text(
|
|
||||||
json.dumps(
|
|
||||||
{
|
|
||||||
"train": [str(recording_a.path)],
|
|
||||||
"val": [str(recording_b.path)],
|
|
||||||
}
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: path_in_list
|
|
||||||
path: {paths_file}
|
|
||||||
format:
|
|
||||||
name: json
|
|
||||||
field: train
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_path_in_list_condition_supports_csv_column(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
audio_dir = tmp_path / "audio"
|
|
||||||
audio_dir.mkdir()
|
|
||||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
|
||||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
|
||||||
paths_file = tmp_path / "recording_paths.csv"
|
|
||||||
paths_file.write_text(f"recording_path\n{recording_a.path}\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: path_in_list
|
|
||||||
path: {paths_file}
|
|
||||||
format:
|
|
||||||
name: csv
|
|
||||||
column: recording_path
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_path_in_list_condition_uses_base_dir(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
data_dir = tmp_path / "dataset"
|
|
||||||
audio_dir = data_dir / "audio"
|
|
||||||
audio_dir.mkdir(parents=True)
|
|
||||||
recording_a = create_recording(path=audio_dir / "a.wav")
|
|
||||||
recording_b = create_recording(path=audio_dir / "b.wav")
|
|
||||||
paths_file = tmp_path / "recording_paths.txt"
|
|
||||||
paths_file.write_text(f"{recording_a.path}\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: path_in_list
|
|
||||||
path: {paths_file}
|
|
||||||
format: txt
|
|
||||||
base_dir: {data_dir}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_path_in_list_condition_outside_allow(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
data_dir = tmp_path / "dataset"
|
|
||||||
inside_dir = data_dir / "audio"
|
|
||||||
inside_dir.mkdir(parents=True)
|
|
||||||
outside_dir = tmp_path / "other"
|
|
||||||
outside_dir.mkdir()
|
|
||||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
|
||||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
|
||||||
paths_file = tmp_path / "recording_paths.txt"
|
|
||||||
paths_file.write_text("dataset/audio/unknown.wav\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: path_in_list
|
|
||||||
path: {paths_file}
|
|
||||||
format: txt
|
|
||||||
base_dir: {data_dir}
|
|
||||||
on_outside: allow
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_outside)
|
|
||||||
assert not condition(recording_inside)
|
|
||||||
|
|
||||||
|
|
||||||
def test_path_in_list_condition_outside_warn(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
data_dir = tmp_path / "dataset"
|
|
||||||
inside_dir = data_dir / "audio"
|
|
||||||
inside_dir.mkdir(parents=True)
|
|
||||||
outside_dir = tmp_path / "other"
|
|
||||||
outside_dir.mkdir()
|
|
||||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
|
||||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
|
||||||
paths_file = tmp_path / "recording_paths.txt"
|
|
||||||
paths_file.write_text("dataset/audio/unknown.wav\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: path_in_list
|
|
||||||
path: {paths_file}
|
|
||||||
format: txt
|
|
||||||
base_dir: {data_dir}
|
|
||||||
on_outside: warn
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_outside)
|
|
||||||
assert not condition(recording_inside)
|
|
||||||
|
|
||||||
|
|
||||||
def test_path_in_list_condition_outside_error(
|
|
||||||
tmp_path: Path,
|
|
||||||
create_recording,
|
|
||||||
) -> None:
|
|
||||||
data_dir = tmp_path / "dataset"
|
|
||||||
inside_dir = data_dir / "audio"
|
|
||||||
inside_dir.mkdir(parents=True)
|
|
||||||
outside_dir = tmp_path / "other"
|
|
||||||
outside_dir.mkdir()
|
|
||||||
recording_inside = create_recording(path=inside_dir / "a.wav")
|
|
||||||
recording_outside = create_recording(path=outside_dir / "x.wav")
|
|
||||||
paths_file = tmp_path / "recording_paths.txt"
|
|
||||||
paths_file.write_text(f"{recording_inside.path}\n")
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: path_in_list
|
|
||||||
path: {paths_file}
|
|
||||||
format: txt
|
|
||||||
base_dir: {data_dir}
|
|
||||||
on_outside: error
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_inside)
|
|
||||||
with pytest.raises(ValueError, match="outside"):
|
|
||||||
condition(recording_outside)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_tag_condition(tmp_path: Path, create_recording) -> None:
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
val = data.Tag(key="split", value="val")
|
|
||||||
|
|
||||||
recording_a = create_recording(
|
|
||||||
path=tmp_path / "a.wav",
|
|
||||||
tags=[train],
|
|
||||||
)
|
|
||||||
recording_b = create_recording(
|
|
||||||
path=tmp_path / "b.wav",
|
|
||||||
tags=[val],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: split
|
|
||||||
value: train
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_all_tags_condition(tmp_path: Path, create_recording) -> None:
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
uk = data.Tag(key="region", value="uk")
|
|
||||||
|
|
||||||
recording_a = create_recording(
|
|
||||||
path=tmp_path / "a.wav",
|
|
||||||
tags=[train, uk],
|
|
||||||
)
|
|
||||||
recording_b = create_recording(
|
|
||||||
path=tmp_path / "b.wav",
|
|
||||||
tags=[train],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_all_tags
|
|
||||||
tags:
|
|
||||||
- key: split
|
|
||||||
value: train
|
|
||||||
- key: region
|
|
||||||
value: uk
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_any_tag_condition(tmp_path: Path, create_recording) -> None:
|
|
||||||
uk = data.Tag(key="region", value="uk")
|
|
||||||
us = data.Tag(key="region", value="us")
|
|
||||||
|
|
||||||
recording_a = create_recording(
|
|
||||||
path=tmp_path / "a.wav",
|
|
||||||
tags=[uk],
|
|
||||||
)
|
|
||||||
recording_b = create_recording(
|
|
||||||
path=tmp_path / "b.wav",
|
|
||||||
tags=[us],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_any_tag
|
|
||||||
tags:
|
|
||||||
- key: region
|
|
||||||
value: eu
|
|
||||||
- key: region
|
|
||||||
value: uk
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_of_condition(tmp_path: Path, create_recording) -> None:
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
uk = data.Tag(key="region", value="uk")
|
|
||||||
us = data.Tag(key="region", value="us")
|
|
||||||
|
|
||||||
recording_a = create_recording(
|
|
||||||
path=tmp_path / "a.wav",
|
|
||||||
tags=[train, uk],
|
|
||||||
)
|
|
||||||
recording_b = create_recording(
|
|
||||||
path=tmp_path / "b.wav",
|
|
||||||
tags=[train, us],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: all_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: split
|
|
||||||
value: train
|
|
||||||
- name: has_any_tag
|
|
||||||
tags:
|
|
||||||
- key: region
|
|
||||||
value: eu
|
|
||||||
- key: region
|
|
||||||
value: uk
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_any_of_condition(tmp_path: Path, create_recording) -> None:
|
|
||||||
train = data.Tag(key="split", value="train")
|
|
||||||
us = data.Tag(key="region", value="us")
|
|
||||||
|
|
||||||
recording_a = create_recording(
|
|
||||||
path=tmp_path / "a.wav",
|
|
||||||
tags=[train],
|
|
||||||
)
|
|
||||||
recording_b = create_recording(
|
|
||||||
path=tmp_path / "b.wav",
|
|
||||||
tags=[us],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: any_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: region
|
|
||||||
value: eu
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: split
|
|
||||||
value: train
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
|
|
||||||
|
|
||||||
def test_not_condition(tmp_path: Path, create_recording) -> None:
|
|
||||||
uk = data.Tag(key="region", value="uk")
|
|
||||||
us = data.Tag(key="region", value="us")
|
|
||||||
|
|
||||||
recording_a = create_recording(
|
|
||||||
path=tmp_path / "a.wav",
|
|
||||||
tags=[uk],
|
|
||||||
)
|
|
||||||
recording_b = create_recording(
|
|
||||||
path=tmp_path / "b.wav",
|
|
||||||
tags=[us],
|
|
||||||
)
|
|
||||||
|
|
||||||
condition = build_recording_condition_from_yaml(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: not
|
|
||||||
condition:
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: region
|
|
||||||
value: us
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(recording_a)
|
|
||||||
assert not condition(recording_b)
|
|
||||||
@ -1,400 +0,0 @@
|
|||||||
import json
|
|
||||||
import textwrap
|
|
||||||
import uuid
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.core import load_config
|
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
SoundEventConditionConfig,
|
|
||||||
build_sound_event_condition,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_sound_event_condition_config(
|
|
||||||
tmp_path: Path,
|
|
||||||
yaml_string: str,
|
|
||||||
) -> SoundEventConditionConfig:
|
|
||||||
config_path = tmp_path / f"{uuid.uuid4().hex}.yaml"
|
|
||||||
config_path.write_text(textwrap.dedent(yaml_string).strip())
|
|
||||||
return load_config(
|
|
||||||
config_path,
|
|
||||||
schema=TypeAdapter(SoundEventConditionConfig),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_condition_from_str(
|
|
||||||
tmp_path: Path,
|
|
||||||
yaml_string: str,
|
|
||||||
base_dir: Path | None = None,
|
|
||||||
):
|
|
||||||
config = load_sound_event_condition_config(tmp_path, yaml_string)
|
|
||||||
return build_sound_event_condition(config, base_dir=base_dir)
|
|
||||||
|
|
||||||
|
|
||||||
def create_sound_event_annotation(
|
|
||||||
recording: data.Recording,
|
|
||||||
geometry: data.Geometry,
|
|
||||||
tags: list[data.Tag] | None = None,
|
|
||||||
) -> data.SoundEventAnnotation:
|
|
||||||
return data.SoundEventAnnotation(
|
|
||||||
sound_event=data.SoundEvent(
|
|
||||||
recording=recording,
|
|
||||||
geometry=geometry,
|
|
||||||
),
|
|
||||||
tags=tags or [],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_tag_condition(
|
|
||||||
sound_event: data.SoundEvent, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
|
||||||
)
|
|
||||||
failing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_all_tags_condition(
|
|
||||||
sound_event: data.SoundEvent,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_all_tags
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- key: event
|
|
||||||
value: Echolocation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Myotis myotis"),
|
|
||||||
data.Tag(key="event", value="Echolocation"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
failing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_any_tag_condition(
|
|
||||||
sound_event: data.SoundEvent,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_any_tag
|
|
||||||
tags:
|
|
||||||
- key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- key: event
|
|
||||||
value: Echolocation
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="event", value="Echolocation")],
|
|
||||||
)
|
|
||||||
failing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[
|
|
||||||
data.Tag(key="species", value="Eptesicus fuscus"),
|
|
||||||
data.Tag(key="event", value="Social"),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_not_condition(sound_event: data.SoundEvent, tmp_path: Path) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: not
|
|
||||||
condition:
|
|
||||||
name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
|
||||||
)
|
|
||||||
failing = data.SoundEventAnnotation(
|
|
||||||
sound_event=sound_event,
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition(
|
|
||||||
sound_event: data.SoundEvent, tmp_path: Path
|
|
||||||
) -> None:
|
|
||||||
passing = data.SoundEventAnnotation(sound_event=sound_event)
|
|
||||||
failing = data.SoundEventAnnotation(sound_event=sound_event)
|
|
||||||
ids_path = tmp_path / "sound_event_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([str(passing.uuid)]))
|
|
||||||
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: id_in_list
|
|
||||||
path: {ids_path}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_id_in_list_condition_uses_base_dir(
|
|
||||||
sound_event: data.SoundEvent,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
passing = data.SoundEventAnnotation(sound_event=sound_event)
|
|
||||||
failing = data.SoundEventAnnotation(sound_event=sound_event)
|
|
||||||
split_dir = tmp_path / "splits"
|
|
||||||
split_dir.mkdir()
|
|
||||||
ids_path = split_dir / "sound_event_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([str(passing.uuid)]))
|
|
||||||
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: id_in_list
|
|
||||||
path: splits/sound_event_ids.json
|
|
||||||
""",
|
|
||||||
base_dir=tmp_path,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"operator,seconds,passing_duration,failing_duration",
|
|
||||||
[
|
|
||||||
("lt", 2, 1, 2),
|
|
||||||
("lte", 2, 2, 3),
|
|
||||||
("gt", 2, 3, 2),
|
|
||||||
("gte", 2, 2, 1),
|
|
||||||
("eq", 2, 2, 3),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_duration_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
recording: data.Recording,
|
|
||||||
operator: str,
|
|
||||||
seconds: int,
|
|
||||||
passing_duration: int,
|
|
||||||
failing_duration: int,
|
|
||||||
) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: duration
|
|
||||||
operator: {operator}
|
|
||||||
seconds: {seconds}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, passing_duration]),
|
|
||||||
)
|
|
||||||
failing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, failing_duration]),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"boundary,operator,hertz,passing_bbox,failing_bbox",
|
|
||||||
[
|
|
||||||
("high", "lt", 300, [0, 100, 1, 200], [0, 100, 1, 300]),
|
|
||||||
("high", "lte", 300, [0, 100, 1, 300], [0, 100, 1, 400]),
|
|
||||||
("high", "gt", 300, [0, 100, 1, 400], [0, 100, 1, 300]),
|
|
||||||
("high", "gte", 300, [0, 100, 1, 300], [0, 100, 1, 200]),
|
|
||||||
("high", "eq", 300, [0, 100, 1, 300], [0, 100, 1, 400]),
|
|
||||||
("low", "lt", 200, [0, 100, 1, 400], [0, 200, 1, 400]),
|
|
||||||
("low", "lte", 200, [0, 200, 1, 400], [0, 300, 1, 400]),
|
|
||||||
("low", "gt", 200, [0, 300, 1, 400], [0, 200, 1, 400]),
|
|
||||||
("low", "gte", 200, [0, 200, 1, 400], [0, 100, 1, 400]),
|
|
||||||
("low", "eq", 200, [0, 200, 1, 400], [0, 300, 1, 400]),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_frequency_condition(
|
|
||||||
tmp_path: Path,
|
|
||||||
recording: data.Recording,
|
|
||||||
boundary: str,
|
|
||||||
operator: str,
|
|
||||||
hertz: int,
|
|
||||||
passing_bbox: list[int],
|
|
||||||
failing_bbox: list[int],
|
|
||||||
) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
f"""
|
|
||||||
name: frequency
|
|
||||||
boundary: {boundary}
|
|
||||||
operator: {operator}
|
|
||||||
hertz: {hertz}
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(
|
|
||||||
coordinates=[float(value) for value in passing_bbox]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
failing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(
|
|
||||||
coordinates=[float(value) for value in failing_bbox]
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_frequency_condition_is_false_for_temporal_geometries(
|
|
||||||
tmp_path: Path,
|
|
||||||
recording: data.Recording,
|
|
||||||
) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: frequency
|
|
||||||
boundary: low
|
|
||||||
operator: eq
|
|
||||||
hertz: 200
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.BoundingBox(coordinates=[0, 200, 1, 400]),
|
|
||||||
)
|
|
||||||
failing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 3]),
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_has_all_tags_fails_if_empty(tmp_path: Path) -> None:
|
|
||||||
with pytest.raises(ValueError, match="at least one tag"):
|
|
||||||
build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: has_all_tags
|
|
||||||
tags: []
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_all_of_condition(tmp_path: Path, recording: data.Recording) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: all_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- name: duration
|
|
||||||
operator: lt
|
|
||||||
seconds: 1
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
|
||||||
)
|
|
||||||
failing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
|
|
||||||
|
|
||||||
def test_any_of_condition(tmp_path: Path, recording: data.Recording) -> None:
|
|
||||||
condition = build_condition_from_str(
|
|
||||||
tmp_path,
|
|
||||||
"""
|
|
||||||
name: any_of
|
|
||||||
conditions:
|
|
||||||
- name: has_tag
|
|
||||||
tag:
|
|
||||||
key: species
|
|
||||||
value: Myotis myotis
|
|
||||||
- name: duration
|
|
||||||
operator: lt
|
|
||||||
seconds: 1
|
|
||||||
""",
|
|
||||||
)
|
|
||||||
|
|
||||||
passing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
|
||||||
tags=[data.Tag(key="species", value="Myotis myotis")],
|
|
||||||
)
|
|
||||||
failing = create_sound_event_annotation(
|
|
||||||
recording=recording,
|
|
||||||
geometry=data.TimeInterval(coordinates=[0, 2]),
|
|
||||||
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert condition(passing)
|
|
||||||
assert not condition(failing)
|
|
||||||
@ -1,100 +0,0 @@
|
|||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from soundevent import data
|
|
||||||
|
|
||||||
from batdetect2.data import DatasetConfig, load_dataset
|
|
||||||
from batdetect2.data.conditions import (
|
|
||||||
HasTagConfig,
|
|
||||||
IdInListConfig,
|
|
||||||
RecordingSatisfiesConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_dataset_applies_clip_filter(
|
|
||||||
example_dataset: DatasetConfig,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
baseline = list(load_dataset(example_dataset))
|
|
||||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
|
||||||
ids_path = tmp_path / "train_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
|
||||||
|
|
||||||
config = example_dataset.model_copy(
|
|
||||||
update={
|
|
||||||
"clip_filter": RecordingSatisfiesConfig(
|
|
||||||
condition=IdInListConfig(path=ids_path)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered = list(load_dataset(config))
|
|
||||||
|
|
||||||
assert len(filtered) == 1
|
|
||||||
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_dataset_clip_filter_is_skipped_when_filters_disabled(
|
|
||||||
example_dataset: DatasetConfig,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
baseline = list(load_dataset(example_dataset))
|
|
||||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
|
||||||
ids_path = tmp_path / "train_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
|
||||||
|
|
||||||
config = example_dataset.model_copy(
|
|
||||||
update={
|
|
||||||
"clip_filter": RecordingSatisfiesConfig(
|
|
||||||
condition=IdInListConfig(path=ids_path)
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered = list(load_dataset(config, apply_filters=False))
|
|
||||||
|
|
||||||
assert len(filtered) == len(baseline)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_dataset_resolves_clip_filter_paths_from_base_dir(
|
|
||||||
example_dataset: DatasetConfig,
|
|
||||||
tmp_path: Path,
|
|
||||||
) -> None:
|
|
||||||
baseline = list(load_dataset(example_dataset))
|
|
||||||
keep_recording_id = str(baseline[0].clip.recording.uuid)
|
|
||||||
split_dir = tmp_path / "splits"
|
|
||||||
split_dir.mkdir()
|
|
||||||
ids_path = split_dir / "train_ids.json"
|
|
||||||
ids_path.write_text(json.dumps([keep_recording_id]))
|
|
||||||
|
|
||||||
config = example_dataset.model_copy(
|
|
||||||
update={
|
|
||||||
"clip_filter": RecordingSatisfiesConfig(
|
|
||||||
condition=IdInListConfig(path=Path("splits/train_ids.json"))
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered = list(load_dataset(config, base_dir=tmp_path))
|
|
||||||
|
|
||||||
assert len(filtered) == 1
|
|
||||||
assert str(filtered[0].clip.recording.uuid) == keep_recording_id
|
|
||||||
|
|
||||||
|
|
||||||
def test_sound_event_filter_keeps_empty_clips(
|
|
||||||
example_dataset: DatasetConfig,
|
|
||||||
) -> None:
|
|
||||||
config = example_dataset.model_copy(
|
|
||||||
update={
|
|
||||||
"sound_event_filter": HasTagConfig(
|
|
||||||
tag=data.Tag(key="species", value="__missing_species__")
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
filtered = list(load_dataset(config))
|
|
||||||
|
|
||||||
assert len(filtered) == 3
|
|
||||||
assert all(
|
|
||||||
len(clip_annotation.sound_events) == 0 for clip_annotation in filtered
|
|
||||||
)
|
|
||||||
491
tests/test_data/test_transforms/test_conditions.py
Normal file
491
tests/test_data/test_transforms/test_conditions.py
Normal file
@ -0,0 +1,491 @@
|
|||||||
|
import textwrap
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
from pydantic import TypeAdapter
|
||||||
|
from soundevent import data
|
||||||
|
|
||||||
|
from batdetect2.data.conditions import (
|
||||||
|
SoundEventConditionConfig,
|
||||||
|
build_sound_event_condition,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_condition_from_str(content):
|
||||||
|
content = textwrap.dedent(content)
|
||||||
|
content = yaml.safe_load(content)
|
||||||
|
config = TypeAdapter(SoundEventConditionConfig).validate_python(content)
|
||||||
|
return build_sound_event_condition(config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_tag(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_all_tags(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: has_all_tags
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- key: event
|
||||||
|
value: Echolocation
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||||
|
data.Tag(key="event", value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
|
data.Tag(key="event", value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
|
data.Tag(key="event", value="Echolocation"),
|
||||||
|
data.Tag(key="sex", value="Female"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_any_tags(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: has_any_tag
|
||||||
|
tags:
|
||||||
|
- key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- key: event
|
||||||
|
value: Echolocation
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||||
|
data.Tag(key="event", value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
|
data.Tag(key="event", value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Eptesicus fuscus"),
|
||||||
|
data.Tag(key="event", value="Social"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_not(sound_event: data.SoundEvent):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: not
|
||||||
|
condition:
|
||||||
|
name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
""")
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
|
)
|
||||||
|
assert condition(sound_event_annotation)
|
||||||
|
|
||||||
|
sound_event_annotation = data.SoundEventAnnotation(
|
||||||
|
sound_event=sound_event,
|
||||||
|
tags=[
|
||||||
|
data.Tag(key="species", value="Myotis myotis"),
|
||||||
|
data.Tag(key="event", value="Echolocation"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
assert not condition(sound_event_annotation)
|
||||||
|
|
||||||
|
|
||||||
|
def test_duration(recording: data.Recording):
|
||||||
|
se1 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording, geometry=data.TimeInterval(coordinates=[0, 1])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se2 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording, geometry=data.TimeInterval(coordinates=[0, 2])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se3 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording, geometry=data.TimeInterval(coordinates=[0, 3])
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: lt
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
assert condition(se1)
|
||||||
|
assert not condition(se2)
|
||||||
|
assert not condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: lte
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert condition(se1)
|
||||||
|
assert condition(se2)
|
||||||
|
assert not condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: gt
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se1)
|
||||||
|
assert not condition(se2)
|
||||||
|
assert condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: gte
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se1)
|
||||||
|
assert condition(se2)
|
||||||
|
assert condition(se3)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: duration
|
||||||
|
operator: eq
|
||||||
|
seconds: 2
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se1)
|
||||||
|
assert condition(se2)
|
||||||
|
assert not condition(se3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency(recording: data.Recording):
|
||||||
|
se12 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 100, 1, 200]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se13 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 100, 2, 300]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se14 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 100, 3, 400]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se24 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 200, 3, 400]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
se34 = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
recording=recording,
|
||||||
|
geometry=data.BoundingBox(coordinates=[0, 300, 3, 400]),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: lt
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
assert condition(se12)
|
||||||
|
assert not condition(se13)
|
||||||
|
assert not condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: lte
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert condition(se12)
|
||||||
|
assert condition(se13)
|
||||||
|
assert not condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: gt
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se12)
|
||||||
|
assert not condition(se13)
|
||||||
|
assert condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: gte
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se12)
|
||||||
|
assert condition(se13)
|
||||||
|
assert condition(se14)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: high
|
||||||
|
operator: eq
|
||||||
|
hertz: 300
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se12)
|
||||||
|
assert condition(se13)
|
||||||
|
assert not condition(se14)
|
||||||
|
|
||||||
|
# LOW
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: lt
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
assert condition(se14)
|
||||||
|
assert not condition(se24)
|
||||||
|
assert not condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: lte
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert condition(se14)
|
||||||
|
assert condition(se24)
|
||||||
|
assert not condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: gt
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se14)
|
||||||
|
assert not condition(se24)
|
||||||
|
assert condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: gte
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se14)
|
||||||
|
assert condition(se24)
|
||||||
|
assert condition(se34)
|
||||||
|
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: eq
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
|
||||||
|
assert not condition(se14)
|
||||||
|
assert condition(se24)
|
||||||
|
assert not condition(se34)
|
||||||
|
|
||||||
|
|
||||||
|
def test_frequency_is_false_for_temporal_geometries(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: frequency
|
||||||
|
boundary: low
|
||||||
|
operator: eq
|
||||||
|
hertz: 200
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 3]),
|
||||||
|
recording=recording,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeStamp(coordinates=3),
|
||||||
|
recording=recording,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
|
||||||
|
def test_has_tags_fails_if_empty():
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
build_condition_from_str("""
|
||||||
|
name: has_tags
|
||||||
|
tags: []
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def test_all_of(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: all_of
|
||||||
|
conditions:
|
||||||
|
- name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- name: duration
|
||||||
|
operator: lt
|
||||||
|
seconds: 1
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
|
||||||
|
def test_any_of(recording: data.Recording):
|
||||||
|
condition = build_condition_from_str("""
|
||||||
|
name: any_of
|
||||||
|
conditions:
|
||||||
|
- name: has_tag
|
||||||
|
tag:
|
||||||
|
key: species
|
||||||
|
value: Myotis myotis
|
||||||
|
- name: duration
|
||||||
|
operator: lt
|
||||||
|
seconds: 1
|
||||||
|
""")
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
|
)
|
||||||
|
assert not condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 2]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Myotis myotis")],
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
|
|
||||||
|
se = data.SoundEventAnnotation(
|
||||||
|
sound_event=data.SoundEvent(
|
||||||
|
geometry=data.TimeInterval(coordinates=[0, 0.5]),
|
||||||
|
recording=recording,
|
||||||
|
),
|
||||||
|
tags=[data.Tag(key="species", value="Eptesicus fuscus")],
|
||||||
|
)
|
||||||
|
assert condition(se)
|
||||||
@ -2,14 +2,11 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from batdetect2 import api
|
from batdetect2 import api
|
||||||
|
|
||||||
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_no_detections_above_nyquist():
|
def test_no_detections_above_nyquist():
|
||||||
"""Test that no detections are made above the nyquist frequency."""
|
"""Test that no detections are made above the nyquist frequency."""
|
||||||
# Recording donated by @@kdarras
|
# Recording donated by @@kdarras
|
||||||
|
|||||||
@ -4,7 +4,6 @@ from pathlib import Path
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
|
||||||
from hypothesis import given, settings
|
from hypothesis import given, settings
|
||||||
from hypothesis import strategies as st
|
from hypothesis import strategies as st
|
||||||
|
|
||||||
@ -14,7 +13,6 @@ from batdetect2.detector import parameters
|
|||||||
|
|
||||||
@settings(deadline=None, max_examples=5)
|
@settings(deadline=None, max_examples=5)
|
||||||
@given(duration=st.floats(min_value=0.1, max_value=2))
|
@given(duration=st.floats(min_value=0.1, max_value=2))
|
||||||
@pytest.mark.slow
|
|
||||||
def test_can_import_model_without_pickle(duration: float):
|
def test_can_import_model_without_pickle(duration: float):
|
||||||
# NOTE: remove this test once no other issues are found This is a temporary
|
# NOTE: remove this test once no other issues are found This is a temporary
|
||||||
# test to check that change in model loading did not impact model behaviour
|
# test to check that change in model loading did not impact model behaviour
|
||||||
@ -44,7 +42,6 @@ def test_can_import_model_without_pickle(duration: float):
|
|||||||
assert predictions_without_pickle == predictions_with_pickle
|
assert predictions_without_pickle == predictions_with_pickle
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_can_import_model_without_pickle_on_test_data(
|
def test_can_import_model_without_pickle_on_test_data(
|
||||||
example_audio_files: List[Path],
|
example_audio_files: List[Path],
|
||||||
):
|
):
|
||||||
|
|||||||
@ -96,7 +96,7 @@ def test_registry_build_unknown_name_raises():
|
|||||||
name = "NonExistentBackbone"
|
name = "NonExistentBackbone"
|
||||||
|
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
backbone_registry.build(FakeConfig()) # ty: ignore[invalid-argument-type]
|
backbone_registry.build(FakeConfig()) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
def test_backbone_config_validates_unet_from_dict():
|
def test_backbone_config_validates_unet_from_dict():
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
from soundevent import data, terms
|
from soundevent import data, terms
|
||||||
|
|
||||||
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
from batdetect2.targets import TargetConfig, build_roi_mapping, build_targets
|
||||||
|
|||||||
@ -1,13 +1,10 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
from soundevent import data
|
from soundevent import data
|
||||||
|
|
||||||
from batdetect2.config import BatDetect2Config
|
from batdetect2.config import BatDetect2Config
|
||||||
from batdetect2.train import run_train
|
from batdetect2.train import run_train
|
||||||
|
|
||||||
pytestmark = pytest.mark.slow
|
|
||||||
|
|
||||||
|
|
||||||
def _build_fast_train_config() -> BatDetect2Config:
|
def _build_fast_train_config() -> BatDetect2Config:
|
||||||
config = BatDetect2Config()
|
config = BatDetect2Config()
|
||||||
|
|||||||
@ -37,7 +37,6 @@ def test_can_initialize_default_module():
|
|||||||
assert isinstance(module, L.LightningModule)
|
assert isinstance(module, L.LightningModule)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_can_save_checkpoint(
|
def test_can_save_checkpoint(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
clip: data.Clip,
|
clip: data.Clip,
|
||||||
@ -183,7 +182,6 @@ def test_api_from_checkpoint_reconstructs_model_config(tmp_path: Path):
|
|||||||
assert api.audio_config.samplerate == module.model_config.samplerate
|
assert api.audio_config.samplerate == module.model_config.samplerate
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
|
||||||
def test_train_smoke_produces_loadable_checkpoint(
|
def test_train_smoke_produces_loadable_checkpoint(
|
||||||
tmp_path: Path,
|
tmp_path: Path,
|
||||||
example_annotations: list[data.ClipAnnotation],
|
example_annotations: list[data.ClipAnnotation],
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user